import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import time import os from huggingface_hub import login # Login to HF login(token=os.environ.get("HF_TOKEN")) # Load model model_id = "kristianfischerai12345/fischgpt-sft" print("Loading FischGPT model...") model = GPT2LMHeadModel.from_pretrained(model_id) tokenizer = GPT2Tokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.eval() print("Model loaded successfully!") # API: Generate response def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9): if not user_message or not user_message.strip(): return { "error": "Empty message", "response": None, "metadata": None } try: prompt = f"<|user|>{user_message.strip()}<|assistant|>" inputs = tokenizer.encode(prompt, return_tensors='pt') start_time = time.time() with torch.no_grad(): outputs = model.generate( inputs, max_length=max_length, temperature=float(temperature), top_p=float(top_p), do_sample=True, pad_token_id=tokenizer.eos_token_id, attention_mask=torch.ones_like(inputs) ) generation_time = time.time() - start_time full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_text.split("<|assistant|>", 1)[1].strip() input_tokens = len(inputs[0]) output_tokens = len(outputs[0]) new_tokens = output_tokens - input_tokens tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0 return { "error": None, "response": response, "metadata": { "input_tokens": input_tokens, "output_tokens": output_tokens, "new_tokens": new_tokens, "generation_time": round(generation_time, 3), "tokens_per_second": round(tokens_per_sec, 1), "model": "FischGPT-SFT", "parameters": { "temperature": temperature, "max_length": max_length, "top_p": top_p } } } except Exception as e: return { "error": str(e), "response": None, "metadata": None } # API: Wake-up ping def wake_up(): return {"status": "awake"} # Gradio Blocks app with gr.Blocks(title="FischGPT API") as app: gr.Markdown("### FischGPT API is running.") # Register endpoints gr.Interface( fn=generate_api, inputs=[ gr.Textbox(label="User Message"), gr.Slider(0.1, 2.0, 0.8, label="Temperature"), gr.Slider(50, 300, 150, label="Max Length"), gr.Slider(0.1, 1.0, 0.9, label="Top-p") ], outputs=gr.JSON(label="Response"), api_name="predict" ) gr.Interface( fn=wake_up, inputs=[], outputs=gr.JSON(label="Status"), api_name="wake-up" ) # Launch app app.queue(api_open=True).launch()