Spaces:
Paused
Paused
| import torch | |
| import logging | |
| import time | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ---------------- CONFIG ---------------- | |
| MODEL_ID = "goonsai-com/civitaiprompts" | |
| MODEL_VARIANT = "Q4_K_M" # This is the HF tag for the quantized model | |
| MODEL_NAME = "CivitAI-Prompts-Q4_K_M" | |
| # ---------------- LOGGING ---------------- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| logger.info("Starting Gradio chatbot...") | |
| # ---------------- LOAD MODEL ---------------- | |
| logger.info(f"Loading tokenizer from {MODEL_ID} (revision={MODEL_VARIANT})") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| revision=MODEL_VARIANT, | |
| trust_remote_code=True | |
| ) | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| logger.info(f"Loading model with dtype {dtype}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| revision=MODEL_VARIANT, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| logger.info("Model loaded successfully.") | |
| # ---------------- CHAT FUNCTION ---------------- | |
| def chat_fn(message): | |
| logger.info(f"Received message: {message}") | |
| # Build prompt | |
| full_text = f"User: {message}\nAssistant:" | |
| logger.info(f"Full prompt for generation:\n{full_text}") | |
| start_time = time.time() | |
| # Tokenize input | |
| inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device) | |
| logger.info("Tokenized input.") | |
| # Generate response | |
| logger.info("Generating response...") | |
| reply_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] | |
| assistant_reply = response.split("Assistant:")[-1].strip() | |
| logger.info(f"Assistant reply: {assistant_reply}") | |
| logger.info(f"Generation time: {time.time() - start_time:.2f}s") | |
| return assistant_reply | |
| # ---------------- GRADIO BLOCKS UI ---------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| message = gr.Textbox(label="Type your message...", placeholder="Hello!") | |
| send_btn = gr.Button("Send") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Assistant Response", lines=10) | |
| send_btn.click(chat_fn, inputs=[message], outputs=[output]) | |
| message.submit(chat_fn, inputs=[message], outputs=[output]) | |
| logger.info("Launching Gradio app...") | |
| demo.launch() |