Skip to content

ML System Design

ML Systems vs Traditional Software

Building ML systems requires all the discipline of traditional software engineering plus additional complexity around data, training, and model behavior.

Traditional Software:
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Code │────▶│ Build │────▶│ Deploy │
└──────────┘ └──────────┘ └──────────┘
Deterministic. Same input → same output.
ML Systems:
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ Data │─▶│ Features │─▶│ Train │─▶│ Evaluate │─▶│ Deploy │
└──────────┘ └──────────┘ └──────────┘ └──────────┘ └──────────┘
│ │ │ │
▼ ▼ ▼ ▼
Data quality Hyperparameter Model quality Model decay
Data drift tuning Bias/Fairness Monitoring
Data versioning Experiment A/B testing Retraining
tracking
Probabilistic. Same input → similar (not identical) output.
Model behavior changes with data. Degrades over time.

What Makes ML Systems Hard

ChallengeDescription
Data dependenciesModels are only as good as their data; bad data breaks models silently
Feedback loopsModel predictions influence future data, creating cycles
Training-serving skewDifferences between how features are computed in training vs serving
Silent failuresA model can return plausible but wrong predictions without crashing
ReproducibilityResults depend on data, code, hyperparameters, random seeds, and hardware
Model decayReal-world data distributions shift, degrading model performance over time

Feature Stores

A feature store is a centralized repository for storing, managing, and serving ML features. It solves the critical problem of training-serving skew — ensuring the same feature computation logic is used in both training and inference.

Without Feature Store:
┌──────────┐ ┌──────────┐
│ Training │ Feature logic A │ Serving │ Feature logic B
│ Pipeline │ (Python script) │ Service │ (Java service)
└──────────┘ └──────────┘
⚠ Different implementations → different results!
With Feature Store:
┌──────────┐ ┌──────────────┐ ┌──────────┐
│ Training │◀───▶│ │◀───▶│ Serving │
│ Pipeline │ │ Feature │ │ Service │
└──────────┘ │ Store │ └──────────┘
│ │
┌──────────┐ │ - Offline │ ┌──────────┐
│ Feature │────▶│ store │────▶│ Model │
│ Pipeline │ │ - Online │ │ Training │
└──────────┘ │ store │ └──────────┘
│ - Registry │
└──────────────┘
Same features for training and serving.

Feature Store Architecture

┌─────────────────────────────────────────────────────────┐
│ Feature Store │
│ │
│ ┌─────────────────┐ ┌─────────────────────────┐ │
│ │ Feature │ │ Offline Store │ │
│ │ Registry │ │ (Historical features) │ │
│ │ │ │ - Data warehouse │ │
│ │ - Definitions │ │ - Parquet files │ │
│ │ - Metadata │ │ - Used for training │ │
│ │ - Lineage │ └─────────────────────────┘ │
│ │ - Versioning │ │
│ └─────────────────┘ ┌─────────────────────────┐ │
│ │ Online Store │ │
│ │ (Low-latency serving) │ │
│ │ - Redis / DynamoDB │ │
│ │ - Latest feature values │ │
│ │ - Used for inference │ │
│ └─────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
# Using Feast (open-source feature store)
from feast import FeatureStore, Entity, FeatureView, Field
from feast.types import Float64, Int64, String
from feast import FileSource
from datetime import timedelta
# Define a data source
driver_stats_source = FileSource(
path="data/driver_stats.parquet",
timestamp_field="event_timestamp",
)
# Define an entity (the primary key)
driver = Entity(
name="driver_id",
description="Driver identifier",
)
# Define a feature view
driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=1),
schema=[
Field(name="conv_rate", dtype=Float64),
Field(name="acc_rate", dtype=Float64),
Field(name="avg_daily_trips", dtype=Int64),
],
source=driver_stats_source,
)
# Initialize the feature store
store = FeatureStore(repo_path="feature_repo/")
# Get training data (offline store)
training_df = store.get_historical_features(
entity_df=entity_df, # DataFrame with driver_id + timestamps
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
],
).to_df()
# Get real-time features for serving (online store)
online_features = store.get_online_features(
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
],
entity_rows=[{"driver_id": 1001}],
).to_dict()
print(online_features)
# {'driver_id': [1001], 'conv_rate': [0.87], 'acc_rate': [0.93]}

