Small Language Models, Big Impact: Fine-Tuning DistilGPT-2 for Medical Queries

Nilesh Dwivedi Last Updated : 23 Nov, 2024
12 min read

Language models have transformed how we interact with data, enabling applications like chatbots, sentiment analysis, and even automated content generation. However, most discussions revolve around large-scale models like GPT-3 or GPT-4, which require significant computational resources and vast datasets. While these models are powerful, they are not always practical for domain-specific tasks or deployment in resource-constrained environments. This is where small language models come into play.

This blog will walk you through the process of training a small language model using the Dataset from Hugging Face, focusing on creating a tailored model for predicting diseases based on symptoms.

Small Language Models, Big Impact: Fine-Tuning DistilGPT-2 for Medical Queries

Learning Objectives

  • Understand how small language models balance efficiency and performance.
  • Learn to fine-tune pre-trained models for domain-specific tasks.
  • Develop skills to preprocess and manage datasets effectively.
  • Master training loops and validation techniques for model evaluation.
  • Adapt and test small models for practical, real-world use cases.

What is a Small Language Model?

A small language model refers to a scaled-down version of large models, optimized to balance performance and efficiency. Examples include DistilGPT-2, ALBERT, and DistilBERT.

These models:

  • Require fewer computational resources.
  • Can be fine-tuned on smaller, domain-specific datasets.
  • Are ideal for applications that prioritize speed and efficiency over handling extensive general-purpose queries.

Why Use a Small Language Model?

  • Efficiency: They run faster and can be trained on GPUs or even powerful CPUs.
  • Domain-Specific Training: Easier to adapt for specialized tasks, such as medical diagnosis or customer service.
  • Cost-Effective Deployment: Require less memory and processing power for real-time applications.
  • Explainability: Smaller architectures are often easier to debug and interpret.

In this tutorial, we will demonstrate how to fine-tune a small language model, specifically DistilGPT-2, to handle a medical task: predicting diseases based on symptoms using the Symptoms and Disease Dataset from Hugging Face. By the end, you’ll understand how small language models can be applied effectively to solve real-world problems in a focused manner.

Overview of the Dataset: Symptoms and Diseases

The Symptoms and Disease Dataset provides mappings of medical instructions or symptom descriptions to their corresponding diseases. This dataset is well-suited for training models to predict diseases or answer medical queries based on symptom descriptions.

Dataset Highlights

  • Input: Symptom-based questions or instructions.
  • Output: The corresponding disease diagnosis.

Example Entries:

Instruction Disease
What are the symptoms of hypertensive disease? The following are the symptoms of hypertensive disease: pain chest, shortness of breath, dizziness, asthenia, fall, syncope, vertigo, sweating increased, palpitation, nausea, angina pectoris, pressure chest
What are the symptoms of diabetes? The following are the symptoms of diabetes: polyuria, polydypsia, shortness of breath, pain chest, asthenia, nausea, orthopnea, rale, sweating increased, unresponsiveness, mental status changes, vertigo, vomiting, labored breathing

This structured dataset enables a small language model to learn relationships between symptoms and diseases effectively.

Building a Small Language Model with DistilGPT-2

This guide provides a practical demonstration of training a small language model using DistilGPT-2 for predicting diseases based on symptoms. Below is the step-by-step explanation of the code with implementation details.

Let’s dive into the steps.

Step1: Install Required Libraries

Ensure you have the necessary libraries installed:

!pip install torch torchtext transformers sentencepiece pandas tqdm datasets
  • torch: Core library for deep learning in Python, used for model training.
  • torchtext: Provides data processing utilities for natural language processing (NLP).
  • transformers: Hugging Face library for using pre-trained language models like GPT-2.
  • sentencepiece: Tokenizer for handling text preprocessing.
  • pandas: For handling tabular data.
  • tqdm: Adds progress bars to loops.
  • datasets: Library for accessing datasets like Hugging Face’s medical datasets.

Step2 : Importing Necessary Libraries

The following libraries are imported to set up the environment for training a small language model:

from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
import ast
import datasets
from tqdm import tqdm
import time
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

Step3 : Load and Explore the Dataset

We’ll use the Symptoms and Disease Dataset from Hugging Face and convert it into a format suitable for training.

