乐闻世界logo
搜索文章和话题

How to Implement Transfer Learning in TensorFlow and What Pre-trained Models Are Available

2月21日 17:07

How to Implement Transfer Learning in TensorFlow and What Pre-trained Models Are Available

Transfer learning is a technique that transfers knowledge from a pre-trained model to a new task, which can significantly improve model performance and reduce training time. TensorFlow provides rich pre-trained models and convenient transfer learning tools.

Basic Concepts of Transfer Learning

What is Transfer Learning

Transfer learning refers to using a model pre-trained on a large dataset and transferring its learned feature extraction capabilities to a new, possibly smaller dataset. This approach is particularly suitable for:

  • Situations with small datasets
  • New tasks similar to pre-training tasks
  • Scenarios requiring quick good performance

Advantages of Transfer Learning

  • Reduce training time
  • Improve model performance
  • Reduce demand for large amounts of labeled data
  • Leverage existing research results

TensorFlow Hub Pre-trained Models

Loading Pre-trained Models with TensorFlow Hub

python
import tensorflow as tf import tensorflow_hub as hub # Load pre-trained model model_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4" pretrained_model = hub.KerasLayer(model_url, trainable=False) # Build transfer learning model model = tf.keras.Sequential([ pretrained_model, tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

Common TensorFlow Hub Models

Image Classification Models

python
# MobileNet V2 mobilenet_v2 = hub.KerasLayer( "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4" ) # EfficientNet efficientnet = hub.KerasLayer( "https://tfhub.dev/google/efficientnet/b0/feature-vector/1" ) # ResNet resnet = hub.KerasLayer( "https://tfhub.dev/tensorflow/resnet_50/feature_vector/1" ) # Inception inception = hub.KerasLayer( "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4" )

Text Processing Models

python
# BERT bert = hub.KerasLayer( "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" ) # Universal Sentence Encoder use = hub.KerasLayer( "https://tfhub.dev/google/universal-sentence-encoder/4" ) # ELMo elmo = hub.KerasLayer( "https://tfhub.dev/google/elmo/3" )

Keras Applications Pre-trained Models

Using Keras Applications

python
from tensorflow.keras.applications import ( VGG16, VGG19, ResNet50, ResNet101, ResNet152, InceptionV3, InceptionResNetV2, MobileNet, MobileNetV2, DenseNet121, DenseNet169, DenseNet201, EfficientNetB0, EfficientNetB1, EfficientNetB2, NASNetMobile, NASNetLarge )

Basic Transfer Learning Workflow

python
import tensorflow as tf from tensorflow.keras import layers, models from tensorflow.keras.applications import VGG16 # Load pre-trained model (excluding top layer) base_model = VGG16( weights='imagenet', include_top=False, input_shape=(224, 224, 3) ) # Freeze pre-trained layers base_model.trainable = False # Add custom classification head model = models.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dense(256, activation='relu'), layers.Dropout(0.5), layers.Dense(10, activation='softmax') ]) # Compile model model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # Train model model.fit(train_dataset, epochs=10, validation_data=val_dataset)

Fine-tuning Pre-trained Model

python
# Unfreeze some layers for fine-tuning base_model.trainable = True # Freeze earlier layers, only fine-tune later layers for layer in base_model.layers[:15]: layer.trainable = False # Use lower learning rate for fine-tuning model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # Continue training model.fit(train_dataset, epochs=5, validation_data=val_dataset)

Complete Transfer Learning Example

Image Classification Transfer Learning

python
import tensorflow as tf from tensorflow.keras import layers, models, applications import tensorflow_datasets as tfds # Load dataset dataset, info = tfds.load('cats_vs_dogs', with_info=True, as_supervised=True) train_data, test_data = dataset['train'], dataset['test'] # Data preprocessing def preprocess(image, label): image = tf.image.resize(image, (224, 224)) image = tf.keras.applications.resnet50.preprocess_input(image) return image, label train_data = train_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE) test_data = test_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE) # Load pre-trained model base_model = applications.ResNet50( weights='imagenet', include_top=False, input_shape=(224, 224, 3) ) # Freeze pre-trained layers base_model.trainable = False # Build model inputs = tf.keras.Input(shape=(224, 224, 3)) x = base_model(inputs, training=False) x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(256, activation='relu')(x) x = layers.Dropout(0.5)(x) outputs = layers.Dense(1, activation='sigmoid')(x) model = models.Model(inputs, outputs) # Compile model model.compile( optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] ) # Train model history = model.fit( train_data, epochs=10, validation_data=test_data ) # Fine-tune model base_model.trainable = True model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='binary_crossentropy', metrics=['accuracy'] ) history_fine = model.fit( train_data, epochs=5, validation_data=test_data )

Text Classification Transfer Learning

python
import tensorflow as tf import tensorflow_hub as hub from tensorflow.keras import layers, models # Load pre-trained BERT model bert_model = hub.KerasLayer( "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4", trainable=False ) # Build model text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text') preprocessed_text = bert_model(text_input) x = layers.Dense(128, activation='relu')(preprocessed_text['pooled_output']) x = layers.Dropout(0.5)(x) output = layers.Dense(1, activation='sigmoid')(x) model = models.Model(text_input, output) # Compile model model.compile( optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] ) # Prepare data train_texts = ["This is a positive sentence", "This is a negative sentence"] train_labels = [1, 0] # Train model model.fit( train_texts, train_labels, epochs=10, batch_size=32 )

