File size: 3,662 Bytes
58a60b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
---
license: apache-2.0
base_model: meta-llama/Llama-3.2-3B
tags:
- reinforcement-learning
- rlhf
- ppo
- gsm8k
- math
- reasoning
- verl
- actor
datasets:
- gsm8k
library_name: transformers
---

# Llama 3.2 3B - GSM8K PPO Fine-tuned (Actor)

This model is a **actor** checkpoint from Llama 3.2 3B fine-tuned on GSM8K using **PPO (Proximal Policy Optimization)** with the **veRL** framework.

## Model Description

- **Base Model**: meta-llama/Llama-3.2-3B
- **Training Framework**: veRL (Versatile Reinforcement Learning)
- **Training Method**: PPO (Proximal Policy Optimization)
- **Dataset**: GSM8K (Grade School Math 8K)
- **Task**: Mathematical reasoning and problem-solving
- **Checkpoint Step**: 467
- **Evaluation Score**: 0.467
- **Model Type**: actor

## Training Details

This checkpoint was trained using PPO on the GSM8K dataset to improve mathematical reasoning capabilities. The model was optimized using reward-based learning to generate more accurate step-by-step solutions to math word problems.

The checkpoint was automatically selected using best-of-n evaluation across multiple training steps, ensuring optimal performance.

## Usage

### Basic Inference

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "samhitha2601/llama3.2-3b-ppo",
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("samhitha2601/llama3.2-3b-ppo")

# Example GSM8K problem
prompt = """Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Answer: Let's solve this step by step:"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    temperature=0.7,
    do_sample=True,
    top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

### Chat Format (if using Instruct variant)

```python
messages = [
    {"role": "user", "content": "Solve this math problem: If a train travels 60 miles per hour for 2.5 hours, how far does it travel?"}
]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

## Performance

This model has been trained with PPO to maximize reward on GSM8K problems, showing improved:
- Step-by-step reasoning
- Arithmetic accuracy
- Problem decomposition
- Solution clarity

## Training Infrastructure

- **Framework**: veRL (Versatile Reinforcement Learning)
- **Algorithm**: PPO
- **Distributed Training**: FSDP (Fully Sharded Data Parallel)
- **World Size**: 2 (2 GPUs/ranks)

## Limitations

- Primarily optimized for GSM8K-style math problems
- May not generalize well to other domains without fine-tuning
- Mathematical reasoning is limited to the complexity seen in GSM8K
- Still susceptible to arithmetic errors on complex calculations

## Citation

If you use this model, please cite:

```bibtex
@misc{llama32-gsm8k-ppo,
  title={Llama 3.2 3B Fine-tuned on GSM8K with PPO},
  author={Your Name},
  year={2025},
  howpublished={\url{https://huggingface.co/samhitha2601/llama3.2-3b-ppo}},
}
```

## Acknowledgments

- **Base Model**: Meta AI (Llama 3.2)
- **Dataset**: GSM8K by OpenAI
- **Training Framework**: veRL
- **Training Method**: PPO (Proximal Policy Optimization)