Utilizzando torch.index_select, torch.gather e torch.take

In alcune situazioni, dovrai eseguire un’indicizzazione/selezione avanzata con Pytorch, ad esempio rispondere alla domanda: “come posso selezionare elementi dal Tensore A seguendo gli indici specificati nel Tensore B?”

In questo post presenteremo i tre metodi più comuni per tali attività, vale a dire torcia.index_select, torcia.raccolta E torcia.prendere. Li spiegheremo tutti in dettaglio e li confronteremo tra loro.

Foto di Serie J SU Unsplash

Certo, una delle motivazioni per questo post è stata che ho dimenticato come e quando utilizzare quale funzione, finendo per cercare su Google, navigare Overflow dello stack e la documentazione ufficiale, a mio avviso, relativamente breve e poco utile. Pertanto, come accennato, approfondiamo qui queste funzioni: motiviamo quando usarle e quali, forniamo esempi in 2 e 3D e mostriamo graficamente la selezione risultante.

Spero che questo post porti chiarezza su tali funzioni ed elimini la necessità di ulteriori esplorazioni: grazie per aver letto!

E ora, senza ulteriori indugi, analizziamo le funzioni una per una. Per tutti, iniziamo prima con un esempio 2D e visualizziamo la selezione risultante, quindi passiamo a un esempio un po’ più complesso in 3D. Inoltre, reimplementiamo l’operazione eseguita in Python semplice: prima puoi considerare lo pseudocodice come un’altra fonte di informazioni su cosa fanno queste funzioni. Alla fine, riassumiamo le funzioni e le loro differenze in una tabella.

torcia.index_select seleziona gli elementi lungo una dimensione, mantenendo invariati gli altri. Cioè: mantieni tutti gli elementi di tutte le altre dimensioni, ma scegli gli elementi nelle dimensioni di destinazione seguendo il tensore dell’indice. Dimostriamolo con un esempio 2D, in cui selezioniamo lungo la dimensione 1:

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# (len_dim_0, num_picks)
picked = torch.index_select(values, 1, indices)

Il tensore risultante ha forma (len_dim_0, num_picks): per ogni elemento lungo la dimensione 0, abbiamo selezionato lo stesso elemento dalla dimensione 1. Visualizziamo questo:

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *