ML System Design
ML Systems vs Traditional Software
Building ML systems requires all the discipline of traditional software engineering plus additional complexity around data, training, and model behavior.
Traditional Software: ┌──────────┐ ┌──────────┐ ┌──────────┐ │ Code │────▶│ Build │────▶│ Deploy │ └──────────┘ └──────────┘ └──────────┘ Deterministic. Same input → same output.
ML Systems: ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ Data │─▶│ Features │─▶│ Train │─▶│ Evaluate │─▶│ Deploy │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │ │ │ ▼ ▼ ▼ ▼ Data quality Hyperparameter Model quality Model decay Data drift tuning Bias/Fairness Monitoring Data versioning Experiment A/B testing Retraining tracking
Probabilistic. Same input → similar (not identical) output. Model behavior changes with data. Degrades over time.What Makes ML Systems Hard
| Challenge | Description |
|---|---|
| Data dependencies | Models are only as good as their data; bad data breaks models silently |
| Feedback loops | Model predictions influence future data, creating cycles |
| Training-serving skew | Differences between how features are computed in training vs serving |
| Silent failures | A model can return plausible but wrong predictions without crashing |
| Reproducibility | Results depend on data, code, hyperparameters, random seeds, and hardware |
| Model decay | Real-world data distributions shift, degrading model performance over time |
Feature Stores
A feature store is a centralized repository for storing, managing, and serving ML features. It solves the critical problem of training-serving skew — ensuring the same feature computation logic is used in both training and inference.
Without Feature Store: ┌──────────┐ ┌──────────┐ │ Training │ Feature logic A │ Serving │ Feature logic B │ Pipeline │ (Python script) │ Service │ (Java service) └──────────┘ └──────────┘ ⚠ Different implementations → different results!
With Feature Store: ┌──────────┐ ┌──────────────┐ ┌──────────┐ │ Training │◀───▶│ │◀───▶│ Serving │ │ Pipeline │ │ Feature │ │ Service │ └──────────┘ │ Store │ └──────────┘ │ │ ┌──────────┐ │ - Offline │ ┌──────────┐ │ Feature │────▶│ store │────▶│ Model │ │ Pipeline │ │ - Online │ │ Training │ └──────────┘ │ store │ └──────────┘ │ - Registry │ └──────────────┘ Same features for training and serving.Feature Store Architecture
┌─────────────────────────────────────────────────────────┐ │ Feature Store │ │ │ │ ┌─────────────────┐ ┌─────────────────────────┐ │ │ │ Feature │ │ Offline Store │ │ │ │ Registry │ │ (Historical features) │ │ │ │ │ │ - Data warehouse │ │ │ │ - Definitions │ │ - Parquet files │ │ │ │ - Metadata │ │ - Used for training │ │ │ │ - Lineage │ └─────────────────────────┘ │ │ │ - Versioning │ │ │ └─────────────────┘ ┌─────────────────────────┐ │ │ │ Online Store │ │ │ │ (Low-latency serving) │ │ │ │ - Redis / DynamoDB │ │ │ │ - Latest feature values │ │ │ │ - Used for inference │ │ │ └─────────────────────────┘ │ └─────────────────────────────────────────────────────────┘# Using Feast (open-source feature store)from feast import FeatureStore, Entity, FeatureView, Fieldfrom feast.types import Float64, Int64, Stringfrom feast import FileSourcefrom datetime import timedelta
# Define a data sourcedriver_stats_source = FileSource( path="data/driver_stats.parquet", timestamp_field="event_timestamp",)
# Define an entity (the primary key)driver = Entity( name="driver_id", description="Driver identifier",)
# Define a feature viewdriver_stats_fv = FeatureView( name="driver_hourly_stats", entities=[driver], ttl=timedelta(days=1), schema=[ Field(name="conv_rate", dtype=Float64), Field(name="acc_rate", dtype=Float64), Field(name="avg_daily_trips", dtype=Int64), ], source=driver_stats_source,)
# Initialize the feature storestore = FeatureStore(repo_path="feature_repo/")
# Get training data (offline store)training_df = store.get_historical_features( entity_df=entity_df, # DataFrame with driver_id + timestamps features=[ "driver_hourly_stats:conv_rate", "driver_hourly_stats:acc_rate", "driver_hourly_stats:avg_daily_trips", ],).to_df()
# Get real-time features for serving (online store)online_features = store.get_online_features( features=[ "driver_hourly_stats:conv_rate", "driver_hourly_stats:acc_rate", ], entity_rows=[{"driver_id": 1001}],).to_dict()
print(online_features)# {'driver_id': [1001], 'conv_rate': [0.87], 'acc_rate': [0.93]}Model Serving
Model serving is how you make predictions available to applications. The two primary patterns are batch inference and real-time inference.
Batch vs Real-Time Serving
Batch Inference: ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │ All data │───▶│ Model │───▶│ Predictions │ │ (daily) │ │ │ │ stored in DB/cache│ └──────────┘ └──────────┘ └──────────────────┘ Run periodically (hourly/daily). Pre-compute predictions.
Real-Time Inference: ┌──────────┐ ┌──────────┐ ┌──────────┐ │ API │───▶│ Model │───▶│ Response │ │ Request │ │ Server │ │ (< 100ms)│ └──────────┘ └──────────┘ └──────────┘ On-demand. Compute predictions per request.| Aspect | Batch Inference | Real-Time Inference |
|---|---|---|
| Latency | Minutes to hours | Milliseconds |
| Throughput | High (bulk processing) | Lower (per-request) |
| Freshness | Predictions may be stale | Always up-to-date |
| Cost | Cheaper (bulk compute) | More expensive (always-on) |
| Complexity | Simpler infrastructure | Requires model servers, load balancing |
| Use cases | Recommendations, reports, email campaigns | Search ranking, fraud detection, chatbots |
Real-Time Serving Architectures
# FastAPI model servingfrom fastapi import FastAPI, HTTPExceptionfrom pydantic import BaseModelimport joblibimport numpy as npfrom typing import List
app = FastAPI(title="ML Model API")
# Load model at startupmodel = None
@app.on_event("startup")async def load_model(): global model model = joblib.load("models/fraud_detector_v2.pkl")
class PredictionRequest(BaseModel): features: List[float] user_id: str
class PredictionResponse(BaseModel): prediction: int probability: float model_version: str
@app.post("/predict", response_model=PredictionResponse)async def predict(request: PredictionRequest): if model is None: raise HTTPException( status_code=503, detail="Model not loaded" )
features = np.array(request.features).reshape(1, -1) prediction = int(model.predict(features)[0]) probability = float( model.predict_proba(features)[0].max() )
return PredictionResponse( prediction=prediction, probability=probability, model_version="v2.1.0" )
@app.get("/health")async def health(): return { "status": "healthy", "model_loaded": model is not None }package main
import ( "encoding/json" "log" "net/http"
tf "github.com/tensorflow/tensorflow/tensorflow/go")
type PredictionRequest struct { Features []float32 `json:"features"` UserID string `json:"user_id"`}
type PredictionResponse struct { Prediction int `json:"prediction"` Probability float32 `json:"probability"` ModelVersion string `json:"model_version"`}
var model *tf.SavedModel
func loadModel() error { var err error model, err = tf.LoadSavedModel( "models/fraud_detector", []string{"serve"}, nil, ) return err}
func predictHandler(w http.ResponseWriter, r *http.Request) { var req PredictionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request", http.StatusBadRequest) return }
// Create input tensor tensor, err := tf.NewTensor( [][]float32{req.Features}, ) if err != nil { http.Error(w, "Tensor error", 500) return }
// Run inference result, err := model.Session.Run( map[tf.Output]*tf.Tensor{ model.Graph.Operation("input").Output(0): tensor, }, []tf.Output{ model.Graph.Operation("output").Output(0), }, nil, ) if err != nil { http.Error(w, "Inference error", 500) return }
probs := result[0].Value().([][]float32)[0] prediction := 0 if probs[1] > probs[0] { prediction = 1 }
resp := PredictionResponse{ Prediction: prediction, Probability: probs[prediction], ModelVersion: "v2.1.0", }
w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp)}
func main() { if err := loadModel(); err != nil { log.Fatal("Failed to load model:", err) } http.HandleFunc("/predict", predictHandler) http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"status":"healthy"}`)) }) log.Println("Model server starting on :8080") log.Fatal(http.ListenAndServe(":8080", nil))}A/B Testing for ML
A/B testing ML models is more complex than testing UI changes because model performance depends on many factors and may take longer to measure.
ML A/B Testing Architecture:
┌────────────────────────────────────────────────┐ │ Traffic Router │ │ (based on user_id hash) │ └──────────┬────────────────────┬────────────────┘ │ 50% │ 50% ▼ ▼ ┌──────────────────┐ ┌──────────────────┐ │ Model A │ │ Model B │ │ (Control) │ │ (Challenger) │ │ v2.0 │ │ v2.1 │ └────────┬─────────┘ └────────┬─────────┘ │ │ ▼ ▼ ┌──────────────────────────────────────────┐ │ Metrics Collection │ │ │ │ - Prediction latency │ │ - Business metrics (CTR, revenue) │ │ - Model quality (accuracy, F1) │ │ - Error rates │ │ - User engagement │ └──────────────────┬───────────────────────┘ │ ▼ ┌──────────────────────────────────────────┐ │ Statistical Analysis │ │ │ │ - Statistical significance test │ │ - Confidence intervals │ │ - Effect size calculation │ │ - Guard rail metrics check │ └──────────────────────────────────────────┘import hashlibfrom dataclasses import dataclassfrom typing import Optionalimport numpy as npfrom scipy import stats
@dataclassclass Experiment: name: str control_model: str treatment_model: str traffic_split: float = 0.5 # 50/50
class ABTestRouter: def __init__(self, experiment: Experiment): self.experiment = experiment self.control_metrics = [] self.treatment_metrics = []
def get_variant(self, user_id: str) -> str: """Deterministic assignment based on user ID.""" hash_val = int( hashlib.sha256( f"{self.experiment.name}:{user_id}".encode() ).hexdigest(), 16 ) if (hash_val % 100) / 100 < self.experiment.traffic_split: return "control" return "treatment"
def get_model(self, user_id: str) -> str: variant = self.get_variant(user_id) if variant == "control": return self.experiment.control_model return self.experiment.treatment_model
def record_metric( self, user_id: str, metric_value: float ): variant = self.get_variant(user_id) if variant == "control": self.control_metrics.append(metric_value) else: self.treatment_metrics.append(metric_value)
def analyze_results(self) -> dict: """Run statistical significance test.""" control = np.array(self.control_metrics) treatment = np.array(self.treatment_metrics)
# Two-sample t-test t_stat, p_value = stats.ttest_ind( control, treatment )
# Effect size (Cohen's d) pooled_std = np.sqrt( (control.std()**2 + treatment.std()**2) / 2 ) cohens_d = ( (treatment.mean() - control.mean()) / pooled_std )
return { "control_mean": float(control.mean()), "treatment_mean": float(treatment.mean()), "relative_improvement": float( (treatment.mean() - control.mean()) / control.mean() * 100 ), "p_value": float(p_value), "is_significant": p_value < 0.05, "cohens_d": float(cohens_d), "control_n": len(control), "treatment_n": len(treatment), }
# Usageexperiment = Experiment( name="fraud_model_v2.1", control_model="models/fraud_v2.0", treatment_model="models/fraud_v2.1", traffic_split=0.5)
router = ABTestRouter(experiment)
# In your API endpointmodel_path = router.get_model(user_id="user_123")prediction = load_and_predict(model_path, features)
# Record business metric laterrouter.record_metric("user_123", metric_value=1.0)Model Monitoring and Drift Detection
Models degrade over time as the real world changes. Monitoring detects these problems before they impact users.
Types of Drift
Data Drift (Feature Drift): The distribution of input features changes.
Training Data Production Data ┌──────────────┐ ┌──────────────┐ │ ___ │ │ ___ │ │ / \ │ │ / \ │ │ / \ │ → │ / \ │ │ / \ │ │ / \│ │ / \ │ │ / │ └──────────────┘ └──────────────┘ Distribution shifted to the right.
Concept Drift: The relationship between features and target changes.
Before: Income > $80K → Low fraud risk After: Income > $80K → Same fraud risk as others (Fraud patterns evolved)
Prediction Drift: The distribution of model predictions changes.
Before: 5% flagged as fraud After: 15% flagged as fraud (Model behavior changed)| Drift Type | What Changes | Detection Method |
|---|---|---|
| Data drift | Input feature distributions | KS test, PSI, JS divergence |
| Concept drift | Feature-target relationship | Monitor prediction accuracy over time |
| Prediction drift | Output distribution | Compare prediction histograms |
| Label drift | Target variable distribution | Monitor ground truth labels |
import numpy as npfrom scipy import statsfrom dataclasses import dataclassfrom typing import Dict, List, Optionalfrom datetime import datetime
@dataclassclass DriftReport: feature_name: str drift_score: float is_drifted: bool method: str timestamp: datetime
class ModelMonitor: def __init__(self, reference_data: dict, threshold=0.05): """ reference_data: baseline feature distributions from training data threshold: p-value threshold for drift detection """ self.reference = reference_data self.threshold = threshold self.alerts = []
def check_data_drift( self, current_data: dict, method: str = "ks" ) -> List[DriftReport]: """Check each feature for distribution drift.""" reports = []
for feature_name in self.reference: ref = np.array(self.reference[feature_name]) curr = np.array(current_data[feature_name])
if method == "ks": # Kolmogorov-Smirnov test statistic, p_value = stats.ks_2samp( ref, curr ) drift_score = statistic is_drifted = p_value < self.threshold
elif method == "psi": # Population Stability Index drift_score = self._calculate_psi(ref, curr) is_drifted = drift_score > 0.2
report = DriftReport( feature_name=feature_name, drift_score=drift_score, is_drifted=is_drifted, method=method, timestamp=datetime.now() ) reports.append(report)
if is_drifted: self.alerts.append(report)
return reports
def check_prediction_drift( self, reference_preds: np.ndarray, current_preds: np.ndarray ) -> DriftReport: """Check if model predictions have shifted.""" statistic, p_value = stats.ks_2samp( reference_preds, current_preds ) return DriftReport( feature_name="predictions", drift_score=statistic, is_drifted=p_value < self.threshold, method="ks", timestamp=datetime.now() )
def check_performance_degradation( self, recent_accuracy: float, baseline_accuracy: float, tolerance: float = 0.05 ) -> bool: """Alert if accuracy drops below tolerance.""" degradation = baseline_accuracy - recent_accuracy if degradation > tolerance: self.alerts.append( f"Performance degraded by " f"{degradation:.2%}" ) return True return False
def _calculate_psi( self, reference, current, bins=10 ) -> float: """Population Stability Index.""" ref_percents = np.histogram( reference, bins=bins )[0] / len(reference) curr_percents = np.histogram( current, bins=bins )[0] / len(current)
# Avoid log(0) ref_percents = np.clip(ref_percents, 1e-6, None) curr_percents = np.clip(curr_percents, 1e-6, None)
psi = np.sum( (curr_percents - ref_percents) * np.log(curr_percents / ref_percents) ) return psi
# Usagemonitor = ModelMonitor( reference_data={ "age": training_ages, "income": training_incomes, "transaction_amount": training_amounts, })
# Check dailydrift_reports = monitor.check_data_drift( current_data=todays_data, method="ks")
for report in drift_reports: if report.is_drifted: print( f"DRIFT DETECTED: {report.feature_name} " f"(score: {report.drift_score:.4f})" )MLOps Pipeline
MLOps applies DevOps principles to ML systems. A mature MLOps pipeline automates the entire lifecycle from data ingestion to model monitoring.
MLOps Pipeline Architecture:
┌──────────────────────────────────────────────────────────────────┐ │ │ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ │ │ Data │ │ Feature │ │ Model │ │ Experiment │ │ │ │ Pipeline │─▶│ Store │─▶│ Training │─▶│ Tracking │ │ │ └──────────┘ └──────────┘ └──────────┘ └──────┬───────┘ │ │ │ │ │ │ │ ┌─────────────────────────────────────────┘ │ │ │ │ │ │ ▼ ▼ │ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ │ │ Data │ │ Model │ │ Model │ │ Model │ │ │ │ Version. │ │ Registry │─▶│ Serving │─▶│ Monitoring │ │ │ └──────────┘ └──────────┘ └──────────┘ └──────────────┘ │ │ │ │ │ Trigger retrain │ │ if drift detected │ │ │ │ │ ◀───────────────────────────────┘ │ └──────────────────────────────────────────────────────────────────┘MLOps Maturity Levels
| Level | Name | Description |
|---|---|---|
| 0 | Manual | Manual data prep, training in notebooks, manual deployment |
| 1 | ML Pipeline | Automated training pipeline, manual deployment |
| 2 | CI/CD for ML | Automated training, testing, and deployment |
| 3 | Full MLOps | Automated retraining, monitoring, drift detection, A/B testing |
Experiment Tracking
import mlflowfrom sklearn.ensemble import RandomForestClassifierfrom sklearn.metrics import accuracy_score, f1_score
# Set up MLflow experimentmlflow.set_experiment("fraud-detection-v2")
with mlflow.start_run(run_name="rf_100_trees"): # Log parameters params = { "n_estimators": 100, "max_depth": 10, "min_samples_split": 5, "class_weight": "balanced" } mlflow.log_params(params)
# Train model model = RandomForestClassifier(**params) model.fit(X_train, y_train)
# Evaluate y_pred = model.predict(X_test) metrics = { "accuracy": accuracy_score(y_test, y_pred), "f1_score": f1_score(y_test, y_pred), "train_samples": len(X_train), "test_samples": len(X_test), } mlflow.log_metrics(metrics)
# Log model mlflow.sklearn.log_model( model, "model", registered_model_name="fraud-detector" )
# Log artifacts (plots, data samples, etc.) mlflow.log_artifact("confusion_matrix.png")
print(f"Run ID: {mlflow.active_run().info.run_id}") print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"F1: {metrics['f1_score']:.4f}")
# Model registry -- promote to productionclient = mlflow.tracking.MlflowClient()client.transition_model_version_stage( name="fraud-detector", version=3, stage="Production")Scaling ML Systems
Scaling Strategies
Model Optimization: ┌─────────────────────────────────────────────┐ │ Technique │ Speedup │ Trade-off │ ├────────────────────┼─────────┼──────────────┤ │ Quantization │ 2-4x │ Slight │ │ (FP32 → INT8) │ │ accuracy drop│ ├────────────────────┼─────────┼──────────────┤ │ Pruning │ 2-10x │ Minimal if │ │ (Remove weights) │ │ done right │ ├────────────────────┼─────────┼──────────────┤ │ Distillation │ 5-50x │ Moderate │ │ (Smaller model) │ │ accuracy drop│ ├────────────────────┼─────────┼──────────────┤ │ Batching │ 2-8x │ Added latency│ │ (Group requests) │ │ per request │ ├────────────────────┼─────────┼──────────────┤ │ Caching │ 100x+ │ Stale results│ │ (Cache predictions)│ │ possible │ └─────────────────────────────────────────────┘# Model serving with batching and cachingfrom fastapi import FastAPIfrom functools import lru_cacheimport asynciofrom collections import defaultdictimport hashlibimport jsonfrom typing import List
app = FastAPI()
class BatchInferenceService: def __init__(self, model, batch_size=32, max_wait_ms=50): self.model = model self.batch_size = batch_size self.max_wait_ms = max_wait_ms self.pending = [] self.cache = {}
def _cache_key(self, features: list) -> str: return hashlib.md5( json.dumps(features, sort_keys=True).encode() ).hexdigest()
async def predict(self, features: list) -> dict: # Check cache first key = self._cache_key(features) if key in self.cache: return self.cache[key]
# Add to batch queue future = asyncio.Future() self.pending.append((features, future))
# Process batch if full or after timeout if len(self.pending) >= self.batch_size: await self._process_batch() else: asyncio.get_event_loop().call_later( self.max_wait_ms / 1000, lambda: asyncio.ensure_future( self._process_batch() ) )
return await future
async def _process_batch(self): if not self.pending: return
batch = self.pending[:self.batch_size] self.pending = self.pending[self.batch_size:]
features_batch = [f for f, _ in batch] futures = [fut for _, fut in batch]
# Batch inference predictions = self.model.predict(features_batch)
# Return results and cache for features, future, pred in zip( features_batch, futures, predictions ): result = {"prediction": pred.tolist()} key = self._cache_key(features) self.cache[key] = result future.set_result(result)
service = BatchInferenceService(model, batch_size=32)
@app.post("/predict")async def predict(request: PredictionRequest): return await service.predict(request.features)ML System Design Checklist
When designing an ML system, work through these considerations systematically.
| Phase | Questions to Answer |
|---|---|
| Problem Framing | Is ML the right approach? What metric defines success? |
| Data | What data do you need? How will you label it? How much? |
| Features | What signals predict the target? How to compute them at serving time? |
| Model | Start simple (logistic regression), add complexity only if needed |
| Training | How often to retrain? How to handle class imbalance? |
| Evaluation | What offline metrics? What online metrics? What is the baseline? |
| Serving | Batch or real-time? What latency SLA? What throughput? |
| Monitoring | How to detect drift? How to measure production accuracy? |
| Feedback Loop | How to collect labels for production predictions? |
| Ethics | Bias, fairness, explainability, privacy requirements? |
Summary
| Concept | Key Takeaway |
|---|---|
| Feature Stores | Centralized feature management eliminates training-serving skew |
| Model Serving | Choose batch or real-time based on latency requirements |
| A/B Testing | Measure model impact with statistical rigor |
| Monitoring | Detect data drift, concept drift, and performance degradation |
| MLOps | Apply CI/CD principles to ML: version data, track experiments, automate |
| Scaling | Use quantization, batching, caching, and distillation |