Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import os | |
MODEL_ID = "piphany/Qwen2.5-7B-Law-Assistant" | |
def main(): | |
# --- 1. 加载模型和分词器 (从HF Hub加载) --- | |
print(f"开始从Hugging Face Hub加载模型: {MODEL_ID}") | |
# 加载分词器 | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
# 加载完整的、已合并的模型 (已移除quantization_config) | |
# Load the full, merged model (quantization_config has been removed) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model.eval() | |
print("模型加载完成,准备启动Web界面...") | |
# --- 2. 定义推理函数 --- | |
def predict(case_fact, history): | |
instruction = "你是一个专业的法律助手。请根据以下提供的案件事实,草拟一份判决书初稿。" | |
user_content = f"{instruction}\n\n案件事实如下:\n{case_fact}" | |
chat_history_for_model = [{"role": "user", "content": user_content}] | |
messages = tokenizer.apply_chat_template( | |
chat_history_for_model, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([messages], return_tensors="pt").to(model.device) | |
generated_ids = model.generate( | |
model_inputs.input_ids, | |
max_new_tokens=2048, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.5, | |
repetition_penalty=1.1, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return response | |
# --- 3. 启动Gradio Web UI --- | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=15, label="请输入案件事实", placeholder="在这里粘贴完整的案情描述..."), | |
outputs=gr.Textbox(lines=20, label="生成的判决书初稿"), | |
title="智能判案辅助系统 Demo", | |
description=f"本系统基于模型 {MODEL_ID} 微调。请输入案情事实,系统将为您生成一份判决书初稿。", | |
theme="soft", | |
allow_flagging="never" | |
) | |
print("\nGradio应用正在启动...") | |
# 在HF Spaces上,share=True是默认的,不需要特别设置 | |
iface.launch() | |
if __name__ == '__main__': | |
main() | |