Understanding gather() in PyTorch: A Beginner-Friendly Guide

by | Machine Learning, Programming, PyTorch

Understanding gather() in PyTorch: A Beginner-Friendly Guide

Introduction

gather() in PyTorch allows us to select specific elements from a tensor based on an index tensor. If you’re new to tensors, think of them as multi-dimensional arrays (like nested lists). This post simplifies gather() with examples and visuals to help you feel confident using it!

What is gather()?

In PyTorch, gather() is a function that lets you pick elements from a tensor according to specified indices along a chosen dimension. This function is extremely helpful when you need to pull specific values from large or complex datasets.

Syntax of gather()

The basic syntax for gather() in PyTorch is as follows:

Syntax
torch.gather(input, dim, index)

Here’s a breakdown of each parameter:

  • input (Tensor): The source tensor from which values are gathered.
  • dim (int): The dimension along which to gather values. Setting dim=0 gathers across columns, while dim=1 gathers across rows.
  • index (LongTensor): A tensor of indices specifying which elements to select from input. The shape of index should match the shape of input in all dimensions except dim.

This syntax allows you to extract specific values from an input tensor along a given axis, based on the positions defined by the index tensor.

How the dim Argument Works

The dim argument specifies the axis along which gather() will operate. If dim=0, it gathers elements along columns. If dim=1, it gathers elements across rows.

Row-wise Selection with dim=1

Let’s start with a basic example where we have a 2D tensor and want to gather specific elements along each row.

Step 1: Set up the Data

Suppose we have the following tensor:

input = [[ 10, 11, 12 ],
[ 13, 14, 15 ],
[ 16, 17, 18 ]]

And we want to gather specific elements using this index tensor:

index = [[ 2, 0 ],
[ 1, 2 ],
[ 0, 1 ]]

Step 2: Use gather()

We apply gather() with dim=1 to specify row-wise selection:

Gathering Rows with dim=1
import torch

input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[2, 0], [1, 2], [0, 1]])

result = torch.gather(input, dim=1, index=index)
print(result)
# Output:
# tensor([[12, 10],
#         [14, 15],
#         [16, 17]])
    

Explanation

Here’s how gather() works in this example:

  • index[0]: For the first row, gather elements at positions 2 and 0, which gives [12, 10].
  • index[1]: For the second row, gather elements at positions 1 and 2, which gives [14, 15].
  • index[2]: For the third row, gather elements at positions 0 and 1, which gives [16, 17].

Common Pitfalls and Errors

Here are a few errors you might encounter with gather():

  • Index Out of Bounds: Ensure your index tensor only contains valid positions for the chosen dimension in input.
  • Dimension Mismatch: The index tensor should have the same shape as the input tensor along the specified dim. Double-check the shapes if you run into mismatched dimension errors.

Row-wise Selection with dim=0

Next, let’s try gathering elements along rows.

Set up Data and Index

We’ll reuse our original input tensor and a new index tensor for row selection:

index = [[ 1, 0, 2 ],
[ 0, 2, 1 ],
[ 2, 1, 0 ]]

Apply gather() with dim=0

Gathering Rows with dim=0
import torch

input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[1, 0, 2], [0, 2, 1], [2, 1, 0]])

result = torch.gather(input, dim=0, index=index)
print(result)
# Output:
# tensor([[13, 11, 18],
#         [10, 17, 15],
#         [16, 14, 12]])
    

Explanation

With dim=0, gather() collects elements from each row in input based on the positions given in index. Here’s how it works for each column:

  • index[:, 0]: In the first column, gather() takes elements from row 1 (13), row 0 (10), and row 2 (16) to form [13, 10, 16].
  • index[:, 1]: In the second column, it takes elements from row 0 (11), row 2 (17), and row 1 (14) to form [11, 17, 14].
  • index[:, 2]: In the third column, it takes elements from row 2 (18), row 1 (15), and row 0 (12) to form [18, 15, 12].

The result is a new tensor where elements have been gathered row-by-row according to index:

The Book Shelf Analogy for gather()

Imagine a library with a bookshelf, where each shelf is like a row in a PyTorch tensor, and each book on the shelf represents an individual element. The books are organized by position, so the first book is position 0, the second is position 1, and so on.