# Load the dataset
dataset = load_dataset("prognosis/symptoms_disease_v1")

dataset

# Convert to a pandas dataframe
updated_data = [{'Input': item['instruction'], 'Disease': item['output']} for item in dataset['train']]
df = pd.DataFrame(updated_data)

df.head(5)
  • Input: Represents the symptom description or medical query.
  • Disease: Corresponding disease diagnosis.
Small Language Models, Big Impact: Fine-Tuning DistilGPT-2 for Medical Queries

Step4 : Select the Device for Model Training

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    # If Apple Silicon, set to 'mps' - otherwise 'cpu' (not advised)
    try:
        device = torch.device('mps')
    except Exception:
        device = torch.device('cpu')

Device Selection:

  • Checks if an NVIDIA GPU is available via torch.cuda.is_available().
  • If a GPU is present, the device is set to cuda, enabling GPU acceleration.
  • If GPU is unavailable but running on Apple Silicon (e.g., M1/M2 chip), the code tries to use the Metal Performance Shaders (MPS) backend with torch.device(‘mps’).
  • If neither GPU nor MPS is available, it defaults to the CPU. Note: CPU is much slower for deep learning tasks.

Step 5: Load the Tokenizer and Pre-trained Model

# The tokenizer turns texts to numbers (and vice-versa)
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

# The transformer
model = GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)

model

Tokenizer

The GPT2Tokenizer from Hugging Face is loaded using from_pretrained(‘distilgpt2’). This tokenizer:

  • Converts input text into numerical tokens for the model to process.
  • Converts model outputs back into human-readable text.
  • Ensures the tokenization logic matches the pre-trained DistilGPT-2 model.

Model

The DistilGPT-2 language model is loaded with GPT2LMHeadModel.from_pretrained(‘distilgpt2’). This is a smaller, efficient version of GPT-2 designed for language tasks like text generation. The model is moved to the appropriate hardware device (GPU, MPS, or CPU) for efficient computation.

model

Step6 : Dataset Preparation and Custom Dataset Class Definition

The LanguageDataset class is designed to:

  • Simplify the ingestion of data from a pandas DataFrame.
  • Tokenize and encode the data in a format compatible with the model.
  • Ensure efficient data preparation for training loops.
# Dataset Prep
class LanguageDataset(Dataset):
    """
    An extension of the Dataset object to:
      - Make training loop cleaner
      - Make ingestion easier from pandas df's
    """
    def __init__(self, df, tokenizer):
        self.labels = df.columns
        self.data = df.to_dict(orient='records')
        self.tokenizer = tokenizer
        x = self.fittest_max_length(df)  # Fix here
        self.max_length = x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][self.labels[0]]
        y = self.data[idx][self.labels[1]]
        text = f"{x} | {y}"
        tokens = self.tokenizer.encode_plus(text, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
        return tokens

    def fittest_max_length(self, df):  # Fix here
        """
        Smallest power of two larger than the longest term in the data set.
        Important to set up max length to speed training time.
        """
        max_length = max(len(max(df[self.labels[0]], key=len)), len(max(df[self.labels[1]], key=len)))
        x = 2
        while x < max_length: x = x * 2
        return x

# Cast the Huggingface data set as a LanguageDataset we defined above
data_sample = LanguageDataset(df, tokenizer)

Key Benefits

  • Modular Design: The custom dataset class makes the training loop cleaner and modular.
  • Tokenization Efficiency: Handles tokenization, padding, and truncation seamlessly.
  • Optimized Length: Ensures all sequences fit within the model’s expected input size.

This step defines and initializes a custom PyTorch Dataset to handle the tokenization and formatting of a text-based dataset, preparing it for training with DistilGPT-2. It simplifies ingestion, ensures consistency in input size, and is tailored for efficient processing by the model.

Step6 : Dataset Preparation and Custom Dataset Class Definition

Step7: Dataset into Training and Validation Sets

train_size = int(0.8 * len(data_sample))
valid_size = len(data_sample) - train_size
train_data, valid_data = random_split(data_sample, [train_size, valid_size])

Divides the dataset into two subsets:

  • Training Set (80%): Used to train the model by optimizing its parameters.
  • Validation Set (20%): Used to evaluate the model’s performance after each epoch without updating parameters.

Step8: Create Data Loaders

# Make the iterators
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)

