Day 22: SFT (Supervised Fine-Tuning) in Practice
Now we combine everything we have learned about LoRA, QLoRA, and dataset preparation to execute actual fine-tuning. Using Hugging Face’s trl library and its SFTTrainer, you do not need to write the complex training loop yourself.
Preparing the Training Environment
# pip install trl peft transformers datasets bitsandbytes wandb
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, TaskType
from trl import SFTTrainer
from datasets import load_dataset
import torch
# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
Loading and Formatting the Dataset
# Load dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]") # 1000 samples for quick experimentation
# Prompt formatting function
def format_instruction(example):
"""Convert Alpaca format to chat format"""
if example.get("input", "").strip():
user_message = f"{example['instruction']}\n\nInput: {example['input']}"
else:
user_message = example["instruction"]
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": example["output"]},
]
# Apply the model's chat_template
text = tokenizer.apply_chat_template(messages, tokenize=False)
return {"text": text}
# Transform dataset
formatted_dataset = dataset.map(format_instruction)
print(f"Transformation complete: {len(formatted_dataset)} samples")
print(f"Sample:\n{formatted_dataset[0]['text'][:300]}")
Running Training with SFTTrainer
# LoRA configuration
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# Training configuration
training_args = TrainingArguments(
output_dir="./sft_output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # Effective batch size = 4 * 4 = 16
learning_rate=2e-4,
weight_decay=0.01,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=10,
save_steps=100,
save_total_limit=3, # Keep only the 3 most recent checkpoints
bf16=True, # bfloat16 training
report_to="wandb", # wandb logging (optional)
gradient_checkpointing=True, # Memory savings
)
# Create SFTTrainer and run training
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=formatted_dataset,
peft_config=lora_config,
processing_class=tokenizer,
max_seq_length=512,
)
# Start training
train_result = trainer.train()
# Print results
print(f"Training complete!")
print(f" Total steps: {train_result.global_step}")
print(f" Training loss: {train_result.training_loss:.4f}")
# Save model
trainer.save_model("./sft_final")
print("Model saved: ./sft_final")
Training Curve Analysis and Checkpoint Management
# Analyze training logs without wandb
import json
# Extract training curve data from trainer logs
log_history = trainer.state.log_history
train_losses = [
(log["step"], log["loss"])
for log in log_history
if "loss" in log
]
# Check training curve
print("Step | Loss")
print("-" * 20)
for step, loss in train_losses[-10:]: # Last 10 entries
print(f"{step:5d} | {loss:.4f}")
# Resume training from checkpoint (if training was interrupted)
# trainer.train(resume_from_checkpoint="./sft_output/checkpoint-200")
# Inference test with the final model
model.eval()
test_input = tokenizer("What is Python?", return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**test_input, max_new_tokens=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Today’s Exercises
- In the code above, change
learning_rateto 1e-5, 1e-4, and 5e-4, compare the training loss curves, and analyze which learning rate is most stable. - Combine
per_device_train_batch_sizeandgradient_accumulation_stepsto find the optimal settings that fit your GPU memory while maintaining the same effective batch size (16). - After training is complete, load a checkpoint and generate answers to 5 test questions, then compare quality with the pre-fine-tuned model.