Skip to content
Home ML / AI Transfer Learning with TensorFlow — Standing on the Shoulders of Giants

Transfer Learning with TensorFlow — Standing on the Shoulders of Giants

Where developers are forged. · Structured learning · Free forever.
📍 Part of: TensorFlow & Keras → Topic 8 of 10
Learn Transfer Learning with TensorFlow and Keras.
⚙️ Intermediate — basic ML / AI knowledge assumed
In this tutorial, you'll learn
Learn Transfer Learning with TensorFlow and Keras.
  • Transfer learning allows you to achieve professional-grade AI accuracy on standard consumer hardware.
  • Freezing the base model prevents 'catastrophic forgetting' of general visual features like edges and shapes.
  • MobileNetV2 is an excellent, lightweight starting point for mobile and web-based vision applications.
✦ Plain-English analogy ✦ Real code with output ✦ Interview questions
Quick Answer
  • Transfer learning reuses weights from a model trained on millions of images (ImageNet) as a starting point for your task
  • include_top=False removes the original classification head — you attach your own Dense output for your classes
  • base_model.trainable = False freezes all pre-learned weights during feature extraction phase
  • GlobalAveragePooling2D is preferred over Flatten — fewer parameters, lower overfitting risk, same spatial coverage
  • Fine-tuning: unfreeze the last N layers of the base and retrain with a very low learning rate (1e-5, not 1e-3)
  • Biggest mistake: not freezing the base model — large gradients from your random head will destroy the pre-trained weights
Production IncidentFine-Tuning Too Early Destroyed a Week of TrainingA team unfroze MobileNetV2's entire base model on epoch 1 with a learning rate of 1e-3. After 50 epochs, validation accuracy was 51% — worse than random initialization on their 4-class problem.
SymptomTraining loss decreased steadily but validation accuracy plateaued at 51% from epoch 5 onward. The model appeared to be learning but was not generalizing.
AssumptionThe team believed that unfreezing everything from the start would allow the model to adapt faster to their medical imaging domain.
Root causeThe randomly initialized Dense head had large, unstable gradients in the early epochs. Without a frozen base, those gradients propagated all the way through 154 MobileNetV2 layers and 'catastrophically overwrote' the pre-trained ImageNet weights — a phenomenon called 'weight smashing.' By epoch 5, the base was producing essentially random feature maps, no different from training from scratch — but without the architecture-appropriate initialization.
FixTwo-phase approach: (1) Freeze base_model.trainable = False and train only the head for 10–20 epochs until the head loss stabilizes below 0.5. (2) Then unfreeze only the last 30–50 layers of the base and retrain with lr=1e-5 (not 1e-3). The slow learning rate prevents catastrophic forgetting of general features.
Key Lesson
Never unfreeze the base model until the custom head has stabilized — head loss should be below 0.5 before fine-tuning beginsFine-tuning learning rate must be 10x–100x lower than initial training rate — use 1e-5 for AdamUnfreeze incrementally from the top of the base — the last 20–50 layers, not all 154
Production Debug GuideCommon failures during feature extraction and fine-tuning phases
Validation accuracy does not improve beyond random chance after 20 epochsCheck that the base model is correctly frozen: print([l.trainable for l in base_model.layers[:5]]). All should be False. Also verify preprocessing — MobileNetV2 requires preprocess_input(), not raw division by 255.
Training accuracy is high but fine-tuning causes accuracy regressionLearning rate is too high for fine-tuning. Reduce to 1e-5 or lower. Recompile the model after unfreezing: model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), ...). Not recompiling after unfreezing is a common silent failure.
Memory OOM when using larger base models (ResNet50, EfficientNetB7)Use gradient checkpointing or reduce batch size to 16 or 8. For MobileNetV2, input_shape=(96, 96, 3) instead of (224, 224, 3) reduces feature map memory by 5x with modest accuracy trade-off.
Model performs well on clean photos but poorly on real-world production imagesAdd strong augmentation: RandomBrightness, RandomContrast, RandomZoom. Your production distribution differs from your training distribution. Consider collecting 50–100 hard examples per class from production and adding them to the training set.

