Open In App

Fine Tuning Large Language Model (LLM)

Last Updated : 10 Dec, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Large Language Models (LLMs) have dramatically transformed natural language processing (NLP), excelling in tasks like text generation, translation, summarization, and question-answering. However, these models may not always be ideal for specific domains or tasks.

To address this, fine-tuning is performed. Fine-tuning customizes pre-trained LLMs to better suit specialized applications by refining the model on smaller, task-specific datasets. This allows the model to enhance its performance while retaining its broad language proficiency.

Fine-Tuning in Large Language Models (LLMs)

Fine-tuning refers to the process of taking a pre-trained model and adapting it to a specific task by training it further on a smaller, domain-specific dataset. Fine tuning is a form of transfer learning that refines the model’s capabilities, improving its accuracy in specialized tasks without needing a massive dataset or expensive computational resources.

Fine-tuning allows us to:

  • Steer the model towards performing optimally on particular tasks.
  • Ensure model outputs align with expected results for real-world applications.
  • Reduce model hallucinations and improve output relevance and honesty.
Fine-Tuning-Large-Language-Models

How is fine-tuning performed?

The general fine-tuning process can be broken down into the following steps:

  1. Select Base Model: Choose a pre-trained model based on your task and compute budget.
  2. Choose Fine-Tuning Method: Select the most appropriate method (e.g., supervised, instruction-based, PEFT) based on the task and dataset.
  3. Prepare Dataset: Structure your data for task-specific training, ensuring the format matches the model's requirements.
  4. Training: Use frameworks like TensorFlow, PyTorch, or high-level libraries like Transformers to fine-tune the model.
  5. Evaluate and Iterate: Test the model, refine it as necessary, and re-train to improve performance.

Implementation: Fine Tuning Large Language Model using DialogSum Database

Let us fine tune a model using PEFT LoRa Method. We will use flan-t5-base model and DialogSum database.

  • Flan-T5 is the instruction fine-tuned version of T5 release by Google.
  • DialogSum is a large-scale dialogue summarization dataset, consisting of 13,460 (Plus 100 holdout data for topic generation) dialogues with corresponding manually labeled summaries and topics.

Step 1: Install Necessary Libraries

The following commands install the required libraries for the task, including Hugging Face Transformers, Datasets, and PEFT (Parameter-Efficient Fine-Tuning). These libraries enable model loading, training, and fine-tuning.

!pip install datasets
!pip install transformers
!pip install evaluate
!pip install accelerate -U
!pip install transformers[torch]
!pip install peft

Step 2: Set Up Environment

Configure the device for computation, using GPU if available. Import all necessary libraries for dataset handling, model loading, tokenization, and evaluation.

Python
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, GenerationConfig
import evaluate
import pandas as pd
import numpy as np

Step 3: Load Dataset

Load the Hugging Face dataset for dialogue summarization. In this example, we use the "knkarthick/dialogsum" dataset.

Python
huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)


Step 4: Load Pre-trained Model and Tokenizer

Use a pre-trained T5 model (google/flan-t5-base) for sequence-to-sequence learning and initialize its tokenizer.

Python
model_name = "google/flan-t5-base"
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Step 5: Check Trainable Parameters

Define a function to calculate and print the percentage of trainable parameters in the model.

Python
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(base_model))

Output:

trainable model parameters: 247577856
all model parameters: 247577856
percentage of trainable model parameters: 100.00%

Step 6: Perform Baseline Inference

Test the pre-trained model on a sample from the test set to evaluate its performance before fine-tuning.

Python
i = 20
dialogue = dataset['test'][i]['dialogue']
summary = dataset['test'][i]['summary']

prompt = f"Summarize the following dialogue  {dialogue}  Summary:"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = tokenizer.decode(base_model.generate(input_ids, max_new_tokens=200)[0], skip_special_tokens=True)

print(f"Input Prompt : {prompt}")
print("--------------------------------------------------------------------")
print("Human evaluated summary ---->")
print(summary)
print("---------------------------------------------------------------------")
print("Baseline model generated summary : ---->")
print(output)

