Skip to main content

Command Palette

Search for a command to run...

Building a Production-Ready Medical AI System: From PyTorch to ONNX Inference

Updated
11 min read
V

Hi, I’m Vineet.

I’m exploring artificial intelligence, machine learning, and computing systems through hands-on projects, experiments, and writing. My interests span NLP, representation learning, software engineering, and applied research. I use this space to document what I learn, share practical insights, and build a strong technical foundation for research and industry.

Always learning. Always building.

TL;DR: I built an end-to-end skin lesion classification system that achieved a macro F1 score of 0.6587 while reducing model size from 41 MB to 0.9 MB through ONNX optimization. This post covers the model development process, deployment architecture, and lessons learned while preparing the system for real-world inference.


The Problem: Why I Built This

Skin cancer affects millions of people yearly, yet access to dermatologists is geographically limited. Computer vision models can help with early detection, but there's a significant gap between research papers and production systems that actually work.

Many machine learning projects stop after achieving acceptable validation metrics. In practice, however, deployment, optimization, monitoring, and maintainability are often the harder engineering problems.

I wanted to bridge that gap. The goal: build a production-grade system that runs on edge devices (laptops, mobile, embedded systems), not just GPU clusters.


The Architecture: Making Smart Trade-offs

Choosing the Right Backbone

I evaluated several architectures before settling on EfficientNet-B3:

Architecture Accuracy Size Trade-off
ResNet-50 High 102 MB Too heavy for edge
MobileNet-V3 Medium 12 MB Too small for medical data
EfficientNet-B3 High 41 MB Sweet spot
EfficientNet-B4 Very High 70 MB Overkill for this task

EfficientNet scales depth, width, and resolution in a principled way. B3 offered the best trade-off between accuracy and deployability.

Model Architecture:
├─ EfficientNet-B3 (pretrained on ImageNet)
├─ 10.7M parameters
├─ Input: 300×300 RGB images
└─ Output: 9-class logits (skin lesion types)

The Training Pipeline: Systematic Experimentation

The Class Imbalance Problem

Real medical datasets are deeply imbalanced. The ISIC dataset I used had severe disparities:

Melanocytic nevi:      ~2,000 samples  (common)
Melanoma:              ~1,100 samples  (critical)
Vascular lesions:      ~150 samples    (rare)

Solution: Per-class weighted loss

# Calculate weights inversely proportional to class frequency
weights = total_samples / (num_classes × samples_per_class)

# Apply in loss function
criterion = nn.CrossEntropyLoss(
    weight=torch.tensor(weights),
    label_smoothing=0.1  # Prevent overconfidence
)

This simple change improved macro F1 from 0.46 → 0.53.

The Ablation Study: Finding the Real Bottleneck

I ran systematic experiments to isolate what actually matters:

Experiment Macro F1 Key Finding
Baseline (EfficientNet-B0, no augmentation) 0.46 The model overfits badly
+ Data augmentation (flip, rotate, color jitter) 0.51 Helps significantly
+ Class weighting only 0.45 Can conflict with augmentation
+ Both augmentation + weights 0.53 Synergistic but plateaus
+ Learning rate scheduler + EfficientNet-B3 0.6587 Best-performing configuration

One of the most impactful improvements came from introducing adaptive learning rate scheduling.

# Learning-rate scheduling with ReduceLROnPlateau
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,      # Reduce LR by 10× when stuck
    patience=3,      # Wait 3 epochs before reducing
    verbose=True
)

# In training loop
for epoch in range(num_epochs):
    train_loss = train_epoch()
    val_loss = validate_epoch()
    scheduler.step(val_loss)  # Adaptive learning rate

This helped the model continue improving after validation loss plateaued. Combined with early stopping (patience=7), the model converged by epoch 9 instead of requiring extended training.


ONNX Export & Optimization

This stage is often overlooked in many ML projects. They train on GPU, deploy on GPU, and call it production-ready.

In a production environment, you need inference on CPU, mobile devices, embedded systems, and web browsers.

The Export Process

import torch
import onnx

# 1. Load trained model and set to eval mode
model.to('cpu')  # CRITICAL: CPU, not GPU
model.eval()