Advanced Transfer Learning Techniques

1. Feature Extraction

python
# Use pre-trained model only as feature extractor base_model = applications.VGG16(weights='imagenet', include_top=False) # Extract features def extract_features(images): features = base_model.predict(images) return features # Train simple classifier on extracted features train_features = extract_features(train_images) classifier = tf.keras.Sequential([ layers.Dense(256, activation='relu'), layers.Dense(10, activation='softmax') ]) classifier.fit(train_features, train_labels, epochs=10)

2. Progressive Unfreezing

python
# Gradually unfreeze layers base_model = applications.ResNet50(weights='imagenet', include_top=False) base_model.trainable = False # Phase 1: Only train classification head model = build_model(base_model) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') model.fit(train_data, epochs=5) # Phase 2: Unfreeze last few layers base_model.trainable = True for layer in base_model.layers[:-10]: layer.trainable = False model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='sparse_categorical_crossentropy' ) model.fit(train_data, epochs=5) # Phase 3: Unfreeze more layers for layer in base_model.layers[:-20]: layer.trainable = False model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), loss='sparse_categorical_crossentropy' ) model.fit(train_data, epochs=5)

3. Learning Rate Scheduling

python
# Use learning rate scheduler initial_learning_rate = 1e-3 decay_steps = 1000 decay_rate = 0.9 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps, decay_rate ) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')

4. Mixed Precision Training

python
from tensorflow.keras import mixed_precision # Enable mixed precision policy = mixed_precision.Policy('mixed_float16') mixed_precision.set_global_policy(policy) # Build model base_model = applications.EfficientNetB0(weights='imagenet', include_top=False) base_model.trainable = False model = tf.keras.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dense(10, activation='softmax') ]) # Use loss scale optimizer optimizer = mixed_precision.LossScaleOptimizer( tf.keras.optimizers.Adam() ) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')

5. Data Augmentation

python
# Add data augmentation data_augmentation = tf.keras.Sequential([ layers.RandomFlip('horizontal'), layers.RandomRotation(0.2), layers.RandomZoom(0.2), layers.RandomContrast(0.1) ]) # Build model inputs = tf.keras.Input(shape=(224, 224, 3)) x = data_augmentation(inputs) x = base_model(x, training=False) x = layers.GlobalAveragePooling2D()(x) outputs = layers.Dense(10, activation='softmax')(x) model = models.Model(inputs, outputs)

Comparison of Common Pre-trained Models

ModelParametersFeaturesUse Cases
VGG16138MSimple structure, easy to understandAcademic research, feature extraction
ResNet5025MResidual connections, deep networkGeneral image classification
MobileNetV23.5MLightweight, suitable for mobileMobile apps, real-time inference
EfficientNetB05.3MEfficient scaling strategyBalance performance and efficiency
InceptionV323MInception modulesComplex image classification
DenseNet1218MDense connectionsMedical image analysis
BERT110MTransformer architectureNatural language processing
GPT-2117M-1.5BGenerative pre-trainingText generation

Transfer Learning Best Practices

  1. Choose appropriate pre-trained model: Select model based on task requirements
  2. Reasonably freeze layers: Initially freeze all pre-trained layers, gradually unfreeze
  3. Use lower learning rate: Use smaller learning rate when fine-tuning
  4. Data augmentation: Augment small datasets
  5. Monitor overfitting: Use validation set to monitor model performance
  6. Progressive unfreezing: Adopt progressive unfreezing strategy
  7. Learning rate scheduling: Use learning rate decay strategy
  8. Mixed precision training: Accelerate training process

Transfer Learning Application Scenarios

1. Medical Image Diagnosis

python
# Use pre-trained ImageNet model for medical image classification base_model = applications.DenseNet121(weights='imagenet', include_top=False) # Add medical image specific classification head

2. Object Detection

python
# Use pre-trained backbone for object detection backbone = applications.ResNet50(weights='imagenet', include_top=False) # Add detection head (e.g., Faster R-CNN, YOLO, etc.)

3. Semantic Segmentation

python
# Use pre-trained model for image segmentation base_model = applications.MobileNetV2(weights='imagenet', include_top=False) # Add segmentation head (e.g., U-Net, DeepLabV3+, etc.)

4. Text Classification

python
# Use pre-trained BERT model for text classification bert = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4") # Add classification layer

5. Sentiment Analysis

python
# Use pre-trained text model for sentiment analysis use = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4") # Add sentiment classification layer

Summary

TensorFlow provides rich transfer learning tools and pre-trained models:

  • TensorFlow Hub: Provides numerous pre-trained models
  • Keras Applications: Built-in classic pre-trained models
  • Flexible fine-tuning strategies: Supports various fine-tuning methods
  • Wide range of application scenarios: Images, text, audio, etc.

Mastering transfer learning techniques will help you quickly build high-performance deep learning models.

标签:Tensorflow