Model Serving

Model serving is how you make predictions available to applications. The two primary patterns are batch inference and real-time inference.

Batch vs Real-Time Serving

Batch Inference:
┌──────────┐ ┌──────────┐ ┌──────────────────┐
│ All data │───▶│ Model │───▶│ Predictions │
│ (daily) │ │ │ │ stored in DB/cache│
└──────────┘ └──────────┘ └──────────────────┘
Run periodically (hourly/daily). Pre-compute predictions.
Real-Time Inference:
┌──────────┐ ┌──────────┐ ┌──────────┐
│ API │───▶│ Model │───▶│ Response │
│ Request │ │ Server │ │ (< 100ms)│
└──────────┘ └──────────┘ └──────────┘
On-demand. Compute predictions per request.
AspectBatch InferenceReal-Time Inference
LatencyMinutes to hoursMilliseconds
ThroughputHigh (bulk processing)Lower (per-request)
FreshnessPredictions may be staleAlways up-to-date
CostCheaper (bulk compute)More expensive (always-on)
ComplexitySimpler infrastructureRequires model servers, load balancing
Use casesRecommendations, reports, email campaignsSearch ranking, fraud detection, chatbots

Real-Time Serving Architectures

# FastAPI model serving
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
from typing import List
app = FastAPI(title="ML Model API")
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
model = joblib.load("models/fraud_detector_v2.pkl")
class PredictionRequest(BaseModel):
features: List[float]
user_id: str
class PredictionResponse(BaseModel):
prediction: int
probability: float
model_version: str
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
if model is None:
raise HTTPException(
status_code=503,
detail="Model not loaded"
)
features = np.array(request.features).reshape(1, -1)
prediction = int(model.predict(features)[0])
probability = float(
model.predict_proba(features)[0].max()
)
return PredictionResponse(
prediction=prediction,
probability=probability,
model_version="v2.1.0"
)
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": model is not None
}

A/B Testing for ML

A/B testing ML models is more complex than testing UI changes because model performance depends on many factors and may take longer to measure.

