Text Classification
Transformers
Safetensors
qwen2
text-generation
text-generation-inference
sarosavo commited on
Commit
b6e8e5f
·
verified ·
1 Parent(s): ef5a054

upload training script and reward server script

Browse files
reward_server/launch_reward.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ MODEL_PATH=$1
4
+ ANSWER_PATH=$2
5
+ METRIC=$3
6
+ PORT=8800
7
+
8
+ export VLLM_ENGINE_ITERATION_TIMEOUT_S=60
9
+ nohup vllm serve ${MODEL_PATH} \
10
+ --trust-remote-code \
11
+ --served-model-name server_model \
12
+ --max-num-seqs 256 \
13
+ --max-model-len 4096 \
14
+ --port 8000 \
15
+ > vllm_server.log &
16
+
17
+ # sleep 60
18
+
19
+ if [[ "${METRIC}" == "prob" ]]; then
20
+ nohup python model_server.py \
21
+ --tokenizer_path ${MODEL_PATH} \
22
+ --answer_path ${ANSWER_PATH} \
23
+ --normalize_reward \
24
+ --port ${PORT} \
25
+ --prob_reward \
26
+ --vllm_url "http://localhost:8000/v1" \
27
+ --vllm_model server_model \
28
+ > reward_server.log &
29
+ else
30
+ nohup python model_server.py \
31
+ --tokenizer_path ${MODEL_PATH} \
32
+ --answer_path ${ANSWER_PATH} \
33
+ --normalize_reward \
34
+ --port ${PORT} \
35
+ --vllm_url "http://localhost:8000/v1" \
36
+ --vllm_model server_model \
37
+ > reward_server.log &
38
+ fi
reward_server/model_server.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+ import torch
4
+ import uvicorn
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import JSONResponse
7
+ from transformers import AutoTokenizer
8
+ import asyncio
9
+ from collections import defaultdict
10
+ import json
11
+ from openai import AsyncOpenAI
12
+ import time
13
+ import math
14
+ # Set OpenAI's API key and API base to use vLLM's API server.
15
+
16
+ # for free-form including multiple-choice
17
+ PROMPT_critic_updated = '''
18
+ Given a problem, determine whether the final answer in the provided (incomplete) solution process matches the reference answer.
19
+ The reference answer may be one single option character (e.g., A, B, C, D), a numerical value, an expression, or a list of answers if multiple questions are involved.
20
+ **The reference answer may be in Chinese or another language, but your evaluation should be language-agnostic.**
21
+
22
+ Your task:
23
+ - Compare the final output of the solution process with the reference answer.
24
+ - If they **match exactly**, output **YES**.
25
+ - If they **do not match**, output **NO**.
26
+ - If the solution process is unclear, incomplete, or ambiguous, assume it is incorrect and output **NO**.
27
+
28
+ Your output must be strictly **'YES'** or **'NO'**, with no additional words, punctuation, or explanation.
29
+
30
+ ---
31
+
32
+ **Question:**
33
+ {question}
34
+
35
+ **Solution Process (Final Step Only):**
36
+ {response}
37
+
38
+ **Reference Answer:**
39
+ {reference}
40
+
41
+ **Output:**
42
+ '''
43
+
44
+
45
+
46
+ def parse_im_sections(text):
47
+ # Match all sections between <|im_start|> and <|im_end|>
48
+ sections = re.findall(r"<\|im_start\|>(.*?)<\|im_end\|>", text, re.DOTALL)
49
+ parsed = {}
50
+ for section in sections:
51
+ try:
52
+ # Split the role and content
53
+ role, content = section.split("\n", 1)
54
+ parsed[role.strip()] = content.strip()
55
+ except ValueError:
56
+ print(f"Skipping malformed section: {section}")
57
+ return parsed
58
+
59
+ def extract_last_non_empty_line(text, role="assistant"):
60
+ # Extract the last non-empty line from assistant's content
61
+ pattern = fr"<\|im_start\|>{role}(.*?)(?:<\|im_start\|>|<\|endoftext\|>|<\|eot_id\|>|$)"
62
+ match = re.search(pattern, text, re.DOTALL)
63
+ if match:
64
+ content = match.group(1).strip()
65
+ # Get the last non-empty line
66
+ lines = [line for line in content.splitlines() if line.strip()]
67
+ if lines:
68
+ last_non_empty_line=lines[-1]
69
+ else:
70
+ return ""
71
+ return last_non_empty_line
72
+ return ""
73
+
74
+
75
+ def reward_normalization(rewards):
76
+ if len(rewards) == 1:
77
+ return [0.0]
78
+ rewards = torch.tensor(rewards, dtype=torch.float64)
79
+ if rewards.std() == 0:
80
+ normalized_rewards = torch.zeros_like(rewards)
81
+ else:
82
+ normalized_rewards = (rewards - rewards.mean()) / rewards.std()
83
+
84
+ return normalized_rewards.tolist()
85
+
86
+
87
+ def strip_sequence(text, pad_token, eos_token):
88
+ pad_token_escaped = re.escape(pad_token)
89
+ eos_token_escaped = re.escape(eos_token)
90
+
91
+ pattern = f"^({eos_token_escaped}|{pad_token_escaped})+"
92
+ text = re.sub(pattern, "", text)
93
+
94
+ pattern = f"({eos_token_escaped}|{pad_token_escaped})+$"
95
+ text = re.sub(pattern, "", text)
96
+ return text
97
+
98
+
99
+ def group_reward_normalization(rewards, n_samples_per_prompt=4):
100
+ rewards = torch.tensor(rewards, dtype=torch.float64)
101
+ rewards = rewards.reshape(-1, n_samples_per_prompt)
102
+
103
+ mean = rewards.mean(dim=-1, keepdim=True)
104
+ std = rewards.std(dim=-1, keepdim=True)
105
+
106
+ normalized_rewards = torch.where(std == 0, torch.zeros_like(rewards), (rewards - mean) / std)
107
+
108
+ return normalized_rewards.flatten().tolist()
109
+
110
+
111
+ class RewardModelProxy:
112
+ def __init__(self, args):
113
+ self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
114
+ self.normalize_reward = args.normalize_reward
115
+ self.group_normalize_reward = args.group_normalize_reward
116
+ self.qa_dict = defaultdict(str)
117
+ self.load_dict(args.answer_path)
118
+ self.temperature = 0
119
+ self.stop=[self.tokenizer.eos_token,"<|im_end|>"]
120
+ self.max_tokens=1
121
+ self.prob_reward=args.prob_reward
122
+ self.log_path=args.log_path
123
+ self.vllm_model=args.vllm_model
124
+
125
+ def load_dict(self, path):
126
+ # Initialize self.qa_dict
127
+ with open(path, "r", encoding="utf-8") as file:
128
+ data = json.load(file)
129
+ for unit in data:
130
+ question = unit["query"][1]["content"]
131
+ label = unit["label"]
132
+ self.qa_dict[question] = label
133
+
134
+ if self.qa_dict:
135
+ sample_question, sample_label = next(iter(self.qa_dict.items()))
136
+ print("Sample Question:", sample_question)
137
+ print("Sample Label:", sample_label)
138
+ else:
139
+ print("qa_dict is empty.")
140
+
141
+
142
+ async def process_sample(self,query):
143
+ query = strip_sequence(query, self.tokenizer.pad_token, self.tokenizer.eos_token)+ self.tokenizer.eos_token
144
+ question = parse_im_sections(query)["user"]
145
+ answer = extract_last_non_empty_line(query, role="assistant")
146
+ if not answer.strip():
147
+ return 0.0
148
+ else:
149
+ prompt_question = PROMPT_critic_updated.format(question=question, reference=self.qa_dict[question], response=answer)
150
+ return await self.get_reward_from_vllm(prompt_question)
151
+
152
+ async def get_reward_from_vllm(self, query):
153
+ """Retrieve model judgment reward (with probability analysis)"""
154
+ max_retries = 10
155
+ delay=10
156
+ for attempt in range(max_retries):
157
+ try:
158
+ response = await client.chat.completions.create(
159
+ model=self.vllm_model,
160
+ messages=[
161
+ {"role": "system", "content": "You are a helpful assistant."},
162
+ {"role": "user", "content": query},
163
+ ],
164
+ temperature=self.temperature,
165
+ max_tokens=self.max_tokens,
166
+ stop=self.stop,
167
+ logprobs=True,
168
+ top_logprobs=10 # Get top 10 token probabilities
169
+ )
170
+ return self.calculate_reward_from_logprobs(response)
171
+
172
+ except Exception as e:
173
+ print(f"Attempt {attempt+1} failed: {str(e)}, retrying in {delay} seconds...")
174
+ await asyncio.sleep(delay)
175
+ print(f"Failed after {max_retries} retries, query content: {query[:200]}...")
176
+ return 0.0 # Return baseline value on failure
177
+
178
+ def calculate_reward_from_logprobs(self, response):
179
+ """Calculate normalized reward based on log probabilities"""
180
+ # Extract probabilities of all possible tokens
181
+ logprobs = response.choices[0].logprobs.content[0].top_logprobs
182
+ token_probs = {token.token: math.exp(token.logprob) for token in logprobs}
183
+
184
+ # Combine probabilities of YES/NO (case-insensitive)
185
+ yes_prob = sum(prob for token, prob in token_probs.items() if token.lower().strip()=="yes")
186
+ no_prob = sum(prob for token, prob in token_probs.items()if token.lower().strip()=="no")
187
+ total = yes_prob + no_prob
188
+ if total == 0:
189
+ return 0.0 # Return baseline value when no valid judgment
190
+ if self.prob_reward:
191
+ print(yes_prob/total)
192
+ return yes_prob / total # Normalized probability
193
+ return 1.0 if yes_prob > no_prob else 0.0 # Hard judgment mode
194
+
195
+ async def get_reward(self, queries):
196
+ print("Processing queries[0]: {}".format(queries[0]))
197
+ tasks = [self.process_sample(query) for query in queries]
198
+ scores = await asyncio.gather(*tasks)
199
+ print("Generated scores: {}".format(scores))
200
+ if self.log_path:
201
+ with open(self.log_path, 'a', encoding='utf-8') as f:
202
+ unit = {
203
+ "query_list": queries if isinstance(queries, list) else [],
204
+ "hard_score_list": scores if isinstance(scores, list) else []
205
+ }
206
+ json.dump(unit, f, ensure_ascii=False)
207
+ f.write('\n')
208
+ if self.normalize_reward:
209
+ return reward_normalization(scores)
210
+ elif self.group_normalize_reward:
211
+ return group_reward_normalization(scores)
212
+ else:
213
+ return scores
214
+
215
+
216
+ if __name__ == "__main__":
217
+ parser = argparse.ArgumentParser()
218
+ # Reward Model
219
+ parser.add_argument("--tokenizer_path", type=str, default=None)
220
+ parser.add_argument("--answer_path", type=str, default=None)
221
+ parser.add_argument("--prob_reward", action="store_true", default=False)
222
+ parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation")
223
+ parser.add_argument("--group_normalize_reward", action="store_true", default=False, help="Enable Group Reward Normazation")
224
+ parser.add_argument("--port", type=int, default=5000, help="Port number for the server")
225
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server")
226
+ parser.add_argument("--log_path", type=str, default=None)
227
+ parser.add_argument("--vllm_url", type=str, default=None)
228
+ parser.add_argument("--vllm_model", type=str, default=None)
229
+ args = parser.parse_args()
230
+ openai_api_key = "EMPTY"
231
+ openai_api_base = args.vllm_url
232
+
233
+ client = AsyncOpenAI(
234
+ api_key=openai_api_key,
235
+ base_url=openai_api_base,
236
+ )
237
+
238
+ # Server setup
239
+ reward_model = RewardModelProxy(args)
240
+ app = FastAPI()
241
+
242
+
243
+ @app.post("/get_reward")
244
+ async def get_reward(request: Request):
245
+ data = await request.json()
246
+ queries = data.get("query")
247
+ rewards = await reward_model.get_reward(queries)
248
+ result = {"rewards": rewards}
249
+ print(f"Sent JSON response: {result}")
250
+ return JSONResponse(result)
251
+
252
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
train.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ EXPERIMENT_NAME=$1 # for example, "sft_reward_training"
4
+ PRETRAIN_PATH=$2 # path_to_Qwen2.5-7B-Instruct
5
+ TRAIN_DATA_PATH=$3 # path_to_training_data (https://huggingface.co/datasets/sarosavo/Master-RM)
6
+
7
+ working_dir=$(pwd)
8
+ LOG_PATH=${working_dir}/${EXPERIMENT_NAME}/train.log
9
+ SAVE_PATH=${working_dir}/${EXPERIMENT_NAME}/checkpoint
10
+ mkdir -p ${SAVE_PATH}
11
+
12
+
13
+
14
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
15
+
16
+ deepspeed --module openrlhf.cli.train_sft \
17
+ --max_len 4096 \
18
+ --dataset $TRAIN_DATA_PATH \
19
+ --input_key query \
20
+ --output_key output \
21
+ --apply_chat_template \
22
+ --train_batch_size 128 \
23
+ --micro_train_batch_size 4 \
24
+ --pretrain $PRETRAIN_PATH \
25
+ --save_path $SAVE_PATH \
26
+ --save_steps -1 \
27
+ --logging_steps 1 \
28
+ --eval_steps -1 \
29
+ --zero_stage 3 \
30
+ --max_epochs 1 \
31
+ --bf16 \
32
+ --flash_attn \
33
+ --learning_rate 5e-6 \
34
+ --packing_samples \
35
+ 2>&1 | tee ${LOG_PATH}
36
+