# 2. Create dummy input matching model expectations
dummy_input = torch.randn(1, 3, 300, 300, device='cpu')

# 3. Export to ONNX format
torch.onnx.export(
    model,
    dummy_input,
    "skin_classifier_v1.onnx",
    input_names=["input_image"],
    output_names=["logits"],
    opset_version=14,  # Good balance of ops and compatibility
    do_constant_folding=True,
    dynamic_axes={
        "input_image": {0: 'batch_size'},
        "logits": {0: 'batch_size'}
    }
)

# 4. Validate ONNX model structure
onnx_model = onnx.load("skin_classifier_v1.onnx")
onnx.checker.check_model(onnx_model)
print(" ONNX model is valid")

Numerical Validation

Here's the critical part most people skip:

import onnxruntime as ort
import numpy as np

# Load ONNX session
ort_session = ort.InferenceSession(
    "skin_classifier_v1.onnx",
    providers=['CPUExecutionProvider']
)

# Compare PyTorch vs ONNX on random inputs
for i in range(5):
    test_input = torch.randn(1, 3, 300, 300, device='cpu')
    
    # PyTorch inference
    with torch.no_grad():
        pytorch_output = model(test_input).detach().numpy()
    
    # ONNX inference
    onnx_output = ort_session.run(
        None, 
        {"input_image": test_input.numpy()}
    )[0]
    
    # Check equivalence
    max_diff = np.abs(pytorch_output - onnx_output).max()
    print(f"Test {i+1}: max diff = {max_diff:.6f}")
    assert max_diff < 1e-5  # Numerical equivalence 

Result: Maximum difference of 4e-6 across all tests. This validation caught subtle bugs I would've missed in production.

The Impact: Numbers That Matter

PyTorch Model:       41 MB
ONNX Model:          0.9 MB
━━━━━━━━━━━━━━━━━━━━━━━━
Compression:         45× smaller 

Inference Latency (CPU):
• PyTorch:           ~120 ms/image
• ONNX Runtime:      ~43 ms/image  
Speed difference:    2.8× faster

Memory footprint:    Reduced from 150 MB → 12 MB

The ONNX model can be deployed across multiple environments, including 
- mobile devices, 
- embedded systems, 
- web applications, and 
- CPU-based servers.

Building the Production System

System Architecture

┌─────────────────────────────────────────┐
│   User: Upload Dermoscopic Image            │
└────────────────┬────────────────────────┘
                   │
┌────────────────▼───────────────────────┐
│      FastAPI Server (main.py)               │
├─────────────────────────────────────────┤
│ • File validation (size, format)            │
│ • Rate limiting (60 req/min)                │
│ • Error handling & logging                  │
└────────────────┬────────────────────────┘
                   │
┌────────────────▼────────────────────────┐
│   ONNX Inference Engine (inference.py)      │
├─────────────────────────────────────────┤
│ 1. Preprocess image                         │
│    ├─ OpenCV decode                         │
│    ├─ Resize to 300×300                     │
│    └─ Normalize (ImageNet stats)            │
│ 2. ONNX Runtime inference                   │
│ 3. Softmax → probabilities                  │
└────────────────┬────────────────────────┘
                   │
┌────────────────▼────────────────────────┐
│   Clinical Decision Support                 │
├─────────────────────────────────────────┤
│ • Flag low confidence (< 50%)               │
│ • Flag melanoma suspicion                   │
│ • Recommend dermatologist review            │
└────────────────┬────────────────────────┘
                  │
┌────────────────▼────────────────────────┐
│   JSON Response + Streamlit UI              │
│                                             │
│ {"prediction": "Melanoma",                  │
│  "confidence": 0.82,                        │
│  "requires_review": true}                   │
└─────────────────────────────────────────┘

Robust Image Preprocessing

