PyTorch Lightning

Gregor Lenz

26.7.2022

PyTorch Lightning

  • Extends the 'low-level' PyTorch functionality
     
  • Standardises neural network train / test setup
     
  • Provides many helpful built-in functions
     
  • Reduces boilerplate code

Who might use it

  • Anyone who has previously trained neural networks with pure PyTorch
model = nn.Sequential(nn.Linear(...), nn.Linear(...))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
device = 'cuda:0'


for epoch in range(10):
  model.train()
  for data, y_true in iter(training_dataloader):
    data = data.to(device)
    targets = targets.to(device)
    optimizer.zero_grad()
    
    y_hat = model(data)
    loss = criterion(y_hat, y_true)
    loss.backward()
    optimizer.step()
    
    # log the loss
    ...
  model.eval()
  # log the epoch loss, evaluate
  ...

Problem: many ways to do the same thing

  • Logging parameters, scalars
     
  • Saving models
     
  • Moving data between devices
     
  • Training on multiple devices
     
  • Repetitive code for training, validation, testing

Lightning makes a few assumptions

  • There is at least a training loop of data that we cycle through, potentially also validation and testing
     
  • If we specify a device, we are going to do all the computations on there
     
  • Statistics are accumulated per epoch
     
  • For every iteration:
    • gradients are set to zero
    • loss is backpropagated
    • optimizer takes a step
import pytorch_lightning as pl


class MyLightningModel(pl.LightningModule):
  def __init__(self):
    # define model here...
  
  def forward(self, data: torch.Tensor):
    # specify how data is fed to the model
  
  def training_step(self, batch):
    # unpack a mini-batch, feed data, calculate losses
  
  def validation_step(self, batch):
    # unpack a mini-batch, feed data, calculate losses, 
    # calculate prediction accuracy
  
  def configure_optimizers(self):
    # specify optimizer here

Lightning makes a few assumptions

Live coding

PyTorch Lightning

By Gregor Lenz

PyTorch Lightning

  • 193