Output:

Input Prompt : Summarize the following dialogue  #Person1#: What's wrong with you? Why are you scratching so much?
#Person2#: I feel itchy! I can't stand it anymore! I think I may be coming down with something. I feel lightheaded and weak.
#Person1#: Let me have a look. Whoa! Get away from me!
#Person2#: What's wrong?
#Person1#: I think you have chicken pox! You are contagious! Get away! Don't breathe on me!
#Person2#: Maybe it's just a rash or an allergy! We can't be sure until I see a doctor.
#Person1#: Well in the meantime you are a biohazard! I didn't get it when I was a kid and I've heard that you can even die if you get it as an adult!
#Person2#: Are you serious? You always blow things out of proportion. In any case, I think I'll go take an oatmeal bath. Summary:
--------------------------------------------------------------------
Human evaluated summary ---->
#Person1# thinks #Person2# has chicken pox and warns #Person2# about the possible hazards but #Person2# thinks it will be fine.
---------------------------------------------------------------------
Baseline model generated summary : ---->
Person1 is scratching so much that he can't stand it anymore.

Step 7: Tokenize Dataset

Tokenize the dataset to prepare it for training. The function generates input and label IDs, truncating or padding them to a fixed length.

Python
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    return example

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary'])
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

Step 8: Apply PEFT with LoRA Configuration

Use PEFT (Parameter-Efficient Fine-Tuning) to minimize training time and resource usage by tuning only specific layers.

Python
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

peft_model_train = get_peft_model(base_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model_train))

Output:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%

Step 9: Define Training Arguments

Set up training configurations, including batch size, learning rate, and the number of epochs.

Python
output_dir = "./peft-dialogue-summary-training"

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3,
    num_train_epochs=5,
)

Step 10: Train the Model

Use Hugging Face Trainer API to train the PEFT-enabled model.

Python
peft_trainer = Trainer(
    model=peft_model_train,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
)

peft_trainer.train()

Output:

TrainOutput(global_step=160, training_loss=3.586883544921875, 
metrics={'train_runtime': 150.5997,
'train_samples_per_second': 4.15,
'train_steps_per_second': 1.062,
'total_flos': 434768117760000.0,
'train_loss': 3.586883544921875, 'epoch': 5.0})

Step 11: Save the Fine-Tuned Model

Save the trained PEFT model and tokenizer for future use.

Python
peft_model_path = "./peft-dialogue-summary-checkpoint-local"
peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

Output:

('./peft-dialogue-summary-checkpoint-local/tokenizer_config.json',
'./peft-dialogue-summary-checkpoint-local/special_tokens_map.json',
'./peft-dialogue-summary-checkpoint-local/spiece.model',
'./peft-dialogue-summary-checkpoint-local/added_tokens.json',
'./peft-dialogue-summary-checkpoint-local/tokenizer.json')

Step 12: Load and Test Fine-Tuned Model

Load the fine-tuned model and test its performance on the same input prompt.

Python
from peft import PeftModel

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
peft_model = PeftModel.from_pretrained(peft_model_base, peft_model_path, is_trainable=False)

peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(f"Input Prompt : {prompt}")
print("--------------------------------------------------------------------")
print("Human evaluated summary ---->")
print(summary)
print("---------------------------------------------------------------------")
print("Baseline model generated summary : ---->")
print(output)
print("---------------------------------------------------------------------")
print("Peft model generated summary : ---->")
print(peft_model_text_output)

Output:

