import gradio as gr from safetensors.torch import load_file from model_loader import get_top_layers, load_model_summary, load_config import tempfile import os import requests import json import logging import traceback logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def inspect_model(model_id, config_file=None): logger.info(f"Processing model ID: {model_id}") if not model_id or '/' not in model_id: return "Please provide a valid model ID in the format username/modelname", "No config loaded." username, modelname = model_id.split('/', 1) logger.info(f"Username: {username}, Model name: {modelname}") model_summary = "Processing..." config_str = "No config loaded." try: model_filename = "model.safetensors" if "/" in modelname: parts = modelname.split("/") modelname = parts[0] if len(parts) > 1 and parts[1].strip(): model_filename = parts[1] model_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/{model_filename}" logger.info(f"Attempting to download model from: {model_url}") response = requests.get(model_url, stream=True) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) logger.info(f"Model file size: {total_size/1024/1024:.2f} MB") with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as tmp: if total_size > 0: downloaded = 0 for chunk in response.iter_content(chunk_size=8192): if chunk: tmp.write(chunk) downloaded += len(chunk) if downloaded % (100 * 1024 * 1024) == 0: logger.info(f"Downloaded {downloaded/1024/1024:.2f} MB / {total_size/1024/1024:.2f} MB") else: tmp.write(response.content) model_path = tmp.name logger.info(f"Model downloaded to temporary file: {model_path}") logger.info("Loading model summary...") summary = load_model_summary(model_path) logger.info(f"Loading state dictionary... (This may take time for large models)") state_dict = load_file(model_path) logger.info("Analyzing top layers...") top_layers = get_top_layers(state_dict, summary["total_params"]) top_layers_str = "\n".join([ f"{layer['name']}: shape={layer['shape']}, params={layer['params']:,} ({layer['percent']}%)" for layer in top_layers ]) config_data = {} if config_file is not None: logger.info("Processing uploaded config file") with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_cfg: tmp_cfg.write(config_file.read()) config_path = tmp_cfg.name logger.info(f"Loading config from uploaded file: {config_path}") config_data = load_config(config_path) os.unlink(config_path) else: config_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/config.json" logger.info(f"Attempting to download config from: {config_url}") try: config_response = requests.get(config_url) config_response.raise_for_status() config_data = json.loads(config_response.content) logger.info("Config file downloaded and parsed successfully") except Exception as e: logger.warning(f"Could not download or parse config file: {str(e)}") config_str = "\n".join([f"{k}: {v}" for k, v in config_data.items()]) if config_data else "No config loaded." # Clean up temporary file logger.info(f"Cleaning up temporary file: {model_path}") os.unlink(model_path) model_summary = ( f" Total tensors: {summary['num_tensors']}\n" f" Total parameters: {summary['total_params']:,}\n\n" f" Top Layers:\n{top_layers_str}" ) logger.info("Model inspection completed successfully") return model_summary, config_str except Exception as e: error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" logger.error(error_msg) return error_msg, "No config loaded." with gr.Blocks(title="Model Inspector") as demo: gr.Markdown("# Model Inspector") gr.Markdown("Enter a HuggingFace model ID in the format username/modelname to analyze its structure, parameter count, and configuration.") gr.Markdown("You can specify a custom safetensors file by using username/modelname/filename.safetensors") with gr.Row(): with gr.Column(): model_id = gr.Textbox( label="Model ID from HuggingFace", placeholder="username/modelname", lines=1 ) config_file = gr.File( label="Upload config.json (optional)", type="binary" ) submit_btn = gr.Button("Analyze Model", variant="primary") status = gr.Markdown("Ready. Enter a model ID and click 'Analyze Model'") with gr.Column(): model_summary = gr.Textbox(label="Model Summary", lines=15) config_output = gr.Textbox(label="Config", lines=10) def update_status(text): return text def on_submit(model_id, config_file): status_update = update_status("Processing... This may take some time for large models.") yield status_update, None, None try: summary, config = inspect_model(model_id, config_file) status_update = update_status("Analysis complete!") yield status_update, summary, config except Exception as e: error_msg = f"Error during analysis: {str(e)}" status_update = update_status(f"❌ {error_msg}") yield status_update, error_msg, "No config loaded." submit_btn.click( fn=on_submit, inputs=[model_id, config_file], outputs=[status, model_summary, config_output], show_progress=True ) demo.launch()