from PIL import Image, ImageDraw, ImageFont import base64 import json from io import BytesIO import requests from urllib.parse import quote BASE_HF_URL = 'https://huggingface.co/datasets/ZonHu1/KRIS-Image/resolve/main' base_bench_dir = 'KRIS_Bench' base_results_dir = 'KRIS_Bench_Results' # cache _annotation_cache = {} _image_cache = {} def get_models(): """Get available models from HF dataset""" try: # Use HF API to get directory listing api_url = f"https://huggingface.co/api/datasets/ZonHu1/KRIS-Image/tree/main/{base_results_dir}" response = requests.get(api_url, timeout=10) if response.status_code == 200: data = response.json() models = [item['path'].split('/')[-1] for item in data if item['type'] == 'directory'] return models else: print(f"Failed to fetch models from HF API, status: {response.status_code}") return [] except Exception as e: print(f"Error fetching models from HF: {e}") return [] def get_categories(): return [ 'abstract_reasoning', 'anomaly_correction', 'biology', 'chemistry', 'color_change', 'count_change', 'geography', 'humanities', 'mathematics', 'medicine', 'multi-element_composition', 'multi-instruction_execution', 'part_completion', 'physics', 'position_movement', 'practical_knowledge', 'rule-based_reasoning', 'size_adjustment', 'temporal_prediction', 'viewpoint_change' ] def get_url_response(url): """response from URL""" try: response = requests.get(url, timeout=10) response.raise_for_status() return response except Exception as e: print(f"Error fetching {url}: {e}") return None def load_image(path): """Load image from Hugging Face dataset with caching""" if path in _image_cache: return _image_cache[path] # Convert path to URL format url = f"{BASE_HF_URL}/{path}" response = get_url_response(url) if response: image = Image.open(BytesIO(response.content)) _image_cache[path] = image return image print(f"Failed to Load Image from:{path}") return None def load_annotations(cat): """Load annotations with caching""" if cat in _annotation_cache: return _annotation_cache[cat] url = f"{BASE_HF_URL}/{base_bench_dir}/{cat}/annotation.json" response = get_url_response(url) if response: annotations = response.json() _annotation_cache[cat] = annotations return annotations print(f"Failed to Load Annotations for {cat}") return {} def horizontal_concat(imgs, bg_color=(255,255,255), max_height=400): # Concatenate images horizontally to display if not imgs: return None target_height = min(max_height, max(img.height for img in imgs)) resized_imgs = [] for img in imgs: if img.height != target_height: ratio = target_height / img.height new_width = int(img.width * ratio) resized_img = img.resize((new_width, target_height), Image.Resampling.LANCZOS) resized_imgs.append(resized_img) else: resized_imgs.append(img) total_width = sum(img.width for img in resized_imgs) new_im = Image.new('RGB', (total_width, target_height), color=bg_color) x_offset = 0 for img in resized_imgs: new_im.paste(img, (x_offset, 0)) x_offset += img.width return new_im def create_placeholder_image(width=512, height=512): """"This is for temporal prediction input images demonstration""" img = Image.new('RGB', (width, height), color=(128, 128, 128)) draw = ImageDraw.Draw(img) font_size = min(width, height) // 4 font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" font = ImageFont.truetype(font_path, font_size) text = "?" # get text bounding box bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] x = (width - text_width) // 2 y = (height - text_height) // 2 draw.text((x, y), text, fill=(255, 255, 255), font=font) return img def handle_temporal_prediction(ori_imgs, cat): """Handle temporal prediction task input images""" if not isinstance(ori_imgs, list) or len(ori_imgs) != 3: return None img_data = [] for img_name in ori_imgs: img = load_image(f"{base_bench_dir}/{cat}/{img_name}") try: parts = img_name.split('-') if len(parts) >= 2: step_part = parts[-1].split('.')[0] time_step = int(step_part) else: time_step = 0 except: time_step = 0 img_data.append((img, time_step, img_name)) img_data.sort(key=lambda x: x[1]) # Create result array for steps 1-4 result_images = [None] * 4 used_steps = set() for img, step, name in img_data: if 1 <= step <= 4: result_images[step - 1] = img used_steps.add(step) all_steps = set([1, 2, 3, 4]) missing_steps = all_steps - used_steps if missing_steps: missing_step = list(missing_steps)[0] placeholder = create_placeholder_image(img_data[0][0].width, img_data[0][0].height) result_images[missing_step - 1] = placeholder result_images = [img for img in result_images if img is not None] return result_images def find_result_image(model, cat, key): """Find result image with different possible extensions and caching""" cache_key = f"{model}_{cat}_{key}" if cache_key in _image_cache: return _image_cache[cache_key] possible_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'] for ext in possible_extensions: fname = f"{key}{ext}" try: img = load_image(f"{base_results_dir}/{model}/{cat}/{fname}") if img is None: continue _image_cache[cache_key] = img return img except Exception: continue # If no image found, return a Gray Image default_img = Image.new('RGB', (512, 512), color=(128, 128, 128)) _image_cache[cache_key] = default_img return default_img def show_one(model, cat, idx): annos = load_annotations(cat) total = len(annos) if total == 0: return None, None, None, "No annotations", "", "", idx, total idx = idx % total key = str(idx + 1) if key not in annos: return None, None, None, "No annotation", "", "", idx, total a = annos[key] ori_imgs = a.get('ori_img') original_images = [] ori_combined = None if isinstance(ori_imgs, list): if cat.lower() == 'temporal_prediction': temporal_images = handle_temporal_prediction(ori_imgs, cat) if temporal_images: original_images = temporal_images ori_combined = horizontal_concat(temporal_images, max_height=300) else: pil_list = [load_image(f"{base_bench_dir}/{cat}/{fn}") for fn in ori_imgs] original_images = pil_list ori_combined = horizontal_concat(pil_list, max_height=300) else: pil_list = [load_image(f"{base_bench_dir}/{cat}/{fn}") for fn in ori_imgs] original_images = pil_list ori_combined = horizontal_concat(pil_list, max_height=300) else: single_img = load_image(f"{base_bench_dir}/{cat}/{a.get('ori_img','')}") original_images = [single_img] ori_combined = single_img res = find_result_image(model, cat, key) return ori_combined, original_images, res, a.get('ori_img',''), a.get('ins_en',''), a.get('explain_en',''), idx, total def pil_to_base64(img): """Convert PIL image to base64 string""" if img is None: return "" buffer = BytesIO() img.save(buffer, format='PNG') img_str = base64.b64encode(buffer.getvalue()).decode() return f"data:image/png;base64,{img_str}" def get_single_model_result_html(model, cat, idx): """Get HTML for a single model result""" _, _, img, _, _, _, _, _ = show_one(model, cat, idx) img_base64 = pil_to_base64(img) if img_base64: return f'''
No models selected
" HTML_HEAD = '''No models available
", 0, "0/0" try: ori_combined, original_images, _, index_info, instruction, explanation, new_idx, total = show_one(available_models[0], cat, idx) model_results_html = get_model_results_html(cat, new_idx, selected_models) cnt = f"{new_idx+1}/{total}" return ori_combined, index_info, instruction, explanation, model_results_html, new_idx, cnt except Exception as e: print(f"Error in render function: {e}") return None, "Error", "", "", "Error rendering content
", 0, "0/0"