Why is zero_grad() Called in PyTorch?

by | Machine Learning, PyTorch

When training neural networks in PyTorch, calling zero_grad() is essential before backpropagating errors. This function is a key part of the optimization process, ensuring that gradients are calculated correctly in each training iteration.

Gradients in Neural Networks

Gradients represent the partial derivatives of the loss function with respect to each parameter. They indicate the direction and rate of change needed to minimize the loss function. The gradient of the loss function \( L(w) \) with respect to a parameter \( w \) is expressed as:

\[ \nabla_w L(w) = \frac{\partial L}{\partial w} \]

In deep learning, gradients guide the adjustment of weights and biases to optimize the model by minimizing the loss function.

Backpropagation and Gradient Descent

Backpropagation calculates gradients across network layers by applying the chain rule. For instance, if we have layers \( x \rightarrow h \rightarrow y \), the gradient of the loss with respect to \( x \) can be computed as:

\[ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial x} \]

After calculating these gradients, we use gradient descent to update the weights. The update rule in gradient descent is:

\[ w := w – \eta \nabla_w L(w) \]

where \( \eta \) is the learning rate, controlling the size of each update step. This ensures that the model learns by adjusting weights in the direction that reduces the loss.

Training with and without zero_grad()
import torch
import torch.nn as nn
import torch.optim as optim

# Ensuring reproducibility across CPU and GPU
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


# Setup model and data
model_with_zero_grad = nn.Linear(2, 1)
model_without_zero_grad = nn.Linear(2, 1)
data, target = torch.tensor([1.0, 2.0]), torch.tensor([1.0])

# Hyperparameters
epochs = 20
learning_rate = 0.1

# Loss function and optimizers
criterion = nn.MSELoss()
optimizer_with_zero_grad = optim.SGD(model_with_zero_grad.parameters(), lr=learning_rate)
optimizer_without_zero_grad = optim.SGD(model_without_zero_grad.parameters(), lr=learning_rate)

# Lists to store loss values
losses_with_zero_grad = []
losses_without_zero_grad = []

