Bahaedev's picture
Update app.py
92f93f9 verified
raw
history blame
2.77 kB
import os
import threading
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from importlib.metadata import PackageNotFoundError
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
# =======================
# Load Secrets
# =======================
SYSTEM_PROMPT = os.environ.get(
"prompt",
"You are a placeholder Sovereign. No secrets found in environment."
)
# =======================
# Model Initialization
# =======================
MODEL_ID = "tiiuae/Falcon3-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Attempt 4-bit quantization; fallback if bitsandbytes is not installed
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
load_in_4bit=True,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
except PackageNotFoundError:
print("bitsandbytes not found; loading full model without quantization.")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
# Create optimized text-generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
return_full_text=False,
max_new_tokens=256,
do_sample=True,
temperature=0.8,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
# =======================
# Core Chat Function
# =======================
def chat_fn(user_input: str) -> str:
prompt = f"### System:\n{SYSTEM_PROMPT}\n\n### User:\n{user_input}\n\n### Assistant:"
output = pipe(prompt)[0]["generated_text"].strip()
return output
# =======================
# Gradio UI
# =======================
def gradio_chat(user_input: str) -> str:
return chat_fn(user_input)
iface = gr.Interface(
fn=gradio_chat,
inputs=gr.Textbox(lines=5, placeholder="Enter your prompt…"),
outputs="text",
title="Prompt Cracking Challenge",
description="Does he really think he is the king?"
)
# =======================
# FastAPI for API access
# =======================
app = FastAPI(title="Prompt Cracking Challenge API")
class Request(BaseModel):
prompt: str
@app.post("/generate")
def generate(req: Request):
return {"response": chat_fn(req.prompt)}
# =======================
# Launch Both Servers
# =======================
def run_api():
port = int(os.environ.get("API_PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
# Start FastAPI in background thread
threading.Thread(target=run_api, daemon=True).start()
# Launch Gradio interface
iface.launch(server_name="0.0.0.0", server_port=7860)