How to Deploy Your First ML Model with Flask or FastAPI (Beginner)
- Load the model at application startup, never inside the request handler — one load per process, reused across all requests.
- Save model, scaler, encoder, and metadata together as a single versioned artifact. A model without its preprocessing pipeline produces silently wrong predictions.
- Always validate inputs before prediction: check structure, types, null values, NaN, feature count, and domain-valid ranges.
- Wrap your trained model in an HTTP endpoint so other applications can request predictions over the network
- Flask is simpler and more familiar; FastAPI is faster, auto-documents, and validates inputs natively with Pydantic
- Save models with joblib (sklearn) or model.save() (TensorFlow) — load once at startup, predict per request
- Add input validation, error handling, and a /health endpoint before exposing any endpoint publicly
- Production rule: never load the model inside the request handler — it reloads on every call and destroys latency
- Biggest mistake: deploying the model file alone without versioning the preprocessing pipeline alongside it
- Always write a consistency test that verifies API predictions match notebook predictions on identical input
Model file not found at runtime despite existing in the project directory
import os; print('cwd:', os.getcwd()); print('script dir:', os.path.dirname(os.path.abspath(__file__)))print(os.listdir(os.path.dirname(os.path.abspath(__file__))))Prediction latency spikes to 5+ seconds after the API has been idle for several minutes
curl -w '\ntotal_time: %{time_total}s\n' http://localhost:8000/healthdocker stats --no-stream # Check if the container was recently restartedInput validation passes but model.predict() throws a ValueError about array shape
import numpy as np; arr = np.array(data['features']); print('shape:', arr.shape, 'dtype:', arr.dtype)assert arr.shape == (1, n_features), f'Expected (1, {n_features}), got {arr.shape}'Production Incident
joblib.load() call inside the Flask route function, assuming the operating system's file cache would make subsequent loads fast. They tested with a single curl command, saw a 2-second response (the initial cold load on their fast SSD), and shipped it.Production Debug GuideCommon signals when your model API misbehaves — and exactly where to look.
model.predict() and log the full traceback server-side with logger.error(exc_info=True). Return a structured JSON error response to the client with a generic message and the appropriate HTTP status code. Never expose Python stack traces to external clients — they leak internal paths, library versions, and implementation details.Lock() around model.predict().tf.keras.backend.clear_session() periodically or after batches. For PyTorch, ensure tensors are moved off GPU and references are deleted. For scikit-learn, check whether your preprocessing creates large intermediate numpy arrays that are not freed. Monitor with docker stats or a /metrics endpoint.time.perf_counter() around each step: JSON parsing, preprocessing, model.predict(), and response serialization. The bottleneck is almost always in a place you did not expect.Training a model is half the work. The other half — the half that actually delivers business value — is making it accessible to other systems. Deployment wraps your model in an HTTP API that accepts input data, runs inference, and returns predictions as JSON. Without deployment, your model exists only in your notebook, visible to nobody except you.
Flask and FastAPI are the two dominant Python frameworks for this task. Flask is the established standard — simple, well-documented, and familiar to most Python developers. FastAPI is the modern alternative — faster, with automatic input validation via Pydantic, auto-generated OpenAPI documentation, and async support built in. Both get the job done. The deployment patterns are identical; only the framework boilerplate differs.
The common misconception is that deployment is a DevOps concern that comes after modeling is finished. In practice, deployment constraints — latency budgets, input formats, memory limits, concurrency requirements — should influence your model choices during training. A model that takes 5 seconds per prediction is unusable for real-time APIs regardless of its accuracy. A model that requires 8 GB of RAM cannot run in a 2 GB container. These constraints are not afterthoughts. They are requirements.
Saving Your Model for Deployment
Before building an API, you need to save your trained model in a format that loads quickly and reliably outside of your notebook. This is not just saving the model — it is saving everything the model needs to produce correct predictions: the preprocessing pipeline, the feature names, the framework versions, and any metadata that future-you (or a teammate) will need to debug a production issue at 2 AM.
The serialization format depends on your framework. Scikit-learn uses joblib. TensorFlow uses SavedModel format. PyTorch uses torch.save(). Each has different trade-offs for loading speed, file size, and cross-version portability. But the principle is the same across all of them: the model and its preprocessing pipeline are a unit. Ship them together or watch predictions silently break.
import joblib import json from pathlib import Path from datetime import datetime def save_sklearn_model( model, scaler, feature_names, output_dir='model_artifacts', model_version='1.0.0' ): """Save sklearn model with its complete preprocessing pipeline. The model, scaler, and metadata are saved as a single versioned artifact directory. All three files must travel together — a model without its scaler produces garbage predictions on raw input. Args: model: fitted sklearn estimator scaler: fitted sklearn transformer (StandardScaler, etc.) feature_names: list of feature name strings in training order output_dir: directory to save artifacts model_version: semantic version string for tracking """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save the trained model joblib.dump(model, output_path / 'model.joblib') # Save the fitted scaler — this MUST travel with the model # Without it, the API receives raw features and the model # expects scaled features. Predictions are silently wrong. joblib.dump(scaler, output_path / 'scaler.joblib') # Save metadata for debugging and validation at load time import sklearn metadata = { 'model_type': type(model).__name__, 'model_version': model_version, 'feature_names': feature_names, 'n_features': len(feature_names), 'sklearn_version': sklearn.__version__, 'python_version': __import__('sys').version, 'saved_at': datetime.utcnow().isoformat(), 'training_notes': 'Add any relevant notes about training data or parameters here' } with open(output_path / 'metadata.json', 'w') as f: json.dump(metadata, f, indent=2) print(f"Model artifacts saved to {output_path}/") for file in sorted(output_path.iterdir()): size_kb = file.stat().st_size / 1024 print(f" {file.name}: {size_kb:.1f} KB") def save_tensorflow_model(model, output_dir='model_artifacts/tf_model'): """Save TensorFlow model in SavedModel format. SavedModel is more portable across TF versions than HDF5 (.h5) and includes the computation graph, making it suitable for TensorFlow Serving and TFLite conversion. """ model.save(output_dir) print(f"TensorFlow SavedModel saved to {output_dir}/") # --- Example usage --- from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split import numpy as np # Simulate training workflow np.random.seed(42) X = np.random.randn(1000, 3) y = (X[:, 0] + X[:, 1] > 0).astype(int) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # Fit preprocessing and model scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) model = RandomForestClassifier(n_estimators=100, random_state=42) model.fit(X_train_scaled, y_train) # Evaluate before saving X_test_scaled = scaler.transform(X_test) accuracy = model.score(X_test_scaled, y_test) print(f"Test accuracy: {accuracy:.3f}") # Save everything together save_sklearn_model( model=model, scaler=scaler, feature_names=['age', 'income', 'credit_score'], model_version='1.0.0' )
joblib.dump() — it is the sklearn standard, handles numpy arrays efficiently, and compresses well. Pin sklearn version in requirements.txt.model.save() for SavedModel format — it is portable, includes the computation graph, and works with TensorFlow Serving for high-throughput production deployment.model.state_dict()) paired with the model class definition. Lighter than saving the full model object and avoids pickle compatibility issues across PyTorch versions.Flask: Simple ML API
Flask is the simplest way to turn a model into an HTTP API. You define routes, load the model at application startup, and return JSON predictions. The entire pattern fits in under 60 lines of code.
Flask is synchronous by default — each request blocks a worker process until the prediction completes and the response is sent. This is perfectly fine for ML inference, which is CPU-bound and does not benefit from async I/O anyway. The bottleneck is model.predict(), not network I/O. For handling concurrent requests, you add more worker processes with gunicorn rather than adding async complexity.
from flask import Flask, request, jsonify import joblib import numpy as np import json import logging from pathlib import Path # Configure structured logging logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(name)s: %(message)s' ) logger = logging.getLogger(__name__) app = Flask(__name__) # --------------------------------------------------------------- # LOAD MODEL AT MODULE LEVEL — not inside a route function. # This code runs once when the process starts. # The loaded objects persist in memory across all requests. # --------------------------------------------------------------- MODEL_DIR = Path(__file__).parent / 'model_artifacts' try: model = joblib.load(MODEL_DIR / 'model.joblib') scaler = joblib.load(MODEL_DIR / 'scaler.joblib') with open(MODEL_DIR / 'metadata.json') as f: metadata = json.load(f) logger.info( f"Model loaded: {metadata['model_type']}, " f"{metadata['n_features']} features, " f"sklearn {metadata.get('sklearn_version', 'unknown')}" ) except Exception as e: logger.error(f"Failed to load model: {e}") raise # Fail fast — do not start serving with no model @app.route('/health', methods=['GET']) def health(): """Health check for load balancers and container orchestrators. Returns 200 only when the model is loaded and ready. Kubernetes liveness probes and Docker HEALTHCHECK use this. """ return jsonify({ 'status': 'healthy', 'model_type': metadata['model_type'], 'model_version': metadata.get('model_version', 'unknown'), 'n_features': metadata['n_features'] }) @app.route('/predict', methods=['POST']) def predict(): """Prediction endpoint. Expects JSON body: {"features": [value1, value2, ...]} Returns: {"prediction": [...], "probabilities": [[...]], "model_version": "..."} """ try: data = request.get_json(force=True) # --- Input validation --- if data is None: return jsonify({'error': 'Request body must be valid JSON'}), 400 if 'features' not in data: return jsonify({'error': 'Missing "features" key in request body'}), 400 features_raw = data['features'] if not isinstance(features_raw, list): return jsonify({'error': '"features" must be a list of numbers'}), 400 # Check for null/None values if any(v is None for v in features_raw): return jsonify({'error': '"features" contains null values'}), 400 features = np.array(features_raw, dtype=np.float64) # Check for NaN or Inf if not np.isfinite(features).all(): return jsonify({'error': '"features" contains NaN or Infinity values'}), 400 # Validate feature count expected = metadata['n_features'] if features.ndim == 1: features = features.reshape(1, -1) if features.shape[1] != expected: return jsonify({ 'error': f'Expected {expected} features, got {features.shape[1]}', 'expected_features': metadata.get('feature_names', []) }), 400 # --- Preprocess and predict --- features_scaled = scaler.transform(features) prediction = model.predict(features_scaled) response = { 'prediction': prediction.tolist(), 'model_version': metadata.get('model_version', 'unknown') } # Include probabilities if the model supports them if hasattr(model, 'predict_proba'): probabilities = model.predict_proba(features_scaled) response['probabilities'] = probabilities.tolist() return jsonify(response) except ValueError as e: logger.warning(f"Validation error: {e}") return jsonify({'error': f'Invalid input: {str(e)}'}), 400 except Exception as e: logger.error(f"Prediction failed: {e}", exc_info=True) return jsonify({'error': 'Internal server error'}), 500 if __name__ == '__main__': # Development only — never use this in production app.run(host='0.0.0.0', port=5000, debug=False)
- 1. Request arrives: Flask parses the JSON body and calls your route function.
- 2. Input validation: Check for missing keys, wrong types, null values, NaN, and wrong feature count. Reject bad input before it reaches the model.
- 3. Preprocessing: Apply the same
scaler.transform()used during training. This must be the exact same fitted scaler object — not a new one. - 4. Prediction:
model.predict()runs inference. This is the expensive step — everything else is microseconds. - 5. Response: Serialize the numpy output to JSON with .tolist() and return with the appropriate HTTP status code.
app.run()) is single-threaded and explicitly not designed for production use. It handles one request at a time and has no process management, no graceful shutdown, and no worker recycling.app.run() in production. If you see app.run() in a Dockerfile CMD, that is a bug.model.predict().FastAPI: Modern ML API
FastAPI is the modern alternative to Flask with three advantages that matter specifically for ML APIs: Pydantic models provide automatic input validation (you define the expected structure once, and every request is validated before your code runs), OpenAPI documentation is generated automatically at /docs, and the async-native ASGI architecture handles concurrent health checks and monitoring requests without blocking prediction workers.
For the prediction logic itself — load model, preprocess, predict, return JSON — the code is nearly identical to Flask. The difference is in the boilerplate that FastAPI eliminates: input validation, error responses for malformed input, and API documentation are all handled by the framework rather than written by hand.
from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field, field_validator import joblib import numpy as np import json import logging from pathlib import Path from typing import List, Optional logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(name)s: %(message)s' ) logger = logging.getLogger(__name__) app = FastAPI( title='ML Prediction API', description='Scikit-learn model serving endpoint with automatic input validation', version='1.0.0' ) # --------------------------------------------------------------- # LOAD MODEL AT STARTUP — same pattern as Flask # --------------------------------------------------------------- MODEL_DIR = Path(__file__).parent / 'model_artifacts' try: model = joblib.load(MODEL_DIR / 'model.joblib') scaler = joblib.load(MODEL_DIR / 'scaler.joblib') with open(MODEL_DIR / 'metadata.json') as f: metadata = json.load(f) logger.info( f"Model loaded: {metadata['model_type']}, " f"{metadata['n_features']} features" ) except Exception as e: logger.error(f"Failed to load model: {e}") raise # --------------------------------------------------------------- # PYDANTIC MODELS — define input/output schemas once # FastAPI validates every request against these automatically. # Invalid requests get a clear 422 error before your code runs. # --------------------------------------------------------------- class PredictionRequest(BaseModel): """Input schema for prediction requests.""" features: List[float] = Field( ..., description='List of numeric feature values in training order', min_length=1 ) @field_validator('features') @classmethod def validate_features(cls, v): """Reject NaN, Infinity, and wrong feature count.""" import math for i, val in enumerate(v): if math.isnan(val) or math.isinf(val): raise ValueError( f'Feature at index {i} is {val} — must be a finite number' ) expected = metadata['n_features'] if len(v) != expected: raise ValueError( f'Expected {expected} features, got {len(v)}. ' f'Expected order: {metadata.get("feature_names", [])}' ) return v class PredictionResponse(BaseModel): """Output schema for prediction responses.""" prediction: List[int] probabilities: Optional[List[List[float]]] = None model_version: str class HealthResponse(BaseModel): """Output schema for health check.""" status: str model_type: str model_version: str n_features: int # --------------------------------------------------------------- # ENDPOINTS # --------------------------------------------------------------- @app.get('/health', response_model=HealthResponse) def health(): """Health check for load balancers and orchestrators.""" return HealthResponse( status='healthy', model_type=metadata['model_type'], model_version=metadata.get('model_version', 'unknown'), n_features=metadata['n_features'] ) @app.post('/predict', response_model=PredictionResponse) def predict(request: PredictionRequest): """Run model inference on the provided features. Input validation (type checking, NaN detection, feature count) is handled automatically by the Pydantic model above. By the time this function runs, the input is guaranteed valid. """ try: features = np.array(request.features).reshape(1, -1) features_scaled = scaler.transform(features) prediction = model.predict(features_scaled) response = PredictionResponse( prediction=prediction.tolist(), model_version=metadata.get('model_version', 'unknown') ) if hasattr(model, 'predict_proba'): probabilities = model.predict_proba(features_scaled) response.probabilities = probabilities.tolist() return response except Exception as e: logger.error(f"Prediction failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail='Internal server error') # --------------------------------------------------------------- # Run with: uvicorn fastapi_app:app --host 0.0.0.0 --port 8000 --workers 4 # Auto-generated docs available at: http://localhost:8000/docs # ---------------------------------------------------------------
- Pydantic validates type, length, and custom constraints automatically on every request.
- Invalid requests never reach
model.predict()— they are rejected at the framework level with a descriptive error message. - The PredictionRequest class replaces 15–20 lines of manual validation code you would write in Flask.
- Custom validators (@field_validator) catch domain-specific issues like NaN values and wrong feature counts.
- The response_model parameter validates output structure too — catching serialization bugs before they reach the client.
Input Validation: The Most Important Step
A model that silently accepts bad input and returns a confident prediction is more dangerous than one that crashes. A crash is visible, immediate, and fixable. A wrong prediction based on garbage input feeds silently into downstream business decisions — pricing, credit approvals, medical recommendations — without anyone noticing until the damage is done.
Input validation is the guard that prevents this. Every prediction request should be checked for structural validity (correct JSON, expected keys), type validity (numbers are actually numbers), value validity (no NaN, no Infinity, no negative ages), and dimensional validity (correct number of features in the correct order).
import numpy as np from typing import List, Dict, Any, Optional class InputValidator: """Validate prediction inputs before they reach the model. Every check that fails here is a prediction error prevented. A model will happily predict on a credit score of 99999 or an age of -5. The validator catches these before the model sees them. """ def __init__(self, metadata: dict, feature_ranges: Optional[dict] = None): self.n_features = metadata['n_features'] self.feature_names = metadata.get('feature_names', []) self.feature_ranges = feature_ranges or {} def validate(self, features: List[float]) -> np.ndarray: """Validate and convert input features. Returns a numpy array ready for preprocessing. Raises ValueError with a clear message if validation fails. """ # Check type if not isinstance(features, list): raise ValueError( f'"features" must be a list, got {type(features).__name__}' ) # Check length if len(features) != self.n_features: raise ValueError( f'Expected {self.n_features} features, got {len(features)}. ' f'Expected: {self.feature_names}' ) # Check for null/None values for i, val in enumerate(features): if val is None: name = self.feature_names[i] if i < len(self.feature_names) else f'index {i}' raise ValueError(f'Feature "{name}" is null') # Convert to numpy and check for NaN/Inf arr = np.array(features, dtype=np.float64) if not np.isfinite(arr).all(): bad_indices = np.where(~np.isfinite(arr))[0] bad_names = [ self.feature_names[i] if i < len(self.feature_names) else f'index {i}' for i in bad_indices ] raise ValueError( f'Non-finite values in features: {bad_names}. ' f'Values must be finite numbers (no NaN, no Infinity).' ) # Check feature ranges (domain validation) for feature_name, (min_val, max_val) in self.feature_ranges.items(): if feature_name in self.feature_names: idx = self.feature_names.index(feature_name) val = arr[idx] if val < min_val or val > max_val: raise ValueError( f'Feature "{feature_name}" value {val} is outside ' f'expected range [{min_val}, {max_val}]. ' f'This may indicate a data pipeline error.' ) return arr.reshape(1, -1) # --- Example usage --- validator = InputValidator( metadata={'n_features': 3, 'feature_names': ['age', 'income', 'credit_score']}, feature_ranges={ 'age': (0, 120), 'income': (0, 10_000_000), 'credit_score': (300, 850) } ) # These will raise clear ValueError messages: # validator.validate([25, 50000]) # Wrong count # validator.validate([25, 50000, None]) # Null value # validator.validate([25, 50000, float('nan')]) # NaN value # validator.validate([-5, 50000, 720]) # Age out of range # validator.validate([25, 50000, 8500]) # Credit score out of range (typo) # This will pass: result = validator.validate([35, 75000, 720]) print(f'Validated input shape: {result.shape}') # (1, 3)
model.predict().Dockerizing Your ML API
Docker packages your model, application code, Python dependencies, and runtime environment into a single portable container. This guarantees the API runs identically on your laptop, a colleague's machine, a CI/CD pipeline, and production servers. No more 'works on my machine' — the container is the machine.
The Dockerfile must include your model artifact files, Python dependencies pinned to exact versions, and the correct entry point command for your production server (gunicorn or uvicorn — never app.run() or uvicorn --reload).
# Dockerfile for ML API deployment # Uses a slim Python base image to minimize attack surface and image size FROM python:3.11-slim # Set working directory inside the container WORKDIR /app # Install system dependencies (if needed for scipy, numpy, etc.) RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ && rm -rf /var/lib/apt/lists/* # Install Python dependencies first — separate layer for Docker cache # When you change only application code, this layer is cached COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Copy application code COPY app/ ./app/ # Copy model artifacts # For models >500MB, consider downloading from S3/GCS at startup instead COPY model_artifacts/ ./model_artifacts/ # Create a non-root user for security RUN useradd --create-home appuser USER appuser # Health check — container orchestrators use this to determine readiness HEALTHCHECK --interval=30s --timeout=5s --retries=3 \ CMD curl -f http://localhost:8000/health || exit 1 # Expose the port (documentation — does not actually publish the port) EXPOSE 8000 # Run with production ASGI server # --workers 2: spawn 2 processes (adjust based on container CPU allocation) # --timeout 120: allow up to 120s for model loading at startup CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2", "--timeout-keep-alive", "120"]
Testing Your ML API
ML APIs need two distinct types of tests, and most teams only write the first type.
The first type is standard API tests: does the endpoint return the correct status codes? Does it reject invalid input with 400? Does it return properly structured JSON? These catch HTTP-level bugs.
The second type — and the more important one — is model consistency tests: does the endpoint return the exact same predictions as running the model directly in a notebook for identical input? This catches the silent bugs: a preprocessing step that was applied differently, a scaler that was re-fitted instead of loaded, features received in the wrong order. These bugs produce valid HTTP 200 responses with wrong predictions, and they are invisible without explicit consistency testing.
import pytest import numpy as np import json from pathlib import Path # FastAPI test client (for Flask, use app.test_client()) from fastapi.testclient import TestClient from app.main import app, model, scaler, metadata client = TestClient(app) class TestHealthEndpoint: """Verify the health check works before testing predictions.""" def test_health_returns_200(self): response = client.get('/health') assert response.status_code == 200 data = response.json() assert data['status'] == 'healthy' assert 'model_type' in data assert data['n_features'] == metadata['n_features'] class TestInputValidation: """Verify that invalid inputs are rejected with clear errors.""" def test_empty_body_returns_422(self): response = client.post('/predict', json={}) assert response.status_code == 422 # Pydantic validation error def test_missing_features_key_returns_422(self): response = client.post('/predict', json={'data': [1, 2, 3]}) assert response.status_code == 422 def test_wrong_feature_count_returns_422(self): response = client.post('/predict', json={'features': [1.0, 2.0]}) assert response.status_code == 422 def test_nan_features_returns_422(self): n = metadata['n_features'] features = [float('nan')] + [1.0] * (n - 1) response = client.post('/predict', json={'features': features}) assert response.status_code == 422 def test_null_features_returns_422(self): n = metadata['n_features'] features = [None] + [1.0] * (n - 1) response = client.post('/predict', json={'features': features}) assert response.status_code == 422 def test_string_features_returns_422(self): response = client.post('/predict', json={'features': ['a', 'b', 'c']}) assert response.status_code == 422 class TestPredictEndpoint: """Verify that valid inputs produce correct responses.""" def test_valid_input_returns_200_with_prediction(self): n = metadata['n_features'] features = [35.0, 75000.0, 720.0][:n] response = client.post('/predict', json={'features': features}) assert response.status_code == 200 data = response.json() assert 'prediction' in data assert 'model_version' in data assert isinstance(data['prediction'], list) def test_probabilities_sum_to_one(self): n = metadata['n_features'] features = [35.0, 75000.0, 720.0][:n] response = client.post('/predict', json={'features': features}) data = response.json() if 'probabilities' in data and data['probabilities']: prob_sum = sum(data['probabilities'][0]) assert abs(prob_sum - 1.0) < 0.01, f'Probabilities sum to {prob_sum}' class TestPredictionConsistency: """THE MOST IMPORTANT TEST CLASS. Verify API predictions match direct model predictions on identical input. This catches preprocessing mismatches that produce valid HTTP 200 responses with wrong predictions — the most dangerous class of deployment bugs. """ @pytest.fixture def sample_inputs(self): """Multiple test cases to cover different input ranges.""" n = metadata['n_features'] return [ [35.0, 75000.0, 720.0][:n], [22.0, 30000.0, 580.0][:n], [65.0, 150000.0, 800.0][:n], [0.0, 0.0, 300.0][:n], # Edge case: minimum values ] def test_api_matches_direct_model(self, sample_inputs): """Run identical input through both API and raw model.""" for features in sample_inputs: # API prediction response = client.post('/predict', json={'features': features}) assert response.status_code == 200 api_prediction = response.json()['prediction'] # Direct model prediction with same preprocessing raw_input = np.array(features).reshape(1, -1) scaled_input = scaler.transform(raw_input) model_prediction = model.predict(scaled_input).tolist() assert api_prediction == model_prediction, ( f'Consistency check failed for input {features}: ' f'API returned {api_prediction}, ' f'direct model returned {model_prediction}' )
- Write a test that runs identical input through both the API endpoint and the raw model with the same preprocessing.
- Compare predictions exactly — for classification, predictions should match perfectly. For regression, use
np.allclose()with a tight tolerance. - Run this test in CI/CD on every commit. If it fails, the preprocessing pipeline in the API has diverged from training.
- Test with multiple input ranges: typical values, edge cases (zeros, minimum values), and boundary values. A mismatch on edge cases often reveals normalization bugs.
Monitoring and Logging in Production
Once deployed, you need visibility into what the API is doing. A model that was accurate during evaluation can degrade silently in production as the input data distribution shifts. Without monitoring, you discover this from customer complaints weeks later.
Log every prediction request with a timestamp, input summary (not the full input — that may contain PII), output class, latency, and any errors. Aggregate these logs into metrics that surface problems early: latency percentiles, error rates, and prediction class distribution over time.
import logging import time import json from collections import Counter, deque from datetime import datetime, timezone from typing import Any, Optional, List logger = logging.getLogger('ml_api') class PredictionMonitor: """Track prediction metrics for production monitoring. Collects latency, error rates, and prediction distribution. Expose via a /metrics endpoint for Prometheus scraping or a /stats endpoint for manual inspection. In production, replace this with proper observability tools (Prometheus + Grafana, Datadog, or OpenTelemetry). This class demonstrates the metrics you should be collecting regardless of the tool you use. """ def __init__(self, max_latency_samples: int = 10000): self.total_requests = 0 self.error_count = 0 self.latency_ms: deque = deque(maxlen=max_latency_samples) self.prediction_distribution = Counter() self._start_time = datetime.now(timezone.utc) def log_request( self, features_summary: dict, prediction: Any, latency_ms: float, status: str = 'success' ): """Log a single prediction request. Args: features_summary: non-PII summary of input (e.g., feature count, value ranges) prediction: the model's output latency_ms: wall-clock inference time in milliseconds status: 'success' or 'error' """ self.total_requests += 1 self.latency_ms.append(latency_ms) if status == 'error': self.error_count += 1 if prediction is not None: pred_key = str(prediction) self.prediction_distribution[pred_key] += 1 # Structured JSON log — ingestible by log aggregation systems logger.info(json.dumps({ 'event': 'prediction', 'timestamp': datetime.now(timezone.utc).isoformat(), 'status': status, 'latency_ms': round(latency_ms, 2), 'prediction': prediction, 'feature_count': features_summary.get('count', 0), })) def get_stats(self) -> dict: """Return summary statistics for the /stats endpoint.""" if not self.latency_ms: return { 'message': 'No requests recorded yet', 'uptime_seconds': ( datetime.now(timezone.utc) - self._start_time ).total_seconds() } sorted_latencies = sorted(self.latency_ms) n = len(sorted_latencies) return { 'total_requests': self.total_requests, 'error_count': self.error_count, 'error_rate_pct': round( self.error_count / max(self.total_requests, 1) * 100, 2 ), 'latency_p50_ms': round(sorted_latencies[n // 2], 2), 'latency_p95_ms': round(sorted_latencies[int(n * 0.95)], 2), 'latency_p99_ms': round(sorted_latencies[int(n * 0.99)], 2), 'prediction_distribution': dict( self.prediction_distribution.most_common(20) ), 'uptime_seconds': round( (datetime.now(timezone.utc) - self._start_time).total_seconds() ) } # Singleton instance — import and use across the application monitor = PredictionMonitor() # --- Example integration with FastAPI --- # @app.post('/predict') # def predict(request: PredictionRequest): # start = time.perf_counter() # try: # prediction = model.predict(preprocessed) # latency = (time.perf_counter() - start) * 1000 # monitor.log_request( # features_summary={'count': len(request.features)}, # prediction=prediction.tolist(), # latency_ms=latency # ) # return ... # except Exception as e: # latency = (time.perf_counter() - start) * 1000 # monitor.log_request( # features_summary={'count': len(request.features)}, # prediction=None, # latency_ms=latency, # status='error' # ) # raise # # @app.get('/stats') # def stats(): # return monitor.get_stats()
Flask vs FastAPI: Choosing Your Framework
Both frameworks serve ML models effectively. The model serving logic — load at startup, preprocess, predict, return JSON — is identical in both. The choice depends on your team's existing experience, whether you want automatic input validation, and whether you need auto-generated API documentation.
If you already know Flask and your API is straightforward, use Flask. If you are starting a new project and want the framework to handle input validation and documentation for you, use FastAPI. Do not over-think this decision — you can always migrate later, and the deployment patterns taught in this guide apply to both.
- If your team knows Flask: use Flask. The deployment concepts in this guide apply identically.
- If you are starting fresh with no framework preference: use FastAPI. Pydantic validation and auto-generated /docs are genuinely useful features for ML APIs.
- If you need async for non-ML reasons (WebSocket support, streaming responses): FastAPI is the clear choice.
- If you are integrating into an existing Flask application: stay with Flask. Do not introduce a second framework.
model.predict() dominates request time regardless of whether Flask or FastAPI handles the HTTP layer.model.predict() is the bottleneck, not HTTP handling.| Aspect | Flask | FastAPI |
|---|---|---|
| Learning Curve | Gentle — most Python developers already know Flask | Moderate — Pydantic models and async/await add new concepts |
| Input Validation | Manual — you write all validation logic by hand | Automatic — Pydantic models validate type, length, and custom rules on every request |
| API Documentation | Manual or via flask-swagger extension | Automatic — /docs (Swagger UI) and /redoc generated from Pydantic models |
| Performance | Synchronous WSGI — solid for CPU-bound ML inference | Async ASGI — 2–3x throughput for I/O-mixed workloads, same for pure CPU inference |
| ML Ecosystem | Huge — most existing tutorials and examples use Flask | Growing rapidly — becoming the default for new ML API projects |
| Production Server | gunicorn (WSGI): gunicorn -w 4 app:app | uvicorn (ASGI): uvicorn app:app --workers 4 |
| Error Responses | Manual JSON error formatting | Automatic 422 responses for validation failures with field-level error details |
| Best For | Simple APIs, existing Flask codebases, teams that know Flask | New projects, teams that want auto-validation and auto-documentation |
🎯 Key Takeaways
- Load the model at application startup, never inside the request handler — one load per process, reused across all requests.
- Save model, scaler, encoder, and metadata together as a single versioned artifact. A model without its preprocessing pipeline produces silently wrong predictions.
- Always validate inputs before prediction: check structure, types, null values, NaN, feature count, and domain-valid ranges.
- Use gunicorn (Flask) or uvicorn (FastAPI) with multiple workers for production. Never deploy with the built-in development server.
- Write consistency tests that compare API predictions to notebook predictions on identical input — this single test class catches more production bugs than all other tests combined.
- Add a /health endpoint that confirms the model is loaded. Container orchestrators and load balancers depend on it to route traffic correctly.
- Log every prediction with timestamp, latency, and output class. Monitor prediction distribution over time — sudden shifts are the earliest signal of data drift.
⚠ Common Mistakes to Avoid
Interview Questions on This Topic
- QHow would you deploy a scikit-learn model as a production API?Mid-levelReveal
- QWhat is the difference between WSGI and ASGI, and why does it matter for ML APIs?SeniorReveal
- QYour model API returns correct HTTP 200 responses but the predictions are wrong compared to the notebook. How do you debug this?Mid-levelReveal
- QHow do you handle model updates in a deployed API without downtime?SeniorReveal
Frequently Asked Questions
Should I use Flask or FastAPI for my first ML API?
If you already know Flask, use Flask — the deployment concepts and patterns matter far more than the framework choice, and they are identical in both. If you are starting fresh with no framework preference, use FastAPI — automatic input validation with Pydantic eliminates the most common source of production bugs (malformed input reaching the model), and the auto-generated /docs endpoint gives you free, always-in-sync API documentation.
Both frameworks deploy ML models equally well. The model serving logic (load, preprocess, predict, return JSON) is identical in both. The difference is in how much boilerplate the framework handles for you versus how much you write by hand.
How do I handle large model files (over 1 GB) in deployment?
Do not bake them into Docker images — a 3 GB Docker image is slow to pull (60+ seconds), slow to schedule, and slow to roll back. Instead, store model files in cloud object storage (S3, GCS, Azure Blob) and download them at container startup using an entrypoint script.
Cache the downloaded model to a local directory (or a persistent volume) so subsequent container restarts do not re-download. Implement a version check at startup: compare the locally cached version against the latest version in storage, and download only when a new version is available.
For very large models (5+ GB), consider using TensorFlow Serving, Triton Inference Server, or a dedicated model serving platform that handles model loading, caching, and versioning as first-class concerns.
Can I serve multiple models from one API?
Yes. The simplest approach: define separate routes for each model (/predict/churn, /predict/fraud) and load all models at startup, storing them in a dictionary keyed by model name. This works well for 2–5 small models that fit comfortably in memory.
For more models or larger models, implement a model registry pattern: maintain a configuration file that maps model names to their artifact paths, load models on first request (lazy loading), and evict least-recently-used models when memory exceeds a threshold.
The critical trade-off is memory. Each loaded scikit-learn model consumes RAM proportional to its size. A random forest with 1,000 trees on 100 features might consume 500 MB. Three such models in one process = 1.5 GB. Plan your container memory limits accordingly.
How do I secure my ML API?
Layer your security. First, use HTTPS everywhere — never serve ML APIs over plain HTTP, even for internal tools. Second, add API key authentication: require an X-API-Key header on every request, validated against a stored set of keys. Rotate keys periodically.
Third, rate-limit requests to prevent abuse and runaway costs — 100 requests per minute per API key is a reasonable default for most ML APIs. Fourth, never expose model internals (feature weights, architecture details, training data statistics) in API responses — return only predictions and confidence scores.
For sensitive models in regulated industries, add input logging for audit trails (with PII redaction), implement IP allowlisting, and consider deploying behind a VPN or API gateway rather than exposing directly to the internet.
Developer and founder of TheCodeForge. I built this site because I was tired of tutorials that explain what to type without explaining why it works. Every article here is written to make concepts actually click.