PyTorch: Part 4 - The Training Loop

Putting It All Together

In the previous parts, we learned about tensors, autograd, and neural networks. Now it’s time to train our network!

Training means adjusting the network’s weights to reduce prediction errors. The process has 4 steps that repeat:

  1. Forward pass: Make predictions
  2. Calculate loss: Measure how wrong we are
  3. Backward pass: Calculate gradients
  4. Update weights: Adjust weights to reduce loss

Let’s go through each step.


Step 1: Forward Pass

The forward pass is simple. We pass data through our network to get predictions.

import torch
import torch.nn as nn

# Create a simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()

# Forward pass
X = torch.randn(5, 10)  # 5 samples, 10 features
predictions = model(X)
print(predictions.shape)  # [5, 2]

Step 2: Calculate Loss

Loss measures how wrong our predictions are. Lower loss = better predictions.

Common Loss Functions

CrossEntropyLoss (for classification):

criterion = nn.CrossEntropyLoss()

# predictions: shape [batch_size, num_classes]
# targets: shape [batch_size] (class indices)
predictions = torch.randn(5, 2)  # 5 samples, 2 classes
targets = torch.tensor([0, 1, 0, 1, 0])  # Correct classes

loss = criterion(predictions, targets)
print(f"Loss: {loss.item()}")

MSELoss (for regression):

criterion = nn.MSELoss()

predictions = torch.randn(5, 1)  # 5 predictions
targets = torch.randn(5, 1)      # 5 target values

loss = criterion(predictions, targets)
print(f"Loss: {loss.item()}")

Which loss to use?

  • Classification (choosing a category): Use CrossEntropyLoss
  • Regression (predicting a number): Use MSELoss
  • Binary classification (yes/no): Use BCELoss with sigmoid

Step 3: Backward Pass

The backward pass calculates gradients. This tells us how to adjust each weight.

# Calculate gradients
loss.backward()

# Now each parameter has a .grad attribute
for name, param in model.named_parameters():
    print(f"{name} grad shape: {param.grad.shape}")

Remember to reset gradients! Gradients accumulate by default. Always call optimizer.zero_grad() before backward().


Step 4: Update Weights

The optimizer updates weights using the gradients. The most common optimizer is Adam.

import torch.optim as optim

# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Update weights
optimizer.step()

Common Optimizers

OptimizerWhen to Use
optim.SGDClassic choice, add momentum for better results
optim.AdamGreat default choice, works well in most cases
optim.AdamWAdam with better weight decay, often performs best

The Complete Training Loop

Now let’s put it all together:

import torch
import torch.nn as nn
import torch.optim as optim

# 1. CREATE MODEL
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()

# 2. DEFINE LOSS AND OPTIMIZER
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 3. CREATE FAKE DATA
X = torch.randn(100, 10)          # 100 samples, 10 features
y = torch.randint(0, 2, (100,))   # 100 labels (0 or 1)

# 4. TRAINING LOOP
num_epochs = 5
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X)
    loss = criterion(outputs, y)

    # Backward pass
    optimizer.zero_grad()   # Reset gradients
    loss.backward()         # Calculate gradients
    optimizer.step()        # Update weights

    # Print progress
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("\nTraining complete!")

Output:

Epoch [1/5], Loss: 0.7234
Epoch [2/5], Loss: 0.6891
Epoch [3/5], Loss: 0.6612
Epoch [4/5], Loss: 0.6389
Epoch [5/5], Loss: 0.6201

Training complete!

The loss is going down! That means our model is learning.


Training with Batches

For large datasets, we don’t pass all data at once. We use batches.

from torch.utils.data import TensorDataset, DataLoader

# Create dataset
X = torch.randn(1000, 10)
y = torch.randint(0, 2, (1000,))
dataset = TensorDataset(X, y)

# Create dataloader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training with batches
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
    total_loss = 0
    for batch_X, batch_y in dataloader:
        # Forward
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')

Output:

Epoch 1, Avg Loss: 0.6543
Epoch 2, Avg Loss: 0.5234
Epoch 3, Avg Loss: 0.4567

DataLoader Parameters

ParameterWhat It Does
batch_sizeNumber of samples per batch
shuffle=TrueRandomize order each epoch (prevents memorization)
num_workersParallel data loading (use for large datasets)
drop_last=TrueDrop incomplete last batch

Making Predictions

After training, we can make predictions:

# Switch to evaluation mode
model.eval()

# Disable gradient tracking
with torch.no_grad():
    # New input
    x_new = torch.randn(1, 10)

    # Get prediction
    output = model(x_new)
    probabilities = torch.softmax(output, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1)

    print(f"Raw output: {output}")
    print(f"Probabilities: {probabilities}")
    print(f"Predicted class: {predicted_class.item()}")

Output:

Raw output: tensor([[-0.3421,  0.5678]])
Probabilities: tensor([[0.2987, 0.7013]])
Predicted class: 1

Always use model.eval() and torch.no_grad() for predictions! This disables dropout and saves memory.


Saving and Loading Models

Save the Model

# Save model parameters
torch.save(model.state_dict(), 'model.pth')
print("Model saved!")

Load the Model

# Create a new model instance
model = SimpleNet()

# Load saved parameters
model.load_state_dict(torch.load('model.pth'))
model.eval()
print("Model loaded!")

Common Mistakes to Avoid

  1. Forgetting optimizer.zero_grad(): Gradients accumulate and mess up training

  2. Wrong mode: Use model.train() for training, model.eval() for inference

  3. Forgetting torch.no_grad(): Wastes memory during inference

  4. Data on wrong device: Make sure model and data are on the same device (CPU or GPU)

  5. Not shuffling data: Always shuffle training data to prevent memorization


Quick Reference

# The training loop pattern
model.train()
for epoch in range(num_epochs):
    for batch_X, batch_y in dataloader:
        # Forward
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# The inference pattern
model.eval()
with torch.no_grad():
    predictions = model(x_new)

Key Takeaways

  • Training has 4 steps: forward, loss, backward, update
  • Always call optimizer.zero_grad() before loss.backward()
  • Use DataLoader to handle batching and shuffling
  • Use model.train() for training, model.eval() for inference
  • Use torch.no_grad() during inference to save memory
  • Save models with torch.save(model.state_dict(), 'model.pth')

What’s Next?

Congratulations! You now understand the fundamentals of PyTorch:

  1. Tensors: The data structure
  2. Autograd: Automatic gradients
  3. Neural Networks: Building models with nn.Module
  4. Training Loop: Putting it all together

From here, you can explore:

  • Convolutional Neural Networks (CNNs) for images
  • Recurrent Neural Networks (RNNs) for sequences
  • Transfer learning with pretrained models
  • More advanced architectures

Keep building and experimenting. The best way to learn is by doing!


Resources

Here are some resources to continue your learning:

Happy learning!

Comments

Join the discussion and share your thoughts