Emilyxml commited on
Commit
cb8b5e8
·
verified ·
1 Parent(s): 32d8a37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -44
app.py CHANGED
@@ -1,52 +1,259 @@
1
- state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0})
2
- state_candidates_info = gr.State([])
3
-
4
- with gr.Row():
5
- md_instruction = gr.Markdown("Loading...")
6
-
7
- with gr.Row():
8
- with gr.Column(scale=1):
9
- # format 设置为 jpeg 进一步减小体积
10
- img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400, format="jpeg")
11
-
12
- with gr.Column(scale=2):
13
- gallery_candidates = gr.Gallery(
14
- label="Candidates (候选结果)",
15
- columns=[2],
16
- height="auto",
17
- object_fit="contain",
18
- interactive=False,
19
- format="jpeg" # 强制输出 JPEG 格式
 
 
 
 
 
 
 
 
20
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**")
23
-
24
- checkbox_options = gr.CheckboxGroup(
25
- choices=[],
26
- label="您的选择",
27
- info="对应上方图片的标签 (Option A, B...)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- with gr.Row():
31
- btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary")
32
- btn_none = gr.Button("🚫 都不满意 (None)", variant="stop")
 
 
 
 
 
 
 
 
 
 
33
 
34
- md_end = gr.Markdown(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- demo.load(
37
- fn=get_next_question,
38
- inputs=[state_user],
39
- outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
40
- )
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- btn_submit.click(
43
- fn=lambda s, c, o: save_and_next(s, c, o, is_none=False),
44
- inputs=[state_user, state_candidates_info, checkbox_options],
45
- outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- btn_none.click(
49
- fn=lambda s, c, o: save_and_next(s, c, o, is_none=True),
50
- inputs=[state_user, state_candidates_info, checkbox_options],
51
- outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
52
- )
 
1
+ import gradio as gr
2
+ import os
3
+ import random
4
+ import uuid
5
+ import csv
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from PIL import Image # 引入 PIL 用于处理图片
9
+ from huggingface_hub import CommitScheduler, snapshot_download
10
+
11
+ # --- 1. 配置区域 ---
12
+ DATASET_REPO_ID = "Emilyxml/moveit"
13
+ DATA_FOLDER = "data"
14
+ LOG_FOLDER = Path("logs")
15
+ LOG_FOLDER.mkdir(parents=True, exist_ok=True)
16
+ TOKEN = os.environ.get("HF_TOKEN")
17
+
18
+ # --- 2. 自动下载数据 ---
19
+ if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER):
20
+ try:
21
+ print("🚀 正在从 Dataset 下载数据...")
22
+ snapshot_download(
23
+ repo_id=DATASET_REPO_ID,
24
+ repo_type="dataset",
25
+ local_dir=DATA_FOLDER,
26
+ token=TOKEN,
27
+ allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"]
28
  )
29
+ print("✅ 数据下载完成!")
30
+ except Exception as e:
31
+ print(f"⚠️ 下载失败: {e}")
32
+
33
+ # --- 3. 启动同步调度器 ---
34
+ scheduler = CommitScheduler(
35
+ repo_id=DATASET_REPO_ID,
36
+ repo_type="dataset",
37
+ folder_path=LOG_FOLDER,
38
+ path_in_repo="logs",
39
+ every=1,
40
+ token=TOKEN
41
+ )
42
+
43
+ # --- 4. 数据加载 ---
44
+ def load_data():
45
+ groups = {}
46
+ if not os.path.exists(DATA_FOLDER):
47
+ return {}, []
48
+
49
+ for filename in os.listdir(DATA_FOLDER):
50
+ if filename.startswith('.'): continue
51
+ file_path = os.path.join(DATA_FOLDER, filename)
52
+ prefix = filename[:5]
53
 
54
+ if prefix not in groups:
55
+ groups[prefix] = {"origin": None, "candidates": [], "instruction": "暂无说明"}
56
+
57
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
58
+ if "_origin" in filename.lower():
59
+ groups[prefix]["origin"] = file_path
60
+ else:
61
+ groups[prefix]["candidates"].append(file_path)
62
+ elif filename.lower().endswith('.txt'):
63
+ try:
64
+ with open(file_path, "r", encoding="utf-8") as f:
65
+ groups[prefix]["instruction"] = f.read()
66
+ except:
67
+ with open(file_path, "r", encoding="gbk") as f:
68
+ groups[prefix]["instruction"] = f.read()
69
+
70
+ valid_groups = {}
71
+ for k, v in groups.items():
72
+ if v["origin"] is not None or len(v["candidates"]) > 0:
73
+ valid_groups[k] = v
74
+
75
+ group_ids = list(valid_groups.keys())
76
+ random.shuffle(group_ids)
77
+ print(f"Loaded {len(group_ids)} groups.")
78
+ return valid_groups, group_ids
79
+
80
+ ALL_GROUPS, ALL_GROUP_IDS = load_data()
81
+
82
+ # --- NEW: 图片优化函数 (提速关键) ---
83
+ def optimize_image(image_path, max_width=800):
84
+ """
85
+ 读取图片并调整大小,减少传输时间。
86
+ max_width: 限制最大宽度为 800px (足够人眼评估)
87
+ """
88
+ if not image_path:
89
+ return None
90
+ try:
91
+ img = Image.open(image_path)
92
+ # 如果图片太大,就缩小
93
+ if img.width > max_width:
94
+ ratio = max_width / img.width
95
+ new_height = int(img.height * ratio)
96
+ img = img.resize((max_width, new_height), Image.LANCZOS)
97
+ return img
98
+ except Exception as e:
99
+ print(f"Error loading image {image_path}: {e}")
100
+ return None
101
+
102
+ # --- 5. 核心逻辑 ---
103
+
104
+ def get_next_question(user_state):
105
+ """准备下一题的数据"""
106
+ idx = user_state["index"]
107
+
108
+ if idx >= len(ALL_GROUP_IDS):
109
+ return (
110
+ gr.update(visible=False),
111
+ gr.update(visible=False),
112
+ gr.update(visible=False),
113
+ gr.update(visible=False),
114
+ gr.update(visible=False),
115
+ gr.update(visible=False),
116
+ gr.update(value="## 🎉 测试结束!感谢您的参与。", visible=True),
117
+ user_state,
118
+ []
119
  )
120
+
121
+ group_id = ALL_GROUP_IDS[idx]
122
+ group_data = ALL_GROUPS[group_id]
123
+
124
+ # 1. 优化原图 (返回 PIL 对象而不是路径)
125
+ origin_img = optimize_image(group_data["origin"], max_width=600)
126
+
127
+ # 2. 优化候选图
128
+ candidates = group_data["candidates"].copy()
129
+ random.shuffle(candidates)
130
+
131
+ gallery_items = []
132
+ choices = []
133
+ candidates_info = []
134
+
135
+ for i, path in enumerate(candidates):
136
+ label = f"Option {chr(65+i)}"
137
+
138
+ # 优化每张候选图
139
+ optimized_img = optimize_image(path, max_width=600)
140
+
141
+ gallery_items.append((optimized_img, label))
142
+ choices.append(label)
143
+ candidates_info.append({"label": label, "path": path})
144
 
145
+ instruction = f"### 任务 ({idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}"
146
+
147
+ return (
148
+ gr.update(value=origin_img, visible=True if origin_img else False),
149
+ gr.update(value=gallery_items, visible=True),
150
+ gr.update(choices=choices, value=[], visible=True),
151
+ gr.update(value=instruction, visible=True),
152
+ gr.update(visible=True),
153
+ gr.update(visible=True),
154
+ gr.update(visible=False),
155
+ user_state,
156
+ candidates_info
157
+ )
158
 
159
+ def save_and_next(user_state, candidates_info, selected_options, is_none=False):
160
+ current_idx = user_state["index"]
161
+ group_id = ALL_GROUP_IDS[current_idx]
162
+
163
+ if is_none:
164
+ choice_str = "Rejected All"
165
+ method_str = "None_Satisfied"
166
+ else:
167
+ if not selected_options:
168
+ raise gr.Error("请至少勾选一个选项,或点击“都不满意”")
169
+
170
+ choice_str = "; ".join(selected_options)
171
+ selected_methods = []
172
+ for opt in selected_options:
173
+ for info in candidates_info:
174
+ if info["label"] == opt:
175
+ path = info["path"]
176
+ filename = os.path.basename(path)
177
+ name = os.path.splitext(filename)[0]
178
+ parts = name.split('_', 1)
179
+ method = parts[1] if len(parts) > 1 else name
180
+ selected_methods.append(method)
181
+ break
182
+ method_str = "; ".join(selected_methods)
183
 
184
+ user_file = LOG_FOLDER / f"user_{user_state['user_id']}.csv"
185
+ with scheduler.lock:
186
+ exists = user_file.exists()
187
+ with open(user_file, "a", newline="", encoding="utf-8") as f:
188
+ writer = csv.writer(f)
189
+ if not exists:
190
+ writer.writerow(["user_id", "timestamp", "group_id", "choices", "methods"])
191
+ writer.writerow([
192
+ user_state["user_id"],
193
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
194
+ group_id,
195
+ choice_str,
196
+ method_str
197
+ ])
198
+
199
+ user_state["index"] += 1
200
+ return get_next_question(user_state)
201
 
202
+ # --- 6. 界面构建 ---
203
+ with gr.Blocks(title="User Study") as demo:
204
+
205
+ state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0})
206
+ state_candidates_info = gr.State([])
207
+
208
+ with gr.Row():
209
+ md_instruction = gr.Markdown("Loading...")
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=1):
213
+ # 将 format 设置为 jpeg 进一步减小体积
214
+ img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400, format="jpeg")
215
+
216
+ with gr.Column(scale=2):
217
+ gallery_candidates = gr.Gallery(
218
+ label="Candidates (候选结果)",
219
+ columns=[2],
220
+ height="auto",
221
+ object_fit="contain",
222
+ interactive=False,
223
+ format="jpeg" # 强制输出 JPEG 格式
224
+ )
225
+
226
+ gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**")
227
+
228
+ checkbox_options = gr.CheckboxGroup(
229
+ choices=[],
230
+ label="您的选择",
231
+ info="对应上方图片的标签 (Option A, B...)"
232
+ )
233
+
234
+ with gr.Row():
235
+ btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary")
236
+ btn_none = gr.Button("🚫 都不满意 (None)", variant="stop")
237
+
238
+ md_end = gr.Markdown(visible=False)
239
+
240
+ demo.load(
241
+ fn=get_next_question,
242
+ inputs=[state_user],
243
+ outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
244
+ )
245
+
246
+ btn_submit.click(
247
+ fn=lambda s, c, o: save_and_next(s, c, o, is_none=False),
248
+ inputs=[state_user, state_candidates_info, checkbox_options],
249
+ outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
250
+ )
251
+
252
+ btn_none.click(
253
+ fn=lambda s, c, o: save_and_next(s, c, o, is_none=True),
254
+ inputs=[state_user, state_candidates_info, checkbox_options],
255
+ outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
256
+ )
257
 
258
+ if __name__ == "__main__":
259
+ demo.launch()