2026-03-23: Training Pipeline Improvements
Changes
Progress Bars (trainer.py)
Added tqdm progress bars to train_one_epoch() and evaluate() batch loops. Previously there was no per-batch output, making it impossible to tell if training was running or hung. Progress bars show batch count, ETA, throughput, and running loss/accuracy.
Multi-GPU Support (trainer.py)
Added automatic DataParallel wrapping when multiple GPUs are detected. Uses _unwrap() helper to access the underlying model for freeze_backbone(), unfreeze_backbone(), and state_dict() saves. No config changes needed.
Learning Rate Warmup (trainer.py, config.py)
Added training.warmup_steps config option with LinearLR scheduler. Warmup starts at 1% of target LR and linearly ramps up over the configured number of steps. A fresh warmup runs for each training phase. Essential for larger batch sizes.
DataLoader Optimizations (trainer.py)
Added persistent_workers=True (avoids respawning workers between epochs) and prefetch_factor=2 (pre-loads batches).
Label File Support and Taxonomy Bug Fix (trainer.py)
Fixed bug where disease_taxonomy.json was loaded as a raw dict, producing 5 “labels” from dict keys instead of the 1,151 diseases. Changed to json.load(f)["diseases"]. Added label_file config option and created labels/common25.txt with the 25 production diseases matching the deployed model.
”Other” Class (dataset.py, config.py)
Ported the “other” class feature from algo_python. When dataset.other_class: true, an “other” class is appended to the label list. Samples with non-matching diagnoses are labeled as “other” and downsampled to dataset.other_ratio (default 10%). The trainer also now saves labels.txt alongside the model.
Removed Unused num_classes (config.py, gen2a_port.yaml)
Removed model.num_classes from config schema since the trainer derives it from the label list.
Config Tuning
Tested various batch_size / num_workers / prefetch combinations on a 2-GPU box (Titan X Pascal + 1080 Ti, 32GB RAM, 12 cores):
| Config | Throughput | Notes |
|---|---|---|
| batch=32, workers=4 | ~64 samples/sec | Original baseline |
| batch=512, workers=10, prefetch=4 | ~711 samples/sec | Best peak, but caused RAM swapping |
| batch=512, workers=6, prefetch=2 | ~34 samples/sec | Severe swap thrashing |
| batch=256, workers=4, prefetch=2 | ~175 samples/sec | Stable, no swapping |
Settled on batch=256, workers=4 as the reliable config for the 32GB box.
Files Changed
services/training/src/ddtrain/training/trainer.py- Progress bars, multi-GPU, warmup, DataLoader opts, label fixesservices/training/src/ddtrain/datasets/dataset.py- “Other” class supportservices/training/src/ddtrain/config.py- warmup_steps, other_class, other_ratio, removed num_classesservices/training/configs/gen2a_port.yaml- Updated configservices/training/labels/common25.txt- Production disease listdocs/content/services/training.mdx- Multi-GPU, labels, other class docsdocs/content/services/training/dataset.mdx- Disease label pipeline docs