File size: 3,662 Bytes
58a60b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
---
license: apache-2.0
base_model: meta-llama/Llama-3.2-3B
tags:
- reinforcement-learning
- rlhf
- ppo
- gsm8k
- math
- reasoning
- verl
- actor
datasets:
- gsm8k
library_name: transformers
---
# Llama 3.2 3B - GSM8K PPO Fine-tuned (Actor)
This model is a **actor** checkpoint from Llama 3.2 3B fine-tuned on GSM8K using **PPO (Proximal Policy Optimization)** with the **veRL** framework.
## Model Description
- **Base Model**: meta-llama/Llama-3.2-3B
- **Training Framework**: veRL (Versatile Reinforcement Learning)
- **Training Method**: PPO (Proximal Policy Optimization)
- **Dataset**: GSM8K (Grade School Math 8K)
- **Task**: Mathematical reasoning and problem-solving
- **Checkpoint Step**: 467
- **Evaluation Score**: 0.467
- **Model Type**: actor
## Training Details
This checkpoint was trained using PPO on the GSM8K dataset to improve mathematical reasoning capabilities. The model was optimized using reward-based learning to generate more accurate step-by-step solutions to math word problems.
The checkpoint was automatically selected using best-of-n evaluation across multiple training steps, ensuring optimal performance.
## Usage
### Basic Inference
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained(
"samhitha2601/llama3.2-3b-ppo",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("samhitha2601/llama3.2-3b-ppo")
# Example GSM8K problem
prompt = """Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer: Let's solve this step by step:"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
### Chat Format (if using Instruct variant)
```python
messages = [
{"role": "user", "content": "Solve this math problem: If a train travels 60 miles per hour for 2.5 hours, how far does it travel?"}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## Performance
This model has been trained with PPO to maximize reward on GSM8K problems, showing improved:
- Step-by-step reasoning
- Arithmetic accuracy
- Problem decomposition
- Solution clarity
## Training Infrastructure
- **Framework**: veRL (Versatile Reinforcement Learning)
- **Algorithm**: PPO
- **Distributed Training**: FSDP (Fully Sharded Data Parallel)
- **World Size**: 2 (2 GPUs/ranks)
## Limitations
- Primarily optimized for GSM8K-style math problems
- May not generalize well to other domains without fine-tuning
- Mathematical reasoning is limited to the complexity seen in GSM8K
- Still susceptible to arithmetic errors on complex calculations
## Citation
If you use this model, please cite:
```bibtex
@misc{llama32-gsm8k-ppo,
title={Llama 3.2 3B Fine-tuned on GSM8K with PPO},
author={Your Name},
year={2025},
howpublished={\url{https://huggingface.co/samhitha2601/llama3.2-3b-ppo}},
}
```
## Acknowledgments
- **Base Model**: Meta AI (Llama 3.2)
- **Dataset**: GSM8K by OpenAI
- **Training Framework**: veRL
- **Training Method**: PPO (Proximal Policy Optimization)
|