Transfer Learning with Fine-Tuning in NLP
Last Updated :
20 May, 2025
Natural Language Processing (NLP) has transformed models like BERT which can understand language context deeply by looking at words both before and after a target word. While BERT is pre-trained on vast amounts of general text making it adapt it to specific tasks like sentiment analysis that requires fine tuning. This process customizes BERT’s knowledge to perform well on domain-specific data while saving time and computational effort compared to training a model from scratch.
Using Hugging Face’s transformers library, we will fine tune a pre-trained BERT model for binary sentiment classification using transfer learning.
Why we Fine Tune a Model like BERT?
Fine tuning uses BERT’s pre-trained knowledge and adapts it to a target task by retraining on a smaller, labeled dataset. This process:
- Saves computational resources compared to training a model from scratch.
- Improves model performance by making it more task-specific.
- Enhances the model’s ability to generalize to unseen data.
Fine-Tuning BERT Model for Sentiment Analysis using Transfer Learning
1. Installing and Importing Required Libraries
First, we will install the Hugging Face transformers library. The transformers library from Hugging Face provides pre-trained models and tokenizers.
!pip install transformers
We are importing PyTorch for tensor operations and model training.
- DataLoader and TensorDataset help load and batch data efficiently during training.
- torch.nn.functional provides functions like softmax for calculating prediction probabilities.
- AdamW is the optimizer suited for transformer models.
- BertTokenizer converts text into tokens that BERT can understand.
- BertForSequenceClassification loads the BERT model adapted for classification tasks.
Python
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
2. Loading the Pre-Trained BERT Model and Tokenizer
We load the bert-base-uncased model and its tokenizer. The tokenizer converts raw text into input IDs and attention masks, which BERT requires.
- BertTokenizer.from_pretrained(): Prepares text for BERT input.
- BertForSequenceClassification.from_pretrained(): Loads BERT configured for classification tasks with two output labels like positive and negative.
Python
pretrained_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name,
num_labels=2)
Move the model to GPU if available:
Python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
3. Preparing the Training Dataset
We create a small labeled dataset for sentiment analysis. Here:
- 1 represents positive sentiment
- 0 represent negative sentiment.
Python
train_texts = [
"I love this product, it's amazing!", # Positive
"Absolutely fantastic experience, will buy again!", # Positive
"Worst purchase ever. Completely useless.", # Negative
"I hate this item, it doesn't work!", # Negative
"The quality is top-notch, highly recommend!", # Positive
"Terrible service, never coming back.", # Negative
"This is the best thing I've ever bought!", # Positive
"Very disappointing. Waste of money.", # Negative
"Superb! Exceeded all my expectations.", # Positive
"Not worth the price at all.", # Negative
]
train_labels = torch.tensor([1, 1, 0, 0, 1, 0, 1, 0, 1, 0]).to(device)
4. Tokenizing the Dataset
The tokenizer processes text into fixed-length sequences, adding padding and truncation as needed.
- padding=True: Ensures all input sequences have the same length.
- truncation=True: Shortens long sentences beyond max_length=128.
Python
encoded_train = tokenizer(train_texts,
padding=True,
truncation=True,
max_length=128,
return_tensors='pt')
train_input_ids = encoded_train['input_ids'].to(device)
train_attention_masks = encoded_train['attention_mask'].to(device)
5. Creating a DataLoader for Efficient Training
Data is wrapped in TensorDataset
and loaded into DataLoader
to enable mini-batch training which improves training efficiency and stability.
- TensorDataset(): Combines input IDs, attention masks and labels into a dataset.
- DataLoader(): Loads data in mini-batches to improve efficiency.
Python
train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
6. Defining the Optimizer
We use the AdamW optimizer which works well with transformer models.
Python
optimizer = AdamW(model.parameters(), lr=2e-5)
7. Training the model
The training loop iterates over batches, computing loss and gradients, updating model weights and tracking accuracy.
- optimizer.zero_grad(): Clears gradients before each batch.
- model(...): Runs forward pass and calculates loss.
- loss.backward(): Backpropagates the error.
- optimizer.step(): Updates model weights based on gradients.
- torch.argmax(F.softmax(...)): Determines predicted class.
Python
epochs = 5
model.train()
for epoch in range(epochs):
total_loss = 0
correct = 0
total = 0
for batch in train_loader:
batch_input_ids, batch_attention_masks, batch_labels = batch
optimizer.zero_grad()
outputs = model(input_ids=batch_input_ids,
attention_mask=batch_attention_masks,
labels=batch_labels)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
loss.backward()
optimizer.step()
preds = torch.argmax(F.softmax(logits, dim=1), dim=1)
correct += (preds == batch_labels).sum().item()
total += batch_labels.size(0)
avg_loss = total_loss / len(train_loader)
accuracy = correct / total * 100
print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
Output:
Training the model7. Saving and Loading the Fine-Tuned Model
We can save the model using torch.save() function. The model’s state dictionary is saved and can be reloaded later for inference or further training.
Saving the model:
Python
torch.save(model.state_dict(), "fine_tuned_bert.pth")
Loading the fine-tuned model:
Python
model.load_state_dict(torch.load("fine_tuned_bert.pth"))
model.to(device)
8. Creating test data for Evaluation
We prepare a test dataset and run the fine-tuned model to measure accuracy and make predictions on new text samples.
- torch.tensor([...]).to(device) converts the label list into a tensor and moves it to the computing device like CPU or GPU.
- tokenizer(test_texts, padding=True, truncation=True, max_length=128, return_tensors='pt') tokenizes the test texts.
- encoded_test['input_ids'] extracts token IDs representing each word or subword.
- encoded_test['attention_mask'] extracts attention masks indicating which tokens should be attended to (1) and which are padding (0).
Python
test_texts = [
"This is a great product, I love it!", # Positive
"Horrible experience, I want a refund!", # Negative
"Highly recommended! Five stars.", # Positive
"Not worth it. I regret buying this.", # Negative
]
test_labels = torch.tensor([1, 0, 1, 0]).to(device)
encoded_test = tokenizer(test_texts,
padding=True,
truncation=True,
max_length=128,
return_tensors='pt')
test_input_ids = encoded_test['input_ids'].to(device)
test_attention_masks = encoded_test['attention_mask'].to(device)
9. Making Predictions and Evaluating Performance
We set the model to evaluation mode using model.eval() to disable training-specific layers like dropout. Accuracy is calculated by comparing predicted labels to true labels and computing the percentage correct. Each test text and its predicted label are printed in a loop for review.
- torch.no_grad() disables gradient calculations for faster and more memory-efficient inference.
- The model processes inputs with model(input_ids=..., attention_mask=...) to produce output logits.
- torch.argmax(outputs.logits, dim=1) selects the class with the highest score as the prediction.
Python
model.eval()
with torch.no_grad():
outputs = model(input_ids=test_input_ids,
attention_mask=test_attention_masks)
predicted_labels = torch.argmax(outputs.logits, dim=1)
test_accuracy = (predicted_labels == test_labels).sum().item() / len(test_labels) * 100
print(f"\nTest Accuracy: {test_accuracy:.2f}%")
for text, label in zip(test_texts, predicted_labels):
print(f'Text: {text}\nPredicted Label: {label.item()}\n')
Output:
predictions by the modelHere we can see that our model is working fine.
You can download the source code from here : Transfer Learning with Fine-Tuning in NLP.
Similar Reads
Non-linear Components In electrical circuits, Non-linear Components are electronic devices that need an external power source to operate actively. Non-Linear Components are those that are changed with respect to the voltage and current. Elements that do not follow ohm's law are called Non-linear Components. Non-linear Co
11 min read
Spring Boot Tutorial Spring Boot is a Java framework that makes it easier to create and run Java applications. It simplifies the configuration and setup process, allowing developers to focus more on writing code for their applications. This Spring Boot Tutorial is a comprehensive guide that covers both basic and advance
10 min read
Class Diagram | Unified Modeling Language (UML) A UML class diagram is a visual tool that represents the structure of a system by showing its classes, attributes, methods, and the relationships between them. It helps everyone involved in a projectâlike developers and designersâunderstand how the system is organized and how its components interact
12 min read
3-Phase Inverter An inverter is a fundamental electrical device designed primarily for the conversion of direct current into alternating current . This versatile device , also known as a variable frequency drive , plays a vital role in a wide range of applications , including variable frequency drives and high power
13 min read
Backpropagation in Neural Network Back Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and
9 min read
What is Vacuum Circuit Breaker? A vacuum circuit breaker is a type of breaker that utilizes a vacuum as the medium to extinguish electrical arcs. Within this circuit breaker, there is a vacuum interrupter that houses the stationary and mobile contacts in a permanently sealed enclosure. When the contacts are separated in a high vac
13 min read
Polymorphism in Java Polymorphism in Java is one of the core concepts in object-oriented programming (OOP) that allows objects to behave differently based on their specific class type. The word polymorphism means having many forms, and it comes from the Greek words poly (many) and morph (forms), this means one entity ca
7 min read
CTE in SQL In SQL, a Common Table Expression (CTE) is an essential tool for simplifying complex queries and making them more readable. By defining temporary result sets that can be referenced multiple times, a CTE in SQL allows developers to break down complicated logic into manageable parts. CTEs help with hi
6 min read
Python Variables In Python, variables are used to store data that can be referenced and manipulated during program execution. A variable is essentially a name that is assigned to a value. Unlike many other programming languages, Python variables do not require explicit declaration of type. The type of the variable i
6 min read
Spring Boot Interview Questions and Answers Spring Boot is a Java-based framework used to develop stand-alone, production-ready applications with minimal configuration. Introduced by Pivotal in 2014, it simplifies the development of Spring applications by offering embedded servers, auto-configuration, and fast startup. Many top companies, inc
15+ min read