Suppose you’re given a list of specific book positions to pick from each shelf or from each position across all shelves. Here’s how the different dim values would change your approach to gathering books:

Gathering Across Each Shelf (dim=1)

With dim=1, you’ll go shelf by shelf and pick the books specified by the positions in the index list for each shelf:

Shelves (Input):
Shelf 1: ["Math", "Physics", "Chemistry"]
Shelf 2: ["Biology", "Geology", "Astronomy"]
Shelf 3: ["History", "Art", "Philosophy"]

Index = [
    [2, 0],  # Positions for Shelf 1
    [1, 2],  # Positions for Shelf 2
    [0, 1]   # Positions for Shelf 3
]
  • Shelf 1 has books [“Math”, “Physics”, “Chemistry”]. The index list tells you to pick the book at positions 2 and 0, so you take “Chemistry” and “Math”.
  • Shelf 2 has books [“Biology”, “Geology”, “Astronomy”]. You’re told to pick books at positions 1 and 2, so you take “Geology” and “Astronomy”.
  • Shelf 3 has books [“History”, “Art”, “Philosophy”]. You pick the books at positions 0 and 1, so you take “History” and “Art”.
Output:
[
[“Chemistry”, “Math”],
[“Geology”, “Astronomy”],
[“History”, “Art”]
]

Gathering Down Each Position (dim=0)

With dim=0, gather() collects elements across all shelves (rows) based on the positions specified in each column of the index tensor:

Shelves (Input):
Shelf 1: ["Math", "Physics", "Chemistry"]
Shelf 2: ["Biology", "Geology", "Astronomy"]
Shelf 3: ["History", "Art", "Philosophy"]

Index = [
    [1, 0, 2],  # Positions for each shelf in the first column
    [0, 2, 1],  # Positions for each shelf in the second column
    [2, 1, 0]   # Positions for each shelf in the third column
]

Here’s how it works for each column:

  • First Column: Using index[:, 0] (the first column of index), you gather elements at positions 1, 0, and 2 from each shelf:
    • Shelf 1, Position 1: “Physics”
    • Shelf 2, Position 0: “Biology”
    • Shelf 3, Position 2: “Philosophy”
    Result for the first column: ["Physics", "Biology", "Philosophy"]
  • Second Column: Using index[:, 1] (the second column of index), you gather elements at positions 0, 2, and 1 from each shelf:
    • Shelf 1, Position 0: “Math”
    • Shelf 2, Position 2: “Astronomy”
    • Shelf 3, Position 1: “Art”
    Result for the second column: ["Math", "Astronomy", "Art"]
  • Third Column: Using index[:, 2] (the third column of index), you gather elements at positions 2, 1, and 0 from each shelf:
    • Shelf 1, Position 2: “Chemistry”
    • Shelf 2, Position 1: “Geology”
    • Shelf 3, Position 0: “History”
    Result for the third column: ["Chemistry", "Geology", "History"]

In this case, setting dim=0 allows you to gather books from the specified positions across all shelves, producing a new tensor:

Output:
[[“Physics”, “Math”, “Chemistry”],
[“Biology”, “Astronomy”, “Geology”],
[“Philosophy”, “Art”, “History”]]

Summary

PyTorch’s gather() is a powerful tool for data extraction:

  • Use dim=1 for row-wise extraction, and dim=0 for column-wise.
  • Check for errors like mismatched dimensions and out-of-bounds indices to avoid common mistakes.

With practice, gather() can handle selective data extraction in PyTorch efficiently.

Congratulations on reading to the end of this tutorial! For further reading on PyTorch, go to the Deep Learning Frameworks page.

Have fun and happy researching!

Profile Picture
Senior Advisor, Data Science | [email protected] |  + posts

Suf is a senior advisor in data science with deep expertise in Natural Language Processing, Complex Networks, and Anomaly Detection. Formerly a postdoctoral research fellow, he applied advanced physics techniques to tackle real-world, data-heavy industry challenges. Before that, he was a particle physicist at the ATLAS Experiment of the Large Hadron Collider. Now, he’s focused on bringing more fun and curiosity to the world of science and research online.

Buy Me a Coffee ✨