A Training Pipeline with PyTorch Lightning and Hydra

Andrey Lukyanenko

DS TechLead, MTS AI

Reasons for writing pipeline

  • Writing everything from scratch takes time and can have errors
  • You have repeatable pieces of code anyway
  • Better understanding how the things work
  • Standardization among the team
    def configure_optimizers(self):
        optimizer = load_obj(self.cfg.optimizer.class_name)(self.model.parameters(),
        **self.cfg.optimizer.params)
        scheduler = load_obj(self.cfg.scheduler.class_name)(optimizer,
        **self.cfg.scheduler.params)

        return (
            [optimizer],
            [{'scheduler': scheduler,
            'interval': self.cfg.scheduler.step,
            'monitor': self.cfg.scheduler.monitor}],
        )
>>> python train.py
>>> python train.py optimizer=sgd
>>> python train.py model=efficientnet_model
>>> python train.py model.encoder.params.arch=resnet34
>>> python train.py datamodule.fold_n=0,1,2 -m
@hydra.main(config_path='conf', config_name='config')
def run_model(cfg: DictConfig) -> None:
    os.makedirs('logs', exist_ok=True)
    print(cfg.pretty())
    if cfg.general.log_code:
        save_useful_info()
    run(cfg)


if __name__ == '__main__':
    run_model()
    def training_step(
            self, batch: Dict[str, torch.Tensor], batch_idx: int
    ) -> Union[int, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]:
        image = batch['image']
        logits = self(image)

        target = batch['target']
        shuffled_target = batch.get('shuffled_target')
        lam = batch.get('lam')

        if shuffled_target is not None:
            loss = self.loss(logits, (target, shuffled_target, lam)).view(1)
        else:
            loss = self.loss(logits, target)

        score = self.metric(logits.argmax(1), target)
        logs = {'train_loss': loss, f'train_{self.cfg.training.metric}': score}
        return {
            'loss': loss,
            'log': logs,
            'progress_bar': logs,
            'logits': logits,
            'target': target,
            f'train_{self.cfg.training.metric}': score,
        }

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        y_true = torch.cat([x['target'] for x in outputs])
        y_pred = torch.cat([x['logits'] for x in outputs])
        score = self.metric(y_pred.argmax(1), y_true)

        logs = {'train_loss': avg_loss, f'train_{self.cfg.training.metric}': score}
        return {'log': logs, 'progress_bar': logs}

Contacts