Table of Contents
Introduction
gather() in PyTorch allows us to select specific elements from a tensor based on an index tensor. If you’re new to tensors, think of them as multi-dimensional arrays (like nested lists). This post simplifies gather() with examples and visuals to help you feel confident using it!
What is gather()?
In PyTorch, gather() is a function that lets you pick elements from a tensor according to specified indices along a chosen dimension. This function is extremely helpful when you need to pull specific values from large or complex datasets.
Syntax of gather()
The basic syntax for gather() in PyTorch is as follows:
torch.gather(input, dim, index)
Here’s a breakdown of each parameter:
input(Tensor): The source tensor from which values are gathered.dim(int): The dimension along which to gather values. Settingdim=0gathers across columns, whiledim=1gathers across rows.index(LongTensor): A tensor of indices specifying which elements to select frominput. The shape ofindexshould match the shape ofinputin all dimensions exceptdim.
This syntax allows you to extract specific values from an input tensor along a given axis, based on the positions defined by the index tensor.
How the dim Argument Works
The dim argument specifies the axis along which gather() will operate. If dim=0, it gathers elements along columns. If dim=1, it gathers elements across rows.
Row-wise Selection with dim=1
Let’s start with a basic example where we have a 2D tensor and want to gather specific elements along each row.
Step 1: Set up the Data
Suppose we have the following tensor:
input = [[ 10, 11, 12 ],
[ 13, 14, 15 ],
[ 16, 17, 18 ]]
And we want to gather specific elements using this index tensor:
index = [[ 2, 0 ],
[ 1, 2 ],
[ 0, 1 ]]
Step 2: Use gather()
We apply gather() with dim=1 to specify row-wise selection:
dim=1import torch
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[2, 0], [1, 2], [0, 1]])
result = torch.gather(input, dim=1, index=index)
print(result)
# Output:
# tensor([[12, 10],
# [14, 15],
# [16, 17]])
Explanation
Here’s how gather() works in this example:
index[0]: For the first row, gather elements at positions 2 and 0, which gives [12, 10].index[1]: For the second row, gather elements at positions 1 and 2, which gives [14, 15].index[2]: For the third row, gather elements at positions 0 and 1, which gives [16, 17].
Common Pitfalls and Errors
Here are a few errors you might encounter with gather():
- Index Out of Bounds: Ensure your
indextensor only contains valid positions for the chosen dimension ininput. - Dimension Mismatch: The
indextensor should have the same shape as theinputtensor along the specifieddim. Double-check the shapes if you run into mismatched dimension errors.
Row-wise Selection with dim=0
Next, let’s try gathering elements along rows.
Set up Data and Index
We’ll reuse our original input tensor and a new index tensor for row selection:
index = [[ 1, 0, 2 ],
[ 0, 2, 1 ],
[ 2, 1, 0 ]]
Apply gather() with dim=0
dim=0import torch
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[1, 0, 2], [0, 2, 1], [2, 1, 0]])
result = torch.gather(input, dim=0, index=index)
print(result)
# Output:
# tensor([[13, 11, 18],
# [10, 17, 15],
# [16, 14, 12]])
Explanation
With dim=0, gather() collects elements from each row in input based on the positions given in index. Here’s how it works for each column:
index[:, 0]: In the first column,gather()takes elements from row 1 (13), row 0 (10), and row 2 (16) to form [13, 10, 16].index[:, 1]: In the second column, it takes elements from row 0 (11), row 2 (17), and row 1 (14) to form [11, 17, 14].index[:, 2]: In the third column, it takes elements from row 2 (18), row 1 (15), and row 0 (12) to form [18, 15, 12].
The result is a new tensor where elements have been gathered row-by-row according to index:
The Book Shelf Analogy for gather()
Imagine a library with a bookshelf, where each shelf is like a row in a PyTorch tensor, and each book on the shelf represents an individual element. The books are organized by position, so the first book is position 0, the second is position 1, and so on.
Suppose you’re given a list of specific book positions to pick from each shelf or from each position across all shelves. Here’s how the different dim values would change your approach to gathering books:
Gathering Across Each Shelf (dim=1)
With dim=1, you’ll go shelf by shelf and pick the books specified by the positions in the index list for each shelf:
Shelves (Input):
Shelf 1: ["Math", "Physics", "Chemistry"]
Shelf 2: ["Biology", "Geology", "Astronomy"]
Shelf 3: ["History", "Art", "Philosophy"]
Index = [
[2, 0], # Positions for Shelf 1
[1, 2], # Positions for Shelf 2
[0, 1] # Positions for Shelf 3
]
- Shelf 1 has books [“Math”, “Physics”, “Chemistry”]. The index list tells you to pick the book at positions 2 and 0, so you take “Chemistry” and “Math”.
- Shelf 2 has books [“Biology”, “Geology”, “Astronomy”]. You’re told to pick books at positions 1 and 2, so you take “Geology” and “Astronomy”.
- Shelf 3 has books [“History”, “Art”, “Philosophy”]. You pick the books at positions 0 and 1, so you take “History” and “Art”.
[
[“Chemistry”, “Math”],
[“Geology”, “Astronomy”],
[“History”, “Art”]
]
Gathering Down Each Position (dim=0)
With dim=0, gather() collects elements across all shelves (rows) based on the positions specified in each column of the index tensor:
Shelves (Input):
Shelf 1: ["Math", "Physics", "Chemistry"]
Shelf 2: ["Biology", "Geology", "Astronomy"]
Shelf 3: ["History", "Art", "Philosophy"]
Index = [
[1, 0, 2], # Positions for each shelf in the first column
[0, 2, 1], # Positions for each shelf in the second column
[2, 1, 0] # Positions for each shelf in the third column
]
Here’s how it works for each column:
- First Column: Using
index[:, 0](the first column ofindex), you gather elements at positions 1, 0, and 2 from each shelf:- Shelf 1, Position 1: “Physics”
- Shelf 2, Position 0: “Biology”
- Shelf 3, Position 2: “Philosophy”
["Physics", "Biology", "Philosophy"] - Second Column: Using
index[:, 1](the second column ofindex), you gather elements at positions 0, 2, and 1 from each shelf:- Shelf 1, Position 0: “Math”
- Shelf 2, Position 2: “Astronomy”
- Shelf 3, Position 1: “Art”
["Math", "Astronomy", "Art"] - Third Column: Using
index[:, 2](the third column ofindex), you gather elements at positions 2, 1, and 0 from each shelf:- Shelf 1, Position 2: “Chemistry”
- Shelf 2, Position 1: “Geology”
- Shelf 3, Position 0: “History”
["Chemistry", "Geology", "History"]
In this case, setting dim=0 allows you to gather books from the specified positions across all shelves, producing a new tensor:
[[“Physics”, “Math”, “Chemistry”],
[“Biology”, “Astronomy”, “Geology”],
[“Philosophy”, “Art”, “History”]]
Summary
PyTorch’s gather() is a powerful tool for data extraction:
- Use
dim=1for row-wise extraction, anddim=0for column-wise. - Check for errors like mismatched dimensions and out-of-bounds indices to avoid common mistakes.
With practice, gather() can handle selective data extraction in PyTorch efficiently.
Congratulations on reading to the end of this tutorial! For further reading on PyTorch, go to the Deep Learning Frameworks page.
Have fun and happy researching!
Suf is a senior advisor in data science with deep expertise in Natural Language Processing, Complex Networks, and Anomaly Detection. Formerly a postdoctoral research fellow, he applied advanced physics techniques to tackle real-world, data-heavy industry challenges. Before that, he was a particle physicist at the ATLAS Experiment of the Large Hadron Collider. Now, he’s focused on bringing more fun and curiosity to the world of science and research online.
