PyTorch Lightning
Gregor Lenz
26.7.2022
PyTorch Lightning
- Extends the 'low-level' PyTorch functionality
- Standardises neural network train / test setup
- Provides many helpful built-in functions
- Reduces boilerplate code
Who might use it
- Anyone who has previously trained neural networks with pure PyTorch
model = nn.Sequential(nn.Linear(...), nn.Linear(...))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
device = 'cuda:0'
for epoch in range(10):
model.train()
for data, y_true in iter(training_dataloader):
data = data.to(device)
targets = targets.to(device)
optimizer.zero_grad()
y_hat = model(data)
loss = criterion(y_hat, y_true)
loss.backward()
optimizer.step()
# log the loss
...
model.eval()
# log the epoch loss, evaluate
...
Problem: many ways to do the same thing
- Logging parameters, scalars
- Saving models
- Moving data between devices
- Training on multiple devices
- Repetitive code for training, validation, testing
Lightning makes a few assumptions
-
There is at least a training loop of data that we cycle through, potentially also validation and testing
-
If we specify a device, we are going to do all the computations on there
-
Statistics are accumulated per epoch
-
For every iteration:
- gradients are set to zero
- loss is backpropagated
- optimizer takes a step
import pytorch_lightning as pl
class MyLightningModel(pl.LightningModule):
def __init__(self):
# define model here...
def forward(self, data: torch.Tensor):
# specify how data is fed to the model
def training_step(self, batch):
# unpack a mini-batch, feed data, calculate losses
def validation_step(self, batch):
# unpack a mini-batch, feed data, calculate losses,
# calculate prediction accuracy
def configure_optimizers(self):
# specify optimizer here
Lightning makes a few assumptions
Live coding
PyTorch Lightning
By Gregor Lenz
PyTorch Lightning
- 196