Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from huggingface_hub import InferenceClient | |
import os | |
import spaces | |
# Available models for selection | |
AVAILABLE_MODELS = [ | |
"Qwen/Qwen2.5-0.5B", | |
"Qwen/Qwen2.5-1.5B", | |
"Qwen/Qwen2.5-7B", | |
"Qwen/Qwen2.5-14B", | |
"meta-llama/Llama-2-7b-chat-hf", | |
"microsoft/phi-2", | |
"bigscience/bloom-560m" | |
] | |
# Default model | |
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B" | |
# Check if we're running in a Space or locally | |
# Hugging Face Spaces set this environment variable | |
IS_SPACE = os.getenv("SPACE_ID") is not None | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model(model_name): | |
global model, tokenizer | |
if IS_SPACE: | |
print(f"Loading model: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
model.to('cuda').eval() | |
return f"Model {model_name} loaded successfully!" | |
else: | |
raise ValueError("Model loading is only supported in Hugging Face Spaces.") | |
# Model configuration | |
MODEL_NAME = DEFAULT_MODEL | |
load_model(MODEL_NAME) | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
model_name, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
repetition_penalty, | |
top_k, | |
): | |
global model, tokenizer | |
# If model name changed, load the new model | |
if model_name != MODEL_NAME: | |
load_model(model_name) | |
# Prepare the conversation in ChatML format | |
messages = [] | |
# Add system message if provided | |
if system_message: | |
messages.append({"role": "system", "content": system_message}) | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
if user_msg: # Add user message | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: # Add assistant message | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add the current message | |
messages.append({"role": "user", "content": message}) | |
# Apply the chat template | |
try: | |
# Use apply_chat_template which handles different model formats | |
chat_text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
except (AttributeError, NotImplementedError): | |
# Fallback for models without chat template | |
chat_text = f"{system_message}\n\n" | |
for msg in messages: | |
if msg["role"] == "system": | |
continue # Already added at the beginning | |
elif msg["role"] == "user": | |
chat_text += f"User: {msg['content']}\n" | |
elif msg["role"] == "assistant": | |
chat_text += f"Assistant: {msg['content']}\n\n" | |
chat_text += "Assistant:" | |
# Tokenize the input | |
inputs = tokenizer(chat_text, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(model.device) | |
# Set up generation parameters | |
gen_kwargs = { | |
"max_new_tokens": int(max_tokens), | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"top_k": int(top_k), | |
"repetition_penalty": float(repetition_penalty), | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id | |
} | |
# Stream the response token by token | |
streamer = iter(model.generate( | |
input_ids, | |
**gen_kwargs, | |
streamer=None | |
)) | |
# Initial empty response | |
response = "" | |
# Process the streamed tokens | |
for output in streamer: | |
# Get the last token generated | |
next_token_id = output[-1] | |
# Decode the token | |
next_token = tokenizer.decode(next_token_id, skip_special_tokens=True) | |
# Append to the response | |
response += next_token | |
# Yield the response so far | |
yield response.strip() | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown("# 🤖 Multi-Model Chat with Zero GPU") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(height=600) | |
msg = gr.Textbox( | |
placeholder="Ask me anything...", | |
container=False, | |
scale=7, | |
) | |
submit = gr.Button("Submit", variant="primary") | |
clear = gr.Button("Clear") | |
with gr.Column(scale=1): | |
gr.Markdown("## Model Settings") | |
model_dropdown = gr.Dropdown( | |
choices=AVAILABLE_MODELS, | |
value=DEFAULT_MODEL, | |
label="Select Model", | |
info="Choose a model for chat" | |
) | |
load_button = gr.Button("Load Model") | |
system_message = gr.Textbox( | |
value="You are a friendly and helpful AI assistant.", | |
label="System Message", | |
info="Instructions for the AI" | |
) | |
gr.Markdown("## Sampling Parameters") | |
max_tokens = gr.Slider( | |
minimum=1, maximum=4096, value=512, step=1, | |
label="Max New Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, maximum=2.0, value=0.7, step=0.1, | |
label="Temperature", | |
info="Higher = more creative, Lower = more focused" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
label="Top-p (nucleus sampling)", | |
info="Cumulative probability cutoff for token selection" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, maximum=2.0, value=1.1, step=0.05, | |
label="Repetition Penalty", | |
info="Penalty for repeating tokens, 1.0 = no penalty" | |
) | |
top_k = gr.Slider( | |
minimum=1, maximum=100, value=50, step=1, | |
label="Top-k", | |
info="Number of highest probability tokens to consider" | |
) | |
# Function to handle chat | |
chat_history = gr.State([]) | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history, model_name, system_msg, max_len, temp, top_p_val, rep_penalty, top_k_val): | |
user_message = history[-1][0] | |
history[-1][1] = "" | |
for response in respond( | |
user_message, | |
history[:-1], | |
model_name, | |
system_msg, | |
max_len, | |
temp, | |
top_p_val, | |
rep_penalty, | |
top_k_val | |
): | |
history[-1][1] = response | |
yield history | |
def clear_chat(): | |
return [], [] | |
msg.submit( | |
user, | |
[msg, chat_history], | |
[msg, chat_history], | |
queue=False | |
).then( | |
bot, | |
[chat_history, model_dropdown, system_message, max_tokens, temperature, top_p, repetition_penalty, top_k], | |
chatbot | |
) | |
submit.click( | |
user, | |
[msg, chat_history], | |
[msg, chat_history], | |
queue=False | |
).then( | |
bot, | |
[chat_history, model_dropdown, system_message, max_tokens, temperature, top_p, repetition_penalty, top_k], | |
chatbot | |
) | |
clear.click(clear_chat, None, [chatbot, chat_history]) | |
load_button.click( | |
load_model, | |
inputs=[model_dropdown], | |
outputs=[gr.Textbox(label="Model Loading Status")] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |