fischgpt-api / app.py
kristianfischerai12345's picture
Update app.py
eace646 verified
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
import os
from huggingface_hub import login
# Login to HF
login(token=os.environ.get("HF_TOKEN"))
# Load model
model_id = "kristianfischerai12345/fischgpt-sft"
print("Loading FischGPT model...")
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
print("Model loaded successfully!")
# API: Generate response
def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9):
if not user_message or not user_message.strip():
return {
"error": "Empty message",
"response": None,
"metadata": None
}
try:
prompt = f"<|user|>{user_message.strip()}<|assistant|>"
inputs = tokenizer.encode(prompt, return_tensors='pt')
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_length,
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
attention_mask=torch.ones_like(inputs)
)
generation_time = time.time() - start_time
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_text.split("<|assistant|>", 1)[1].strip()
input_tokens = len(inputs[0])
output_tokens = len(outputs[0])
new_tokens = output_tokens - input_tokens
tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0
return {
"error": None,
"response": response,
"metadata": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"new_tokens": new_tokens,
"generation_time": round(generation_time, 3),
"tokens_per_second": round(tokens_per_sec, 1),
"model": "FischGPT-SFT",
"parameters": {
"temperature": temperature,
"max_length": max_length,
"top_p": top_p
}
}
}
except Exception as e:
return {
"error": str(e),
"response": None,
"metadata": None
}
# API: Wake-up ping
def wake_up():
return {"status": "awake"}
# Gradio Blocks app
with gr.Blocks(title="FischGPT API") as app:
gr.Markdown("### FischGPT API is running.")
# Register endpoints
gr.Interface(
fn=generate_api,
inputs=[
gr.Textbox(label="User Message"),
gr.Slider(0.1, 2.0, 0.8, label="Temperature"),
gr.Slider(50, 300, 150, label="Max Length"),
gr.Slider(0.1, 1.0, 0.9, label="Top-p")
],
outputs=gr.JSON(label="Response"),
api_name="predict"
)
gr.Interface(
fn=wake_up,
inputs=[],
outputs=gr.JSON(label="Status"),
api_name="wake-up"
)
# Launch app
app.queue(api_open=True).launch()