Training Service
The Training Service is the ML training pipeline for DermaDetectβs skin disease classification model. It produces PyTorch models from the curated historical dataset.
Overview
Location: services/training/
Technology Stack:
- PyTorch + torchvision for model training
- Polars for data loading (Parquet manifests)
- albumentations for image augmentation
- TensorBoard for experiment tracking
- Pydantic for configuration
Quick Start
1. Sync the dataset
The curated dataset lives on S3. Sync it locally before training:
cd services/training
# Sync manifest files + images (~150GB first time, incremental after)
uv run python -c "
from ddtrain.datasets.sync import sync_dataset
sync_dataset()
"This downloads to ~/.cache/dermadetect/v1/ by default.
2. Run training
uv run python -m ddtrain.training.trainer \
--config configs/gen2a_port.yamlOr programmatically:
from ddtrain.config import ExperimentConfig
from ddtrain.datasets.dataset import DermaDetectDataset
from ddtrain.datasets.features import MetadataEncoder
from ddtrain.datasets.transforms import get_train_transforms, get_eval_transforms
from ddtrain.models.gen2a import Gen2AModel
from ddtrain.training.trainer import train
config = ExperimentConfig(name="my_experiment")
# Load feature schema and build encoder
encoder = MetadataEncoder("~/.cache/dermadetect/v1/feature_schema.json")
# Define which diseases to classify
label_list = ["acne vulgaris", "contact dermatitis", "eczema uns", ...]
# Build datasets
train_ds = DermaDetectDataset(
dataset_dir="~/.cache/dermadetect/v1",
split="train",
encoder=encoder,
label_list=label_list,
transform=get_train_transforms(224),
)
val_ds = DermaDetectDataset(
dataset_dir="~/.cache/dermadetect/v1",
split="val",
encoder=encoder,
label_list=label_list,
transform=get_eval_transforms(224),
)
# Build model and train
model = Gen2AModel(
num_metadata_features=encoder.dim,
num_classes=len(label_list),
)
train(model, config, train_ds, val_ds)3. Monitor training
Weights & Biases (primary β enabled by default):
# First time: authenticate
wandb login
# Runs are tracked at https://wandb.ai/<entity>/dermadetect
# Each run logs: loss, accuracy, learning rate, phase, epoch time
# Best model is saved as a W&B artifactTo disable W&B, set wandb.enabled: false in the config YAML, or WANDB_MODE=disabled.
TensorBoard (also logged simultaneously):
tensorboard --logdir outputs/my_experiment/tb_logsArchitecture
Gen2A Model (legacy port)
The Gen2A model is a multimodal classifier that combines image features with patient metadata:
Image (224x224) βββ ResNet50 (pretrained) βββ Dense(256) βββ β
ββββ Dense(64) βββ Sigmoid(num_classes)
Metadata (~120 features) βββ MLP(256) βββ ββββββββββββββββββ βTwo-phase training:
- Phase 1 (3 epochs): Backbone frozen, train only the heads at
base_lr - Phase 2 (15 epochs): Backbone unfrozen, fine-tune everything at
finetune_lr
Multi-GPU Support
The trainer automatically detects multiple GPUs and wraps the model with DataParallel. No configuration needed β just have multiple CUDA devices available.
Label Selection
Models are trained on a curated subset of diseases specified by a label file (one disease per line). The production model uses 25 diseases defined in labels/common25.txt.
The dataset uses multi-hot encoding: each sample gets a label vector where label[i] = 1.0 for each matching disease. This supports multi-label classification (a sample can have multiple diagnoses).
βOtherβ Class
When dataset.other_class: true, an additional βotherβ class is appended to the label list. Samples whose diagnosis doesnβt match any disease in the label file are labeled as βotherβ instead of getting an all-zeros vector.
The dataset.other_ratio parameter (default 0.1) controls what fraction of the training set consists of βotherβ samples. Non-matching samples are downsampled to maintain this ratio, with diverse sampling across diseases. This prevents the βotherβ class from overwhelming training while still teaching the model to distinguish βnone of the aboveβ from low-confidence predictions.
This feature was ported from the original algo_python systemβs -O / --other_class flag.
Dataset
The training data is a static historical snapshot from the original DermaDetect system:
| Metric | Count |
|---|---|
| Total images | 369,786 |
| Total cases | 143,449 |
| Total patients | 89,806 |
| Unique diagnoses | 1,151 |
| Vendors | Maccabi, Yeledoctor |
Splits are by patient (not by case or image) to prevent data leakage. 80/10/10 train/val/test.
Data location: s3://dermadetect-ml-datasets/v1/
Key Files
| File | Purpose |
|---|---|
scripts/run_etl.py | ETL: SQL dump β Parquet manifest (run once) |
scripts/copy_images.py | Copy images from backup bucket to curated bucket (run once) |
src/ddtrain/datasets/sync.py | S3 β local filesystem mirror |
src/ddtrain/datasets/manifest.py | Load manifest + splits from Parquet |
src/ddtrain/datasets/features.py | Metadata encoding (raw values β tensors) |
src/ddtrain/datasets/dataset.py | PyTorch Dataset |
src/ddtrain/datasets/transforms.py | Image augmentation transforms |
src/ddtrain/models/gen2a.py | Gen2A model (ResNet50 + MLP) |
src/ddtrain/training/trainer.py | Two-phase training loop |
src/ddtrain/config.py | Pydantic config schema |
configs/gen2a_port.yaml | Default experiment config |
Running Tests
cd services/training
uv run python -m pytest tests/ -vTests include:
- Feature encoding (continuous, categorical, boolean 3-state, multi-select)
- Manifest loading and patient-level split validation
- Model forward pass, freeze/unfreeze, architecture dimensions
- Real-data validation against ETL output (if available locally)