Skip to Content
ServicesTrainingTraining Service

Training Service

The Training Service is the ML training pipeline for DermaDetect’s skin disease classification model. It produces PyTorch models from the curated historical dataset.

Overview

Location: services/training/

Technology Stack:

  • PyTorch + torchvision for model training
  • Polars for data loading (Parquet manifests)
  • albumentations for image augmentation
  • TensorBoard for experiment tracking
  • Pydantic for configuration

Quick Start

1. Sync the dataset

The curated dataset lives on S3. Sync it locally before training:

cd services/training # Sync manifest files + images (~150GB first time, incremental after) uv run python -c " from ddtrain.datasets.sync import sync_dataset sync_dataset() "

This downloads to ~/.cache/dermadetect/v1/ by default.

2. Run training

uv run python -m ddtrain.training.trainer \ --config configs/gen2a_port.yaml

Or programmatically:

from ddtrain.config import ExperimentConfig from ddtrain.datasets.dataset import DermaDetectDataset from ddtrain.datasets.features import MetadataEncoder from ddtrain.datasets.transforms import get_train_transforms, get_eval_transforms from ddtrain.models.gen2a import Gen2AModel from ddtrain.training.trainer import train config = ExperimentConfig(name="my_experiment") # Load feature schema and build encoder encoder = MetadataEncoder("~/.cache/dermadetect/v1/feature_schema.json") # Define which diseases to classify label_list = ["acne vulgaris", "contact dermatitis", "eczema uns", ...] # Build datasets train_ds = DermaDetectDataset( dataset_dir="~/.cache/dermadetect/v1", split="train", encoder=encoder, label_list=label_list, transform=get_train_transforms(224), ) val_ds = DermaDetectDataset( dataset_dir="~/.cache/dermadetect/v1", split="val", encoder=encoder, label_list=label_list, transform=get_eval_transforms(224), ) # Build model and train model = Gen2AModel( num_metadata_features=encoder.dim, num_classes=len(label_list), ) train(model, config, train_ds, val_ds)

3. Monitor training

Weights & Biases (primary β€” enabled by default):

# First time: authenticate wandb login # Runs are tracked at https://wandb.ai/<entity>/dermadetect # Each run logs: loss, accuracy, learning rate, phase, epoch time # Best model is saved as a W&B artifact

To disable W&B, set wandb.enabled: false in the config YAML, or WANDB_MODE=disabled.

TensorBoard (also logged simultaneously):

tensorboard --logdir outputs/my_experiment/tb_logs

Architecture

Gen2A Model (legacy port)

The Gen2A model is a multimodal classifier that combines image features with patient metadata:

Image (224x224) ──→ ResNet50 (pretrained) ──→ Dense(256) ──→ ┐ β”œβ”€β”€β†’ Dense(64) ──→ Sigmoid(num_classes) Metadata (~120 features) ──→ MLP(256) ──→ ─────────────────→ β”˜

Two-phase training:

  1. Phase 1 (3 epochs): Backbone frozen, train only the heads at base_lr
  2. Phase 2 (15 epochs): Backbone unfrozen, fine-tune everything at finetune_lr

Multi-GPU Support

The trainer automatically detects multiple GPUs and wraps the model with DataParallel. No configuration needed β€” just have multiple CUDA devices available.

Label Selection

Models are trained on a curated subset of diseases specified by a label file (one disease per line). The production model uses 25 diseases defined in labels/common25.txt.

The dataset uses multi-hot encoding: each sample gets a label vector where label[i] = 1.0 for each matching disease. This supports multi-label classification (a sample can have multiple diagnoses).

”Other” Class

When dataset.other_class: true, an additional β€œother” class is appended to the label list. Samples whose diagnosis doesn’t match any disease in the label file are labeled as β€œother” instead of getting an all-zeros vector.

The dataset.other_ratio parameter (default 0.1) controls what fraction of the training set consists of β€œother” samples. Non-matching samples are downsampled to maintain this ratio, with diverse sampling across diseases. This prevents the β€œother” class from overwhelming training while still teaching the model to distinguish β€œnone of the above” from low-confidence predictions.

This feature was ported from the original algo_python system’s -O / --other_class flag.

Dataset

The training data is a static historical snapshot from the original DermaDetect system:

MetricCount
Total images369,786
Total cases143,449
Total patients89,806
Unique diagnoses1,151
VendorsMaccabi, Yeledoctor

Splits are by patient (not by case or image) to prevent data leakage. 80/10/10 train/val/test.

Data location: s3://dermadetect-ml-datasets/v1/

Key Files

FilePurpose
scripts/run_etl.pyETL: SQL dump β†’ Parquet manifest (run once)
scripts/copy_images.pyCopy images from backup bucket to curated bucket (run once)
src/ddtrain/datasets/sync.pyS3 β†’ local filesystem mirror
src/ddtrain/datasets/manifest.pyLoad manifest + splits from Parquet
src/ddtrain/datasets/features.pyMetadata encoding (raw values β†’ tensors)
src/ddtrain/datasets/dataset.pyPyTorch Dataset
src/ddtrain/datasets/transforms.pyImage augmentation transforms
src/ddtrain/models/gen2a.pyGen2A model (ResNet50 + MLP)
src/ddtrain/training/trainer.pyTwo-phase training loop
src/ddtrain/config.pyPydantic config schema
configs/gen2a_port.yamlDefault experiment config

Running Tests

cd services/training uv run python -m pytest tests/ -v

Tests include:

  • Feature encoding (continuous, categorical, boolean 3-state, multi-select)
  • Manifest loading and patient-level split validation
  • Model forward pass, freeze/unfreeze, architecture dimensions
  • Real-data validation against ETL output (if available locally)
Last updated on