ML A/B Testing Architecture:
┌────────────────────────────────────────────────┐
│ Traffic Router │
│ (based on user_id hash) │
└──────────┬────────────────────┬────────────────┘
│ 50% │ 50%
▼ ▼
┌──────────────────┐ ┌──────────────────┐
│ Model A │ │ Model B │
│ (Control) │ │ (Challenger) │
│ v2.0 │ │ v2.1 │
└────────┬─────────┘ └────────┬─────────┘
│ │
▼ ▼
┌──────────────────────────────────────────┐
│ Metrics Collection │
│ │
│ - Prediction latency │
│ - Business metrics (CTR, revenue) │
│ - Model quality (accuracy, F1) │
│ - Error rates │
│ - User engagement │
└──────────────────┬───────────────────────┘
┌──────────────────────────────────────────┐
│ Statistical Analysis │
│ │
│ - Statistical significance test │
│ - Confidence intervals │
│ - Effect size calculation │
│ - Guard rail metrics check │
└──────────────────────────────────────────┘
import hashlib
from dataclasses import dataclass
from typing import Optional
import numpy as np
from scipy import stats
@dataclass
class Experiment:
name: str
control_model: str
treatment_model: str
traffic_split: float = 0.5 # 50/50
class ABTestRouter:
def __init__(self, experiment: Experiment):
self.experiment = experiment
self.control_metrics = []
self.treatment_metrics = []
def get_variant(self, user_id: str) -> str:
"""Deterministic assignment based on user ID."""
hash_val = int(
hashlib.sha256(
f"{self.experiment.name}:{user_id}".encode()
).hexdigest(),
16
)
if (hash_val % 100) / 100 < self.experiment.traffic_split:
return "control"
return "treatment"
def get_model(self, user_id: str) -> str:
variant = self.get_variant(user_id)
if variant == "control":
return self.experiment.control_model
return self.experiment.treatment_model
def record_metric(
self, user_id: str, metric_value: float
):
variant = self.get_variant(user_id)
if variant == "control":
self.control_metrics.append(metric_value)
else:
self.treatment_metrics.append(metric_value)
def analyze_results(self) -> dict:
"""Run statistical significance test."""
control = np.array(self.control_metrics)
treatment = np.array(self.treatment_metrics)
# Two-sample t-test
t_stat, p_value = stats.ttest_ind(
control, treatment
)
# Effect size (Cohen's d)
pooled_std = np.sqrt(
(control.std()**2 + treatment.std()**2) / 2
)
cohens_d = (
(treatment.mean() - control.mean()) / pooled_std
)
return {
"control_mean": float(control.mean()),
"treatment_mean": float(treatment.mean()),
"relative_improvement": float(
(treatment.mean() - control.mean()) /
control.mean() * 100
),
"p_value": float(p_value),
"is_significant": p_value < 0.05,
"cohens_d": float(cohens_d),
"control_n": len(control),
"treatment_n": len(treatment),
}
# Usage
experiment = Experiment(
name="fraud_model_v2.1",
control_model="models/fraud_v2.0",
treatment_model="models/fraud_v2.1",
traffic_split=0.5
)
router = ABTestRouter(experiment)
# In your API endpoint
model_path = router.get_model(user_id="user_123")
prediction = load_and_predict(model_path, features)
# Record business metric later
router.record_metric("user_123", metric_value=1.0)

Model Monitoring and Drift Detection

Models degrade over time as the real world changes. Monitoring detects these problems before they impact users.

Types of Drift

Data Drift (Feature Drift):
The distribution of input features changes.
Training Data Production Data
┌──────────────┐ ┌──────────────┐
│ ___ │ │ ___ │
│ / \ │ │ / \ │
│ / \ │ → │ / \ │
│ / \ │ │ / \│
│ / \ │ │ / │
└──────────────┘ └──────────────┘
Distribution shifted to the right.
Concept Drift:
The relationship between features and target changes.
Before: Income > $80K → Low fraud risk
After: Income > $80K → Same fraud risk as others
(Fraud patterns evolved)
Prediction Drift:
The distribution of model predictions changes.
Before: 5% flagged as fraud
After: 15% flagged as fraud
(Model behavior changed)
Drift TypeWhat ChangesDetection Method
Data driftInput feature distributionsKS test, PSI, JS divergence
Concept driftFeature-target relationshipMonitor prediction accuracy over time
Prediction driftOutput distributionCompare prediction histograms
Label driftTarget variable distributionMonitor ground truth labels
import numpy as np
from scipy import stats
from dataclasses import dataclass
from typing import Dict, List, Optional
from datetime import datetime
@dataclass
class DriftReport:
feature_name: str
drift_score: float
is_drifted: bool
method: str
timestamp: datetime
class ModelMonitor:
def __init__(self, reference_data: dict, threshold=0.05):
"""
reference_data: baseline feature distributions
from training data
threshold: p-value threshold for drift detection
"""
self.reference = reference_data
self.threshold = threshold
self.alerts = []
def check_data_drift(
self,
current_data: dict,
method: str = "ks"
) -> List[DriftReport]:
"""Check each feature for distribution drift."""
reports = []
for feature_name in self.reference:
ref = np.array(self.reference[feature_name])
curr = np.array(current_data[feature_name])
if method == "ks":
# Kolmogorov-Smirnov test
statistic, p_value = stats.ks_2samp(
ref, curr
)
drift_score = statistic
is_drifted = p_value < self.threshold
elif method == "psi":
# Population Stability Index
drift_score = self._calculate_psi(ref, curr)
is_drifted = drift_score > 0.2
report = DriftReport(
feature_name=feature_name,
drift_score=drift_score,
is_drifted=is_drifted,
method=method,
timestamp=datetime.now()
)
reports.append(report)
if is_drifted:
self.alerts.append(report)
return reports
def check_prediction_drift(
self,
reference_preds: np.ndarray,
current_preds: np.ndarray
) -> DriftReport:
"""Check if model predictions have shifted."""
statistic, p_value = stats.ks_2samp(
reference_preds, current_preds
)
return DriftReport(
feature_name="predictions",
drift_score=statistic,
is_drifted=p_value < self.threshold,
method="ks",
timestamp=datetime.now()
)
def check_performance_degradation(
self,
recent_accuracy: float,
baseline_accuracy: float,
tolerance: float = 0.05
) -> bool:
"""Alert if accuracy drops below tolerance."""
degradation = baseline_accuracy - recent_accuracy
if degradation > tolerance:
self.alerts.append(
f"Performance degraded by "
f"{degradation:.2%}"
)
return True
return False
def _calculate_psi(
self, reference, current, bins=10
) -> float:
"""Population Stability Index."""
ref_percents = np.histogram(
reference, bins=bins
)[0] / len(reference)
curr_percents = np.histogram(
current, bins=bins
)[0] / len(current)
# Avoid log(0)
ref_percents = np.clip(ref_percents, 1e-6, None)
curr_percents = np.clip(curr_percents, 1e-6, None)
psi = np.sum(
(curr_percents - ref_percents) *
np.log(curr_percents / ref_percents)
)
return psi
# Usage
monitor = ModelMonitor(
reference_data={
"age": training_ages,
"income": training_incomes,
"transaction_amount": training_amounts,
}
)
# Check daily
drift_reports = monitor.check_data_drift(
current_data=todays_data,
method="ks"
)
for report in drift_reports:
if report.is_drifted:
print(
f"DRIFT DETECTED: {report.feature_name} "
f"(score: {report.drift_score:.4f})"
)

