import os import torch import gradio as gr from threading import Thread from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) MODEL_ID = os.getenv("MODEL_ID", "Phonepadith/aidc-llm-laos-10k-gemma-3-4b-it") def load_model(): use_4bit = os.getenv("USE_4BIT", "0") == "1" kw = { "device_map": "auto", "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32, "trust_remote_code": True, "offload_folder": "./offload", # ๐Ÿ‘ˆ Add this } if use_4bit: kw.update(dict( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, bnb_4bit_use_double_quant=True, )) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kw) return model, tokenizer model, tokenizer = load_model() DEFAULT_SYSTEM_PROMPT = ( "You are a helpful Lao/English assistant. Answer clearly and concisely. " "Prefer Lao when the user speaks Lao. Be factual and avoid harmful content." ) def build_messages(history, user_message, system_prompt): """ Convert Gradio history to HF chat template messages. history is a list of (user, assistant) tuples. """ msgs = [] if system_prompt and system_prompt.strip(): msgs.append({"role": "system", "content": system_prompt.strip()}) for u, a in history: if u is None: continue msgs.append({"role": "user", "content": u}) if a is not None and a != "": msgs.append({"role": "assistant", "content": a}) msgs.append({"role": "user", "content": user_message}) return msgs def generate_reply(user_message, history, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens): """ Streaming generator for Gradio ChatInterface. """ # Prepare chat template prompt messages = build_messages(history, user_message, system_prompt) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Set generation args do_sample = temperature and temperature > 0 inputs = tokenizer(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), repetition_penalty=float(repetition_penalty), eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) # Generate in a background thread while we yield stream tokens thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text # ---- UI ---- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ๐Ÿ‡ฑ๐Ÿ‡ฆ AIDC LLM (Gemma-3 4B IT) โ€” Chat Model: `Phonepadith/aidc-llm-laos-10k-gemma-3-4b-it` """ ) with gr.Row(): system_prompt = gr.Textbox( label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3, show_label=True, placeholder="Define assistant behavior here..." ) with gr.Accordion("Generation settings", open=False): temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") top_k = gr.Slider(1, 100, value=40, step=1, label="Top-k") repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="Repetition penalty") max_new_tokens = gr.Slider(16, 2048, value=512, step=16, label="Max new tokens") chat = gr.ChatInterface( fn=lambda message, history: generate_reply( message, history, system_prompt.value, temperature.value, top_p.value, top_k.value, repetition_penalty.value, max_new_tokens.value ), type="generator", title="AIDC LLM Laos Chat", retry_btn="โ†ป Retry", undo_btn="โ†ฉ Undo", clear_btn="๐Ÿ—‘ Clear", chatbot=gr.Chatbot(height=520, show_copy_button=True), textbox=gr.Textbox(placeholder="เบžเบดเบกเบ„เปเบฒเบ–เบฒเบกเบ‚เบญเบ‡เบ—เปˆเบฒเบ™... / Type your message...", lines=2), ) gr.Markdown( "Tip: set `USE_4BIT=1` in Space secrets to load in 4-bit (requires GPU and bitsandbytes)." ) demo.queue(max_size=16).launch()