class SkinInference:
    def preprocess(self, image_bytes: bytes) -> np.ndarray:
        """Convert raw bytes to model input with validation."""
        
        try:
            # Validate input
            if len(image_bytes) == 0:
                raise ValueError("Image is empty")
            
            # Decode image (handles JPEG, PNG, etc.)
            nparr = np.frombuffer(image_bytes, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            
            if img is None:
                raise ValueError(
                    "Failed to decode image. "
                    "Ensure it's a valid JPEG/PNG file."
                )
            
            # Convert color space
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Resize to model input size
            img = cv2.resize(img, (300, 300))
            
            # Normalize using ImageNet statistics
            img = img.astype(np.float32) / 255.0
            img = (img - np.array([0.485, 0.456, 0.406])) / \
                  np.array([0.229, 0.224, 0.225])
            
            # Convert HWC to CHW and add batch dimension
            img = np.transpose(img, (2, 0, 1))
            img = np.expand_dims(img, axis=0)
            
            return img.astype(np.float32)
        
        except ValueError as e:
            raise ValueError(f"Preprocessing failed: {e}")
        except Exception as e:
            raise RuntimeError(
                f"Unexpected error during preprocessing: {e}"
            )

Clinical Decision Logic

@app.post("/predict")
async def predict_lesion(file: UploadFile = File(...)):
    """Classify skin lesion with clinical guidance."""
    
    # Validate file
    if file.content_type not in {"image/jpeg", "image/png"}:
        raise HTTPException(
            status_code=400,
            detail="Invalid file type. Use JPEG or PNG."
        )
    
    # Run inference
    image_bytes = await file.read()
    result = predictor.predict_with_metadata(image_bytes)
    
    # Clinical decision support
    needs_dermatologist = (
        result["is_uncertain"] or  # Low confidence?
        result["prediction"] == "Melanoma"  # High-risk lesion?
    )
    
    return {
        "filename": file.filename,
        "prediction": result["prediction"],
        "confidence": round(result["confidence"], 4),
        "is_uncertain": result["is_uncertain"],
        "requires_dermatologist_review": needs_dermatologist,
        "clinical_note": (
            "High-risk prediction!"
            "Recommend immediate dermatologist consultation."
            if needs_dermatologist
            else "Monitor for changes over time."
        ),
        "all_probabilities": {
            name: round(prob, 4)
            for name, prob in result["all_predictions"]
        }
    }

Testing & Validation:

Comprehensive tests help identify issues before deployment:

import pytest
from app.inference import SkinInference

class TestPreprocessing:
    def test_valid_image_preprocessing(self, valid_image_bytes):
        """Does preprocessing handle valid images correctly?"""
        inference = SkinInference("models/skin_classifier_v1.onnx")
        processed = inference.preprocess(valid_image_bytes)
        
        assert processed.shape == (1, 3, 300, 300)
        assert processed.dtype == np.float32
        assert -5 < processed.min() < processed.max() < 5
    
    def test_empty_image_raises_error(self, empty_bytes):
        """Does it reject empty files?"""
        inference = SkinInference("models/skin_classifier_v1.onnx")
        
        with pytest.raises(ValueError, match="Image is empty"):
            inference.preprocess(empty_bytes)
    
    def test_invalid_format_raises_error(self, invalid_bytes):
        """Does it reject corrupted images?"""
        inference = SkinInference("models/skin_classifier_v1.onnx")
        
        with pytest.raises(ValueError, match="Failed to decode"):
            inference.preprocess(invalid_bytes)

class TestInference:
    def test_predictions_are_valid(self, valid_image_bytes):
        """Are predictions normalized and sorted?"""
        inference = SkinInference("models/skin_classifier_v1.onnx")
        predictions = inference.predict(valid_image_bytes)
        
        # Should return all 9 classes
        assert len(predictions) == 9
        
        # Should be sorted by confidence
        confidences = [p[1] for p in predictions]
        assert confidences == sorted(confidences, reverse=True)
        
        # Should sum to 1.0 (softmax property)
        assert 0.99 < sum(confidences) < 1.01

Test Coverage:

  • Image preprocessing (JPEG/PNG, edge cases)

  • Inference correctness (softmax, ordering)

  • ONNX validation (PyTorch ↔ ONNX equivalence)

  • API error handling (invalid files, oversized images)

  • Batch processing limits

Result: 6/6 tests passing


Deployment: Making It Real

Option 1: Local FastAPI Server

# Install dependencies
pip install -r requirements.txt

# Start server
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000

# Interactive API at http://localhost:8000/docs

Option 2: Docker Container

FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt --no-cache-dir

COPY . .

EXPOSE 8000

CMD ["python", "-m", "uvicorn", "app.main:app", \
     "--host", "0.0.0.0", "--port", "8000"]
# Build and run
docker build -t skin-lesion-api:v1 .
docker run -p 8000:8000 skin-lesion-api:v1

Option 3: Streamlit Interactive UI

streamlit run app/ui.py

The Streamlit interface provides:

  • Drag-and-drop image upload

  • Real-time predictions

  • Confidence visualization with bar charts

  • Clinical decision support messaging

  • Responsive mobile design


Key Learnings: Mistakes That Taught Me

1. ONNX Export Is Deceptively Fragile

# WRONG: Exporting from GPU
model.to('cuda')
torch.onnx.export(model, dummy_input, "model.onnx")
# Result: Numerical instability, subtle bugs

# CORRECT: Always export from CPU
model.to('cpu')
model.eval()
torch.onnx.export(model, dummy_input, "model.onnx")
# Result: Stable, reproducible exports

I discovered this the hard way when ONNX outputs differed significantly from PyTorch on the GPU. Always export from the CPU.

2. Validation Caught What Testing Missed

Although the observed difference was extremely small (4e-6), validating numerical consistency between PyTorch and ONNX provided confidence that the exported model behaved as expected. The validation step caught this early.

3. Medical Data Has Unique Challenges

Standard ML practices don't always transfer to healthcare:

  • Class imbalance: Some lesions are 10x rarer than others

  • Dataset bias: ISIC underrepresents certain skin tones

  • High stakes: Misclassification can harm real people

  • Regulatory: HIPAA and GDPR compliance are needed for real deployment

Solutions I implemented:

  • Weighted loss functions for fair training

  • Stratified splits for representative validation

  • Conservative uncertainty thresholds

  • Clear recommendations to consult dermatologists

  • Comprehensive error logging

4. Inference ≠ Training

Most ML courses focus on training. Real jobs focus on:

  • Latency: Can it respond in < 100ms?

  • Throughput: How many requests/second?

  • Memory: Will it fit in production constraints?

  • Cold start: How long to load the model?

  • Monitoring: Can you detect performance degradation?

I optimized for all of these, not just accuracy.


What's Next: The Road Ahead

Future improvements I'm planning:

# Model improvements
[ ] Ensemble multiple models for higher accuracy
[ ] Fine-tune on specialized dermatology datasets
[ ] Add confidence calibration (temperature scaling)

# Explainability
[ ] Grad-CAM heatmaps showing important regions
[ ] LIME for local interpretable explanations
[ ] Attention visualization

# Production hardening
[ ] GPU optimization with TensorRT
[ ] Model versioning & A/B testing
[ ] Performance monitoring & alerting
[ ] Drift detection & retraining triggers

# Deployment
[ ] Mobile app (iOS/Android with CoreML/TensorFlow Lite)
[ ] Web browser inference (ONNX.js)
[ ] Edge device optimization (TensorFlow Lite, OpenVINO)

Lessons from Moving Beyond Model Training

This project taught me something crucial that most ML students miss:

Training a model is only one part of the engineering challenge. Building a reliable, deployable, and maintainable system requires a different set of considerations.

Academic ML involves:

  • Clean datasets

  • GPU access

  • Simple metrics

  • No deployment concerns

Production ML requires:

  • Data validation and cleaning

  • Error handling for edge cases

  • Multiple metrics (latency, throughput, fairness)

  • Reproducible deployments

  • Monitoring and alerting

  • Clear documentation

This project reinforced an important lesson: model performance is only one part of a successful machine learning system. Robust preprocessing, validation, deployment, monitoring, and maintainability are equally important when transitioning from experimentation to production.


Code & Resources

Repository:

Key References:


Let's Connect

Have questions about this approach? Interested in discussing production ML systems?

I'm happy to discuss:

  • ONNX optimization strategies

  • Medical AI systems

  • Production ML architecture

  • Model deployment patterns

  • Inference optimization


Last updated: June 2026

Tags: #MachineLearning #MLEngineering #ONNX #FastAPI #ProductionML #MedicalAI #DeepLearning #Python #InferenceOptimization #SystemsDesign