Input Prompt : Summarize the following dialogue  #Person1#: What's wrong with you? Why are you scratching so much?
#Person2#: I feel itchy! I can't stand it anymore! I think I may be coming down with something. I feel lightheaded and weak.
#Person1#: Let me have a look. Whoa! Get away from me!
#Person2#: What's wrong?
#Person1#: I think you have chicken pox! You are contagious! Get away! Don't breathe on me!
#Person2#: Maybe it's just a rash or an allergy! We can't be sure until I see a doctor.
#Person1#: Well in the meantime you are a biohazard! I didn't get it when I was a kid and I've heard that you can even die if you get it as an adult!
#Person2#: Are you serious? You always blow things out of proportion. In any case, I think I'll go take an oatmeal bath. Summary:
--------------------------------------------------------------------
Human evaluated summary ---->
#Person1# thinks #Person2# has chicken pox and warns #Person2# about the possible hazards but #Person2# thinks it will be fine.
---------------------------------------------------------------------
Baseline model generated summary : ---->
Person1 is scratching so much that he can't stand it anymore.
---------------------------------------------------------------------
Peft model generated summary : ---->
#Person2# is scratching so much. #Person2# thinks he has chicken pox.

Complete Code

You can download the source code from here.

Python
# Step 1: Install Necessary Libraries
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install accelerate -U
!pip install transformers[torch]
!pip install peft

# Step 2: Import Libraries
import torch
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, GenerationConfig
import pandas as pd
import numpy as np
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

# Step 3: Configure Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Step 4: Load Dataset
huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

# Step 5: Load Pre-trained Model and Tokenizer
model_name = "google/flan-t5-base"
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Step 6: Define Function to Count Trainable Parameters
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(base_model))

# Step 7: Perform Baseline Inference
i = 20
dialogue = dataset['test'][i]['dialogue']
summary = dataset['test'][i]['summary']

prompt = f"Summarize the following dialogue  {dialogue}  Summary:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = tokenizer.decode(base_model.generate(input_ids, max_new_tokens=200)[0], skip_special_tokens=True)

print(f"Input Prompt : {prompt}")
print("--------------------------------------------------------------------")
print("Human evaluated summary ---->")
print(summary)
print("---------------------------------------------------------------------")
print("Baseline model generated summary : ---->")
print(output)

# Step 8: Tokenize Dataset
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    return example

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary'])
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

# Step 9: Apply PEFT with LoRA Configuration
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

peft_model_train = get_peft_model(base_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model_train))

# Step 10: Define Training Arguments
output_dir = "./peft-dialogue-summary-training"
peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3,
    num_train_epochs=5,
)

# Step 11: Train the Model
peft_trainer = Trainer(
    model=peft_model_train,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
)
peft_trainer.train()

# Step 12: Save the Fine-Tuned Model
peft_model_path = "./peft-dialogue-summary-checkpoint-local"
peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

# Step 13: Load and Test Fine-Tuned Model
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
peft_model = PeftModel.from_pretrained(peft_model_base, peft_model_path, is_trainable=False)

peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(f"Input Prompt : {prompt}")
print("--------------------------------------------------------------------")
print("Human evaluated summary ---->")
print(summary)
print("---------------------------------------------------------------------")
print("Baseline model generated summary : ---->")
print(output)
print("---------------------------------------------------------------------")
print("Peft model generated summary : ---->")
print(peft_model_text_output)

Types of Fine Tuning Methods

Types-of-Fine-Tuning-Methods

1. Supervised Fine-Tuning

Supervised fine-tuning involves further training a pre-trained model using a task-specific dataset with labeled input-output pairs. This process allows the model to learn how to map inputs to outputs based on the given dataset.

Process:

  1. Use a pre-trained model.
  2. Prepare a dataset with input-output pairs as expected by the model.
  3. Adjust the pre-trained weights during fine-tuning to adapt the model to the new task.

Supervised fine-tuning is ideal for tasks such as sentiment analysis, text classification, and named entity recognition where labeled datasets are available.

2. Instruction Fine-Tuning

Instruction fine-tuning augments input-output examples with detailed instructions in the prompt template. This allows the model to generalize better to new tasks, especially those involving natural language instructions.

Process:

  • Use a pre-trained model.
  • Prepare a dataset in the form of instruction-response pairs.
  • Train the model with the instruction fine-tuning process, similar to neural network training.

