File size: 3,251 Bytes
9334142
 
 
 
fc1f9f0
eace646
7e2f0c9
eace646
7e2f0c9
9334142
eace646
9334142
 
 
 
 
 
 
 
 
eace646
9334142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eace646
994a85f
eace646
994a85f
eace646
 
 
 
 
3689085
9334142
 
 
 
 
 
 
 
 
eace646
994a85f
 
 
 
eace646
994a85f
eace646
becdb13
eace646
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
import os
from huggingface_hub import login

# Login to HF
login(token=os.environ.get("HF_TOKEN"))

# Load model
model_id = "kristianfischerai12345/fischgpt-sft"
print("Loading FischGPT model...")
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.eval()
print("Model loaded successfully!")

# API: Generate response
def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9):
    if not user_message or not user_message.strip():
        return {
            "error": "Empty message",
            "response": None,
            "metadata": None
        }
    try:
        prompt = f"<|user|>{user_message.strip()}<|assistant|>"
        inputs = tokenizer.encode(prompt, return_tensors='pt')
        start_time = time.time()
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_length=max_length,
                temperature=float(temperature),
                top_p=float(top_p),
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                attention_mask=torch.ones_like(inputs)
            )
        generation_time = time.time() - start_time
        full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = full_text.split("<|assistant|>", 1)[1].strip()
        input_tokens = len(inputs[0])
        output_tokens = len(outputs[0])
        new_tokens = output_tokens - input_tokens
        tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0
        return {
            "error": None,
            "response": response,
            "metadata": {
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "new_tokens": new_tokens,
                "generation_time": round(generation_time, 3),
                "tokens_per_second": round(tokens_per_sec, 1),
                "model": "FischGPT-SFT",
                "parameters": {
                    "temperature": temperature,
                    "max_length": max_length,
                    "top_p": top_p
                }
            }
        }
    except Exception as e:
        return {
            "error": str(e),
            "response": None,
            "metadata": None
        }

# API: Wake-up ping
def wake_up():
    return {"status": "awake"}

# Gradio Blocks app
with gr.Blocks(title="FischGPT API") as app:
    gr.Markdown("### FischGPT API is running.")
    
    # Register endpoints
    gr.Interface(
        fn=generate_api,
        inputs=[
            gr.Textbox(label="User Message"),
            gr.Slider(0.1, 2.0, 0.8, label="Temperature"),
            gr.Slider(50, 300, 150, label="Max Length"),
            gr.Slider(0.1, 1.0, 0.9, label="Top-p")
        ],
        outputs=gr.JSON(label="Response"),
        api_name="predict"
    )

    gr.Interface(
        fn=wake_up,
        inputs=[],
        outputs=gr.JSON(label="Status"),
        api_name="wake-up"
    )

# Launch app
app.queue(api_open=True).launch()