Malum0x's picture
fix: hf adjustments
617e5e5
import gradio as gr
from safetensors.torch import load_file
from model_loader import get_top_layers, load_model_summary, load_config
import tempfile
import os
import requests
import json
import logging
import traceback
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def inspect_model(model_id, config_file=None):
logger.info(f"Processing model ID: {model_id}")
if not model_id or '/' not in model_id:
return "Please provide a valid model ID in the format username/modelname", "No config loaded."
username, modelname = model_id.split('/', 1)
logger.info(f"Username: {username}, Model name: {modelname}")
model_summary = "Processing..."
config_str = "No config loaded."
try:
model_filename = "model.safetensors"
if "/" in modelname:
parts = modelname.split("/")
modelname = parts[0]
if len(parts) > 1 and parts[1].strip():
model_filename = parts[1]
model_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/{model_filename}"
logger.info(f"Attempting to download model from: {model_url}")
response = requests.get(model_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {total_size/1024/1024:.2f} MB")
with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as tmp:
if total_size > 0:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
if chunk:
tmp.write(chunk)
downloaded += len(chunk)
if downloaded % (100 * 1024 * 1024) == 0:
logger.info(f"Downloaded {downloaded/1024/1024:.2f} MB / {total_size/1024/1024:.2f} MB")
else:
tmp.write(response.content)
model_path = tmp.name
logger.info(f"Model downloaded to temporary file: {model_path}")
logger.info("Loading model summary...")
summary = load_model_summary(model_path)
logger.info(f"Loading state dictionary... (This may take time for large models)")
state_dict = load_file(model_path)
logger.info("Analyzing top layers...")
top_layers = get_top_layers(state_dict, summary["total_params"])
top_layers_str = "\n".join([
f"{layer['name']}: shape={layer['shape']}, params={layer['params']:,} ({layer['percent']}%)"
for layer in top_layers
])
config_data = {}
if config_file is not None:
logger.info("Processing uploaded config file")
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_cfg:
tmp_cfg.write(config_file.read())
config_path = tmp_cfg.name
logger.info(f"Loading config from uploaded file: {config_path}")
config_data = load_config(config_path)
os.unlink(config_path)
else:
config_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/config.json"
logger.info(f"Attempting to download config from: {config_url}")
try:
config_response = requests.get(config_url)
config_response.raise_for_status()
config_data = json.loads(config_response.content)
logger.info("Config file downloaded and parsed successfully")
except Exception as e:
logger.warning(f"Could not download or parse config file: {str(e)}")
config_str = "\n".join([f"{k}: {v}" for k, v in config_data.items()]) if config_data else "No config loaded."
# Clean up temporary file
logger.info(f"Cleaning up temporary file: {model_path}")
os.unlink(model_path)
model_summary = (
f" Total tensors: {summary['num_tensors']}\n"
f" Total parameters: {summary['total_params']:,}\n\n"
f" Top Layers:\n{top_layers_str}"
)
logger.info("Model inspection completed successfully")
return model_summary, config_str
except Exception as e:
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
logger.error(error_msg)
return error_msg, "No config loaded."
with gr.Blocks(title="Model Inspector") as demo:
gr.Markdown("# Model Inspector")
gr.Markdown("Enter a HuggingFace model ID in the format username/modelname to analyze its structure, parameter count, and configuration.")
gr.Markdown("You can specify a custom safetensors file by using username/modelname/filename.safetensors")
with gr.Row():
with gr.Column():
model_id = gr.Textbox(
label="Model ID from HuggingFace",
placeholder="username/modelname",
lines=1
)
config_file = gr.File(
label="Upload config.json (optional)",
type="binary"
)
submit_btn = gr.Button("Analyze Model", variant="primary")
status = gr.Markdown("Ready. Enter a model ID and click 'Analyze Model'")
with gr.Column():
model_summary = gr.Textbox(label="Model Summary", lines=15)
config_output = gr.Textbox(label="Config", lines=10)
def update_status(text):
return text
def on_submit(model_id, config_file):
status_update = update_status("Processing... This may take some time for large models.")
yield status_update, None, None
try:
summary, config = inspect_model(model_id, config_file)
status_update = update_status("Analysis complete!")
yield status_update, summary, config
except Exception as e:
error_msg = f"Error during analysis: {str(e)}"
status_update = update_status(f"❌ {error_msg}")
yield status_update, error_msg, "No config loaded."
submit_btn.click(
fn=on_submit,
inputs=[model_id, config_file],
outputs=[status, model_summary, config_output],
show_progress=True
)
demo.launch()