Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from collections import Counter | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import re | |
import logging | |
from typing import List, Dict, Any | |
import gc | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Model configurations - maps display names to HF model paths | |
PREDEFINED_MODELS = [ | |
"meta-llama/Llama-3.2-1B", | |
"google/gemma-2-2b", | |
"Qwen/Qwen3-0.6B", | |
"Qwen/Qwen2.5-0.5B", | |
"Qwen/Qwen2.5-1.5B", | |
"bigscience/bloom-560m", | |
"CohereForAI/aya-expanse-8b", | |
"common-pile/comma-v0.1-2t", | |
"google/byt5-small", | |
"google/byt5-small", | |
"gsaltintas/supertoken_models-llama_gpt2", | |
] | |
# Global cache for loaded models | |
model_cache = {} | |
def parse_dataset(text): | |
"""Parse the input dataset text into structured questions""" | |
if not text.strip(): | |
return [], "Please enter your dataset" | |
lines = text.strip().split('\n') | |
if len(lines) < 2: | |
return [], "Dataset must have at least a header and one question" | |
# Skip header and detect delimiter | |
first_data_line = lines[1] if len(lines) > 1 else lines[0] | |
delimiter = '\t' if '\t' in first_data_line else ',' | |
questions = [] | |
errors = [] | |
for i, line in enumerate(lines[1:], 2): # Start from line 2 (after header) | |
line = line.strip() | |
if not line: | |
continue | |
parts = [part.strip().strip('"') for part in line.split(delimiter)] | |
if len(parts) < 5: | |
errors.append(f"Line {i}: Not enough columns (need 5, got {len(parts)})") | |
continue | |
question = { | |
'question': parts[0], | |
'correct_answer': parts[1], | |
'choices': [parts[2], parts[3], parts[4]] | |
} | |
# Ensure correct answer is in choices | |
if question['correct_answer'] not in question['choices']: | |
question['choices'].append(question['correct_answer']) | |
questions.append(question) | |
error_msg = '\n'.join(errors) if errors else "" | |
return questions, error_msg | |
def setup_tokenizer(model_path): | |
tokenizer_name = model_path | |
if "supertoken" in model_path: | |
from huggingface_hub import list_repo_files, hf_hub_download | |
import json | |
files = list_repo_files(model_path) | |
if "tokenizer_config.json" in files: | |
tokenizer_path = hf_hub_download(repo_id=model_path, filename="tokenizer_config.json") | |
with open(tokenizer_path) as f: | |
tok_config = json.load(f)["data"]["tokenizer"] | |
if tok_config["name"] == "huggingface": | |
tokenizer_name = tok_config["path"] | |
# todo: tiktoken | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True, legacy=True) | |
return tokenizer | |
def load_model_and_tokenizer(model_path, use_cache=True, progress_callback=None): | |
"""Load model and tokenizer with caching""" | |
global model_cache | |
if use_cache and model_path in model_cache: | |
logger.info(f"Using cached model: {model_path}") | |
if progress_callback: | |
progress_callback(1.0, f"✅ Using cached model: {model_path}") | |
return model_cache[model_path] | |
try: | |
if progress_callback: | |
progress_callback(0.1, f"🔄 Starting to load model: {model_path}") | |
logger.info(f"Loading model: {model_path}") | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if progress_callback: | |
progress_callback(0.2, f"📥 Loading tokenizer for {model_path}...") | |
# Load tokenizer | |
tokenizer = setup_tokenizer(model_path) | |
# Add pad token if missing | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if progress_callback: | |
progress_callback(0.5, f"🧠 Loading model weights for {model_path}... (this may take a while)") | |
# Load model with appropriate settings | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" if device== "cuda" else None, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
model_info = { | |
'tokenizer': tokenizer, | |
'model': model, | |
'device': device | |
} | |
if use_cache: | |
model_cache[model_path] = model_info | |
if progress_callback: | |
progress_callback(1.0, f"✅ Successfully loaded model: {model_path}") | |
return model_info | |
except Exception as e: | |
import code | |
error_msg = f"❌ Error loading model {model_path}: {str(e)}" | |
logger.error(error_msg) | |
# code.interact(local=dict(globals(), **locals())) | |
if progress_callback: | |
progress_callback(0.0, error_msg) | |
return None | |
def calculate_choice_likelihood(model, tokenizer, question, choice): | |
"""Calculate the log-likelihood of the choice given the question prompt""" | |
try: | |
prompt = f"Question: {question}\nAnswer: " | |
prompt=question | |
full_text = f"{prompt} {choice}" | |
# Tokenize full input (prompt + answer) | |
input_ids = tokenizer.encode(full_text, return_tensors="pt", add_special_tokens=False).to(model.device) | |
prompt_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) | |
if input_ids.size(1) <= prompt_ids.size(1): | |
logger.warning("Answer tokens are empty after tokenization.") | |
return float("-inf") | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs.logits | |
# Get logits for the answer tokens only | |
answer_len = input_ids.size(1) - prompt_ids.size(1) | |
target_ids = input_ids[:, -answer_len:] | |
logits = logits[:, prompt_ids.size(1)-1:-1, :] # shifted for next-token prediction | |
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) | |
total_log_prob = token_log_probs.sum().item() | |
return total_log_prob | |
except Exception as e: | |
logger.error(f"Error calculating likelihood for choice '{choice}': {str(e)}") | |
return float("-inf") | |
def evaluate_model_on_questions(model_path, questions, progress_callback=None): | |
"""Evaluate a single model on all questions using likelihood-based scoring""" | |
model_info = load_model_and_tokenizer(model_path, progress_callback=progress_callback) | |
if model_info is None: | |
return [{'error': f'Failed to load model {model_path}'}] * len(questions) | |
results = [] | |
model = model_info['model'] | |
tokenizer = model_info['tokenizer'] | |
for i, question in enumerate(questions): | |
try: | |
# Calculate likelihood for each choice | |
choice_likelihoods = {} | |
choice_probs = {} | |
for choice in question['choices']: | |
likelihood = calculate_choice_likelihood(model, tokenizer, question['question'], choice) | |
choice_likelihoods[choice] = likelihood | |
# Convert log probabilities to probabilities for confidence scoring | |
max_log_prob = max(choice_likelihoods.values()) | |
choice_probs = {choice: torch.exp(torch.tensor(log_prob - max_log_prob)).item() | |
for choice, log_prob in choice_likelihoods.items()} | |
# Normalize probabilities | |
total_prob = sum(choice_probs.values()) | |
if total_prob > 0: | |
choice_probs = {choice: prob / total_prob for choice, prob in choice_probs.items()} | |
# Select the choice with highest likelihood | |
predicted_choice = max(choice_likelihoods.keys(), key=lambda x: choice_likelihoods[x]) | |
is_correct = predicted_choice == question['correct_answer'] | |
# Confidence is the probability of the selected choice | |
confidence = choice_probs.get(predicted_choice, 0.0) | |
results.append({ | |
'question_idx': i, | |
'predicted': predicted_choice, | |
'correct': is_correct, | |
'confidence': confidence, | |
'choice_likelihoods': choice_likelihoods, | |
'choice_probabilities': choice_probs, | |
'raw_response': f"Likelihoods: {choice_likelihoods}" | |
}) | |
if progress_callback: | |
# Use remaining 80% for evaluation progress | |
evaluation_progress = 0.2 + (i + 1) / len(questions) * 0.8 | |
progress_callback(evaluation_progress, f"🔍 Evaluating {model_path}: {i+1}/{len(questions)} questions (likelihood-based)") | |
except Exception as e: | |
logger.error(f"Error evaluating question {i} with {model_path}: {str(e)}") | |
results.append({ | |
'question_idx': i, | |
'predicted': question['choices'][0] if question['choices'] else '', | |
'correct': False, | |
'confidence': 0.0, | |
'choice_likelihoods': {}, | |
'choice_probabilities': {}, | |
'raw_response': f"Error: {str(e)}" | |
}) | |
return results | |
def run_evaluation(dataset_text, selected_predefined, custom_models_text="", progress=gr.Progress()): | |
"""Main evaluation function""" | |
if not dataset_text.strip(): | |
return ( | |
"Please enter your dataset", | |
"<p>No data provided</p>", | |
None, | |
None, | |
gr.update(visible=True) | |
) | |
# Parse custom models | |
custom_models = [] | |
if custom_models_text is None: | |
custom_models_text = "" | |
if custom_models_text.strip(): | |
custom_models = [model.strip() for model in custom_models_text.strip().split('\n') if model.strip()] | |
# Combine selected models | |
all_models = [] | |
# Add predefined models | |
all_models.extend(selected_predefined) | |
all_models.extend(custom_models) | |
if not all_models: | |
return ( | |
"Please select at least one model or add custom models", | |
"<p>No models selected</p>", | |
None, | |
None, | |
gr.update(visible=False) | |
) | |
# Parse dataset | |
questions, parse_error = parse_dataset(dataset_text) | |
if parse_error: | |
return ( | |
f"Dataset parsing error:\n{parse_error}", | |
"<p>Failed to parse dataset</p>", | |
None, | |
None, | |
gr.update(visible=True) | |
) | |
if not questions: | |
return ( | |
"No valid questions found in dataset", | |
"<p>No questions to evaluate</p>", | |
None, | |
None, | |
gr.update(visible=True) | |
) | |
# Run evaluation | |
progress(0, "Starting evaluation...") | |
results = {} | |
total_steps = len(all_models) * len(questions) | |
current_step = 0 | |
summary_md = create_summary_markdown({}) | |
for model_path in all_models: | |
display_name = model_path.split('/')[-1] if '/' in model_path else model_path | |
try: | |
def model_progress(p, msg): | |
nonlocal current_step | |
current_step = int(p * len(questions)) | |
overall_progress = current_step / total_steps | |
progress(overall_progress, msg) | |
model_results = evaluate_model_on_questions(model_path, questions, model_progress) | |
results[display_name] = model_results | |
except Exception as e: | |
logger.error(f"Failed to evaluate {display_name}: {str(e)}") | |
results[display_name] = [{'error': str(e)}] * len(questions) | |
# Clean up GPU memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Generate outputs | |
summary_stats = generate_summary_stats(questions, results) | |
summary_md = create_summary_markdown(summary_stats) | |
detailed_html = create_detailed_results_html(questions, results) | |
accuracy_chart = create_accuracy_chart(summary_stats) | |
confidence_chart = create_confidence_chart(results) | |
return ( | |
summary_md, | |
detailed_html, | |
accuracy_chart, | |
confidence_chart, | |
gr.update(visible=True) | |
) | |
def generate_summary_stats(questions, results): | |
"""Generate summary statistics for all models""" | |
summary = {} | |
for model, model_results in results.items(): | |
if not model_results or 'error' in model_results[0]: | |
summary[model] = { | |
'accuracy': 0.0, | |
'correct': 0, | |
'total': len(questions), | |
'avg_confidence': 0.0, | |
'error': model_results[0].get('error', 'Unknown error') if model_results else 'No results' | |
} | |
continue | |
correct_count = sum(1 for r in model_results if r.get('correct', False)) | |
total_count = len(model_results) | |
accuracy = correct_count / total_count if total_count > 0 else 0 | |
# Calculate average confidence | |
avg_confidence = sum(r.get('confidence', 0) for r in model_results) / total_count if total_count > 0 else 0 | |
summary[model] = { | |
'accuracy': accuracy, | |
'correct': correct_count, | |
'total': total_count, | |
'avg_confidence': avg_confidence | |
} | |
return summary | |
def create_summary_markdown(summary_stats): | |
"""Create markdown summary of results""" | |
if not summary_stats: | |
return "No results available" | |
# Sort by accuracy | |
sorted_models = sorted(summary_stats.items(), key=lambda x: x[1]['accuracy'], reverse=True) | |
lines = ["## 🏆 Model Performance Summary\n"] | |
for i, (model, stats) in enumerate(sorted_models): | |
if 'error' in stats: | |
lines.append(f"❌ **{model}**: Error - {stats['error']}") | |
continue | |
accuracy_pct = stats['accuracy'] * 100 | |
medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}." | |
lines.append( | |
f"{medal} **{model}**: {accuracy_pct:.1f}% " | |
f"({stats['correct']}/{stats['total']} correct, " | |
f"avg confidence: {stats['avg_confidence']:.2f})" | |
) | |
return "\n".join(lines) | |
def create_detailed_results_html(questions, results): | |
"""Create detailed HTML results for each question""" | |
if not questions or not results: | |
return "<p>No detailed results available</p>" | |
html_parts = [""" | |
<style> | |
.question-card { | |
background: white; | |
border-radius: 12px; | |
padding: 20px; | |
margin-bottom: 20px; | |
box-shadow: 0 2px 8px rgba(0,0,0,0.1); | |
border-left: 5px solid #667eea; | |
} | |
.question-header { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin-bottom: 15px; | |
} | |
.question-number { | |
background: linear-gradient(135deg, #667eea, #764ba2); | |
color: white; | |
padding: 6px 12px; | |
border-radius: 20px; | |
font-weight: bold; | |
font-size: 14px; | |
} | |
.question-text { | |
font-weight: 600; | |
font-size: 16px; | |
margin: 15px 0; | |
color: #2d3748; | |
} | |
.choices { | |
background: #f8fafc; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
.choice { | |
margin: 8px 0; | |
color: #4a5568; | |
} | |
.correct-answer { | |
background: linear-gradient(135deg, #c6f6d5, #9ae6b4); | |
border-left: 4px solid #48bb78; | |
border-radius: 6px; | |
padding: 12px; | |
margin: 10px 0; | |
font-weight: 600; | |
color: #22543d; | |
} | |
.model-results { | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); | |
gap: 12px; | |
margin-top: 15px; | |
} | |
.model-result { | |
padding: 12px; | |
border-radius: 8px; | |
text-align: center; | |
font-weight: 600; | |
transition: transform 0.2s ease; | |
} | |
.model-result:hover { | |
transform: scale(1.02); | |
} | |
.result-correct { | |
background: linear-gradient(135deg, #c6f6d5, #9ae6b4); | |
color: #22543d; | |
border: 2px solid #48bb78; | |
} | |
.result-incorrect { | |
background: linear-gradient(135deg, #fed7d7, #fca5a5); | |
color: #742a2a; | |
border: 2px solid #e53e3e; | |
} | |
.result-error { | |
background: linear-gradient(135deg, #fbb6ce, #f687b3); | |
color: #744210; | |
border: 2px solid #d69e2e; | |
} | |
.raw-response { | |
font-size: 10px; | |
margin-top: 4px; | |
opacity: 0.7; | |
font-family: monospace; | |
} | |
</style> | |
"""] | |
for q_idx, question in enumerate(questions): | |
html_parts.append(f""" | |
<div class="question-card"> | |
<div class="question-header"> | |
<span class="question-number">Q{q_idx + 1}</span> | |
</div> | |
<div class="question-text">{question['question']}</div> | |
<div class="choices"> | |
<strong>Choices:</strong><br> | |
{' | '.join(f'{chr(65+i)}) {choice}' for i, choice in enumerate(question['choices']))} | |
</div> | |
<div class="correct-answer"> | |
<strong>✓ Correct Answer:</strong> {question['correct_answer']} | |
</div> | |
<div class="model-results"> | |
""") | |
# Add results for each model | |
for model, model_results in results.items(): | |
if q_idx < len(model_results): | |
result = model_results[q_idx] | |
if 'error' in result: | |
html_parts.append(f""" | |
<div class="model-result result-error"> | |
<div>⚠️ {model}</div> | |
<div style="font-size: 12px; margin-top: 4px;"> | |
Error occurred | |
</div> | |
<div class="raw-response">{result.get('raw_response', 'Unknown error')}</div> | |
</div> | |
""") | |
else: | |
result_class = 'result-correct' if result.get('correct', False) else 'result-incorrect' | |
icon = '✅' if result.get('correct', False) else '❌' | |
html_parts.append(f""" | |
<div class="model-result {result_class}"> | |
<div>{icon} {model}</div> | |
<div style="font-size: 12px; margin-top: 4px;"> | |
"{result.get('predicted', 'No prediction')}" | |
</div> | |
<div class="raw-response">Raw: "{result.get('raw_response', '')}"</div> | |
</div> | |
""") | |
html_parts.append(""" | |
</div> | |
</div> | |
""") | |
return "".join(html_parts) | |
def create_accuracy_chart(summary_stats): | |
"""Create accuracy comparison chart""" | |
if not summary_stats: | |
return None | |
models = [] | |
accuracies = [] | |
for model, stats in summary_stats.items(): | |
if 'error' not in stats: | |
models.append(model) | |
accuracies.append(stats['accuracy'] * 100) | |
if not models: | |
return None | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=models, | |
y=accuracies, | |
marker_color='lightblue', | |
text=[f'{acc:.1f}%' for acc in accuracies], | |
textposition='auto', | |
) | |
]) | |
fig.update_layout( | |
title="Model Accuracy Comparison", | |
xaxis_title="Models", | |
yaxis_title="Accuracy (%)", | |
template="plotly_white", | |
showlegend=False | |
) | |
return fig | |
def create_confidence_chart(results): | |
"""Create confidence distribution chart""" | |
if not results: | |
return None | |
data = [] | |
for model, model_results in results.items(): | |
for result in model_results: | |
if 'error' not in result and 'confidence' in result: | |
data.append({ | |
'Model': model, | |
'Confidence': result['confidence'], | |
'Correct': 'Correct' if result.get('correct', False) else 'Incorrect' | |
}) | |
if not data: | |
return None | |
df = pd.DataFrame(data) | |
fig = px.box( | |
df, | |
x='Model', | |
y='Confidence', | |
color='Correct', | |
title="Confidence Distribution by Model and Correctness", | |
template="plotly_white" | |
) | |
return fig | |
# Sample datasets for quick testing | |
SAMPLE_DATASETS = { | |
"Custom (enter below)": "", | |
"LP": """Question,Correct Answer,Choice1,Choice2,Choice3 | |
In which country is Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch located? Wales Germany France Scotland | |
In which country is Llanfair pwllgwyngyll located? Wales Germany France Scotland | |
In which country is Llanfair PG located? Wales Germany France Scotland""", | |
"Simple Math": """Question,Correct Answer,Choice1,Choice2,Choice3 | |
What is 2+2?,4,3,2,5 | |
What is 5*3?,15,12,16,18 | |
What is 10-7?,3,7,4,2 | |
What is 8/2?,4,3,2,5""", | |
"World Capitals": """Question,Correct Answer,Choice1,Choice2,Choice3 | |
What is the capital of France?,Paris,London,Berlin,Rome | |
What is the capital of Japan?,Tokyo,Seoul,Beijing,Bangkok | |
What is the capital of Brazil?,Brasília,Rio de Janeiro,São Paulo,Salvador | |
What is the capital of Australia?,Canberra,Sydney,Melbourne,Perth""", | |
"Science Quiz": """Question,Correct Answer,Choice1,Choice2,Choice3 | |
What is the chemical symbol for gold?,Au,Ag,Ca,K | |
Which planet is closest to the Sun?,Mercury,Venus,Earth,Mars | |
What is the speed of light?,299792458 m/s,300000000 m/s,2992458 m/s,299000000 m/s | |
What gas do plants absorb from the atmosphere?,Carbon dioxide,Oxygen,Nitrogen,Hydrogen""" | |
} | |
# Custom CSS | |
css = """ | |
.gradio-container { | |
font-family: 'Inter', sans-serif; | |
} | |
.sample-text { | |
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; | |
font-size: 12px; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(title="🤖 Model Performance Comparison", theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown(""" | |
# 🤖 Model Performance Comparison Tool | |
Compare LLM performance on multiple-choice questions using Hugging Face models. | |
**Format**: Each line should have: `Question,Correct Answer,Choice1,Choice2,Choice3` | |
💡 **Features**: | |
- Model evaluation using HuggingFace transformers | |
- Support for custom models via HF model paths | |
- Detailed question-by-question results | |
- Performance charts and statistics | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Sample dataset selector | |
sample_selector = gr.Dropdown( | |
choices=list(SAMPLE_DATASETS.keys()), | |
value="Custom (enter below)", | |
label="Choose sample dataset or enter your own", | |
interactive=True | |
) | |
# Dataset input | |
dataset_input = gr.Textbox( | |
label="Dataset (CSV/TSV format)", | |
placeholder="""Enter your dataset here... | |
Example format: | |
Question,Correct Answer,Choice1,Choice2,Choice3 | |
What is 2+2?,4,3,2,5 | |
What is the capital of France?,Paris,London,Berlin,Paris""", | |
lines=8, | |
max_lines=15 | |
) | |
gr.Markdown(""" | |
**Format Requirements**: | |
- First line: header (will be ignored), leave empty if no header | |
- Each data line: Question, Correct Answer, Choice1, Choice2, Choice3 | |
- Use commas or tabs as separators | |
""") | |
with gr.Column(scale=1): | |
# Model selection | |
with gr.Tabs(): | |
with gr.TabItem("🤖 Predefined Models"): | |
predefined_selector = gr.CheckboxGroup( | |
choices=PREDEFINED_MODELS, | |
value=[PREDEFINED_MODELS[0]], | |
label="Select from popular models", | |
interactive=True | |
) | |
with gr.TabItem("➕ Custom Models"): | |
custom_models_input = gr.Textbox( | |
label="Custom HuggingFace Model Paths", | |
placeholder="""Enter HuggingFace model paths (one per line): | |
microsoft/DialoGPT-medium | |
bigscience/bloom-560m""", | |
lines=5, | |
info="Add any HuggingFace model path. One model per line.", | |
) | |
gr.Markdown(""" | |
**Examples of valid model paths**: | |
- `microsoft/DialoGPT-medium` | |
- `bigscience/bloom-560m` | |
- `facebook/opt-350m` | |
- Your own fine-tuned models! | |
""") | |
# Evaluate button | |
evaluate_btn = gr.Button( | |
"⚡ Run Evaluation", | |
variant="primary", | |
scale=1 | |
) | |
gr.Markdown(""" | |
**⚠️ Note**: | |
- Larger models require more GPU memory, currently we only run on CPU | |
- First run will download models (may take time) | |
- Models are cached for subsequent runs | |
""") | |
# Results section | |
with gr.Column(visible=True) as results_section: | |
gr.Markdown("## 📊 Results") | |
summary_output = gr.Markdown( | |
value="Results will appear here...", | |
label="Performance Summary" | |
) | |
with gr.Row(): | |
accuracy_plot = gr.Plot(label="Accuracy Comparison") | |
confidence_plot = gr.Plot(label="Confidence Analysis") | |
detailed_results = gr.HTML( | |
value="<p>Detailed results will appear here...</p>", | |
label="Detailed Question-by-Question Results" | |
) | |
# Event handlers | |
def update_dataset_from_sample(sample_name): | |
if sample_name in SAMPLE_DATASETS: | |
return gr.update(value=SAMPLE_DATASETS[sample_name]) | |
return gr.update() | |
sample_selector.change( | |
fn=update_dataset_from_sample, | |
inputs=sample_selector, | |
outputs=dataset_input | |
) | |
evaluate_btn.click( | |
fn=run_evaluation, | |
inputs=[dataset_input, predefined_selector, custom_models_input], | |
outputs=[summary_output, detailed_results, accuracy_plot, confidence_plot, results_section] | |
) | |
gr.Markdown(""" | |
--- | |
### About Model Evaluation | |
This tool loads and runs HuggingFace models for evaluation: | |
**🏗️ How it works**: | |
- Downloads models from HuggingFace Hub | |
- Formats questions as prompts for each model | |
- Runs likelihood based evaluation | |
**⚡ Performance Tips**: | |
- Use smaller models for testing | |
- Larger models (7B+) require significant GPU memory | |
- Models are cached after first load | |
**🔧 Supported Models**: | |
- Any HuggingFace autoregressive language model | |
- Both instruction-tuned and base models | |
- Custom fine-tuned models via HF paths | |
""") | |
if __name__ == "__main__": | |
demo.launch() |