Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import time | |
import os | |
from huggingface_hub import login | |
# Login to HF | |
login(token=os.environ.get("HF_TOKEN")) | |
# Load model | |
model_id = "kristianfischerai12345/fischgpt-sft" | |
print("Loading FischGPT model...") | |
model = GPT2LMHeadModel.from_pretrained(model_id) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.eval() | |
print("Model loaded successfully!") | |
# API: Generate response | |
def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9): | |
if not user_message or not user_message.strip(): | |
return { | |
"error": "Empty message", | |
"response": None, | |
"metadata": None | |
} | |
try: | |
prompt = f"<|user|>{user_message.strip()}<|assistant|>" | |
inputs = tokenizer.encode(prompt, return_tensors='pt') | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_length=max_length, | |
temperature=float(temperature), | |
top_p=float(top_p), | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
attention_mask=torch.ones_like(inputs) | |
) | |
generation_time = time.time() - start_time | |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = full_text.split("<|assistant|>", 1)[1].strip() | |
input_tokens = len(inputs[0]) | |
output_tokens = len(outputs[0]) | |
new_tokens = output_tokens - input_tokens | |
tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0 | |
return { | |
"error": None, | |
"response": response, | |
"metadata": { | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"new_tokens": new_tokens, | |
"generation_time": round(generation_time, 3), | |
"tokens_per_second": round(tokens_per_sec, 1), | |
"model": "FischGPT-SFT", | |
"parameters": { | |
"temperature": temperature, | |
"max_length": max_length, | |
"top_p": top_p | |
} | |
} | |
} | |
except Exception as e: | |
return { | |
"error": str(e), | |
"response": None, | |
"metadata": None | |
} | |
# API: Wake-up ping | |
def wake_up(): | |
return {"status": "awake"} | |
# Gradio Blocks app | |
with gr.Blocks(title="FischGPT API") as app: | |
gr.Markdown("### FischGPT API is running.") | |
# Register endpoints | |
gr.Interface( | |
fn=generate_api, | |
inputs=[ | |
gr.Textbox(label="User Message"), | |
gr.Slider(0.1, 2.0, 0.8, label="Temperature"), | |
gr.Slider(50, 300, 150, label="Max Length"), | |
gr.Slider(0.1, 1.0, 0.9, label="Top-p") | |
], | |
outputs=gr.JSON(label="Response"), | |
api_name="predict" | |
) | |
gr.Interface( | |
fn=wake_up, | |
inputs=[], | |
outputs=gr.JSON(label="Status"), | |
api_name="wake-up" | |
) | |
# Launch app | |
app.queue(api_open=True).launch() | |