PyTorch ML
Project context
This is a PyTorch-based ML project — model training, fine-tuning, and inference. We use Hugging Face for transformers and tokenizers, PyTorch Lightning for training orchestration when scale demands it, and Weights & Biases for experiment tracking.
Stack
- Python 3.12+
- PyTorch 2.4+ (CUDA / MPS / CPU)
- Hugging Face
transformers,datasets,accelerate,peft - PyTorch Lightning (optional, for multi-GPU)
wandbfor experiment trackingbitsandbytesfor 4-bit / 8-bit quantizationuvfor env managementsafetensorsfor model serialization (never.pt/.bin)
Folder structure
src/
data/
dataset.py — torch.utils.data.Dataset implementations
collate.py — collation functions
tokenize.py
model/
model.py — model classes (or HF AutoModel wrappers)
config.py
training/
train.py — main training entrypoint
optimizer.py — optimizer + scheduler factories
callbacks.py
inference/
predict.py
serve.py — FastAPI / Modal / Replicate wrapper
configs/
base.yaml
experiments/<name>.yaml
checkpoints/ — saved as safetensors, .gitignored
Reproducibility
import torch, random, numpy as np
def set_seeds(seed: int) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # disable for full determinism
- Set seeds at the top of every training run
- Pin all dependencies (
uv lockcommitted) - Log the git SHA, dataset hash, and full config to W&B
- Save checkpoints as
safetensors, not pickle-based formats
Datasets
- Use Hugging Face
datasetsfor tabular and text — its memory-mapping is much faster than rolling your own - For custom data, subclass
torch.utils.data.Datasetwith__len__and__getitem__ - Move heavy preprocessing into
dataset.map(...)so it caches automatically
DataLoader
from torch.utils.data import DataLoader
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4, # ~CPU count, but tune
pin_memory=True, # if using CUDA
persistent_workers=True, # avoid worker startup overhead each epoch
collate_fn=my_collate,
)
pin_memory=Truefor CUDA,Falsefor MPS / CPU- Use
persistent_workers=Trueunless you re-create the dataloader between epochs - Set
num_workers=0when debugging (so tracebacks point at the right place)
Training loop
- Use
accelerateorlightninginstead of writing your own DDP boilerplate - Mixed precision:
torch.cuda.amp.autocast(CUDA) ortorch.amp.autocast("mps")(Apple Silicon) - Gradient accumulation when batch doesn't fit;
accelerator.accumulate(model)handles it cleanly - Always
optimizer.zero_grad(set_to_none=True)— slight speedup over default
Checkpointing
- Save model + optimizer + scheduler + step + epoch + RNG state
- Use
safetensors.torch.save_filefor the model weights — nottorch.save(model.state_dict()) - Keep the last N checkpoints + the best by validation metric; rotate the rest
- Save config alongside the checkpoint so it's reproducible without the codebase
LoRA / fine-tuning
- Use
peftfor parameter-efficient fine-tuning - Set rank deliberately (
r=8is a common sweet spot for 7B-class models; higher for smaller models) - Save only the adapter weights — much smaller than full checkpoints
- For QLoRA: load base model in 4-bit via
bitsandbytes, then attach LoRA adapters
Inference
- Use
model.eval()andwith torch.no_grad():(ortorch.inference_mode()— slightly faster) - For batch inference, batch up requests; for streaming, use
transformers'TextIteratorStreamer - Quantize for inference if memory is tight:
bitsandbytes,gguf, ortorch.compile's 8-bit path
Patterns to avoid
pickle/torch.save(..., 'model.pt')— usesafetensors- Hand-rolling DDP — use
accelerateorlightning - Hard-coded paths in scripts — use a config file (Hydra / YAML)
- Forgetting
.eval()at inference — dropout and batchnorm will give wrong results .cuda()everywhere — usemodel.to(device)and passdevicefrom config- Silent device mismatches —
RuntimeError: Expected all tensors to be on the same device— set up ato_devicehelper
Logging & experiment tracking
- W&B: log loss every step, metrics every epoch, hyperparameters at start
- Save the config as a W&B artifact
- Use W&B's media logging for sample inputs/outputs
- Tag runs (
baseline,ablation-x) for filtering
Testing
- Smoke tests with a 2-batch tiny dataset to catch shape errors fast
- Unit-test datasets and collation functions
- For models, test forward pass shapes; for losses, test gradient flow
Tooling
uv venv && uv syncpython -m src.training.train --config configs/experiments/foo.yamlpytestruff check && ruff formatnvidia-smi/nvtopto watch GPU
AI behavioral rules
- Always set seeds at the top of any training script
- Never use
pickle/ rawtorch.savefor model weights — usesafetensors - Always wrap inference in
torch.inference_mode()andmodel.eval() - Prefer
accelerateover hand-rolled distributed code - Log experiments to W&B by default; never silent training runs
- Verify shapes via tiny-batch smoke test before launching long training
- Don't add new dependencies without surfacing the GPU/memory implications
- Run
pytest(smoke tests) andruff checkbefore declaring a task done