DataLoaders feed data in manageable batches during training and validation.

train_loader:

  • Feeds data from the training set in batches.
  • shuffle=True: Randomizes the order of training data to prevent overfitting and ensure generalization.

valid_loader:

  • Feeds data from the validation set in batches.
  • No shuffling: Ensures consistent evaluation.
# Set the number of epochs
num_epochs = 2
# Model params
BATCH_SIZE = 8
# Training parameters
batch_size = BATCH_SIZE
model_name = 'distilgpt2'
gpu = 0

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

tokenizer.pad_token = tokenizer.eos_token
# Init a results dataframe
results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
                                'training_loss', 'validation_loss', 'epoch_duration_sec'])

Epochs and Batch Size:

  • Sets the number of epochs (2) for complete passes through the training data.
  • Defines batch size (8) for efficient data processing.

Model and GPU Tracking:

  • Tracks the model name (distilgpt2) and GPU usage for training.

Loss Function:

  • Uses CrossEntropyLoss to measure prediction errors while ignoring padding tokens.

Optimizer:

  • Configures the Adam optimizer with a learning rate of 5e-4 for weight updates.

Results Logging:

  • Initializes a DataFrame to store metrics like epoch duration, training loss, and validation loss.

This step sets up the key parameters, components, and tracking mechanisms required for the training process. It ensures the training loop is configured with appropriate values and prepares a structure for logging the results.

Step10: Training and Validation Loop

# The training loop
for epoch in range(num_epochs):
    start_time = time.time()  # Start the timer for the epoch

    # Training
    ## This line tells the model we're in 'learning mode'
    model.train()
    epoch_training_loss = 0
    train_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs} Batch Size: {batch_size}, Transformer: {model_name}")
    for batch in train_iterator:
        optimizer.zero_grad()
        inputs = batch['input_ids'].squeeze(1).to(device)
        targets = inputs.clone()
        outputs = model(input_ids=inputs, labels=targets)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_iterator.set_postfix({'Training Loss': loss.item()})
        epoch_training_loss += loss.item()
    avg_epoch_training_loss = epoch_training_loss / len(train_iterator)

    # Validation
    # Validation
    model.eval()
    epoch_validation_loss = 0
    total_loss = 0
    valid_iterator = tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
    with torch.no_grad():
        for batch in valid_iterator:
            inputs = batch['input_ids'].squeeze(1).to(device)
            targets = inputs.clone()
            outputs = model(input_ids=inputs, labels=targets)
            loss = outputs.loss
            total_loss += loss.item()  # Convert tensor to scalar
            valid_iterator.set_postfix({'Validation Loss': loss.item()})
            epoch_validation_loss += loss.item()

    avg_epoch_validation_loss = epoch_validation_loss / len(valid_loader)

    end_time = time.time()  # End the timer for the epoch
    epoch_duration_sec = end_time - start_time  # Calculate the duration in seconds

    new_row = {'transformer': model_name,
               'batch_size': batch_size,
               'gpu': gpu,
               'epoch': epoch+1,
               'training_loss': avg_epoch_training_loss,
               'validation_loss': avg_epoch_validation_loss,
               'epoch_duration_sec': epoch_duration_sec}  # Add epoch_duration to the dataframe

    results.loc[len(results)] = new_row
    print(f"Epoch: {epoch+1}, Validation Loss: {total_loss/len(valid_loader)}")

Epoch Timer:

  • Starts a timer at the beginning of each epoch to calculate its duration.

Training Phase:

  • Sets the model to training mode using model.train() to enable weight updates.
  • Iterates over batches from the train_loader:
    • Zeroes out gradients: optimizer.zero_grad().
    • Performs forward pass: Computes outputs by feeding inputs to the model.
    • Calculates loss: Measures how far predictions are from the targets.
    • Backpropagation: Updates gradients using loss.backward().
    • Optimizer step: Adjusts model weights to minimize the loss.

Validation Phase:

  • Sets the model to evaluation mode using model.eval() to disable weight updates and dropout layers.
  • Iterates over batches from the valid_loader:
    • Computes validation loss without backpropagation using torch.no_grad().
    • Tracks total validation loss to compute the average for the epoch.