Training a deep neural network from scratch requires two things most developers don't have: millions of labeled images and weeks of GPU time. Transfer Learning is the industry workaround. By using pre-trained models from 'TensorFlow Hub' or 'Keras Applications,' you can leverage patterns learned by Google or Microsoft to solve your specific problems.

In this guide, we'll demonstrate how to 'freeze' the base of a massive model (MobileNetV2), swap out its 'head' for our own classification task, and fine-tune it for near-perfect accuracy with just a few hundred images. At TheCodeForge, we utilize this strategy to deploy state-of-the-art vision systems without the overhead of massive data collection.

1. Loading a Pre-trained Base Model

Most of the work in a vision model happens in the early layers that detect edges and textures. We load these layers but set include_top=False to remove the final classification layer, since we want to predict our own classes, not the original 1,000 categories from ImageNet.

Crucially, we freeze the weights. If we didn't, the initial large errors from our randomly initialized new layers would 'pollute' the refined weights of the pre-trained model.

load_pretrained.py · PYTHON
12345678910111213141516
import tensorflow as tf

# io.thecodeforge: Standard Transfer Learning Base Initialization
# Load MobileNetV2 optimized for 160x160 color images
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3),
    include_top=False,
    weights='imagenet'
)

# Freeze the base - we don't want to break the pre-learned patterns yet
base_model.trainable = False

print(f"Trainable layers: {sum(1 for l in base_model.layers if l.trainable)}")
print(f"Frozen layers: {sum(1 for l in base_model.layers if not l.trainable)}")
base_model.summary()
▶ Output
Trainable layers: 0
Frozen layers: 155
Total params: 2,257,984 | Trainable params: 0
Mental Model
Feature Extraction vs. Fine-Tuning — Two Distinct Phases
Transfer learning is a two-phase process. Phase 1 trains only your new head against frozen pre-trained weights. Phase 2 unfreezes the upper base layers for domain-specific adaptation.
  • Phase 1 (Feature Extraction): base frozen, head only — fast, safe, use lr=1e-3
  • Phase 2 (Fine-Tuning): unfreeze top 20–50 layers, retrain with lr=1e-5
  • Never combine both phases — always let Phase 1 stabilize first
  • The boundary: when head val_loss stops improving is when to start fine-tuning
  • Each pre-trained model has its own required preprocessing — use the model's own preprocess_input()
📊 Production Insight
The order matters: freeze first, let head stabilize, then unfreeze incrementally.
Skipping Phase 1 and fine-tuning from epoch 1 is the single most common transfer learning mistake that wastes GPU budget.
For the preprocessing requirement per model, consult tf.keras.applications docs — MobileNetV2 needs preprocess_input(), not /255.
🎯 Key Takeaway
include_top=False + base_model.trainable=False is the correct starting configuration — always.
Trainable params should be zero for the base and non-zero only for your head.
Preprocessing is model-specific — MobileNetV2 expects [-1, 1], not [0, 1].

2. Adding a Custom Head

Now we 'attach' our own layers to the top of the pre-trained base. This new 'head' will learn to interpret the complex features extracted by MobileNet to classify our specific images. This stage is often called 'Feature Extraction' because we treat the base model as a fixed mathematical transformation of the pixels.

custom_head.py · PYTHON
123456789101112131415161718192021
# io.thecodeforge: Attaching the Classification Head

# Preprocessing baked in — MobileNetV2 requires inputs scaled to [-1, 1]
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

