|
import os |
|
import time |
|
import json |
|
import shutil |
|
import zipfile |
|
import gradio as gr |
|
from eval_exp import evaluate |
|
from datetime import datetime |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
|
|
|
|
def load_splits(): |
|
splits_dir = "chinatravel/evaluation/default_splits" |
|
splits = [] |
|
for filename in os.listdir(splits_dir): |
|
if filename.endswith(".txt"): |
|
splits.append(filename.replace(".txt", "")) |
|
return splits |
|
|
|
|
|
SPLITS_LIST = load_splits() |
|
|
|
|
|
SUBMIT_DIR = os.path.abspath("submissions") |
|
OUTPUT_DIR = os.path.abspath("outputs") |
|
|
|
|
|
shutil.rmtree(SUBMIT_DIR, ignore_errors=True) |
|
shutil.rmtree(OUTPUT_DIR, ignore_errors=True) |
|
|
|
os.makedirs(SUBMIT_DIR, exist_ok=True) |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
def clean_old_outputs(folder, keep_hours=24): |
|
now = time.time() |
|
for fname in os.listdir(folder): |
|
fpath = os.path.join(folder, fname) |
|
if os.path.isfile(fpath) and now - os.path.getmtime(fpath) > keep_hours * 3600: |
|
os.remove(fpath) |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(lambda: clean_old_outputs(OUTPUT_DIR), "interval", hours=12) |
|
scheduler.start() |
|
|
|
|
|
class Arguments: |
|
def __init__(self, splits, result_dir): |
|
self.splits = splits |
|
self.result_dir = result_dir |
|
|
|
|
|
def handle_submission(zip_file, dataset_choice): |
|
if zip_file is None: |
|
yield "❌ 请上传 zip 文件!", 0, 0, 0, None |
|
return |
|
|
|
shutil.rmtree(SUBMIT_DIR, ignore_errors=True) |
|
os.makedirs(SUBMIT_DIR, exist_ok=True) |
|
|
|
with zipfile.ZipFile(zip_file.name, "r") as zip_ref: |
|
print(f"正在解压缩 {zip_file.name} 到 {SUBMIT_DIR}...") |
|
zip_ref.extractall(SUBMIT_DIR) |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
print(f"Submission dir: {SUBMIT_DIR}") |
|
print(os.path.splitext(zip_file.name)) |
|
unzipped_dir = os.path.join( |
|
SUBMIT_DIR, os.path.basename(zip_file.name).replace(".zip", "") |
|
) |
|
print(f"Unzipped directory: {unzipped_dir}") |
|
output_path = os.path.join(OUTPUT_DIR, f"result_main_{timestamp}.json") |
|
args = Arguments(splits=dataset_choice, result_dir=unzipped_dir) |
|
|
|
try: |
|
yield "🚀 开始测评...", 0, 0, 0, None |
|
|
|
result = {} |
|
for progress in evaluate(args, result): |
|
stage = progress.get("stage", "") |
|
progress_value = progress.get("progress", 0) |
|
|
|
if stage == "schema": |
|
yield "Schema 阶段测评中...", progress_value, 0, 0, None |
|
elif stage == "commonsense": |
|
yield "Commonsense 阶段测评中...", 100, progress_value, 0, None |
|
elif stage == "logic": |
|
yield "Logic 阶段测评中...", 100, 100, progress_value, None |
|
elif stage == "final": |
|
result.update(progress.get("result", {})) |
|
yield "测评完成,正在保存结果...", 100, 100, 100, None |
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
json.dump(result, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
result_file.value = output_path |
|
result_file.visible = True |
|
yield "✅ 测评完成!", 100, 100, 100, output_path |
|
|
|
except Exception as e: |
|
import traceback |
|
|
|
traceback.print_exc() |
|
yield f"❌ 测评异常:{e}", 0, 0, 0, None |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 📊 ChinaTravel 模型测评") |
|
|
|
with gr.Row(): |
|
zip_input = gr.File(label="上传模型预测 zip 文件", file_types=[".zip"]) |
|
dataset_choice = gr.Radio( |
|
SPLITS_LIST, label="选择评估数据集", value="validation" |
|
) |
|
|
|
submit_btn = gr.Button("开始测评") |
|
|
|
output_msg = gr.Markdown() |
|
result_file = gr.File(label="结果文件下载") |
|
|
|
|
|
schema_progress = gr.Slider( |
|
label="Schema 阶段进度", minimum=0, maximum=100, value=0, interactive=False |
|
) |
|
commonsense_progress = gr.Slider( |
|
label="Commonsense 阶段进度", minimum=0, maximum=100, value=0, interactive=False |
|
) |
|
logic_progress = gr.Slider( |
|
label="Logic 阶段进度", minimum=0, maximum=100, value=0, interactive=False |
|
) |
|
|
|
submit_btn.click( |
|
handle_submission, |
|
inputs=[zip_input, dataset_choice], |
|
outputs=[ |
|
output_msg, |
|
schema_progress, |
|
commonsense_progress, |
|
logic_progress, |
|
result_file, |
|
], |
|
) |
|
|
|
demo.launch(debug=True) |
|
|