How to Implement Transfer Learning in TensorFlow and What Pre-trained Models Are Available
Transfer learning is a technique that transfers knowledge from a pre-trained model to a new task, which can significantly improve model performance and reduce training time. TensorFlow provides rich pre-trained models and convenient transfer learning tools.
Basic Concepts of Transfer Learning
What is Transfer Learning
Transfer learning refers to using a model pre-trained on a large dataset and transferring its learned feature extraction capabilities to a new, possibly smaller dataset. This approach is particularly suitable for:
- Situations with small datasets
- New tasks similar to pre-training tasks
- Scenarios requiring quick good performance
Advantages of Transfer Learning
- Reduce training time
- Improve model performance
- Reduce demand for large amounts of labeled data
- Leverage existing research results
TensorFlow Hub Pre-trained Models
Loading Pre-trained Models with TensorFlow Hub
pythonimport tensorflow as tf import tensorflow_hub as hub # Load pre-trained model model_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4" pretrained_model = hub.KerasLayer(model_url, trainable=False) # Build transfer learning model model = tf.keras.Sequential([ pretrained_model, tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
Common TensorFlow Hub Models
Image Classification Models
python# MobileNet V2 mobilenet_v2 = hub.KerasLayer( "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4" ) # EfficientNet efficientnet = hub.KerasLayer( "https://tfhub.dev/google/efficientnet/b0/feature-vector/1" ) # ResNet resnet = hub.KerasLayer( "https://tfhub.dev/tensorflow/resnet_50/feature_vector/1" ) # Inception inception = hub.KerasLayer( "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4" )
Text Processing Models
python# BERT bert = hub.KerasLayer( "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" ) # Universal Sentence Encoder use = hub.KerasLayer( "https://tfhub.dev/google/universal-sentence-encoder/4" ) # ELMo elmo = hub.KerasLayer( "https://tfhub.dev/google/elmo/3" )
Keras Applications Pre-trained Models
Using Keras Applications
pythonfrom tensorflow.keras.applications import ( VGG16, VGG19, ResNet50, ResNet101, ResNet152, InceptionV3, InceptionResNetV2, MobileNet, MobileNetV2, DenseNet121, DenseNet169, DenseNet201, EfficientNetB0, EfficientNetB1, EfficientNetB2, NASNetMobile, NASNetLarge )
Basic Transfer Learning Workflow
pythonimport tensorflow as tf from tensorflow.keras import layers, models from tensorflow.keras.applications import VGG16 # Load pre-trained model (excluding top layer) base_model = VGG16( weights='imagenet', include_top=False, input_shape=(224, 224, 3) ) # Freeze pre-trained layers base_model.trainable = False # Add custom classification head model = models.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dense(256, activation='relu'), layers.Dropout(0.5), layers.Dense(10, activation='softmax') ]) # Compile model model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # Train model model.fit(train_dataset, epochs=10, validation_data=val_dataset)
Fine-tuning Pre-trained Model
python# Unfreeze some layers for fine-tuning base_model.trainable = True # Freeze earlier layers, only fine-tune later layers for layer in base_model.layers[:15]: layer.trainable = False # Use lower learning rate for fine-tuning model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # Continue training model.fit(train_dataset, epochs=5, validation_data=val_dataset)
Complete Transfer Learning Example
Image Classification Transfer Learning
pythonimport tensorflow as tf from tensorflow.keras import layers, models, applications import tensorflow_datasets as tfds # Load dataset dataset, info = tfds.load('cats_vs_dogs', with_info=True, as_supervised=True) train_data, test_data = dataset['train'], dataset['test'] # Data preprocessing def preprocess(image, label): image = tf.image.resize(image, (224, 224)) image = tf.keras.applications.resnet50.preprocess_input(image) return image, label train_data = train_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE) test_data = test_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE) # Load pre-trained model base_model = applications.ResNet50( weights='imagenet', include_top=False, input_shape=(224, 224, 3) ) # Freeze pre-trained layers base_model.trainable = False # Build model inputs = tf.keras.Input(shape=(224, 224, 3)) x = base_model(inputs, training=False) x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(256, activation='relu')(x) x = layers.Dropout(0.5)(x) outputs = layers.Dense(1, activation='sigmoid')(x) model = models.Model(inputs, outputs) # Compile model model.compile( optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] ) # Train model history = model.fit( train_data, epochs=10, validation_data=test_data ) # Fine-tune model base_model.trainable = True model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='binary_crossentropy', metrics=['accuracy'] ) history_fine = model.fit( train_data, epochs=5, validation_data=test_data )
Text Classification Transfer Learning
pythonimport tensorflow as tf import tensorflow_hub as hub from tensorflow.keras import layers, models # Load pre-trained BERT model bert_model = hub.KerasLayer( "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4", trainable=False ) # Build model text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text') preprocessed_text = bert_model(text_input) x = layers.Dense(128, activation='relu')(preprocessed_text['pooled_output']) x = layers.Dropout(0.5)(x) output = layers.Dense(1, activation='sigmoid')(x) model = models.Model(text_input, output) # Compile model model.compile( optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] ) # Prepare data train_texts = ["This is a positive sentence", "This is a negative sentence"] train_labels = [1, 0] # Train model model.fit( train_texts, train_labels, epochs=10, batch_size=32 )
Advanced Transfer Learning Techniques
1. Feature Extraction
python# Use pre-trained model only as feature extractor base_model = applications.VGG16(weights='imagenet', include_top=False) # Extract features def extract_features(images): features = base_model.predict(images) return features # Train simple classifier on extracted features train_features = extract_features(train_images) classifier = tf.keras.Sequential([ layers.Dense(256, activation='relu'), layers.Dense(10, activation='softmax') ]) classifier.fit(train_features, train_labels, epochs=10)
2. Progressive Unfreezing
python# Gradually unfreeze layers base_model = applications.ResNet50(weights='imagenet', include_top=False) base_model.trainable = False # Phase 1: Only train classification head model = build_model(base_model) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') model.fit(train_data, epochs=5) # Phase 2: Unfreeze last few layers base_model.trainable = True for layer in base_model.layers[:-10]: layer.trainable = False model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='sparse_categorical_crossentropy' ) model.fit(train_data, epochs=5) # Phase 3: Unfreeze more layers for layer in base_model.layers[:-20]: layer.trainable = False model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), loss='sparse_categorical_crossentropy' ) model.fit(train_data, epochs=5)
3. Learning Rate Scheduling
python# Use learning rate scheduler initial_learning_rate = 1e-3 decay_steps = 1000 decay_rate = 0.9 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps, decay_rate ) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
4. Mixed Precision Training
pythonfrom tensorflow.keras import mixed_precision # Enable mixed precision policy = mixed_precision.Policy('mixed_float16') mixed_precision.set_global_policy(policy) # Build model base_model = applications.EfficientNetB0(weights='imagenet', include_top=False) base_model.trainable = False model = tf.keras.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dense(10, activation='softmax') ]) # Use loss scale optimizer optimizer = mixed_precision.LossScaleOptimizer( tf.keras.optimizers.Adam() ) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
5. Data Augmentation
python# Add data augmentation data_augmentation = tf.keras.Sequential([ layers.RandomFlip('horizontal'), layers.RandomRotation(0.2), layers.RandomZoom(0.2), layers.RandomContrast(0.1) ]) # Build model inputs = tf.keras.Input(shape=(224, 224, 3)) x = data_augmentation(inputs) x = base_model(x, training=False) x = layers.GlobalAveragePooling2D()(x) outputs = layers.Dense(10, activation='softmax')(x) model = models.Model(inputs, outputs)
Comparison of Common Pre-trained Models
| Model | Parameters | Features | Use Cases |
|---|---|---|---|
| VGG16 | 138M | Simple structure, easy to understand | Academic research, feature extraction |
| ResNet50 | 25M | Residual connections, deep network | General image classification |
| MobileNetV2 | 3.5M | Lightweight, suitable for mobile | Mobile apps, real-time inference |
| EfficientNetB0 | 5.3M | Efficient scaling strategy | Balance performance and efficiency |
| InceptionV3 | 23M | Inception modules | Complex image classification |
| DenseNet121 | 8M | Dense connections | Medical image analysis |
| BERT | 110M | Transformer architecture | Natural language processing |
| GPT-2 | 117M-1.5B | Generative pre-training | Text generation |
Transfer Learning Best Practices
- Choose appropriate pre-trained model: Select model based on task requirements
- Reasonably freeze layers: Initially freeze all pre-trained layers, gradually unfreeze
- Use lower learning rate: Use smaller learning rate when fine-tuning
- Data augmentation: Augment small datasets
- Monitor overfitting: Use validation set to monitor model performance
- Progressive unfreezing: Adopt progressive unfreezing strategy
- Learning rate scheduling: Use learning rate decay strategy
- Mixed precision training: Accelerate training process
Transfer Learning Application Scenarios
1. Medical Image Diagnosis
python# Use pre-trained ImageNet model for medical image classification base_model = applications.DenseNet121(weights='imagenet', include_top=False) # Add medical image specific classification head
2. Object Detection
python# Use pre-trained backbone for object detection backbone = applications.ResNet50(weights='imagenet', include_top=False) # Add detection head (e.g., Faster R-CNN, YOLO, etc.)
3. Semantic Segmentation
python# Use pre-trained model for image segmentation base_model = applications.MobileNetV2(weights='imagenet', include_top=False) # Add segmentation head (e.g., U-Net, DeepLabV3+, etc.)
4. Text Classification
python# Use pre-trained BERT model for text classification bert = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4") # Add classification layer
5. Sentiment Analysis
python# Use pre-trained text model for sentiment analysis use = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4") # Add sentiment classification layer
Summary
TensorFlow provides rich transfer learning tools and pre-trained models:
- TensorFlow Hub: Provides numerous pre-trained models
- Keras Applications: Built-in classic pre-trained models
- Flexible fine-tuning strategies: Supports various fine-tuning methods
- Wide range of application scenarios: Images, text, audio, etc.
Mastering transfer learning techniques will help you quickly build high-performance deep learning models.