Table of Contents
Introduction
gather()
in PyTorch allows us to select specific elements from a tensor based on an index tensor. If you’re new to tensors, think of them as multi-dimensional arrays (like nested lists). This post simplifies gather()
with examples and visuals to help you feel confident using it!
What is gather()
?
In PyTorch, gather()
is a function that lets you pick elements from a tensor according to specified indices along a chosen dimension. This function is extremely helpful when you need to pull specific values from large or complex datasets.
Syntax of gather()
The basic syntax for gather()
in PyTorch is as follows:
torch.gather(input, dim, index)
Here’s a breakdown of each parameter:
input
(Tensor): The source tensor from which values are gathered.dim
(int): The dimension along which to gather values. Settingdim=0
gathers across columns, whiledim=1
gathers across rows.index
(LongTensor): A tensor of indices specifying which elements to select frominput
. The shape ofindex
should match the shape ofinput
in all dimensions exceptdim
.
This syntax allows you to extract specific values from an input
tensor along a given axis, based on the positions defined by the index
tensor.
How the dim
Argument Works
The dim
argument specifies the axis along which gather()
will operate. If dim=0
, it gathers elements along columns. If dim=1
, it gathers elements across rows.
Row-wise Selection with dim=1
Let’s start with a basic example where we have a 2D tensor and want to gather specific elements along each row.
Step 1: Set up the Data
Suppose we have the following tensor:
input = [[ 10, 11, 12 ],
[ 13, 14, 15 ],
[ 16, 17, 18 ]]
And we want to gather specific elements using this index tensor:
index = [[ 2, 0 ],
[ 1, 2 ],
[ 0, 1 ]]
Step 2: Use gather()
We apply gather()
with dim=1
to specify row-wise selection:
dim=1
import torch
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[2, 0], [1, 2], [0, 1]])
result = torch.gather(input, dim=1, index=index)
print(result)
# Output:
# tensor([[12, 10],
# [14, 15],
# [16, 17]])
Explanation
Here’s how gather()
works in this example:
index[0]
: For the first row, gather elements at positions 2 and 0, which gives [12, 10].index[1]
: For the second row, gather elements at positions 1 and 2, which gives [14, 15].index[2]
: For the third row, gather elements at positions 0 and 1, which gives [16, 17].
Common Pitfalls and Errors
Here are a few errors you might encounter with gather()
:
- Index Out of Bounds: Ensure your
index
tensor only contains valid positions for the chosen dimension ininput
. - Dimension Mismatch: The
index
tensor should have the same shape as theinput
tensor along the specifieddim
. Double-check the shapes if you run into mismatched dimension errors.
Row-wise Selection with dim=0
Next, let’s try gathering elements along rows.
Set up Data and Index
We’ll reuse our original input
tensor and a new index
tensor for row selection:
index = [[ 1, 0, 2 ],
[ 0, 2, 1 ],
[ 2, 1, 0 ]]
Apply gather()
with dim=0
dim=0
import torch
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[1, 0, 2], [0, 2, 1], [2, 1, 0]])
result = torch.gather(input, dim=0, index=index)
print(result)
# Output:
# tensor([[13, 11, 18],
# [10, 17, 15],
# [16, 14, 12]])
Explanation
With dim=0
, gather()
collects elements from each row in input
based on the positions given in index
. Here’s how it works for each column:
index[:, 0]
: In the first column,gather()
takes elements from row 1 (13), row 0 (10), and row 2 (16) to form [13, 10, 16].index[:, 1]
: In the second column, it takes elements from row 0 (11), row 2 (17), and row 1 (14) to form [11, 17, 14].index[:, 2]
: In the third column, it takes elements from row 2 (18), row 1 (15), and row 0 (12) to form [18, 15, 12].
The result is a new tensor where elements have been gathered row-by-row according to index
:
The Book Shelf Analogy for gather()
Imagine a library with a bookshelf, where each shelf is like a row in a PyTorch tensor, and each book on the shelf represents an individual element. The books are organized by position, so the first book is position 0
, the second is position 1
, and so on.
Suppose you’re given a list of specific book positions to pick from each shelf or from each position across all shelves. Here’s how the different dim
values would change your approach to gathering books:
Gathering Across Each Shelf (dim=1
)
With dim=1
, you’ll go shelf by shelf and pick the books specified by the positions in the index
list for each shelf:
Shelves (Input): Shelf 1: ["Math", "Physics", "Chemistry"] Shelf 2: ["Biology", "Geology", "Astronomy"] Shelf 3: ["History", "Art", "Philosophy"] Index = [ [2, 0], # Positions for Shelf 1 [1, 2], # Positions for Shelf 2 [0, 1] # Positions for Shelf 3 ]
- Shelf 1 has books [“Math”, “Physics”, “Chemistry”]. The index list tells you to pick the book at positions 2 and 0, so you take “Chemistry” and “Math”.
- Shelf 2 has books [“Biology”, “Geology”, “Astronomy”]. You’re told to pick books at positions 1 and 2, so you take “Geology” and “Astronomy”.
- Shelf 3 has books [“History”, “Art”, “Philosophy”]. You pick the books at positions 0 and 1, so you take “History” and “Art”.
[
[“Chemistry”, “Math”],
[“Geology”, “Astronomy”],
[“History”, “Art”]
]
Gathering Down Each Position (dim=0
)
With dim=0
, gather()
collects elements across all shelves (rows) based on the positions specified in each column of the index
tensor:
Shelves (Input): Shelf 1: ["Math", "Physics", "Chemistry"] Shelf 2: ["Biology", "Geology", "Astronomy"] Shelf 3: ["History", "Art", "Philosophy"] Index = [ [1, 0, 2], # Positions for each shelf in the first column [0, 2, 1], # Positions for each shelf in the second column [2, 1, 0] # Positions for each shelf in the third column ]
Here’s how it works for each column:
- First Column: Using
index[:, 0]
(the first column ofindex
), you gather elements at positions 1, 0, and 2 from each shelf:- Shelf 1, Position 1: “Physics”
- Shelf 2, Position 0: “Biology”
- Shelf 3, Position 2: “Philosophy”
["Physics", "Biology", "Philosophy"]
- Second Column: Using
index[:, 1]
(the second column ofindex
), you gather elements at positions 0, 2, and 1 from each shelf:- Shelf 1, Position 0: “Math”
- Shelf 2, Position 2: “Astronomy”
- Shelf 3, Position 1: “Art”
["Math", "Astronomy", "Art"]
- Third Column: Using
index[:, 2]
(the third column ofindex
), you gather elements at positions 2, 1, and 0 from each shelf:- Shelf 1, Position 2: “Chemistry”
- Shelf 2, Position 1: “Geology”
- Shelf 3, Position 0: “History”
["Chemistry", "Geology", "History"]
In this case, setting dim=0
allows you to gather books from the specified positions across all shelves, producing a new tensor:
[[“Physics”, “Math”, “Chemistry”],
[“Biology”, “Astronomy”, “Geology”],
[“Philosophy”, “Art”, “History”]]
Summary
PyTorch’s gather()
is a powerful tool for data extraction:
- Use
dim=1
for row-wise extraction, anddim=0
for column-wise. - Check for errors like mismatched dimensions and out-of-bounds indices to avoid common mistakes.
With practice, gather()
can handle selective data extraction in PyTorch efficiently.
Congratulations on reading to the end of this tutorial! For further reading on PyTorch, go to the Deep Learning Frameworks page.
Have fun and happy researching!
Suf is a senior advisor in data science with deep expertise in Natural Language Processing, Complex Networks, and Anomaly Detection. Formerly a postdoctoral research fellow, he applied advanced physics techniques to tackle real-world, data-heavy industry challenges. Before that, he was a particle physicist at the ATLAS Experiment of the Large Hadron Collider. Now, he’s focused on bringing more fun and curiosity to the world of science and research online.