model = tf.keras.Sequential([
    tf.keras.layers.Lambda(preprocess_input, input_shape=(160, 160, 3)),
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.2), # Standard Forge practice for regularization
    tf.keras.layers.Dense(1, activation='sigmoid') # Binary classifier (e.g., Cat vs Dog)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=1e-3),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Phase 1: Train head only
history_phase1 = model.fit(train_dataset, epochs=20, validation_data=val_dataset)
▶ Output
Epoch 20/20: loss: 0.22 - accuracy: 0.91 - val_loss: 0.19 - val_accuracy: 0.93
🔥Why GlobalAveragePooling?
This layer converts the 2D spatial features into a 1D vector. It's more computationally efficient than a 'Flatten' layer and significantly reduces the number of parameters, which is a key defense against overfitting when working with small datasets.
📊 Production Insight
GlobalAveragePooling2D is strictly preferred over Flatten for transfer learning heads.
A (5, 5, 1280) MobileNetV2 output: Flatten gives 32,000 Dense inputs, GAP gives 1,280 — 25x fewer parameters.
Lambda layers for preprocessing make the preprocessing part of the SavedModel — no serving-side preprocessing drift.
🎯 Key Takeaway
Bake preprocessing inside the model with a Lambda or Rescaling layer.
GlobalAveragePooling2D over Flatten — always for transfer learning heads.
Phase 1 training should reach val_accuracy > 0.85 before you consider fine-tuning.

3. Implementation: Java Model Inference Service

Once your Transfer Learning model is trained and exported as a SavedModel, it can be integrated into a high-concurrency Java backend using the TensorFlow Java API.

io/thecodeforge/ml/VisionService.java · JAVA
1234567891011121314151617181920212223242526272829
package io.thecodeforge.ml;

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public class VisionService {
    private SavedModelBundle model;

    /**
     * io.thecodeforge: Loading and serving pre-trained artifacts
     */
    public void initModel(String modelDir) {
        this.model = SavedModelBundle.load(modelDir, "serve");
    }

    public float predict(float[][][][] imageTensorData) {
        try (Tensor<Float> input = Tensor.create(imageTensorData)) {
            Tensor<Float> result = model.session().runner()
                .feed("serving_default_input_1", input)
                .fetch("StatefulPartitionedCall")
                .run().get(0).expect(Float.class);

            float[][] matrix = new float[1][1];
            result.copyTo(matrix);
            return matrix[0][0];
        }
    }
}
▶ Output
// Compiled for Forge-Backend Runtime
📊 Production Insight
The input key 'serving_default_input_1' must be verified with: saved_model_cli show --dir model_dir --all.
The serving signature name varies by how the model was saved — inspect before deploying to Java.
For the full serialization guide, see tensorflow-save-load-model.
🎯 Key Takeaway
Java inference from a Python-trained model requires matching the exact serving signature keys.
Always inspect the model signature before writing Java feeding code.
SavedModel is the only cross-language portable format — H5 is Python-only.

4. Audit Logging: Experiment Metadata

In a professional pipeline, we track which 'Base Model' and 'Weights' were used. This SQL schema ensures full lineage for every model deployed to production.

io/thecodeforge/db/transfer_audit.sql · SQL
12345678910111213141516
-- io.thecodeforge: ML Experiment Tracking
INSERT INTO io.thecodeforge.experiments (
    model_id,
    base_architecture,
    pretrained_weights,
    frozen_layers_count,
    final_accuracy,
    created_at
) VALUES (
    'FORGE-V2-FINETUNED',
    'MobileNetV2',
    'ImageNet',
    154,
    0.982,
    CURRENT_TIMESTAMP
);
📊 Production Insight
Record fine_tuning_start_epoch and learning_rate_phase2 — two models with identical final accuracy may have very different robustness profiles depending on how aggressively they were fine-tuned.
For automated tracking of these fields, see experiment-tracking-mlflow.
🎯 Key Takeaway
Transfer learning lineage needs more metadata than from-scratch training — record which layers were frozen and for how long.
fineTuning_lr is as important as final_accuracy for debugging production regressions.
This SQL schema is the floor; MLflow automates the ceiling.

5. Deployment: The Inference Container

We wrap the inference engine in a Docker container to handle dependency isolation, specifically ensuring the correct version of the TensorFlow runtime is present.

Dockerfile · DOCKERFILE
123456789101112
# io.thecodeforge: High-Performance Vision Inference
FROM tensorflow/tensorflow:2.14.0

WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY saved_model/ /app/model/
COPY inference_api.py .

EXPOSE 8080
CMD ["python", "inference_api.py"]
▶ Output
Successfully built image thecodeforge/vision-api:latest
📊 Production Insight
For inference-only deployments, the CPU-only TF image is sufficient and 4x smaller than the GPU variant.
If your inference latency target is under 50ms per image, consider TFLite quantization instead — see tensorflow-lite-mobile for the full conversion workflow.
🎯 Key Takeaway
Use the CPU-only TF image for inference unless you have a hard <50ms latency requirement.
For mobile or edge deployments, convert to TFLite after transfer learning — the TFLite guide covers the exact conversion workflow.
🗂 Training from Scratch vs. Transfer Learning
When each approach is appropriate
FeatureTraining from ScratchTransfer Learning
Data RequiredMassive (10k+ images)Small (100s of images)
Compute TimeDays / WeeksMinutes / Hours
AccuracyHigh (if data exists)Extremely High (starts with 'knowledge')
ComplexityHigh (Architecture design)Low (Using proven models)
Use CaseNiche/Unique data domainsGeneral objects, faces, cars, etc.

🎯 Key Takeaways

  • Transfer learning allows you to achieve professional-grade AI accuracy on standard consumer hardware.
  • Freezing the base model prevents 'catastrophic forgetting' of general visual features like edges and shapes.
  • MobileNetV2 is an excellent, lightweight starting point for mobile and web-based vision applications.
  • Fine-tuning is an optional optimization step that unfreezes the final layers of the base model for domain-specific accuracy.
  • Always package your vision services in Docker to ensure the C++ backend for TensorFlow remains consistent across deployments.

⚠ Common Mistakes to Avoid

    Not freezing the base model before training the head
    Symptom

    Training loss decreases but the model converges to near-random accuracy on validation data — the base weights have been corrupted by large head gradients

    Fix

    Set base_model.trainable = False before the first compile. Verify with: print(sum(1 for l in base_model.layers if l.trainable)) — must be 0. Only unfreeze for fine-tuning after the head has stabilized.

    Not using the correct preprocessing function for the base model
    Symptom

    Validation accuracy plateaus at 5–15% even though the architecture is correct — the model has never seen inputs in this range during training

    Fix

    Each Keras application has its own preprocess_input. MobileNetV2: tf.keras.applications.mobilenet_v2.preprocess_input(). ResNet50: tf.keras.applications.resnet50.preprocess_input(). Bake it into the model as a Lambda layer — never as external preprocessing.

    Fine-tuning too early or with too high a learning rate
    Symptom

    Model performance regresses sharply after unfreezing — val_accuracy drops from 93% to 60% within 3 epochs of fine-tuning

    Fix

    Only fine-tune after Phase 1 head training has stabilized. Use lr=1e-5 (not the original 1e-3) when fine-tuning. Unfreeze only the top 20–50 layers of the base, not all of them.

    Using a base model input shape incompatible with your image size
    Symptom

    The spatial resolution after the base model's final layer is 0x0 — a degenerate feature map that feeds GlobalAveragePooling nothing meaningful

    Fix

    Input images must be at least 32x32 for MobileNetV2 and 197x197 for ViT models. If your images are smaller, resize with tf.image.resize() before feeding, or use a different base architecture designed for small inputs.

Interview Questions on This Topic

  • QWhat is the 'Vanishing Gradient' problem and how does Transfer Learning help avoid it during early training phases?SeniorReveal
    Vanishing gradients occur when error signals diminish exponentially as they propagate backward through deep networks — layers close to the input receive near-zero gradient updates and stop learning. Transfer learning sidesteps this in Phase 1 by freezing the base model entirely. Only the shallow custom head receives gradient updates, so there is no deep chain of multiplication to cause vanishing. In Phase 2 fine-tuning, pre-trained weights provide a well-conditioned starting point — the magnitude of activations is already in a healthy range, so gradients propagate more cleanly than they would from random initialization.
  • QDescribe the 'Feature Extraction' vs 'Fine-tuning' stages. At what point do you reduce the learning rate when fine-tuning?SeniorReveal
    Feature Extraction: the base model is frozen (trainable=False), and only the custom head layers are updated. The head learns to map the base model's fixed feature representations to your target classes. Fine-Tuning: after the head stabilizes (val_loss is flat for 3–5 epochs), unfreeze the last 20–50 layers of the base and retrain with a learning rate 10x–100x lower than the initial rate. The reduced learning rate prevents catastrophic forgetting of general features like edges and textures. Rule: if Phase 1 used lr=1e-3, fine-tune with lr=1e-5. Recompile the model after unfreezing — the optimizer state must be reset for the new trainable configuration.
  • QWhy do we remove the 'top' (fully connected) layer of a pre-trained model when applying it to a new classification task?JuniorReveal
    The 'top' of a pre-trained model (e.g., MobileNetV2 trained on ImageNet) contains Dense layers designed to classify into the original 1,000 ImageNet categories. These layers have weights tuned specifically for those 1,000 classes and an output dimension of 1,000. Your task almost certainly has different classes and a different number of outputs. By setting include_top=False, you remove these task-specific layers and keep only the convolutional backbone — the general-purpose feature extractor. You then attach your own Dense output with the correct number of units and activation function for your specific problem.
  • QWhat is 'Domain Adaptation' and how does it relate to the effectiveness of ImageNet weights on medical imaging data?SeniorReveal
    Domain Adaptation is the challenge of applying a model trained on one data distribution to a significantly different target distribution. ImageNet weights encode natural image statistics: color photos of everyday objects with natural lighting. Medical images (X-rays, MRI, pathology slides) have fundamentally different statistics — grayscale, different textures, different frequency content. This means ImageNet features are less transferable than for natural photo tasks. In practice, fine-tuning with even 500–1000 labeled medical images often still outperforms training from scratch, but the accuracy gain from transfer learning is smaller than for natural image tasks. Convert MRI slices to 3-channel by repeating the grayscale channel to match the RGB input expectation.
  • QHow do you handle the bottleneck of 'Internal Covariate Shift' when unfreezing layers that contain Batch Normalization?SeniorReveal
    Batch Normalization layers contain running mean and variance statistics computed during the original training. When you unfreeze these layers, TF by default sets them to training mode, which updates the running statistics with your new data's distribution. This can destabilize the feature maps, especially if your domain differs from ImageNet. The fix: keep BN layers frozen even when unfreezing the surrounding Conv layers. In Keras: for layer in base_model.layers: if isinstance(layer, tf.keras.layers.BatchNormalization): layer.trainable = False. This preserves the original normalization statistics while allowing the conv weights to adapt.

Frequently Asked Questions

What is 'Fine-tuning' and how does it differ from 'Feature Extraction'?

Feature Extraction is keeping the pre-trained base completely frozen and only training the new head. Fine-tuning is unfreezing the last few layers of the base model and training them with a very low learning rate to adapt the high-level features to your specific data.

Why do we remove the 'top' layer of a pre-trained model?

The 'top' layer of models like MobileNet was designed to classify 1,000 specific categories from the ImageNet competition. Since your project likely has different categories (e.g., 'Defective' vs 'Functional' parts), we replace that layer with one that matches your specific output count.

What is the 'ImageNet' dataset and why is it so important for transfer learning?

ImageNet is a massive database of over 14 million hand-annotated images. Models trained on it have essentially 'seen' almost everything in the natural world, making them the perfect 'general experts' to build upon.

Can I use Transfer Learning for text or audio?

Absolutely. You can use pre-trained models like BERT (via Hugging Face Transformers) for text or YAMNet for audio. The principle remains the same: leverage a model that already understands the fundamental 'language' of the data. See the hugging-face-transformers guide for the NLP version of this workflow.

Should I always use transfer learning instead of training from scratch?

For visual tasks with fewer than 50,000 images: yes, almost always. Transfer learning will outperform training from scratch in both accuracy and training time. Exceptions: highly specialized domains where ImageNet statistics are completely irrelevant (e.g., astronomical imaging, radar), or when you have millions of labeled domain-specific images that justify architecture search from scratch.

🔥
Naren Founder & Author

Developer and founder of TheCodeForge. I built this site because I was tired of tutorials that explain what to type without explaining why it works. Every article here is written to make concepts actually click.

← PreviousKeras Callbacks — ModelCheckpoint and EarlyStoppingNext →Saving and Loading Models in TensorFlow
Forged with 🔥 at TheCodeForge.io — Where Developers Are Forged