Performance Logging:

  • Average Losses:
    • Computes the average training and validation losses for the epoch.
  • Result Tracking:
    • Logs the epoch number, average losses, GPU usage, and epoch duration in the results DataFrame.

Progress Display:

  • Uses tqdm to show real-time progress for both training and validation with metrics like loss for easy monitoring.

This step defines the core training and validation loop for the model, handling the forward pass, backpropagation, weight updates, and validation to evaluate model performance.

Training and Validation Loop: Fine-Tuning DistilGPT-2 for Medical Queries

Step11: Model Testing and Response Validation

# Define the input string
input_str = "What are the symptoms of Chicken pox?"

# Encode the input string with padding and attention mask
encoded_input = tokenizer.encode_plus(
    input_str,
    return_tensors='pt',
    padding=True,
    truncation=True,
    max_length=50  # Adjust max_length as needed
)

# Move tensors to the appropriate device
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)

# Set the pad_token_id to the tokenizer's eos_token_id
pad_token_id = tokenizer.eos_token_id

# Generate the output
output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_length=50,  # Adjust max_length as needed
    num_return_sequences=1,
    do_sample=True,
    top_k=8,
    top_p=0.95,
    temperature=0.5,
    repetition_penalty=1.2,
    pad_token_id=pad_token_id
)

# Decode and print the output
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded_output)
  • Input Query: A specific question is defined, e.g., “What are the symptoms of Chicken pox?”.
  • Tokenization: Converts the query into numerical tokens with appropriate padding and truncation.
  • Generate Response: The fine-tuned model processes the tokens to produce a response using controlled sampling parameters like top_k, temperature, and max_length.
  • Decode Output: Converts the model’s tokenized output back into human-readable text.
  • Validate Output: Tests if the model generates a coherent and relevant response to the input query, assessing its qualitative performance.

This step qualitatively tests the model’s performance by providing a sample query and evaluating its generated response. It helps validate the model’s ability to produce relevant and meaningful outputs.

You can refer this for details.

Comparing DistilGPT-2 Pre-Fine Tuning and Post-Fine Tuning

Fine-tuning DistilGPT-2, a compact version of GPT-2, tailors the model to specific tasks, enhancing its performance in targeted applications. Here’s a comparison of DistilGPT-2’s capabilities before and after fine-tuning:

Task Performance

  • Pre-Fine-Tuning: DistilGPT-2, pre-trained on general text data, excels at generating coherent and contextually relevant text across a broad range of topics. However, it may lack depth in specialized domains, such as medical diagnostics.
  • Post-Fine-Tuning: After fine-tuning on a domain-specific dataset—like the Symptoms and Disease Dataset—the model becomes adept at generating accurate and relevant responses within that domain. For instance, it can effectively predict diseases based on symptom descriptions.

Response Accuracy

  • Pre-Fine-Tuning: The model’s responses are general and may not align precisely with specialized queries, leading to less accurate or relevant outputs in niche areas.
  • Post-Fine-Tuning: Fine-tuning enhances the model’s understanding of domain-specific terminology and relationships, resulting in more precise and contextually appropriate responses.

Adaptability

  • Pre-Fine-Tuning: While versatile, the model’s general training limits its effectiveness in specialized tasks without additional adaptation.
  • Post-Fine-Tuning: The model becomes highly specialized, performing exceptionally well in the fine-tuned domain but may lose some generalization capabilities outside that area.

Efficiency

  • Pre-Fine-Tuning: DistilGPT-2 is already optimized for efficiency, offering faster inference times and lower computational requirements compared to larger models like GPT-3.
  • Post-Fine-Tuning: Fine-tuning maintains this efficiency while enhancing performance in the targeted domain, making it suitable for deployment in resource-constrained environments.

Practical Application

  • Pre-Fine-Tuning: The model serves well for general-purpose text generation but may not meet the accuracy demands of specialized applications.
  • Post-Fine-Tuning: It becomes a powerful tool for specific tasks, such as medical query answering, providing reliable and relevant information based on the fine-tuned dataset.

Pre-Fine Tuning output of the Query

