import torch import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from postprocess import post_process # Dictionary of available models models = { "ViHerbQA-Base": "QQuinley/ViHerbQA-base", "ViHerbQA-Large": "QQuinley/ViHerbQA-large", # Add more models as needed } # Initialize with empty models (will load on selection) loaded_models = {} loaded_tokenizers = {} def load_model(model_name): if model_name not in loaded_models: model_path = models[model_name] tokenizer_path = models[model_name] print(f"Loading {model_name} from {model_path}...") loaded_models[model_name] = AutoModelForSeq2SeqLM.from_pretrained(model_path) loaded_tokenizers[model_name] = AutoTokenizer.from_pretrained(tokenizer_path) return loaded_models[model_name], loaded_tokenizers[model_name] # Function to generate answer with the selected model def generate_answer(model, tokenizer, input_text, q_len, t_len): device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) inputs = tokenizer(input_text, max_length=q_len, padding="max_length", truncation=True, pad_to_max_length=True, add_special_tokens=True, return_tensors="pt") input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=t_len, use_cache=True, ) predicted_answer = tokenizer.decode(outputs[0], skip_special_tokens=True) cleaned_answer = post_process(predicted_answer) final_answer = cleaned_answer # \n*** Lưu ý, thông tin trên chỉ mang tính chất tham khảo. Để hiểu rõ hơn bạn nên liên hệ với bác sĩ chuyên môn. ***" return final_answer # Gradio interface with model selection with gr.Blocks() as demo: gr.Markdown("# ViHerbQA") gr.Markdown("Trò chuyện với ViHerbQA - mô hình hỏi đáp dược liệu Y học cổ truyền Việt Nam") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(models.keys()), value="ViHerbQA-Large", label="Select Model" ) chatbot = gr.ChatInterface( fn=lambda message, history, model_name: generate_answer(*load_model(model_name), message, 1024, 1024), additional_inputs=[model_dropdown], chatbot=gr.Chatbot(height=600), examples=[ ["Actiso có tên khoa học là gì?"], ["Những ai không nên dùng Hoa nhài?"], ["Lá Thường xuân có hiệu quả như thế nào trong việc điều trị bệnh nào về đường hô hấp?"] ], cache_examples=False, ) if __name__ == "__main__": load_model("ViHerbQA-Large") load_model("ViHerbQA-Base") demo.launch(show_error=True, debug=True, ssr_mode=False, share=True)