Spaces:
Sleeping
Sleeping
import gradio as gr | |
from safetensors.torch import load_file | |
from model_loader import get_top_layers, load_model_summary, load_config | |
import tempfile | |
import os | |
import requests | |
import json | |
import logging | |
import traceback | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def inspect_model(model_id, config_file=None): | |
logger.info(f"Processing model ID: {model_id}") | |
if not model_id or '/' not in model_id: | |
return "Please provide a valid model ID in the format username/modelname", "No config loaded." | |
username, modelname = model_id.split('/', 1) | |
logger.info(f"Username: {username}, Model name: {modelname}") | |
model_summary = "Processing..." | |
config_str = "No config loaded." | |
try: | |
model_filename = "model.safetensors" | |
if "/" in modelname: | |
parts = modelname.split("/") | |
modelname = parts[0] | |
if len(parts) > 1 and parts[1].strip(): | |
model_filename = parts[1] | |
model_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/{model_filename}" | |
logger.info(f"Attempting to download model from: {model_url}") | |
response = requests.get(model_url, stream=True) | |
response.raise_for_status() | |
total_size = int(response.headers.get('content-length', 0)) | |
logger.info(f"Model file size: {total_size/1024/1024:.2f} MB") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as tmp: | |
if total_size > 0: | |
downloaded = 0 | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
tmp.write(chunk) | |
downloaded += len(chunk) | |
if downloaded % (100 * 1024 * 1024) == 0: | |
logger.info(f"Downloaded {downloaded/1024/1024:.2f} MB / {total_size/1024/1024:.2f} MB") | |
else: | |
tmp.write(response.content) | |
model_path = tmp.name | |
logger.info(f"Model downloaded to temporary file: {model_path}") | |
logger.info("Loading model summary...") | |
summary = load_model_summary(model_path) | |
logger.info(f"Loading state dictionary... (This may take time for large models)") | |
state_dict = load_file(model_path) | |
logger.info("Analyzing top layers...") | |
top_layers = get_top_layers(state_dict, summary["total_params"]) | |
top_layers_str = "\n".join([ | |
f"{layer['name']}: shape={layer['shape']}, params={layer['params']:,} ({layer['percent']}%)" | |
for layer in top_layers | |
]) | |
config_data = {} | |
if config_file is not None: | |
logger.info("Processing uploaded config file") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_cfg: | |
tmp_cfg.write(config_file.read()) | |
config_path = tmp_cfg.name | |
logger.info(f"Loading config from uploaded file: {config_path}") | |
config_data = load_config(config_path) | |
os.unlink(config_path) | |
else: | |
config_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/config.json" | |
logger.info(f"Attempting to download config from: {config_url}") | |
try: | |
config_response = requests.get(config_url) | |
config_response.raise_for_status() | |
config_data = json.loads(config_response.content) | |
logger.info("Config file downloaded and parsed successfully") | |
except Exception as e: | |
logger.warning(f"Could not download or parse config file: {str(e)}") | |
config_str = "\n".join([f"{k}: {v}" for k, v in config_data.items()]) if config_data else "No config loaded." | |
# Clean up temporary file | |
logger.info(f"Cleaning up temporary file: {model_path}") | |
os.unlink(model_path) | |
model_summary = ( | |
f" Total tensors: {summary['num_tensors']}\n" | |
f" Total parameters: {summary['total_params']:,}\n\n" | |
f" Top Layers:\n{top_layers_str}" | |
) | |
logger.info("Model inspection completed successfully") | |
return model_summary, config_str | |
except Exception as e: | |
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
return error_msg, "No config loaded." | |
with gr.Blocks(title="Model Inspector") as demo: | |
gr.Markdown("# Model Inspector") | |
gr.Markdown("Enter a HuggingFace model ID in the format username/modelname to analyze its structure, parameter count, and configuration.") | |
gr.Markdown("You can specify a custom safetensors file by using username/modelname/filename.safetensors") | |
with gr.Row(): | |
with gr.Column(): | |
model_id = gr.Textbox( | |
label="Model ID from HuggingFace", | |
placeholder="username/modelname", | |
lines=1 | |
) | |
config_file = gr.File( | |
label="Upload config.json (optional)", | |
type="binary" | |
) | |
submit_btn = gr.Button("Analyze Model", variant="primary") | |
status = gr.Markdown("Ready. Enter a model ID and click 'Analyze Model'") | |
with gr.Column(): | |
model_summary = gr.Textbox(label="Model Summary", lines=15) | |
config_output = gr.Textbox(label="Config", lines=10) | |
def update_status(text): | |
return text | |
def on_submit(model_id, config_file): | |
status_update = update_status("Processing... This may take some time for large models.") | |
yield status_update, None, None | |
try: | |
summary, config = inspect_model(model_id, config_file) | |
status_update = update_status("Analysis complete!") | |
yield status_update, summary, config | |
except Exception as e: | |
error_msg = f"Error during analysis: {str(e)}" | |
status_update = update_status(f"❌ {error_msg}") | |
yield status_update, error_msg, "No config loaded." | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[model_id, config_file], | |
outputs=[status, model_summary, config_output], | |
show_progress=True | |
) | |
demo.launch() |