from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Load pre-trained DistilGPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
model = GPT2LMHeadModel.from_pretrained("distilgpt2")

# Set the padding token to the end-of-sequence token (common practice for GPT-2-based models)
tokenizer.pad_token = tokenizer.eos_token

# Define the input query
input_query = "What are the symptoms of Chicken pox?"

# Tokenize the input query
input_tokens = tokenizer.encode_plus(
    input_query,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=50  # Adjust max_length if needed
)

# Generate response using the pre-trained model
output_tokens = model.generate(
    input_ids=input_tokens["input_ids"],
    attention_mask=input_tokens["attention_mask"],
    max_length=50,  # Adjust max_length if needed
    num_return_sequences=1,
    do_sample=True,  # Sampling adds randomness for diverse outputs
    top_k=8,  # Keep top 8 most probable tokens at each step
    top_p=0.95,  # Consider tokens with a cumulative probability of 0.95
    temperature=0.7,  # Adjust temperature for response diversity
    repetition_penalty=1.2,  # Penalize repetitive token generations
    pad_token_id=tokenizer.pad_token_id  # Handle padding gracefully
)

# Decode the generated output to human-readable text
decoded_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)

# Print the results
print("Pre-Fine-Tuning Response:")
print(decoded_output)
output: Fine-Tuning DistilGPT-2 for Medical Queries

The response from the pre-fine-tuned DistilGPT-2 model highlights its general-purpose nature. While it’s coherent and grammatically correct, it lacks specific, accurate information about the symptoms of chickenpox. This behavior is expected because the pre-trained model hasn’t been exposed to domain-specific knowledge about diseases or symptoms.

Post-Fine Tuning output of the Query

Post-Fine Tuning output of the Query

How Post Fine-Tuning Responses have Improved

Once fine-tuned on the dataset “Symptoms and Disease Dataset,” the model :

  • Learn Specific Relationships: Understand the mapping between symptoms and diseases.
  • Generate Targeted Responses: Provide medically accurate and relevant details when queried.

In summary, fine-tuning DistilGPT-2 transforms it from a general-purpose language model into a specialized tool, enhancing its performance and accuracy in specific domains while retaining its inherent efficiency.

Conclusion

Small language models, such as DistilGPT-2, are a powerful and efficient alternative to large-scale models for domain-specific tasks. Through this tutorial, we demonstrated how to fine-tune DistilGPT-2 using the Symptoms and Disease Dataset, focusing on building a lightweight yet effective model for medical query answering. The process involved data preparation, training, validation, and response generation, showcasing the practical applications of small models in real-world scenarios.

The success of this approach lies in its balance between computational efficiency and performance, making small language models an excellent choice for resource-constrained environments or specialized use cases.

Key Takeaways

  • Small models like DistilGPT-2 are efficient, resource-friendly, and practical for domain-specific tasks.
  • Fine-tuning allows small models to specialize in focused applications like medical query answering.
  • A structured workflow ensures smooth implementation, from dataset preparation to response validation.
  • Small models are cost-effective and scalable for various real-world applications.
  • Inference testing ensures the model generates relevant, coherent, and deployable outputs.

Frequently Asked Questions

Q1. What is a small language model?

A. A small language model, like DistilGPT-2, is a compact version of large models designed to balance performance and efficiency. It requires fewer computational resources, making it ideal for resource-constrained environments and domain-specific tasks.

Q2. Why use a small language model instead of a large one like GPT-3?

A. Small models are faster, cost-effective, and easier to fine-tune on specific datasets. They are ideal when large-scale general-purpose capabilities are unnecessary, such as in applications requiring domain-specific expertise.

Q3. What is fine-tuning, and why is it important?

A. Fine-tuning is the process of adapting a pre-trained model to a specific task or domain by training it on a curated dataset. It improves the model’s performance for specialized tasks, such as predicting diseases from symptoms.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

My name is Nilesh Dwivedi, and I'm excited to join this vibrant community of bloggers and readers. I'm currently in my first year of BTech, specializing in Data Science and Artificial Intelligence at IIIT Dharwad. I'm passionate about technology and data science and looking forward to write more blogs.

Responses From Readers

Clear

Congratulations, You Did It!
Well Done on Completing Your Learning Journey. Stay curious and keep exploring!

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details