Spaces:
Paused
Paused
best-of-N exceed token issue
Browse files
app.py
CHANGED
|
@@ -161,7 +161,10 @@ def plan2align_translate_text(text, session_id, model, tokenizer, device, src_la
|
|
| 161 |
reward_model_type=reward_model_type,
|
| 162 |
session_id=session_id
|
| 163 |
)
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
| 165 |
return result, score
|
| 166 |
|
| 167 |
def evaluate_candidates(source, candidates, language, session_id):
|
|
@@ -178,21 +181,25 @@ def original_translation(text, src_language, target_language, session_id):
|
|
| 178 |
return "", 0
|
| 179 |
|
| 180 |
def best_of_n_translation(text, src_language, target_language, n, session_id):
|
| 181 |
-
if not check_token_length(text,
|
| 182 |
-
return "Warning: Input text
|
| 183 |
candidates = []
|
| 184 |
for i in range(n):
|
| 185 |
cand_list = basic_translate(text, src_language, target_language)
|
| 186 |
if cand_list:
|
| 187 |
candidates.append(cand_list[0])
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
return best, score
|
| 192 |
|
| 193 |
def mpc_translation(text, src_language, target_language, iterations, session_id):
|
| 194 |
-
if not check_token_length(text,
|
| 195 |
-
return "Warning: Input text
|
| 196 |
current_trans = ""
|
| 197 |
best_score = None
|
| 198 |
for i in range(iterations):
|
|
@@ -201,11 +208,17 @@ def mpc_translation(text, src_language, target_language, iterations, session_id)
|
|
| 201 |
else:
|
| 202 |
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
return current_trans, best_score
|
| 210 |
|
| 211 |
# ---------- Gradio function ----------
|
|
@@ -240,8 +253,7 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
|
|
| 240 |
)
|
| 241 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
| 242 |
if "Best-of-N" in translation_methods:
|
| 243 |
-
best_candidate, best_score = best_of_n_translation(text, src_language, target_language,
|
| 244 |
-
max_iterations_value, session_id)
|
| 245 |
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
|
| 246 |
if "MPC" in translation_methods:
|
| 247 |
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
|
|
|
|
| 161 |
reward_model_type=reward_model_type,
|
| 162 |
session_id=session_id
|
| 163 |
)
|
| 164 |
+
try:
|
| 165 |
+
_, score = evaluate_candidates(text, [result], task_language, session_id)
|
| 166 |
+
except:
|
| 167 |
+
score = 0
|
| 168 |
return result, score
|
| 169 |
|
| 170 |
def evaluate_candidates(source, candidates, language, session_id):
|
|
|
|
| 181 |
return "", 0
|
| 182 |
|
| 183 |
def best_of_n_translation(text, src_language, target_language, n, session_id):
|
| 184 |
+
if not check_token_length(text, 4096):
|
| 185 |
+
return "Warning: Input text too long.", 0
|
| 186 |
candidates = []
|
| 187 |
for i in range(n):
|
| 188 |
cand_list = basic_translate(text, src_language, target_language)
|
| 189 |
if cand_list:
|
| 190 |
candidates.append(cand_list[0])
|
| 191 |
+
try:
|
| 192 |
+
best, score = evaluate_candidates(text, candidates, target_language, session_id)
|
| 193 |
+
print("best_of_n evaluate_candidates results:")
|
| 194 |
+
print(best, score)
|
| 195 |
+
except:
|
| 196 |
+
print("evaluate_candidates fail")
|
| 197 |
+
return "Warning: Input text too long.", 0
|
| 198 |
return best, score
|
| 199 |
|
| 200 |
def mpc_translation(text, src_language, target_language, iterations, session_id):
|
| 201 |
+
if not check_token_length(text, 4096):
|
| 202 |
+
return "Warning: Input text too long.", 0
|
| 203 |
current_trans = ""
|
| 204 |
best_score = None
|
| 205 |
for i in range(iterations):
|
|
|
|
| 208 |
else:
|
| 209 |
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
|
| 210 |
|
| 211 |
+
try:
|
| 212 |
+
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
|
| 213 |
+
print("mpc evaluate_candidates results:")
|
| 214 |
+
print(best, score)
|
| 215 |
+
current_trans = best
|
| 216 |
+
best_score = score
|
| 217 |
+
except:
|
| 218 |
+
print("evaluate_candidates fail")
|
| 219 |
+
current_trans = cand_list[0]
|
| 220 |
+
best_score = 0
|
| 221 |
+
|
| 222 |
return current_trans, best_score
|
| 223 |
|
| 224 |
# ---------- Gradio function ----------
|
|
|
|
| 253 |
)
|
| 254 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
| 255 |
if "Best-of-N" in translation_methods:
|
| 256 |
+
best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
|
|
|
|
| 257 |
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
|
| 258 |
if "MPC" in translation_methods:
|
| 259 |
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
|