ct_test / app.py
Cbphcr's picture
Update app.py
d2b2b78 verified
import os
import time
import json
import shutil
import zipfile
import gradio as gr
from eval_exp import evaluate
from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
def load_splits():
splits_dir = "chinatravel/evaluation/default_splits"
splits = []
for filename in os.listdir(splits_dir):
if filename.endswith(".txt"):
splits.append(filename.replace(".txt", ""))
return splits
SPLITS_LIST = load_splits()
# SUBMIT_DIR = "./submissions"
# OUTPUT_DIR = "./outputs"
SUBMIT_DIR = os.path.abspath("submissions")
OUTPUT_DIR = os.path.abspath("outputs")
# clear directories if they already exist
shutil.rmtree(SUBMIT_DIR, ignore_errors=True)
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
os.makedirs(SUBMIT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
def clean_old_outputs(folder, keep_hours=24):
now = time.time()
for fname in os.listdir(folder):
fpath = os.path.join(folder, fname)
if os.path.isfile(fpath) and now - os.path.getmtime(fpath) > keep_hours * 3600:
os.remove(fpath)
scheduler = BackgroundScheduler()
scheduler.add_job(lambda: clean_old_outputs(OUTPUT_DIR), "interval", hours=12)
scheduler.start()
class Arguments:
def __init__(self, splits, result_dir):
self.splits = splits
self.result_dir = result_dir
def handle_submission(zip_file, dataset_choice):
if zip_file is None:
yield "❌ 请上传 zip 文件!", 0, 0, 0, None
return
shutil.rmtree(SUBMIT_DIR, ignore_errors=True)
os.makedirs(SUBMIT_DIR, exist_ok=True)
with zipfile.ZipFile(zip_file.name, "r") as zip_ref:
print(f"正在解压缩 {zip_file.name}{SUBMIT_DIR}...")
zip_ref.extractall(SUBMIT_DIR)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"Submission dir: {SUBMIT_DIR}")
print(os.path.splitext(zip_file.name))
unzipped_dir = os.path.join(
SUBMIT_DIR, os.path.basename(zip_file.name).replace(".zip", "")
)
print(f"Unzipped directory: {unzipped_dir}")
output_path = os.path.join(OUTPUT_DIR, f"result_main_{timestamp}.json")
args = Arguments(splits=dataset_choice, result_dir=unzipped_dir)
try:
yield "🚀 开始测评...", 0, 0, 0, None
result = {}
for progress in evaluate(args, result):
stage = progress.get("stage", "")
progress_value = progress.get("progress", 0)
if stage == "schema":
yield "Schema 阶段测评中...", progress_value, 0, 0, None
elif stage == "commonsense":
yield "Commonsense 阶段测评中...", 100, progress_value, 0, None
elif stage == "logic":
yield "Logic 阶段测评中...", 100, 100, progress_value, None
elif stage == "final":
result.update(progress.get("result", {}))
yield "测评完成,正在保存结果...", 100, 100, 100, None
# 保存结果到文件
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=4)
# 在测评完成后更新结果文件的值和可见性
result_file.value = output_path
result_file.visible = True
yield "✅ 测评完成!", 100, 100, 100, output_path
except Exception as e:
import traceback
traceback.print_exc()
yield f"❌ 测评异常:{e}", 0, 0, 0, None
with gr.Blocks() as demo:
gr.Markdown("# 📊 ChinaTravel 模型测评")
with gr.Row():
zip_input = gr.File(label="上传模型预测 zip 文件", file_types=[".zip"])
dataset_choice = gr.Radio(
SPLITS_LIST, label="选择评估数据集", value="validation"
)
submit_btn = gr.Button("开始测评")
output_msg = gr.Markdown()
result_file = gr.File(label="结果文件下载") # , visible=False)
# 添加三个进度条
schema_progress = gr.Slider(
label="Schema 阶段进度", minimum=0, maximum=100, value=0, interactive=False
)
commonsense_progress = gr.Slider(
label="Commonsense 阶段进度", minimum=0, maximum=100, value=0, interactive=False
)
logic_progress = gr.Slider(
label="Logic 阶段进度", minimum=0, maximum=100, value=0, interactive=False
)
submit_btn.click(
handle_submission,
inputs=[zip_input, dataset_choice],
outputs=[
output_msg,
schema_progress,
commonsense_progress,
logic_progress,
result_file,
],
)
demo.launch(debug=True)