Select Page

PyTorch Cat Vs Stack Explained

by | Machine Learning, Programming, Python, PyTorch

Concatenating joins a sequence of tensors along an existing axis. The PyTorch function for concatenation is cat(). Stacking joins a sequence of tensors along a new axis. The PyTorch function for stacking is stack().

This tutorial will go through the two PyTorch functions with code examples.


PyTorch Cat

We can use the PyTorch cat() function to concatenate a sequence of tensors along the same dimension. The tensors must have the same shape (except in the concatenating dimension) or be empty.

Syntax

torch.cat(tensors, dim=0, *, out=None)

Parameters

  • tensors (sequence of Tensors): Required. Any Python sequence of tensors of the same type. Non-empty tensors must have the same shape except in the concatenating dimension.
  • dim (int): Optional. The dimension to concatenate the tensors over.

Keyword Arguments

  • out (Tensor): Optional. Output tensor

Example

Let’s look at an example where we concatenate three tensors into one tensor using cat(). First, we have to import the PyTorch library and then use the tensor() function to create the tensors:

import torch

x = torch.tensor([2, 3, 4, 5])

y = torch.tensor([4, 10, 30])

z = torch.tensor([7, 22, 4, 8, 3, 6])

Next, we can concatenate the tensors along the 0th dimension, the only available axis.

xyz = torch.cat(

(x, y, z), dim=0

)

print(xyz)

print(xyz.shape)

Let’s run the code to see the result:

tensor([ 2,  3,  4,  5,  4, 10, 30,  7, 22,  4,  8,  3,  6])

PyTorch Stack

We can use the PyTorch stack() function to concatenate a sequence of tensors along a new dimension. The tensors must have the same shape.

Syntax

torch.stack(tensors, dim=0, *, out=None)

Parameters

  • tensors (sequence of Tensors): Required. Python sequence of tensors of the same size.
  • dim (int): Optional. The new dimension to insert. The dimension must be between 0 and the number of dimensions of concatenated tensors.

Keyword Arguments

  • out (Tensor): Optional. Output tensor

Example

Let’s look at an example where we stack three tensors into one tensor using stack(). First, we have to import the PyTorch library and then use the tensor() function to create the tensors:

import torch

x = torch.tensor([2, 3, 4, 5])

y = torch.tensor([4, 10, 30, 40])

z = torch.tensor([8, 7, 16, 14])

In the above code, the tensors x, y, and z are one-dimensional, each having four elements. Next, we will stack the tensors along dim=0 and dim=1.

# Stacking Tensors using dimension 0

stacked_0 = torch.stack(

(x, y, z), dim=0

)

# Stacking Tensors using dimension 1

stacked_1 = torch.stack(

(x,y, z), dim=1

)
# Resultant combined tensor with new axes along dimension 0

print(stacked_0)

# Shape of combined tensor

print(stacked_0.shape)

# Resultant combined tensor with new axes along dimension 1

print(stacked_1)

# Shape of combined tensor

print(stacked_1.shape)

Let’s run the code to get the result:

tensor([[ 2,  3,  4,  5],
        [ 4, 10, 30, 40],
        [ 8,  7, 16, 14]])
torch.Size([3, 4])
tensor([[ 2,  4,  8],
        [ 3, 10,  7],
        [ 4, 30, 16],
        [ 5, 40, 14]])
torch.Size([4, 3])

The resultant concatenated tensor is two-dimensional. As the individual tensors are one-dimensional, we can stack them with dimensions 0 and 1.

With dim=0 the tensors are stacked row-wise, giving us a 3×4 matrix. With dim=1 we transpose the tensors and stack them column-wise, giving us a 4×3 matrix.

PyTorch Cat Vs Stack

The two PyTorch functions offer similar functionality but differ in how they concatenate tensors. The cat() function concatenates tensors along the existing dimension. The stack() function concatenates tensors along a new dimension not present in the individual tensors.

We can derive the same results of the stack() function using the cat() function. We can apply the unsqueeze operation to each tensor before passing them to the cat() function to get the same result. Let’s look at the result with the tensors from the previous example

import torch

x = torch.tensor([2, 3, 4, 5])

y = torch.tensor([4, 10, 30, 40])

z = torch.tensor([8, 7, 16, 14])

xyz = torch.cat((x.unsqueeze(0), y.unsqueeze(0), z.unsqueeze(0)), dim=0)

print(xyz)

print(xyz.shape)

The unsqueeze operation adds a new dimension of length one to the tensors, and then we concatenate along the first axis. Let’s run the code to get the result:

tensor([[ 2,  3,  4,  5],
        [ 4, 10, 30, 40],
        [ 8,  7, 16, 14]])
torch.Size([3, 4])

Therefore torch.stack((A, B), dim=0) is equivalent to torch.cat((A.unsqueeze(0), B.unsqueeze(0)), dim=0

Summary

Congratulations on reading to the end of this tutorial. We have gone through how to concatenate tensors using both cat() and stack() and explained the differences between the two functions.

For further reading on PyTorch, go to the article: How to Convert NumPy Array to PyTorch Tensor.

To learn more about Python for data science and machine learning, go to the online courses page on Python for the most comprehensive courses available.

Have fun and happy researching!