MLOps Pipeline

MLOps applies DevOps principles to ML systems. A mature MLOps pipeline automates the entire lifecycle from data ingestion to model monitoring.

MLOps Pipeline Architecture:
┌──────────────────────────────────────────────────────────────────┐
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │
│ │ Data │ │ Feature │ │ Model │ │ Experiment │ │
│ │ Pipeline │─▶│ Store │─▶│ Training │─▶│ Tracking │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────┬───────┘ │
│ │ │ │
│ │ ┌─────────────────────────────────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │
│ │ Data │ │ Model │ │ Model │ │ Model │ │
│ │ Version. │ │ Registry │─▶│ Serving │─▶│ Monitoring │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────────┘ │
│ │ │
│ Trigger retrain │
│ if drift detected │
│ │ │
│ ◀───────────────────────────────┘ │
└──────────────────────────────────────────────────────────────────┘

MLOps Maturity Levels

LevelNameDescription
0ManualManual data prep, training in notebooks, manual deployment
1ML PipelineAutomated training pipeline, manual deployment
2CI/CD for MLAutomated training, testing, and deployment
3Full MLOpsAutomated retraining, monitoring, drift detection, A/B testing

Experiment Tracking

import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
# Set up MLflow experiment
mlflow.set_experiment("fraud-detection-v2")
with mlflow.start_run(run_name="rf_100_trees"):
# Log parameters
params = {
"n_estimators": 100,
"max_depth": 10,
"min_samples_split": 5,
"class_weight": "balanced"
}
mlflow.log_params(params)
# Train model
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# Evaluate
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"f1_score": f1_score(y_test, y_pred),
"train_samples": len(X_train),
"test_samples": len(X_test),
}
mlflow.log_metrics(metrics)
# Log model
mlflow.sklearn.log_model(
model,
"model",
registered_model_name="fraud-detector"
)
# Log artifacts (plots, data samples, etc.)
mlflow.log_artifact("confusion_matrix.png")
print(f"Run ID: {mlflow.active_run().info.run_id}")
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1: {metrics['f1_score']:.4f}")
# Model registry -- promote to production
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name="fraud-detector",
version=3,
stage="Production"
)