# Training loop with zero_grad
for epoch in range(epochs):
    optimizer_with_zero_grad.zero_grad()
    output = model_with_zero_grad(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer_with_zero_grad.step()
    losses_with_zero_grad.append(loss.item())  # Store loss
    print(f"Epoch {epoch+1} (With zero_grad): Loss = {loss.item()}")

# Training loop without zero_grad
for epoch in range(epochs):
    output = model_without_zero_grad(data)
    loss = criterion(output, target)
    loss.backward()  # No zero_grad, so gradients accumulate
    optimizer_without_zero_grad.step()
    losses_without_zero_grad.append(loss.item())  # Store loss
    print(f"Epoch {epoch+1} (Without zero_grad): Loss = {loss.item()}")

print(f"losses_with_zero_grad: {losses_with_zero_grad}")
print(f"losses_without_zero_grad: {losses_without_zero_grad}")
            

Plotting Losses

Now that we have printed the loss values for each epoch, let’s take a step further and visualize these values. By plotting the loss over time, we’ll be able to see the effect of calling (or not calling) zero_grad() on our model’s training process.

When training a machine learning model, the loss should typically decrease over epochs as the model learns and improves. If zero_grad() is used correctly, we expect to see a smooth and generally decreasing loss curve, indicating steady progress. However, without calling zero_grad(), the gradients accumulate over time, which can make the loss fluctuate wildly or even increase as training progresses. This erratic behavior is due to the unintended accumulation of gradients, leading to unstable updates in the model’s weights.

By comparing these two plots, you’ll get a clear visual understanding of how essential it is to reset gradients in each iteration of training. Let’s see these effects on a plot to deepen our understanding.

The plot demonstrates the importance of zero_grad(). With it, the model’s loss decreases steadily, while without it, the loss is erratic due to accumulated gradients.

Monitoring Loss

Monitoring the loss during training is crucial to understanding how well a model is learning. In PyTorch, there are several effective ways to keep track of loss values, helping you adjust training settings if necessary. Here are some commonly used methods:

1. Printing Loss Values

The simplest way to monitor loss is to print it out at regular intervals within the training loop. This gives a quick view of how the loss is changing over epochs. For example:

Print Loss in Training Loop
import torch
import torch.nn as nn
import torch.optim as optim

# Dummy data
data = torch.randn(10, 3)
target = torch.randn(10, 1)

# Model, loss function, and optimizer
model = nn.Linear(3, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop with loss printing
num_epochs = 5
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(data)
    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()

    # Print loss for each epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    

2. Using TensorBoard

For more advanced monitoring, PyTorch integrates with TensorBoard, allowing you to visualize loss values over time. This is especially helpful for larger models, as it gives a more detailed and interactive view:

Logging to TensorBoard
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim

# Dummy data
data = torch.randn(10, 3)
target = torch.randn(10, 1)

# Model, loss function, and optimizer
model = nn.Linear(3, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# TensorBoard setup
writer = SummaryWriter(log_dir='./runs')

# Training loop with TensorBoard logging
num_epochs = 5
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(data)
    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()

    # Log loss to TensorBoard
    writer.add_scalar('Training Loss', loss.item(), epoch)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

writer.close()
    

To view the results in TensorBoard, run the command tensorboard --logdir=runs in your terminal and open the provided URL.

3. Saving Loss to List or CSV for Analysis

If you prefer to analyze the loss later or create custom plots, you can save loss values to a list or a CSV file. This approach is also useful for tracking other metrics alongside loss.

Saving Loss to CSV
import csv
import torch
import torch.nn as nn
import torch.optim as optim

# Dummy data
data = torch.randn(10, 3)
target = torch.randn(10, 1)

# Model, loss function, and optimizer
model = nn.Linear(3, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Initialize CSV file
with open("loss_values.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Epoch", "Loss"])

# Training loop with CSV logging
num_epochs = 5
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(data)
    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()

    # Save loss to CSV
    with open("loss_values.csv", "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([epoch+1, loss.item()])

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    

4. Real-time Plotting with Matplotlib

To see loss values update in real-time, you can use Matplotlib to create an interactive plot. This is effective when working in a Jupyter notebook or an interactive environment:

Real-time Loss Plotting with Matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

# Dummy data
data = torch.randn(10, 3)
target = torch.randn(10, 1)

# Model, loss function, and optimizer
model = nn.Linear(3, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Real-time plotting setup
loss_values = []
plt.ion()  # Enable interactive mode

num_epochs = 30
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(data)
    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()

    # Update loss values and plot
    loss_values.append(loss.item())
    plt.clf()  # Clear figure
    plt.plot(loss_values, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.pause(0.1)  # Pause to update plot

plt.ioff()  # Disable interactive mode
plt.show()
    

Each of these methods provides different ways to monitor and visualize loss values, helping you to better understand your model’s learning process and make any necessary adjustments. Using zero_grad() correctly in each of these examples ensures that your loss values reflect accurate and stable training behavior.

Key Takeaways:

  • Always call zero_grad() before computing new gradients
  • Use optimizer.zero_grad() rather than zeroing gradients manually
  • Consider gradient accumulation for advanced use cases
  • Monitor your loss values to ensure proper gradient handling

Summary

In summary, zero_grad() is necessary to reset gradients in PyTorch. Without it, gradients accumulate across iterations, leading to incorrect updates and unstable training. By calling zero_grad() before each backpropagation, we keep training on track, ensuring that each update is based solely on the current data.

Next time you build a training loop, remember that zero_grad() is crucial for preventing accumulated gradients and for keeping your model’s training consistent and effective.

For further reading on PyTorch, go to the Deep Learning Frameworks page.

Have fun and happy researching!

Profile Picture
Senior Advisor, Data Science | [email protected] |  + posts

Suf is a senior advisor in data science with deep expertise in Natural Language Processing, Complex Networks, and Anomaly Detection. Formerly a postdoctoral research fellow, he applied advanced physics techniques to tackle real-world, data-heavy industry challenges. Before that, he was a particle physicist at the ATLAS Experiment of the Large Hadron Collider. Now, he’s focused on bringing more fun and curiosity to the world of science and research online.

Buy Me a Coffee ✨