Open-Source AI Cookbook documentation
Advanced GRPO Fine-tuning for Mathematical Reasoning with Multi-Reward Training
Advanced GRPO Fine-tuning for Mathematical Reasoning with Multi-Reward Training
Authored by: Behrooz Azarkhalili
This notebook demonstrates advanced GRPO (Group Relative Policy Optimization) for mathematical reasoning using a comprehensive multi-reward training system. We’ll fine-tune a model on the GSM8K dataset with four specialized reward functions.
Key Features:
- 4 Reward Functions: Format compliance, approximate matching, answer correctness, and number extraction
- Memory Efficient: 4-bit quantization + LoRA for consumer GPUs
- Interactive Monitoring: Real-time training metrics with trackio dashboard
- Structured Output: Enforces step-by-step reasoning format
The model learns to generate structured mathematical solutions with clear reasoning steps and accurate numerical answers.
Installation and Setup
Install the required packages for GRPO training with memory-efficient techniques.
# Install required packages for GRPO mathematical reasoning training
!pip install transformers datasets trl bitsandbytes peft trackio
GPU Environment Detection
Verify GPU availability and display hardware specifications for optimal training configuration.
import torch
# Verify CUDA availability and display GPU specifications
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
# Display current GPU details for training optimization
print(f"Current GPU: {torch.cuda.current_device()}")
print(f"GPU name: {torch.cuda.get_device_name()}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
# Provide guidance for enabling GPU in Colab
print("⚠️ No GPU available. This notebook requires a GPU for efficient training.")
print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")
Core Library Imports
Import essential libraries for GRPO training, model configuration, and experiment tracking.
import trackio # Experiment tracking dashboard
import re # Regex patterns for reward functions
# GRPO training components
from trl import GRPOConfig, GRPOTrainer
# Model and tokenization
from transformers import (
AutoModelForCausalLM, # Causal language model loading
AutoTokenizer, # Text tokenization
BitsAndBytesConfig, # Quantization configuration
)
# Parameter-efficient fine-tuning
from peft import LoraConfig, get_peft_model, TaskType
# Dataset handling
from datasets import load_dataset
# Logging configuration
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Suppress httpx request logs that appear during trackio usage
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("gradio_client").setLevel(logging.WARNING)
Model Selection and Configuration
Choose a compact but capable model suitable for mathematical reasoning with memory constraints.
# Select model optimized for instruction-following and reasoning
model_name = "Qwen/Qwen2.5-3B-Instruct" # 3B parameter model balances capability and memory usage
max_seq_length = 2048 # Token limit for mathematical problems (reduce if OOM)
print(f"Loading model: {model_name}")
print(f"Max sequence length: {max_seq_length}")
# Configure 4-bit quantization for ~75% memory reduction
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Enable 4-bit precision (vs 16-bit default)
bnb_4bit_quant_type="nf4", # NormalFloat4: optimal for neural network weights
bnb_4bit_compute_dtype=torch.float16, # Use FP16 for forward/backward passes
bnb_4bit_use_double_quant=True, # Further quantize quantization constants
)
print("✅ 4-bit quantization configured")
print(" Memory reduction: ~75% vs FP16")
# Load model with quantization and automatic device mapping
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config, # Apply 4-bit quantization
device_map="auto", # Auto-distribute across available GPUs/CPU
trust_remote_code=True, # Allow custom model code execution
torch_dtype=torch.float16, # Use FP16 for non-quantized operations
)
# Load corresponding tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Allow custom tokenizer code
# Ensure tokenizer has proper padding token for batch processing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"✅ Model loaded successfully!")
print(f"📊 Model parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"🧮 Quantized parameters: ~{sum(p.numel() for p in model.parameters() if hasattr(p, 'quant_type')) / 1e6:.1f}M")
LoRA Configuration
Apply Low-Rank Adaptation to train only ~0.1% of parameters while maintaining performance.
# Configure LoRA for mathematical reasoning adaptation
lora_config = LoraConfig(
r=16, # Rank: adaptation capacity (16 good for reasoning tasks)
lora_alpha=32, # Scaling factor (typically 2x rank)
target_modules=["q_proj", "v_proj"], # Focus on attention query/value for reasoning
lora_dropout=0.1, # Regularization to prevent overfitting
bias="none", # Skip bias adaptation for simplicity
task_type=TaskType.CAUSAL_LM, # Causal language modeling task
)
print("🔧 Applying LoRA adaptation to model...")
# Apply LoRA configuration to create trainable adapter
model = get_peft_model(model, lora_config)
# Display parameter efficiency
print("📊 LoRA Training Parameters Summary:")
model.print_trainable_parameters() # Shows trainable vs total parameters
GSM8K Dataset Setup
Configure the GSM8K mathematical reasoning dataset with structured output format for step-by-step solutions.
# Define structured output format for mathematical reasoning
reasoning_start = "<start_working_out>" # Begin reasoning section
reasoning_end = "<end_working_out>" # End reasoning section
solution_start = "<SOLUTION>" # Begin final answer
solution_end = "</SOLUTION>" # End final answer
# System prompt that teaches the model our desired reasoning structure
system_prompt = f"""You are a mathematical reasoning assistant.
When given a math problem:
1. Show your step-by-step work between {reasoning_start} and {reasoning_end}
2. Provide your final numerical answer between {solution_start} and {solution_end}
3. Be precise and show all calculation steps clearly."""
print("✅ Format tokens and system prompt defined")
print(f" Reasoning format: {reasoning_start} ... {reasoning_end}")
print(f" Solution format: {solution_start} ... {solution_end}")
# Dataset processing utilities
def extract_hash_answer(text):
"""Extract numerical answer from GSM8K format (#### marker)"""
if "####" not in text:
return None
# GSM8K uses format: "Explanation... #### 42"
return text.split("####")[1].strip()
def process_dataset_example(example):
"""Convert GSM8K example to conversation format for GRPO training"""
question = example["question"]
answer = extract_hash_answer(example["answer"])
# Create conversation with system prompt for structured reasoning
prompt = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
return {
"prompt": prompt, # Input conversation
"answer": answer, # Ground truth for reward functions
}
print("✅ Dataset processing functions defined")
# Load and preprocess GSM8K training dataset
print("🔄 Loading GSM8K mathematical reasoning dataset...")
dataset = load_dataset("openai/gsm8k", "main", split="train")
# Apply conversation formatting to all examples
dataset = dataset.map(process_dataset_example)
print(f"✅ Dataset loaded and processed!")
print(f"📊 Training examples: {len(dataset):,}")
print(f"🎯 Sample question: {dataset[0]['prompt'][1]['content']}...")
print(f"🎯 Sample answer: {dataset[0]['answer']}")
# Show structure of first example for verification
print(f"\n📋 Example structure:")
print(f" Prompt: {len(dataset[0]['prompt'])} messages (system + user)")
print(f" Answer: {dataset[0]['answer']} (ground truth for rewards)")
Multi-Reward System Design
Implement four complementary reward functions to evaluate different aspects of mathematical reasoning:
- Exact Format Matching: Perfect structure compliance
- Approximate Matching: Partial credit for format elements
- Answer Correctness: Mathematical accuracy with graduated scoring
- Number Extraction: Ability to parse and output numerical results
# Compiled regex patterns for efficient reward computation
match_format = re.compile(
rf"^[\s]{{0,}}" # Optional whitespace at start
rf"{reasoning_start}.+?{reasoning_end}.*?" # Reasoning section (non-greedy)
rf"{solution_start}(.+?){solution_end}" # Solution section with capture group
rf"[\s]{{0,}}$", # Optional whitespace at end
flags=re.MULTILINE | re.DOTALL, # Multi-line matching with . matching newlines
)
match_numbers = re.compile(
rf"{solution_start}.*?([\d\.]{{1,}})", # Extract numbers from solution section
flags=re.MULTILINE | re.DOTALL, # Flexible pattern matching
)
# Reward Function 1: Exact Format Compliance
def match_format_exactly(completions, **kwargs):
"""
High reward (3.0) for perfect format adherence
Ensures model learns the complete structured output pattern
"""
scores = []
for completion in completions:
response = completion[0]["content"]
# Check if response matches complete format pattern
score = 3.0 if match_format.search(response) is not None else 0.0
scores.append(score)
return scores
# Reward Function 2: Partial Format Credit
def match_format_approximately(completions, **kwargs):
"""
Graduated scoring for format elements
Encourages learning individual components even if not perfect
"""
scores = []
for completion in completions:
response = completion[0]["content"]
score = 0
# Award +0.5 for correct token count, -0.5 for wrong count
score += 0.5 if response.count(reasoning_start) == 1 else -0.5
score += 0.5 if response.count(reasoning_end) == 1 else -0.5
score += 0.5 if response.count(solution_start) == 1 else -0.5
score += 0.5 if response.count(solution_end) == 1 else -0.5
scores.append(score)
return scores
# Reward Function 3: Mathematical Accuracy
def check_answer_correctness(prompts, completions, answer, **kwargs):
"""
Graduated scoring for mathematical accuracy:
- 3.0: Exact match
- 1.5: Within 10% (close answer)
- 0.5: Within 20% (reasonable attempt)
- -0.5: Wrong answer (penalty for incorrect math)
"""
responses = [completion[0]["content"] for completion in completions]
# Extract answers using format pattern
extracted_responses = [
guess.group(1) if (guess := match_format.search(r)) is not None else None for r in responses
]
scores = []
for guess, true_answer in zip(extracted_responses, answer):
if guess is None: # No extractable answer
scores.append(0)
continue
# Exact string match gets full points
if guess.strip() == true_answer.strip():
scores.append(3.0)
else:
# Try numerical comparison for partial credit
try:
ratio = float(guess) / float(true_answer)
if 0.9 <= ratio <= 1.1: # Within 10%
scores.append(1.5)
elif 0.8 <= ratio <= 1.2: # Within 20%
scores.append(0.5)
else: # Wrong answer
scores.append(-0.5)
except (ValueError, ZeroDivisionError):
scores.append(-0.5) # Invalid numerical format
return scores
# Reward Function 4: Number Extraction Ability
def check_numbers_extraction(prompts, completions, answer, **kwargs):
"""
Tests the model's ability to extract numerical values from solution sections
Complementary to exact format matching - focuses on parsing capability
"""
responses = [completion[0]["content"] for completion in completions]
# Extract numbers from solution sections using number pattern
extracted_responses = [
guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses
]
scores = []
for guess, true_answer in zip(extracted_responses, answer):
if guess is None: # No extractable number
scores.append(0)
continue
try:
# Simple numerical equality check
true_val = float(true_answer.strip())
guess_val = float(guess.strip())
# Binary scoring: correct (1.5) or incorrect (0)
scores.append(1.5 if guess_val == true_val else 0.0)
except (ValueError, TypeError):
scores.append(0) # Invalid number format
return scores
GRPO Training Setup
Configure training parameters optimized for mathematical reasoning with memory constraints.
# Configure GRPO training parameters for mathematical reasoning
training_args = GRPOConfig(
# Learning parameters optimized for reasoning tasks
learning_rate=5e-6, # Conservative LR to prevent destabilizing reasoning
# Memory-efficient batch configuration
per_device_train_batch_size=2, # Small batch for GPU memory constraints
gradient_accumulation_steps=8, # Effective batch size = 2 * 8 = 16
# Sequence length limits for mathematical problems
max_prompt_length=1024, # Sufficient for complex word problems
max_completion_length=1024, # Room for detailed step-by-step reasoning
# Training duration and monitoring
max_steps=10, # Short demo run (increase to 500+ for production)
logging_steps=1, # Log metrics every step for close monitoring
# Stability and output configuration
output_dir="./trl_grpo_outputs",
max_grad_norm=0.1, # Aggressive gradient clipping for stable training
report_to="trackio", # use trackio for experiment tracking (instead of wandb/tensorboard)
)
# Create unique run name with timestamp to ensure fresh tracking
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
run_name = f"qwen2.5-3b-gsm8k-grpo-{timestamp}"
# Initialize trackio experiment tracking with unique run name
trackio.init(
project="GRPO-Mathematical-Reasoning", # Project name for organization
name=run_name, # Unique run identifier with timestamp
config={
# Model and dataset configuration
"model_name": "Qwen/Qwen2.5-3B-Instruct",
"dataset": "GSM8K",
"technique": "GRPO + LoRA + 4-bit",
# Training hyperparameters
"learning_rate": training_args.learning_rate,
"batch_size": training_args.per_device_train_batch_size,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"effective_batch_size": training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
"max_steps": training_args.max_steps,
# LoRA configuration
"lora_r": 16,
"lora_alpha": 32,
# GRPO-specific settings
"num_generations": training_args.num_generations, # Default: 8 generations per step
"max_prompt_length": training_args.max_prompt_length,
"max_completion_length": training_args.max_completion_length,
# Reward system
"num_reward_functions": 4,
},
)
print("🎯 GRPO Configuration Summary:")
print(f" Learning rate: {training_args.learning_rate}")
print(
f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
)
print(f" Training steps: {training_args.max_steps}")
print(f" Generations per step: {training_args.num_generations}")
print(f"✅ Trackio experiment tracking initialized")
print(f"📊 Run name: {run_name}")
Trainer Initialization with Trackio Integration
Set up the GRPO trainer with our multi-reward system and experiment tracking.
# Initialize GRPO trainer with multi-reward system
# trackio_callback = TrackioCallback() # Create trackio logging callback
trainer = GRPOTrainer(
model=model, # LoRA-adapted quantized model
reward_funcs=[ # Four complementary reward functions
match_format_exactly, # Perfect structure compliance
match_format_approximately, # Partial format credit
check_answer_correctness, # Mathematical accuracy
check_numbers_extraction, # Number parsing ability
],
args=training_args, # Training configuration
train_dataset=dataset, # Processed GSM8K dataset
)
print("✅ GRPO Trainer initialized successfully!")
print(f"📊 Training dataset: {len(dataset):,} examples")
print(f"🎯 Reward functions: {len(trainer.reward_funcs)} active")
print(f"📈 Trackio integration: Enabled")
print(f"🔄 Ready for training with {training_args.num_generations} generations per step")
Begin GRPO Training
Start the training process with real-time reward monitoring. Watch for gradual improvement in both format compliance and mathematical accuracy.
# Execute GRPO training with multi-reward optimization
print("🚀 Starting GRPO training...")
print("📊 Monitor metrics: reward scores, KL divergence, policy gradients")
print("🔍 Trackio will log: losses, rewards, learning rate, gradients")
# Run the training process
trainer.train()
# Complete the trackio experiment
trackio.finish()
print("✅ Training completed successfully!")
print(f"💾 Model saved to: {training_args.output_dir}")
Experiment Dashboard
Launch the interactive trackio dashboard to analyze training progress, reward evolution, and model performance metrics.
# Launch interactive trackio dashboard for experiment analysis
# View training curves, reward progression, loss evolution, and hyperparameter effects
trackio.show(project="GRPO-Mathematical-Reasoning")
# Alternative: Launch from command line with: trackio show --project "GRPO-Mathematical-Reasoning"
Model Evaluation and Testing
Test the trained model’s mathematical reasoning capability with structured output validation.
# Define model testing function with optimized generation parameters
def test_model(question, max_length=512):
"""
Test the trained model on mathematical questions
Args:
question (str): Mathematical problem to solve
max_length (int): Maximum tokens to generate
Returns:
str: Model's structured response with reasoning and solution
"""
# Format input using conversation template
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
# Apply chat template and tokenize
text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True, # Add assistant prompt
tokenize=False, # Return string, not tokens
)
# Tokenize and move to appropriate device
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(f"🤔 Processing: {question}")
# Generate response with reasoning-optimized parameters
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=0.7, # Balance creativity and consistency
do_sample=True, # Enable sampling for varied reasoning paths
top_p=0.9, # Nucleus sampling for quality
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1, # Reduce repetitive reasoning steps
length_penalty=1.0, # Neutral preference for response length
early_stopping=True, # Stop at natural completion
)
# Decode and extract only the generated portion
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(text) :].strip()
return generated_text
# Test model on GSM8K problem
gsm8k_question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
expected_answer = "72"
# Generate response
gsm8k_response = test_model(gsm8k_question, max_length=768)
print(f"Question: {gsm8k_question}")
print(f"Model Response:\n{gsm8k_response}")
# Validate format compliance
has_reasoning = reasoning_start in gsm8k_response and reasoning_end in gsm8k_response
has_solution = solution_start in gsm8k_response and solution_end in gsm8k_response
print(f"\nFormat Check:")
print(f"Reasoning section: {has_reasoning}")
print(f"Solution section: {has_solution}")
# Check answer accuracy if solution section exists
if has_solution:
try:
solution_text = gsm8k_response.split(solution_start)[1].split(solution_end)[0].strip()
extracted_number = "".join(filter(str.isdigit, solution_text))
expected_number = "".join(filter(str.isdigit, expected_answer))
is_correct = extracted_number == expected_number
print(f"Extracted: {solution_text}")
print(f"Expected: {expected_answer}")
print(f"Correct: {is_correct}")
except:
print("Could not extract solution")
Clean Up Resources
Free GPU memory and clear cached tensors for optimal resource management.
from pathlib import Path
def remove_trackio_project(project_name):
"""Remove a trackio project by deleting its database file"""
cache_dir = Path.home() / ".cache" / "huggingface" / "trackio"
db_file = cache_dir / f"{project_name}.db"
if db_file.exists():
db_file.unlink()
print(f"Removed trackio project: {project_name}")
else:
print(f"Project not found: {project_name}")
# Clean up trackio experiment database to free storage space
# WARNING: This permanently deletes all experiment logs and metrics
remove_trackio_project("GRPO-Mathematical-Reasoning")
# Free GPU memory and clear Python garbage collection
import gc
torch.cuda.empty_cache() # Clear PyTorch CUDA memory cache
gc.collect() # Run Python garbage collector
print("✅ GPU memory cache cleared")
print("✅ Python garbage collection completed")
print("🧹 Resources freed for other processes")
References
Papers and Research
- GRPO Algorithm: Group Relative Policy Optimization - The original GRPO paper introducing group-based relative policy optimization
- GSM8K Dataset: Training Verifiers to Solve Math Word Problems - Cobbe et al., OpenAI
- LoRA: Low-Rank Adaptation of Large Language Models - Hu et al., Microsoft
- QLoRA: Efficient Finetuning of Quantized LLMs - Dettmers et al., 4-bit quantization for efficient training
Libraries and Frameworks
- TRL (Transformers Reinforcement Learning): HuggingFace TRL - Official library for RLHF and advanced training techniques
- Transformers: HuggingFace Transformers - State-of-the-art NLP library
- PEFT: Parameter-Efficient Fine-Tuning - Efficient adaptation methods
- BitsAndBytes: 8-bit & 4-bit Quantization - Memory-efficient training
Models Used
- Qwen2.5-3B-Instruct: Qwen Model Series - Alibaba’s instruction-tuned language model
- Alternative Models: Gemma-2B, DialoGPT, GPT-2 (configurable in the notebook)
Datasets
- GSM8K: OpenAI GSM8K - Grade School Math 8K problems dataset
- Format: Mathematical word problems requiring multi-step reasoning and numerical answers
Key Concepts
- Reinforcement Learning from Human Feedback (RLHF): Training language models using reward signals
- Group Relative Policy Optimization: Advanced RL technique comparing responses in groups rather than absolute scoring
- Structured Generation: Teaching models to follow specific output formats with reasoning sections
- Multi-Reward Training: Using multiple reward functions for comprehensive evaluation