The Problem With Raw PyTorch
Training a PyTorch model from scratch means writing hundreds of lines of boilerplate: the training loop, validation loop, optimizer step, gradient zeroing, device management, mixed precision scaler, distributed training setup, checkpoint saving, and metric logging. Every project replicates this code, and every replication introduces subtle bugs.
PyTorch Lightning separates your research code (what your model does) from engineering code (how training runs). Write the model; Lightning handles the rest.
LightningModule: The Core Pattern
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
class TextClassifier(pl.LightningModule):
def __init__(self, vocab_size: int, hidden_dim: int, num_classes: int, lr: float = 1e-3):
super().__init__()
self.save_hyperparameters()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, nhead=8, batch_first=True),
num_layers=4,
)
self.classifier = nn.Linear(hidden_dim, num_classes)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
x = self.transformer(x)
return self.classifier(x[:, 0]) # CLS token
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
acc = (logits.argmax(dim=-1) == y).float().mean()
self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=True)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
Trainer: Everything Else
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
trainer = Trainer(
max_epochs=100,
accelerator="gpu",
devices=4, # 4 GPUs — DDP automatically
precision="bf16-mixed", # mixed precision
gradient_clip_val=1.0, # gradient clipping
accumulate_grad_batches=4, # gradient accumulation
callbacks=[
ModelCheckpoint(monitor="val_loss", save_top_k=3, mode="min"),
EarlyStopping(monitor="val_loss", patience=10),
LearningRateMonitor(),
],
logger=WandbLogger(project="my-project"),
)
model = TextClassifier(vocab_size=30000, hidden_dim=256, num_classes=5)
trainer.fit(model, train_dataloader, val_dataloader)
One line to go from 1 GPU to 4 GPUs with DDP. No DistributedDataParallel setup, no manual sampler changes.
LightningDataModule for Reproducible Data
class TextDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
self.train_dataset = TextDataset(self.data_dir, split="train")
self.val_dataset = TextDataset(self.data_dir, split="val")
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=4)
Lightning vs Raw PyTorch
Lightning adds roughly 20% overhead compared to a perfectly optimized raw PyTorch loop, but this matters only in extreme performance scenarios. For research and production ML, Lightning's reproducibility and reduced bug surface are worth the tradeoff.
Resources: PyTorch Lightning docs, GitHub.