Instruction fine-tuning is commonly used in building chatbots, question answering systems, and other tasks that require natural language interaction.

3. Parameter-Efficient Fine-Tuning (PEFT)

Training a full model is resource-intensive. PEFT methods enable the efficient use of memory and computation by modifying only a subset of the model's parameters, significantly reducing the required memory for training.

PEFT Methods:

  1. Selective Method: Freeze most layers of the model and only fine-tune specific layers.
  2. Reparameterization Method (LoRA): Use low-rank matrices to reparameterize model weights, freezing the original weights and adding small, trainable parameters.
    Example: If a model has a dimension of 512x64, full fine-tuning would require 32,768 parameters. With LoRA, the number of parameters can be reduced to 4,608.
  3. Additive Method: Add new layers to the encoder or decoder side of the model and train these for the specific task.
  4. Soft Prompting: Train only the new tokens added to the model prompt, keeping other tokens and weights frozen.

PEFT is useful when working with large models that exceed memory limits, reducing both training costs and resource requirements.

4. Reinforcement Learning with Human Feedback (RLHF)

RLHF aligns a fine-tuned model's output to human preferences using reinforcement learning. This method refines model behavior after the initial fine-tuning phase.

Process:

  1. Prepare Dataset: Generate prompt-completion pairs and rank them based on human evaluators' alignment criteria.
  2. Train Reward Model: Build a reward model that scores completions based on human feedback.
  3. Update Model: Use reinforcement learning, typically the PPO algorithm, to update the model weights based on the reward model.

RLHF is ideal for tasks where human-like outputs are necessary, such as generating text that aligns with user expectations or ethical guidelines.

Prompt Engineering vs RAG vs Fine tuning.

Let us explore the difference between prompt engineering, RAG, and fine-tuning.

Criteria

Prompt Engineering

RAG

Fine-Tuning

Purpose

Prompt engineering focuses on how to write an effective prompt that can maximize the generation of an optimized output for a given task.

The purpose of RAG is to relevant information for a given prompt from an external database.

Fine-tuning focuses on training and adapting a model for a specific task.

Model

Model weights are not updated. It focuses on building an effective prompt.

Model weights are not updated. It focuses on building context for a given prompt.

Model weights are updated

Complexity

No technical knowledge required

Compared to fine-tuning it is less complex as it requires skills related to vector databases and retrieval mechanisms only

Technical knowledge required

Compute Cost

Very less cost. Only costs related to API calls

Cost-effective compared to fine-tuning.

We may need specialized hardware to train the model depending on model size and dataset size

Knowledge

The model does not learn new data

The prompt is equipped with new data in the form of context

The model learns new data

Benefits of Fine Tuning LLMs

Fine-tuning offers several advantages:

  • Increased Performance: Fine-tuned models adapt to new data, leading to more accurate and reliable outputs.
  • Efficiency: Fine-tuning saves computational costs by adapting pre-trained models rather than training a model from scratch.
  • Domain Adaptation: LLMs can be tailored to specific industries like medical, legal, or financial domains by focusing on relevant terminology and structures.
  • Better Generalization: Models fine-tuned on task-specific data generalize better to the unique patterns and structures of the task.

When to use fine-tuning?

When we build an LLM application the first step is to select an appropriate pre-trained or foundation model suitable for our use case. Once the base model is selected we should try prompt engineering to quickly see whether the model fits our use case realistically or not and evaluate the performance of the base model on our use case.

In case with prompt engineering we are not able to achieve a reasonable level of performance we should proceed with fine-tuning. Fine-tuning should be done when we want the model to specialize for a particular task or set of tasks and have a labeled unbiased diverse dataset available. It is also advisable to do fine-tuning for domain-specific adoption like learning medical law or finance language.

Conclusion

In this article, we got an overview of various fine-tuning methods available, the benefits of fine-tuning, evaluation criteria for fine-tuning, and how fine-tuning is generally performed. We then saw python implementation of LoRa training.


Similar Reads