Spaces:
Sleeping
Sleeping
File size: 3,251 Bytes
9334142 fc1f9f0 eace646 7e2f0c9 eace646 7e2f0c9 9334142 eace646 9334142 eace646 9334142 eace646 994a85f eace646 994a85f eace646 3689085 9334142 eace646 994a85f eace646 994a85f eace646 becdb13 eace646 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
import os
from huggingface_hub import login
# Login to HF
login(token=os.environ.get("HF_TOKEN"))
# Load model
model_id = "kristianfischerai12345/fischgpt-sft"
print("Loading FischGPT model...")
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
print("Model loaded successfully!")
# API: Generate response
def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9):
if not user_message or not user_message.strip():
return {
"error": "Empty message",
"response": None,
"metadata": None
}
try:
prompt = f"<|user|>{user_message.strip()}<|assistant|>"
inputs = tokenizer.encode(prompt, return_tensors='pt')
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_length,
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
attention_mask=torch.ones_like(inputs)
)
generation_time = time.time() - start_time
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_text.split("<|assistant|>", 1)[1].strip()
input_tokens = len(inputs[0])
output_tokens = len(outputs[0])
new_tokens = output_tokens - input_tokens
tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0
return {
"error": None,
"response": response,
"metadata": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"new_tokens": new_tokens,
"generation_time": round(generation_time, 3),
"tokens_per_second": round(tokens_per_sec, 1),
"model": "FischGPT-SFT",
"parameters": {
"temperature": temperature,
"max_length": max_length,
"top_p": top_p
}
}
}
except Exception as e:
return {
"error": str(e),
"response": None,
"metadata": None
}
# API: Wake-up ping
def wake_up():
return {"status": "awake"}
# Gradio Blocks app
with gr.Blocks(title="FischGPT API") as app:
gr.Markdown("### FischGPT API is running.")
# Register endpoints
gr.Interface(
fn=generate_api,
inputs=[
gr.Textbox(label="User Message"),
gr.Slider(0.1, 2.0, 0.8, label="Temperature"),
gr.Slider(50, 300, 150, label="Max Length"),
gr.Slider(0.1, 1.0, 0.9, label="Top-p")
],
outputs=gr.JSON(label="Response"),
api_name="predict"
)
gr.Interface(
fn=wake_up,
inputs=[],
outputs=gr.JSON(label="Status"),
api_name="wake-up"
)
# Launch app
app.queue(api_open=True).launch()
|