Skip to Content
Ai Log2026-03-23: Training Pipeline Improvements

2026-03-23: Training Pipeline Improvements

Changes

Progress Bars (trainer.py)

Added tqdm progress bars to train_one_epoch() and evaluate() batch loops. Previously there was no per-batch output, making it impossible to tell if training was running or hung. Progress bars show batch count, ETA, throughput, and running loss/accuracy.

Multi-GPU Support (trainer.py)

Added automatic DataParallel wrapping when multiple GPUs are detected. Uses _unwrap() helper to access the underlying model for freeze_backbone(), unfreeze_backbone(), and state_dict() saves. No config changes needed.

Learning Rate Warmup (trainer.py, config.py)

Added training.warmup_steps config option with LinearLR scheduler. Warmup starts at 1% of target LR and linearly ramps up over the configured number of steps. A fresh warmup runs for each training phase. Essential for larger batch sizes.

DataLoader Optimizations (trainer.py)

Added persistent_workers=True (avoids respawning workers between epochs) and prefetch_factor=2 (pre-loads batches).

Label File Support and Taxonomy Bug Fix (trainer.py)

Fixed bug where disease_taxonomy.json was loaded as a raw dict, producing 5 “labels” from dict keys instead of the 1,151 diseases. Changed to json.load(f)["diseases"]. Added label_file config option and created labels/common25.txt with the 25 production diseases matching the deployed model.

”Other” Class (dataset.py, config.py)

Ported the “other” class feature from algo_python. When dataset.other_class: true, an “other” class is appended to the label list. Samples with non-matching diagnoses are labeled as “other” and downsampled to dataset.other_ratio (default 10%). The trainer also now saves labels.txt alongside the model.

Removed Unused num_classes (config.py, gen2a_port.yaml)

Removed model.num_classes from config schema since the trainer derives it from the label list.

Config Tuning

Tested various batch_size / num_workers / prefetch combinations on a 2-GPU box (Titan X Pascal + 1080 Ti, 32GB RAM, 12 cores):

ConfigThroughputNotes
batch=32, workers=4~64 samples/secOriginal baseline
batch=512, workers=10, prefetch=4~711 samples/secBest peak, but caused RAM swapping
batch=512, workers=6, prefetch=2~34 samples/secSevere swap thrashing
batch=256, workers=4, prefetch=2~175 samples/secStable, no swapping

Settled on batch=256, workers=4 as the reliable config for the 32GB box.

Files Changed

  • services/training/src/ddtrain/training/trainer.py - Progress bars, multi-GPU, warmup, DataLoader opts, label fixes
  • services/training/src/ddtrain/datasets/dataset.py - “Other” class support
  • services/training/src/ddtrain/config.py - warmup_steps, other_class, other_ratio, removed num_classes
  • services/training/configs/gen2a_port.yaml - Updated config
  • services/training/labels/common25.txt - Production disease list
  • docs/content/services/training.mdx - Multi-GPU, labels, other class docs
  • docs/content/services/training/dataset.mdx - Disease label pipeline docs
Last updated on