Scaling ML Systems

Scaling Strategies

Model Optimization:
┌─────────────────────────────────────────────┐
│ Technique │ Speedup │ Trade-off │
├────────────────────┼─────────┼──────────────┤
│ Quantization │ 2-4x │ Slight │
│ (FP32 → INT8) │ │ accuracy drop│
├────────────────────┼─────────┼──────────────┤
│ Pruning │ 2-10x │ Minimal if │
│ (Remove weights) │ │ done right │
├────────────────────┼─────────┼──────────────┤
│ Distillation │ 5-50x │ Moderate │
│ (Smaller model) │ │ accuracy drop│
├────────────────────┼─────────┼──────────────┤
│ Batching │ 2-8x │ Added latency│
│ (Group requests) │ │ per request │
├────────────────────┼─────────┼──────────────┤
│ Caching │ 100x+ │ Stale results│
│ (Cache predictions)│ │ possible │
└─────────────────────────────────────────────┘
# Model serving with batching and caching
from fastapi import FastAPI
from functools import lru_cache
import asyncio
from collections import defaultdict
import hashlib
import json
from typing import List
app = FastAPI()
class BatchInferenceService:
def __init__(self, model, batch_size=32, max_wait_ms=50):
self.model = model
self.batch_size = batch_size
self.max_wait_ms = max_wait_ms
self.pending = []
self.cache = {}
def _cache_key(self, features: list) -> str:
return hashlib.md5(
json.dumps(features, sort_keys=True).encode()
).hexdigest()
async def predict(self, features: list) -> dict:
# Check cache first
key = self._cache_key(features)
if key in self.cache:
return self.cache[key]
# Add to batch queue
future = asyncio.Future()
self.pending.append((features, future))
# Process batch if full or after timeout
if len(self.pending) >= self.batch_size:
await self._process_batch()
else:
asyncio.get_event_loop().call_later(
self.max_wait_ms / 1000,
lambda: asyncio.ensure_future(
self._process_batch()
)
)
return await future
async def _process_batch(self):
if not self.pending:
return
batch = self.pending[:self.batch_size]
self.pending = self.pending[self.batch_size:]
features_batch = [f for f, _ in batch]
futures = [fut for _, fut in batch]
# Batch inference
predictions = self.model.predict(features_batch)
# Return results and cache
for features, future, pred in zip(
features_batch, futures, predictions
):
result = {"prediction": pred.tolist()}
key = self._cache_key(features)
self.cache[key] = result
future.set_result(result)
service = BatchInferenceService(model, batch_size=32)
@app.post("/predict")
async def predict(request: PredictionRequest):
return await service.predict(request.features)

ML System Design Checklist

When designing an ML system, work through these considerations systematically.

PhaseQuestions to Answer
Problem FramingIs ML the right approach? What metric defines success?
DataWhat data do you need? How will you label it? How much?
FeaturesWhat signals predict the target? How to compute them at serving time?
ModelStart simple (logistic regression), add complexity only if needed
TrainingHow often to retrain? How to handle class imbalance?
EvaluationWhat offline metrics? What online metrics? What is the baseline?
ServingBatch or real-time? What latency SLA? What throughput?
MonitoringHow to detect drift? How to measure production accuracy?
Feedback LoopHow to collect labels for production predictions?
EthicsBias, fairness, explainability, privacy requirements?

Summary

ConceptKey Takeaway
Feature StoresCentralized feature management eliminates training-serving skew
Model ServingChoose batch or real-time based on latency requirements
A/B TestingMeasure model impact with statistical rigor
MonitoringDetect data drift, concept drift, and performance degradation
MLOpsApply CI/CD principles to ML: version data, track experiments, automate
ScalingUse quantization, batching, caching, and distillation