What is the difference between torch.mm, torch.matmul and torch.mul?

by | Data Science, Machine Learning, PyTorch

PyTorch provides a variety of tensor operations, and understanding the differences between torch.mm, torch.matmul, and torch.mul is essential when working with tensor computations. Although they might look similar, these functions serve different purposes and operate under distinct rules based on the tensor dimensions.

1. torch.mm – Matrix Multiplication

torch.mm performs matrix multiplication between two 2-dimensional tensors (matrices). It requires both tensors to have exactly two dimensions, with the shape requirement that the number of columns in the first tensor must match the number of rows in the second tensor.

Example with Two 2-Dimensional Tensors

import torch

# Define two matrices
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

# Perform matrix multiplication
result = torch.mm(a, b)
print(result)
# Output:
# tensor([[19, 22],
#         [43, 50]])
  

In this example, torch.mm(a, b) multiplies matrices a and b, producing a new 2D matrix. If either a or b is not a 2D matrix, torch.mm will raise an error.

Example with 1-Dimensional Tensor

import torch

# Define a 2D matrix and a 1D tensor
a = torch.tensor([[1, 2], [3, 4]])  # 2D matrix
b = torch.tensor([5, 6])            # 1D tensor

# Attempt to perform matrix multiplication
try:
    result = torch.mm(a, b)
except RuntimeError as e:
    print("Error:", e)
# Expected output:
# Error: mat2 must be a matrix

In this example, trying to use torch.mm(a, b) with a as a 2D tensor and b as a 1D tensor raises an error: mat2 must be a matrix. This is because torch.mm only supports matrix multiplication between two 2D tensors.

2. torch.matmul – General Matrix Multiplication

torch.matmul is more flexible than torch.mm as it supports both 1D and higher-dimensional tensors. It follows broadcasting rules similar to NumPy, allowing it to perform operations on a wider range of tensor shapes, including vector-matrix, matrix-matrix, and even batched matrix-matrix multiplications.

Note: For readers familiar with NumPy, PyTorch also allows the @ operator as an equivalent to torch.matmul. For example, a @ b can be used in place of torch.matmul(a, b), offering a more concise syntax that works the same way as in NumPy.

Example – Matrix Multiplication:

# 2D matrix multiplication (same as torch.mm)
result = torch.matmul(a, b)
print(result)

# Using the @ operator
result_alt = a @ b
print(result_alt)
# Output:
# tensor([[19, 22],
#        [43, 50]])
# tensor([[19, 22],
#        [43, 50]])

Example – Vector-Matrix Multiplication:

# Define a vector and a matrix
v = torch.tensor([1, 2])
result = torch.matmul(v, a)
print(result)
# Output:
# tensor([ 7, 10])
  

In this example, torch.matmul(v, a) performs vector-matrix multiplication. Since v is 1D and a is 2D, torch.matmul treats v as a row vector.

Example – Batched Matrix Multiplication:


# Define batched matrices
a_batch = torch.randn(2, 3, 4)
b_batch = torch.randn(2, 4, 5)
result = torch.matmul(a_batch, b_batch)
print(result.shape)
# Output:
# torch.Size([2, 3, 5])
  

With 3D tensors, torch.matmul performs batched matrix multiplication. Here, the first dimension represents the batch size, and each matrix in the batch is multiplied separately.

3. torch.mul – Element-wise Multiplication

torch.mul performs element-wise (Hadamard) multiplication, which means each element in the first tensor is multiplied by the corresponding element in the second tensor. Unlike torch.mm and torch.matmul, torch.mul does not follow matrix multiplication rules; it only requires that the tensors have compatible shapes for broadcasting.

Example:

# Define two tensors
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])

# Perform element-wise multiplication
result = torch.mul(a, b)
print(result)
# Output:
# tensor([[ 5, 12],
#         [21, 32]])
  

In this example, each element of a is multiplied by the corresponding element in b. If the tensors are of different shapes, PyTorch will attempt to broadcast them if possible.

Summary

Here is a quick comparison of these three functions:

  • torch.mm: Only for 2D matrices, follows strict matrix multiplication rules.
  • torch.matmul: Supports higher-dimensional tensors with broadcasting, allowing for a wider range of matrix multiplications.
  • torch.mul: Element-wise multiplication that does not adhere to matrix multiplication rules; supports broadcasting for compatible shapes.

Knowing the differences between these functions helps in selecting the appropriate function based on the specific tensor shapes and the operation you want to perform.

For further reading on PyTorch, go to the Deep Learning Frameworks page.

Have fun and happy researching!

Profile Picture
Senior Advisor, Data Science | [email protected] | + posts

Suf is a senior advisor in data science with deep expertise in Natural Language Processing, Complex Networks, and Anomaly Detection. Formerly a postdoctoral research fellow, he applied advanced physics techniques to tackle real-world, data-heavy industry challenges. Before that, he was a particle physicist at the ATLAS Experiment of the Large Hadron Collider. Now, he’s focused on bringing more fun and curiosity to the world of science and research online.

Buy Me a Coffee