LLM 25-Day Course - Day 22: SFT (Supervised Fine-Tuning) in Practice

Day 22: SFT (Supervised Fine-Tuning) in Practice

Now we combine everything we have learned about LoRA, QLoRA, and dataset preparation to execute actual fine-tuning. Using Hugging Face’s trl library and its SFTTrainer, you do not need to write the complex training loop yourself.

Preparing the Training Environment

# pip install trl peft transformers datasets bitsandbytes wandb

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, TaskType
from trl import SFTTrainer
from datasets import load_dataset
import torch

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading and Formatting the Dataset

# Load dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]")  # 1000 samples for quick experimentation

# Prompt formatting function
def format_instruction(example):
    """Convert Alpaca format to chat format"""
    if example.get("input", "").strip():
        user_message = f"{example['instruction']}\n\nInput: {example['input']}"
    else:
        user_message = example["instruction"]

    messages = [
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": example["output"]},
    ]
    # Apply the model's chat_template
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    return {"text": text}

# Transform dataset
formatted_dataset = dataset.map(format_instruction)
print(f"Transformation complete: {len(formatted_dataset)} samples")
print(f"Sample:\n{formatted_dataset[0]['text'][:300]}")

Running Training with SFTTrainer

# LoRA configuration
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# Training configuration
training_args = TrainingArguments(
    output_dir="./sft_output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,     # Effective batch size = 4 * 4 = 16
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,                # Keep only the 3 most recent checkpoints
    bf16=True,                         # bfloat16 training
    report_to="wandb",                 # wandb logging (optional)
    gradient_checkpointing=True,       # Memory savings
)

# Create SFTTrainer and run training
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset,
    peft_config=lora_config,
    processing_class=tokenizer,
    max_seq_length=512,
)

# Start training
train_result = trainer.train()

# Print results
print(f"Training complete!")
print(f"  Total steps: {train_result.global_step}")
print(f"  Training loss: {train_result.training_loss:.4f}")

# Save model
trainer.save_model("./sft_final")
print("Model saved: ./sft_final")

Training Curve Analysis and Checkpoint Management

# Analyze training logs without wandb
import json

# Extract training curve data from trainer logs
log_history = trainer.state.log_history

train_losses = [
    (log["step"], log["loss"])
    for log in log_history
    if "loss" in log
]

# Check training curve
print("Step | Loss")
print("-" * 20)
for step, loss in train_losses[-10:]:  # Last 10 entries
    print(f"{step:5d} | {loss:.4f}")

# Resume training from checkpoint (if training was interrupted)
# trainer.train(resume_from_checkpoint="./sft_output/checkpoint-200")

# Inference test with the final model
model.eval()
test_input = tokenizer("What is Python?", return_tensors="pt").to(model.device)
with torch.no_grad():
    output = model.generate(**test_input, max_new_tokens=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))

Today’s Exercises

  1. In the code above, change learning_rate to 1e-5, 1e-4, and 5e-4, compare the training loss curves, and analyze which learning rate is most stable.
  2. Combine per_device_train_batch_size and gradient_accumulation_steps to find the optimal settings that fit your GPU memory while maintaining the same effective batch size (16).
  3. After training is complete, load a checkpoint and generate answers to 5 test questions, then compare quality with the pre-fine-tuned model.

Was this article helpful?