import gradio as gr import os import random import uuid import csv from datetime import datetime from pathlib import Path from PIL import Image # 引入 PIL 用于处理图片 from huggingface_hub import CommitScheduler, snapshot_download # --- 1. 配置区域 --- DATASET_REPO_ID = "Emilyxml/moveit" DATA_FOLDER = "data" LOG_FOLDER = Path("logs") LOG_FOLDER.mkdir(parents=True, exist_ok=True) TOKEN = os.environ.get("HF_TOKEN") # --- 2. 自动下载数据 --- if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER): try: print("🚀 正在从 Dataset 下载数据...") snapshot_download( repo_id=DATASET_REPO_ID, repo_type="dataset", local_dir=DATA_FOLDER, token=TOKEN, allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"] ) print("✅ 数据下载完成!") except Exception as e: print(f"⚠️ 下载失败: {e}") # --- 3. 启动同步调度器 --- scheduler = CommitScheduler( repo_id=DATASET_REPO_ID, repo_type="dataset", folder_path=LOG_FOLDER, path_in_repo="logs", every=1, token=TOKEN ) # --- 4. 数据加载 --- def load_data(): groups = {} if not os.path.exists(DATA_FOLDER): return {}, [] for filename in os.listdir(DATA_FOLDER): if filename.startswith('.'): continue file_path = os.path.join(DATA_FOLDER, filename) prefix = filename[:5] if prefix not in groups: groups[prefix] = {"origin": None, "candidates": [], "instruction": "暂无说明"} if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): if "_origin" in filename.lower(): groups[prefix]["origin"] = file_path else: groups[prefix]["candidates"].append(file_path) elif filename.lower().endswith('.txt'): try: with open(file_path, "r", encoding="utf-8") as f: groups[prefix]["instruction"] = f.read() except: with open(file_path, "r", encoding="gbk") as f: groups[prefix]["instruction"] = f.read() valid_groups = {} for k, v in groups.items(): if v["origin"] is not None or len(v["candidates"]) > 0: valid_groups[k] = v group_ids = list(valid_groups.keys()) random.shuffle(group_ids) print(f"Loaded {len(group_ids)} groups.") return valid_groups, group_ids ALL_GROUPS, ALL_GROUP_IDS = load_data() # --- NEW: 图片优化函数 (提速关键) --- def optimize_image(image_path, max_width=800): """ 读取图片并调整大小,减少传输时间。 max_width: 限制最大宽度为 800px (足够人眼评估) """ if not image_path: return None try: img = Image.open(image_path) # 如果图片太大,就缩小 if img.width > max_width: ratio = max_width / img.width new_height = int(img.height * ratio) img = img.resize((max_width, new_height), Image.LANCZOS) return img except Exception as e: print(f"Error loading image {image_path}: {e}") return None # --- 5. 核心逻辑 --- def get_next_question(user_state): """准备下一题的数据""" idx = user_state["index"] if idx >= len(ALL_GROUP_IDS): return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="## 🎉 测试结束!感谢您的参与。", visible=True), user_state, [] ) group_id = ALL_GROUP_IDS[idx] group_data = ALL_GROUPS[group_id] # 1. 优化原图 (返回 PIL 对象而不是路径) origin_img = optimize_image(group_data["origin"], max_width=600) # 2. 优化候选图 candidates = group_data["candidates"].copy() random.shuffle(candidates) gallery_items = [] choices = [] candidates_info = [] for i, path in enumerate(candidates): label = f"Option {chr(65+i)}" # 优化每张候选图 optimized_img = optimize_image(path, max_width=600) gallery_items.append((optimized_img, label)) choices.append(label) candidates_info.append({"label": label, "path": path}) instruction = f"### 任务 ({idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}" return ( gr.update(value=origin_img, visible=True if origin_img else False), gr.update(value=gallery_items, visible=True), gr.update(choices=choices, value=[], visible=True), gr.update(value=instruction, visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), user_state, candidates_info ) def save_and_next(user_state, candidates_info, selected_options, is_none=False): current_idx = user_state["index"] group_id = ALL_GROUP_IDS[current_idx] if is_none: choice_str = "Rejected All" method_str = "None_Satisfied" else: if not selected_options: raise gr.Error("请至少勾选一个选项,或点击“都不满意”") choice_str = "; ".join(selected_options) selected_methods = [] for opt in selected_options: for info in candidates_info: if info["label"] == opt: path = info["path"] filename = os.path.basename(path) name = os.path.splitext(filename)[0] parts = name.split('_', 1) method = parts[1] if len(parts) > 1 else name selected_methods.append(method) break method_str = "; ".join(selected_methods) user_file = LOG_FOLDER / f"user_{user_state['user_id']}.csv" with scheduler.lock: exists = user_file.exists() with open(user_file, "a", newline="", encoding="utf-8") as f: writer = csv.writer(f) if not exists: writer.writerow(["user_id", "timestamp", "group_id", "choices", "methods"]) writer.writerow([ user_state["user_id"], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), group_id, choice_str, method_str ]) user_state["index"] += 1 return get_next_question(user_state) # --- 6. 界面构建 --- with gr.Blocks(title="User Study") as demo: state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0}) state_candidates_info = gr.State([]) with gr.Row(): md_instruction = gr.Markdown("Loading...") with gr.Row(): with gr.Column(scale=1): # 将 format 设置为 jpeg 进一步减小体积 img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400, format="jpeg") with gr.Column(scale=2): gallery_candidates = gr.Gallery( label="Candidates (候选结果)", columns=[2], height="auto", object_fit="contain", interactive=False, format="jpeg" # 强制输出 JPEG 格式 ) gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**") checkbox_options = gr.CheckboxGroup( choices=[], label="您的选择", info="对应上方图片的标签 (Option A, B...)" ) with gr.Row(): btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary") btn_none = gr.Button("🚫 都不满意 (None)", variant="stop") md_end = gr.Markdown(visible=False) demo.load( fn=get_next_question, inputs=[state_user], outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] ) btn_submit.click( fn=lambda s, c, o: save_and_next(s, c, o, is_none=False), inputs=[state_user, state_candidates_info, checkbox_options], outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] ) btn_none.click( fn=lambda s, c, o: save_and_next(s, c, o, is_none=True), inputs=[state_user, state_candidates_info, checkbox_options], outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] ) if __name__ == "__main__": demo.launch()