Understanding transfer learning in deep learning
David Weedmark2022-11-17 | 8 min read
For data science teams working with inadequate data or too much data and not enough time or resources to process it, transfer learning can represent a significant shortcut in machine learning model development. Most often associated with deep learning and neural networks, it’s also being used in place of traditional machine learning model training techniques as a method of accelerating development.
What Is transfer learning?
Transfer learning is a machine learning technique that reuses a completed model that was developed for one task as the starting point for a new model to accomplish a new task. The knowledge used by the first model is thus transferred to the second model. The phrase “transfer learning” comes from human psychology, wherein a person who knows how to play the piano, for example, can more easily learn to play a violin, compared to someone with no experience at all.
Subsets of transfer learning
The use of transfer learning depends on three factors: what needs to be transferred, how it should be transferred and when it should be transferred. Because the source dataset and target data in the transfer learning settings can vary in either their domains or their tasks, there are three different subsets of transfer learning.
- Inductive learning: Tasks are different, regardless of any similarities between the source and target domains.
- Transductive learning: Tasks are the same, but the feature spaces or the marginal probability distributions between domains are different.
- Unsupervised learning: When no labeled data can be used for training.
Why Is transfer learning used?
Transfer learning can speed up progress and improve performance when training a new model, so it’s primarily used whenever time is a factor or the resources required for training a model are large. For these reasons, you’ll often see transfer learning used in deep learning projects, including neural networks for solving natural language processing (NLP) or computer vision (CV) tasks. Transfer learning is also used when concept drift could become a problem or when multi-task learning is required from the second model.
Another time transfer learning is used is when the available training data is insufficient. In these situations, the weights from the pre-trained model can be used to initialize the weights of the new model.
Transfer learning is only successful when the features learned by the first model on its first task are generalized and can be transferred to the second task. For the same reason, the dataset used in the second training needs to be similar to that used in the first training.
Transfer learning in machine learning vs. deep learning
Transfer learning is a growing trend in deep learning models, as well as an alternative when traditional machine learning used to be used. Traditional machine learning models are usually designed to perform specific tasks and are trained using datasets that are tailored for the model’s needs. Once they are trained and tested, they are put into environments that mirror the training environment to fulfill the purpose they were trained for. This can be described as an isolated approach to machine learning.
Where transfer learning differs in machine learning is that isolation is not encouraged. The goal is to leverage knowledge from pre-trained models to train subsequent models. The result of this is a daisy-chain of learning, often resulting in faster and more efficient model development. To use a human analogy, consider math and science teachers in high school educating future scientists, physicians, dentists and data scientists.
In the case of deep transfer learning, used for complex models in areas like natural language processing (NLP) and computer vision, the tasks required of models may be singular, however, the training dataset may be insufficient. There may be too little data, or the data may be insufficiently labeled. When the prospect of acquiring or configuring a quality dataset requires too much time, deep transfer learning may be faster and more efficient.
Transfer learning approaches
The use of transfer learning can be accomplished using a few different approaches:
Train a model on similar domains
This approach involves training a model on similar domains. Suppose, for example, you need to solve problem A, but you don’t have enough data. Problem B is similar to problem B, and you have plenty of data for that problem. You can therefore train a model on Problem B, then use that successful model to bootstrap a new model to work on Problem A.
Feature extraction
A transfer learning approach involves feature extraction. In this case, a data science team would train a deep neural network to be used as an automatic feature extractor. Once applied to the pre-trained model, its representations can then be exported into a new model.
Develop pre-trained models
Finally, this approach involves developing pre-trained models with transfer learning in mind. Organizations with a rich background in developing models may often have a library of models to be used. When a new action is required or a problem needs to be solved, a pre-trained model can be taken, tuned to address the problem at hand, and then used to train a new model.
Transfer learning process
Regardless of the approach taken, the process outlined below is a high-level overview of how transfer learning is developed and implemented:
- Obtain a pre-trained model: These may be available from your organization’s own library of models or from another repository, such as those at PyTorch Hub or TensorFlow Hub.
- Freeze layers: This is required to prevent the weights in the model to be re-initialized. If they are re-initialized, the model will lose all of the learning it came with.
- Train new layers: New layers need to be added to the model to turn its features into new predictions on the new dataset.
- Improve model with fine-tuning: Not always required, fine-tuning the base model can improve model performance. This involves unfreezing portions of the base model and then training it again at a low learning rate on a new dataset.
Transfer learning in action
There are multitudes of applications of transfer learning used today, with multitudes more on the horizon. One common scenario where transfer learning is used is with predictive modeling using image data as the data input. Examples of pre-trained models for image-based dataset problems include the Oxford VGG Model, the Google Inception Model and the Microsoft ResNet Model.
Another popular area for transfer learning is when the model needs to work with language data, including natural language processing using text. Examples of pre-trained models for language-based data problems include Google’s word2vec Model, as well as Stanford’s GloVe Model.
Developing and implementing transfer learning
While transfer learning can reduce processing requirements and cut down on development time, this means very little if algorithms, pre-trained models, and datasets aren’t easily accessible and well-documented, with the governance requirements model-driven organizations require.
If you are ready to get started with transfer learning, see our hands-on tutorial for transfer learning using Keras and HuggingFace.
David Weedmark is a published author who has worked as a project manager, software developer and as a network security consultant.
RELATED TAGS