How to Visualize Machine Learning Results (Matplotlib & Seaborn)
- Every Matplotlib chart starts with fig, ax =
plt.subplots()— use the object-oriented interface, always. - Seaborn handles DataFrame grouping and statistical estimation automatically — use it for rapid exploration, then drop down to Matplotlib for polish.
- Confusion matrices reveal class-level failures that accuracy hides — always show both raw counts and row-normalized percentages.
- Matplotlib is the foundation — every chart in Python builds on its figure/axes model
- Seaborn wraps Matplotlib with statistical defaults and far less boilerplate code
- Confusion matrices, ROC curves, and residual plots reveal model flaws numbers hide
- Use fig.savefig() at 300 DPI — screen-resolution plots break in reports and slides
- Production rule: never present raw accuracy alone — always pair with precision, recall, or error distribution
- Biggest mistake: choosing the wrong chart type for the data relationship you want to communicate
- Always call plt.close(fig) after saving — open figures leak memory and crash long-running pipelines
Confusion matrix shows all predictions in one class
print(f'Positive predictions: {y_pred.sum()} / {len(y_pred)}')print(df['target'].value_counts(normalize=True))Learning curve shows training score much higher than validation score
from sklearn.model_selection import learning_curvetrain_sizes, train_scores, val_scores = learning_curve(model, X, y, cv=5, train_sizes=np.linspace(0.1, 1.0, 10))Feature importance plot shows one dominant feature at 95%+
print(df.corrwith(df['target']).abs().sort_values(ascending=False).head(10))# Retrain without the suspicious feature and compare performance
model_no_leak = model.fit(X.drop(columns=['suspicious_feature']), y)Production Incident
Production Debug GuideWhen your charts do not reveal what you expect — or when they reveal something you did not anticipate.
plt.show(), which destroys the figure in most backends. Use fig.savefig('name.png', dpi=300, bbox_inches='tight') — the bbox_inches parameter prevents label clipping. Set figsize explicitly in plt.subplots() rather than relying on notebook defaults.Model metrics like accuracy and F1-score tell you the score. Visualizations tell you why. A confusion matrix shows exactly which classes your model confuses. A residual plot reveals systematic prediction errors that RMSE averages away. A learning curve tells you whether collecting more data will help or whether you need a fundamentally different model. These are not decorative — they are diagnostic tools.
Matplotlib provides the rendering engine. Seaborn provides statistical awareness on top of it. You need both: Matplotlib for full control over publication-quality figures, and Seaborn for rapid exploratory analysis with sensible defaults. They are not competitors — Seaborn is literally built on Matplotlib, and every Seaborn plot returns a Matplotlib axes object you can customize further.
The common mistake is treating visualization as an afterthought — something you do after the model is trained and shipped. In production, a well-designed diagnostic dashboard catches model degradation weeks before aggregate metrics move. The charts you build during evaluation become your monitoring tools after deployment. Skip them, and you are flying blind.
Matplotlib Fundamentals: Figure and Axes
Every Matplotlib chart lives inside a Figure that contains one or more Axes. The Figure is the canvas — it controls overall dimensions, background color, and file output. The Axes is the actual plot area with its own x-axis, y-axis, title, and data layers.
Understanding this hierarchy prevents 90% of the layout confusion beginners hit. When you call plt.plot(), Matplotlib implicitly creates a Figure and Axes behind the scenes. This works for quick exploration but falls apart the moment you need multiple subplots, consistent sizing, or saved files. The object-oriented interface — fig, ax = plt.subplots() — gives you explicit handles to both objects and should be your default for anything beyond throwaway exploration.
import matplotlib.pyplot as plt import numpy as np # --- Method 1: pyplot interface (quick exploration only) --- # Implicitly creates a Figure and Axes. Fine for throwaway cells. plt.plot([1, 2, 3], [4, 5, 6]) plt.title('Simple Line Plot') plt.xlabel('X Axis') plt.ylabel('Y Axis') plt.show() # --- Method 2: object-oriented interface (production standard) --- # Explicitly creates Figure and Axes. Use this for everything you save. fig, ax = plt.subplots(figsize=(10, 6)) ax.plot([1, 2, 3], [4, 5, 6], marker='o', linewidth=2, label='Series A') ax.set_title('Production-Ready Line Plot', fontsize=14, fontweight='bold') ax.set_xlabel('X Axis') ax.set_ylabel('Y Axis') ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig('plot.png', dpi=300, bbox_inches='tight') plt.close(fig) # Free memory — critical in loops and pipelines # --- Multi-panel figure: the pattern you will use most --- fig, axes = plt.subplots(2, 2, figsize=(14, 10)) np.random.seed(42) data = np.random.randn(200) # Panel 1: Distribution axes[0, 0].hist(data, bins=30, edgecolor='black', alpha=0.7, color='steelblue') axes[0, 0].set_title('Distribution') axes[0, 0].set_xlabel('Value') axes[0, 0].set_ylabel('Frequency') # Panel 2: Sequential scatter axes[0, 1].scatter(np.arange(len(data)), data, alpha=0.4, s=12, color='coral') axes[0, 1].axhline(y=0, color='black', linestyle='--', alpha=0.3) axes[0, 1].set_title('Sequential Scatter') axes[0, 1].set_xlabel('Index') # Panel 3: Box plot axes[1, 0].boxplot(data, vert=True, patch_artist=True, boxprops=dict(facecolor='lightblue')) axes[1, 0].set_title('Box Plot') # Panel 4: Cumulative sum axes[1, 1].plot(np.cumsum(data), color='seagreen', linewidth=1.5) axes[1, 1].set_title('Cumulative Sum') axes[1, 1].set_xlabel('Index') fig.suptitle('Exploratory Data Summary', fontsize=16, fontweight='bold') fig.tight_layout() fig.savefig('dashboard.png', dpi=300, bbox_inches='tight') plt.close(fig)
- Figure = the full canvas. Controls overall size (figsize), background, DPI, and file saving.
- Axes = one plot area. Has its own x-axis, y-axis, title, legend, and data layers. A Figure can hold many Axes.
- fig, ax =
plt.subplots()creates one Figure with one Axes. This is your starting point for every chart. - fig, axes = plt.subplots(2, 3) creates a 2×3 grid. Access individual plots with axes[row, col].
- Always use the object-oriented interface (ax.plot, ax.set_title) for anything you save or present. The pyplot interface (plt.plot, plt.title) operates on an implicit 'current axes' that causes bugs in multi-panel figures.
plt.show() then fig.savefig(), you save a blank file with no error message.fig.savefig() first, plt.show() second — or skip plt.show() entirely in automated pipelines.plt.show(). Use fig.savefig() and plt.close(fig) to render and release memory. Open figures accumulate and will eventually crash long-running processes.plt.subplots() is your starting point for every chart — no exceptions for production code.Seaborn for Statistical Visualization
Seaborn builds on Matplotlib with high-level functions that understand DataFrames natively. Pass column names directly, and Seaborn handles grouping, aggregation, statistical estimation, and legend creation automatically. Where Matplotlib requires 20 lines for a grouped bar chart with confidence intervals, Seaborn does it in 3.
The key insight is that Seaborn is not a replacement for Matplotlib — it is an accelerator for the statistical plotting patterns you use most often. Every Seaborn function returns a Matplotlib axes object, so you can always drop down to Matplotlib for fine-grained customization after Seaborn does the heavy lifting.
import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import numpy as np # Set Seaborn theme once at the top of your notebook or script sns.set_theme(style='whitegrid', palette='muted', font_scale=1.1) # Generate example data np.random.seed(42) df = pd.DataFrame({ 'feature_a': np.random.randn(200), 'feature_b': np.random.randn(200) * 2 + 1, 'category': np.random.choice(['Class A', 'Class B', 'Class C'], 200), 'target': np.random.choice([0, 1], 200) }) # --- Distribution plots: understand feature spread --- fig, axes = plt.subplots(1, 2, figsize=(14, 5)) sns.histplot(data=df, x='feature_a', hue='category', kde=True, ax=axes[0]) axes[0].set_title('Feature A Distribution by Category') sns.boxplot(data=df, x='category', y='feature_b', ax=axes[1]) axes[1].set_title('Feature B Spread by Category') fig.tight_layout() fig.savefig('distributions.png', dpi=300, bbox_inches='tight') plt.close(fig) # --- Correlation heatmap: find feature relationships --- fig, ax = plt.subplots(figsize=(8, 6)) numeric_df = df.select_dtypes(include=[np.number]) corr_matrix = numeric_df.corr() sns.heatmap( corr_matrix, annot=True, fmt='.2f', cmap='RdBu_r', center=0, vmin=-1, vmax=1, ax=ax, linewidths=0.5, square=True ) ax.set_title('Feature Correlation Matrix') fig.tight_layout() fig.savefig('correlation.png', dpi=300, bbox_inches='tight') plt.close(fig) # --- Pair plot: explore all pairwise relationships at once --- # Useful for small feature sets (<10 features). Slow for large ones. pair = sns.pairplot( df, hue='category', diag_kind='kde', plot_kws={'alpha': 0.4, 's': 15} ) pair.figure.suptitle('Pairwise Feature Relationships', y=1.02) pair.savefig('pairplot.png', dpi=150, bbox_inches='tight') plt.close('all') # --- Seaborn + Matplotlib customization: the practical pattern --- fig, ax = plt.subplots(figsize=(10, 6)) sns.violinplot(data=df, x='category', y='feature_a', ax=ax, inner='quartile') # Drop down to Matplotlib for fine-tuning ax.set_title('Feature A Violin Plot', fontsize=14, fontweight='bold') ax.set_xlabel('Category', fontsize=12) ax.set_ylabel('Feature A Value', fontsize=12) ax.axhline(y=0, color='red', linestyle='--', alpha=0.5, label='Zero baseline') ax.legend() fig.tight_layout() fig.savefig('violin_customized.png', dpi=300, bbox_inches='tight') plt.close(fig)
- Seaborn excels at: grouped plots, statistical overlays (confidence intervals, KDE curves), automatic legend handling, DataFrame-native column references.
- Matplotlib excels at: precise axis control, custom annotations and arrows, multi-panel layouts with unequal sizing, publication-quality formatting.
- You can always access the underlying Matplotlib axes from any Seaborn plot: ax = sns.histplot(...); ax.set_xlim(0, 100).
- Rule of thumb: prototype in Seaborn, polish in Matplotlib. Start fast, refine as needed.
sns.set_theme() once at the very top of your notebook or script, and document the style choice.Confusion Matrix: Where Your Model Gets Confused
The confusion matrix is the single most important diagnostic chart for classification models. It shows exactly which classes your model confuses with which — information that a scalar metric like accuracy or F1 compresses into a single number and loses.
A model with 95% accuracy might be completely failing on one class. In a fraud detection system where only 2% of transactions are fraudulent, a model that predicts 'not fraud' for every single input achieves 98% accuracy while catching zero fraud. Only the confusion matrix reveals this. Always plot it. Always.
import matplotlib.pyplot as plt import seaborn as sns import numpy as np from sklearn.metrics import confusion_matrix def plot_confusion_matrix( y_true, y_pred, labels=None, title='Confusion Matrix' ): """Production-grade confusion matrix with both counts and percentages. Displays two panels side by side: - Left: raw counts (useful for understanding volume) - Right: row-normalized percentages (useful for understanding recall per class) Args: y_true: ground truth labels y_pred: predicted labels labels: list of class names for axis labels title: figure title Returns: Matplotlib Figure object (caller saves and closes). """ cm = confusion_matrix(y_true, y_pred) cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Left panel: raw counts sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels, ax=axes[0], linewidths=0.5 ) axes[0].set_xlabel('Predicted') axes[0].set_ylabel('Actual') axes[0].set_title(f'{title} (Counts)') # Right panel: row-normalized percentages (each row sums to 100%) sns.heatmap( cm_percent, annot=True, fmt='.1f', cmap='Blues', xticklabels=labels, yticklabels=labels, ax=axes[1], linewidths=0.5, vmin=0, vmax=100 ) axes[1].set_xlabel('Predicted') axes[1].set_ylabel('Actual') axes[1].set_title(f'{title} (Row %, i.e., Recall)') fig.tight_layout() return fig # Example usage np.random.seed(42) y_true = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2] * 20) y_pred = np.array([0, 0, 1, 1, 1, 0, 2, 2, 2] * 20) labels = ['Cat', 'Dog', 'Bird'] fig = plot_confusion_matrix(y_true, y_pred, labels=labels, title='Animal Classifier') fig.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight') plt.close(fig)
ROC and Precision-Recall Curves
ROC curves plot the true positive rate against the false positive rate across all possible classification thresholds. They answer the question: as I lower the threshold to catch more positives, how many false positives do I accept?
Precision-Recall curves are more informative for imbalanced datasets because they focus exclusively on the positive class. On a dataset where only 1% of samples are positive, ROC can show an impressive AUC of 0.95 while the model's precision at useful recall levels is actually terrible. Precision-Recall curves expose this directly.
Both curves let you visualize the tradeoff space and choose the optimal threshold for your specific business requirements — something a single F1 score cannot do.
import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import ( roc_curve, auc, precision_recall_curve, average_precision_score ) def plot_roc_and_pr(y_true, y_proba, title='Model Evaluation'): """Plot ROC and Precision-Recall curves side by side. Both curves visualize model performance across all possible classification thresholds. Together they give a complete picture that no single metric can provide. Args: y_true: ground truth binary labels (0 or 1) y_proba: predicted probabilities for the positive class title: figure title prefix Returns: Matplotlib Figure object. """ # Compute ROC curve fpr, tpr, roc_thresholds = roc_curve(y_true, y_proba) roc_auc = auc(fpr, tpr) # Compute Precision-Recall curve precision, recall, pr_thresholds = precision_recall_curve(y_true, y_proba) avg_precision = average_precision_score(y_true, y_proba) fig, axes = plt.subplots(1, 2, figsize=(14, 6)) # --- ROC Curve --- axes[0].plot(fpr, tpr, linewidth=2, label=f'Model (AUC = {roc_auc:.3f})') axes[0].plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Random (AUC = 0.5)') axes[0].fill_between(fpr, tpr, alpha=0.1) axes[0].set_xlabel('False Positive Rate') axes[0].set_ylabel('True Positive Rate (Recall)') axes[0].set_title(f'{title} — ROC Curve') axes[0].legend(loc='lower right') axes[0].grid(True, alpha=0.3) axes[0].set_xlim([-0.02, 1.02]) axes[0].set_ylim([-0.02, 1.02]) # --- Precision-Recall Curve --- axes[1].plot( recall, precision, linewidth=2, color='orange', label=f'Model (AP = {avg_precision:.3f})' ) baseline = y_true.sum() / len(y_true) axes[1].axhline( y=baseline, color='k', linestyle='--', alpha=0.5, label=f'Random baseline = {baseline:.3f}' ) axes[1].fill_between(recall, precision, alpha=0.1, color='orange') axes[1].set_xlabel('Recall') axes[1].set_ylabel('Precision') axes[1].set_title(f'{title} — Precision-Recall Curve') axes[1].legend(loc='lower left') axes[1].grid(True, alpha=0.3) axes[1].set_xlim([-0.02, 1.02]) axes[1].set_ylim([0, 1.05]) fig.tight_layout() return fig # Example: imbalanced fraud detection scenario np.random.seed(42) y_true = np.random.choice([0, 1], size=500, p=[0.95, 0.05]) y_proba = np.clip(y_true * 0.6 + np.random.randn(500) * 0.2, 0, 1) fig = plot_roc_and_pr(y_true, y_proba, title='Fraud Detection') fig.savefig('roc_pr_curves.png', dpi=300, bbox_inches='tight') plt.close(fig)
Residual Plots for Regression Models
Residual plots reveal systematic errors in regression models that aggregate metrics like RMSE and MAE completely hide. RMSE tells you the average error magnitude. Residual plots tell you whether those errors are random (acceptable) or structured (a sign your model is missing something).
If residuals show a pattern — a curve, a fan shape, clusters — your model is not capturing a relationship in the data. No amount of hyperparameter tuning will fix this. You need different features, a different transformation, or a different model family. The residual plot is the chart that tells you which.
import matplotlib.pyplot as plt import seaborn as sns import numpy as np from scipy import stats from sklearn.linear_model import LinearRegression from sklearn.datasets import make_regression def plot_regression_diagnostics(y_true, y_pred, title='Regression Diagnostics'): """Four-panel diagnostic plot for regression models. Panels: 1. Predicted vs Actual — overall fit quality 2. Residuals vs Predicted — detect non-linearity, heteroscedasticity 3. Residual Distribution — check normality assumption 4. Q-Q Plot — sensitive normality check at distribution tails Args: y_true: actual target values (numpy array) y_pred: predicted target values (numpy array) title: overall figure title Returns: Matplotlib Figure object. """ residuals = y_true - y_pred fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # Panel 1: Predicted vs Actual axes[0, 0].scatter(y_true, y_pred, alpha=0.4, s=15, color='steelblue') min_val = min(y_true.min(), y_pred.min()) max_val = max(y_true.max(), y_pred.max()) axes[0, 0].plot( [min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect prediction' ) axes[0, 0].set_xlabel('Actual') axes[0, 0].set_ylabel('Predicted') axes[0, 0].set_title('Predicted vs Actual') axes[0, 0].legend() # Panel 2: Residuals vs Predicted (the most important panel) axes[0, 1].scatter(y_pred, residuals, alpha=0.4, s=15, color='coral') axes[0, 1].axhline(y=0, color='r', linestyle='--', linewidth=2) axes[0, 1].set_xlabel('Predicted Value') axes[0, 1].set_ylabel('Residual (Actual - Predicted)') axes[0, 1].set_title('Residuals vs Predicted') # Panel 3: Residual Distribution sns.histplot(residuals, kde=True, ax=axes[1, 0], bins=30, color='steelblue') axes[1, 0].axvline(x=0, color='r', linestyle='--') axes[1, 0].set_xlabel('Residual') axes[1, 0].set_title(f'Residual Distribution (mean={residuals.mean():.2f})') # Panel 4: Q-Q plot (normality check — deviations at tails matter most) stats.probplot(residuals, dist='norm', plot=axes[1, 1]) axes[1, 1].set_title('Q-Q Plot (Normality Check)') fig.suptitle(title, fontsize=14, fontweight='bold') fig.tight_layout() return fig # Example X, y = make_regression( n_samples=300, n_features=3, noise=15, random_state=42 ) model = LinearRegression().fit(X, y) y_pred = model.predict(X) fig = plot_regression_diagnostics(y, y_pred, title='Linear Regression Diagnostics') fig.savefig('residual_plots.png', dpi=300, bbox_inches='tight') plt.close(fig)
- Residuals vs Predicted: random scatter centered on zero. No fan shape, no curve, no clusters.
- Residual Distribution: approximately normal, centered at zero. Skew or heavy tails indicate the model handles some value ranges worse than others.
- Q-Q Plot: points follow the diagonal line closely. Deviations at the tails mean the model produces more extreme errors than a normal distribution predicts.
- If you see any pattern in the residual plot, your model is missing a signal. Add features, apply transformations, or switch model families.
Feature Importance Visualization
Feature importance plots show which inputs drive your model's predictions. For tree-based models, importance is built in via impurity reduction. For any model, permutation importance provides a model-agnostic alternative by measuring how much accuracy drops when each feature's values are randomly shuffled.
Visualization makes these rankings immediately interpretable to non-technical stakeholders who need to understand why the model makes the decisions it does — not just what it predicts. A horizontal bar chart sorted by importance is the universal format that everyone from data scientists to product managers can read.
import matplotlib.pyplot as plt import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.inspection import permutation_importance from sklearn.datasets import make_classification def plot_feature_importance( model, feature_names, X_test, y_test, top_n=15 ): """Plot built-in and permutation importance side by side. Built-in importance (Gini) is fast but biased toward high-cardinality features. Permutation importance is slower but model-agnostic and unbiased. Showing both highlights discrepancies worth investigating. Args: model: fitted sklearn estimator feature_names: list of feature name strings X_test: test features for permutation importance y_test: test labels for permutation importance top_n: number of top features to display Returns: Matplotlib Figure object. """ fig, axes = plt.subplots(1, 2, figsize=(14, max(6, top_n * 0.4))) # Left panel: built-in importance (tree-based models only) if hasattr(model, 'feature_importances_'): importances = model.feature_importances_ indices = np.argsort(importances)[::-1][:top_n] axes[0].barh( [feature_names[i] for i in indices][::-1], importances[indices][::-1], color='steelblue', edgecolor='black', alpha=0.8 ) axes[0].set_xlabel('Gini Importance (Impurity Reduction)') axes[0].set_title('Built-in Feature Importance') else: axes[0].text( 0.5, 0.5, 'Not available\n(model has no feature_importances_)', ha='center', va='center', fontsize=12, transform=axes[0].transAxes ) axes[0].set_title('Built-in Feature Importance (N/A)') # Right panel: permutation importance (model-agnostic) perm_result = permutation_importance( model, X_test, y_test, n_repeats=10, random_state=42, n_jobs=-1 ) perm_mean = perm_result.importances_mean perm_std = perm_result.importances_std indices = np.argsort(perm_mean)[::-1][:top_n] axes[1].barh( [feature_names[i] for i in indices][::-1], perm_mean[indices][::-1], xerr=perm_std[indices][::-1], color='coral', edgecolor='black', alpha=0.8 ) axes[1].set_xlabel('Mean Accuracy Decrease When Shuffled') axes[1].set_title('Permutation Importance') fig.suptitle( 'Feature Importance Comparison', fontsize=14, fontweight='bold' ) fig.tight_layout() return fig # Example X, y = make_classification( n_samples=1000, n_features=10, n_informative=5, random_state=42 ) feature_names = [f'feature_{i}' for i in range(10)] model = RandomForestClassifier( n_estimators=100, random_state=42 ).fit(X, y) fig = plot_feature_importance(model, feature_names, X, y) fig.savefig('feature_importance.png', dpi=300, bbox_inches='tight') plt.close(fig)
Learning Curves: Diagnosing Bias and Variance
Learning curves plot model performance against training set size. They answer the most fundamental question in model improvement: should I get more data, or should I change the model?
The gap between the training score and validation score at each data size reveals whether your model suffers from high bias (underfitting — both curves are low) or high variance (overfitting — training is high, validation is low). This is not an academic distinction. It directly determines whether spending three weeks collecting more data will help or be completely wasted effort.
import matplotlib.pyplot as plt import numpy as np from sklearn.model_selection import learning_curve from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import make_classification def plot_learning_curve( estimator, X, y, title='Learning Curve', cv=5, scoring='accuracy' ): """Plot learning curve showing the bias-variance tradeoff. The gap between training and validation curves tells you exactly what to fix: more data, more regularization, or a different model. Args: estimator: unfitted sklearn estimator (will be cloned internally) X: feature matrix y: target vector title: plot title cv: number of cross-validation folds scoring: sklearn scoring metric name Returns: Matplotlib Figure object. """ train_sizes, train_scores, val_scores = learning_curve( estimator, X, y, cv=cv, n_jobs=-1, train_sizes=np.linspace(0.1, 1.0, 10), scoring=scoring ) train_mean = train_scores.mean(axis=1) train_std = train_scores.std(axis=1) val_mean = val_scores.mean(axis=1) val_std = val_scores.std(axis=1) fig, ax = plt.subplots(figsize=(10, 6)) # Confidence bands ax.fill_between( train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color='blue' ) ax.fill_between( train_sizes, val_mean - val_std, val_mean + val_std, alpha=0.1, color='orange' ) # Mean curves ax.plot( train_sizes, train_mean, 'o-', color='blue', linewidth=2, label='Training Score' ) ax.plot( train_sizes, val_mean, 'o-', color='orange', linewidth=2, label='Validation Score' ) ax.set_xlabel('Training Set Size') ax.set_ylabel(scoring.capitalize()) ax.set_title(title, fontsize=14, fontweight='bold') ax.legend(loc='lower right') ax.grid(True, alpha=0.3) # Annotate the final gap between curves final_gap = train_mean[-1] - val_mean[-1] ax.annotate( f'Gap: {final_gap:.3f}', xy=(train_sizes[-1], (train_mean[-1] + val_mean[-1]) / 2), fontsize=11, fontweight='bold', color='red', ha='right' ) fig.tight_layout() return fig # Example X, y = make_classification( n_samples=1000, n_features=20, n_informative=10, random_state=42 ) model = RandomForestClassifier(n_estimators=100, random_state=42) fig = plot_learning_curve(model, X, y, title='Random Forest Learning Curve') fig.savefig('learning_curve.png', dpi=300, bbox_inches='tight') plt.close(fig)
- Large gap (training high, validation low) = high variance (overfitting). Fix with: more data, stronger regularization, fewer features, simpler model.
- Both curves low and converging together = high bias (underfitting). Fix with: more features, more complex model, less regularization. More data will NOT help here.
- Both curves high and converging together = good fit. Model is well-calibrated for this data volume.
- Validation curve still rising at the right edge = more data will help. Collecting additional training examples is a productive investment.
Saving and Formatting for Production
Charts in notebooks are for exploration. Charts in reports, dashboards, presentations, and papers require consistent formatting, appropriate resolution, and accessible color choices. The gap between a notebook plot and a production-ready figure is not aesthetics — it is legibility, accessibility, and reproducibility.
A chart that looks fine on your 4K monitor becomes an unreadable blur when projected onto a conference room screen or embedded in a PDF at print resolution. This section covers the production formatting pipeline that ensures your figures survive every medium they encounter.
import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np def apply_production_style(): """Apply a consistent, publication-quality style globally. Call this once at the top of your notebook or script. Overrides Matplotlib defaults with production-safe values. """ mpl.rcParams.update({ # Typography 'font.size': 12, 'axes.titlesize': 14, 'axes.labelsize': 12, 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10, 'figure.titlesize': 16, # Figure defaults 'figure.figsize': (10, 6), 'figure.dpi': 100, # Screen display DPI 'savefig.dpi': 300, # Saved file DPI 'savefig.bbox': 'tight', # Prevent label clipping 'savefig.pad_inches': 0.1, # Grid and spines 'axes.grid': True, 'grid.alpha': 0.3, 'axes.spines.top': False, # Remove top spine 'axes.spines.right': False, # Remove right spine # Lines and markers 'lines.linewidth': 2, 'lines.markersize': 6, }) print("Production style applied.") def save_publication(fig, filename, formats=None): """Save figure in multiple formats for different use cases. Args: fig: Matplotlib Figure object filename: base filename without extension formats: list of format strings. Defaults to PNG + SVG. """ if formats is None: formats = ['png', 'svg'] for fmt in formats: filepath = f"{filename}.{fmt}" fig.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white') print(f"Saved: {filepath}") # --- Usage --- apply_production_style() fig, ax = plt.subplots() colors = ['#2563eb', '#16a34a', '#dc2626'] # Blue, green, red — distinguishable ax.bar( ['Model A', 'Model B', 'Model C'], [0.89, 0.92, 0.87], color=colors, edgecolor='black', alpha=0.9 ) ax.set_ylabel('Accuracy') ax.set_title('Model Comparison — Q1 2026') ax.set_ylim(0.80, 0.95) # Add value labels on bars for i, v in enumerate([0.89, 0.92, 0.87]): ax.text(i, v + 0.003, f'{v:.2f}', ha='center', fontweight='bold') save_publication(fig, 'model_comparison') plt.close(fig)
- Use colorblind-safe palettes: sns.color_palette('colorblind') or the 'muted' palette. Avoid pure red/green combinations as the only differentiator.
- Add patterns (hatching), markers, or line styles to distinguish series — not just color. ax.bar(..., hatch='//') adds visual texture.
- Never use the 'jet' or 'rainbow' colormap for continuous data — they introduce perceptual artifacts. Use 'viridis', 'plasma', or 'cividis' instead.
- Add direct value labels on bars and direct labels on lines instead of relying on a distant legend that requires color matching.
- Test your charts in grayscale. If they still communicate the message, they are accessible.
plt.show().| Aspect | Matplotlib | Seaborn |
|---|---|---|
| Learning Curve | Steeper — more code required for statistical plots | Gentler — sensible defaults and fewer lines for common charts |
| Control Level | Full pixel-level control over every element | Less granular control, but faster to prototype |
| DataFrame Awareness | None — requires manual extraction of arrays from DataFrames | Native — pass column names directly via data= parameter |
| Statistical Plots | Manual — compute confidence intervals, KDE, regressions yourself | Built-in — automatic confidence intervals, KDE, regression lines |
| Multi-Panel Layouts | Excellent — full control over grid spacing and sizing | Limited — pairplot and FacetGrid handle specific patterns only |
| Customization | Unlimited — every element is individually addressable | Good via Matplotlib axes access, but some Seaborn elements resist customization |
| Production Formatting | Full control via rcParams and style sheets | Inherits Matplotlib settings, adds its own theme layer via set_theme() |
| Best For | Final figures, custom annotations, publication-quality output | Exploratory analysis, statistical summaries, rapid prototyping |
🎯 Key Takeaways
- Every Matplotlib chart starts with fig, ax =
plt.subplots()— use the object-oriented interface, always. - Seaborn handles DataFrame grouping and statistical estimation automatically — use it for rapid exploration, then drop down to Matplotlib for polish.
- Confusion matrices reveal class-level failures that accuracy hides — always show both raw counts and row-normalized percentages.
- ROC curves work for balanced data; Precision-Recall curves are essential for imbalanced data. Plot both.
- Residual plots diagnose regression model errors that RMSE averages away — check for patterns, not just magnitude.
- Learning curves tell you whether to invest in more data or a different model — read the gap between training and validation curves.
- Save at 300 DPI with
fig.savefig()and always call plt.close(fig) afterward to prevent memory leaks in pipelines. - Use perceptually uniform colormaps (viridis, plasma, cividis) — never use jet or rainbow for continuous data.
⚠ Common Mistakes to Avoid
Interview Questions on This Topic
- QWhy is a Precision-Recall curve more informative than an ROC curve for imbalanced classification problems?Mid-levelReveal
- QYour residual plot shows a U-shaped pattern. What does this tell you about your regression model, and what would you do about it?Mid-levelReveal
- QHow would you present model evaluation results to a non-technical stakeholder who needs to decide whether to deploy the model?JuniorReveal
- QExplain the difference between built-in feature importance and permutation importance. When would they disagree, and which would you trust?SeniorReveal
Frequently Asked Questions
Should I use Matplotlib or Seaborn?
Use both — they are not alternatives. Seaborn is built on top of Matplotlib, and every Seaborn plot returns a Matplotlib axes object. Use Seaborn for quick statistical plots during exploration: histograms with KDE overlays, grouped boxplots, correlation heatmaps, pair plots. Use Matplotlib for final presentation control: precise axis formatting, custom annotations, multi-panel layouts with unequal sizing, publication-quality output. The practical pattern is: prototype in Seaborn for speed, then customize with Matplotlib methods for polish.
How do I choose the right chart type for my data?
Match the chart to the relationship you want to communicate. Distribution of a single variable: histogram or KDE plot. Comparison across categories: bar chart or boxplot. Correlation between two numeric variables: scatter plot (with alpha transparency for large datasets). Trend over time or ordered sequence: line chart. Matrix of values: heatmap. For ML diagnostics specifically: confusion matrix for classification evaluation, residual plot for regression evaluation, learning curve for bias-variance diagnosis, feature importance bar chart for model interpretability, ROC or PR curve for threshold selection.
Why do my saved plots look different from what I see in the notebook?
Notebook display and file saving use different rendering backends and resolutions. The notebook renders at screen resolution (72–96 DPI) using the inline backend, while savefig uses the DPI value you specify. Additionally, the notebook may auto-adjust figure size to fit the cell width. Always use fig.savefig('name.png', dpi=300, bbox_inches='tight') with an explicit figsize in plt.subplots() to get consistent, predictable output. Test by opening the saved file directly — not by comparing to the notebook display. And always save before calling plt.show(), which destroys the figure in most backends.
How many charts should I include in a model evaluation report?
For classification: confusion matrix, ROC or Precision-Recall curve (or both for imbalanced data), and feature importance. For regression: predicted vs actual scatter, residual plot (four-panel diagnostic), and feature importance. That is 3 charts per model, each answering a specific question about model quality. Add learning curves only if actively diagnosing overfitting or underfitting. Add prediction probability distribution plots for production monitoring. Every chart must answer a specific question — if you cannot state the question the chart answers, remove it. Stakeholders need insight, not decoration.
How do I make my charts accessible to colorblind viewers?
Three rules cover most cases. First, use colorblind-safe palettes: Seaborn's 'colorblind' palette, or perceptually uniform colormaps like 'viridis' and 'cividis'. Avoid red-green as the sole differentiator — the most common color vision deficiency affects red-green perception. Second, add redundant visual channels: different line styles (solid, dashed, dotted), different markers (circle, square, triangle), or hatching patterns on bars. This way color is not the only signal. Third, add direct labels — annotate bars with their values, label lines directly instead of using a distant legend that requires color matching. Test your final figure in grayscale: if it still communicates the message, it is accessible.
Developer and founder of TheCodeForge. I built this site because I was tired of tutorials that explain what to type without explaining why it works. Every article here is written to make concepts actually click.