|
import os |
|
import time |
|
import gradio as gr |
|
from tqdm import tqdm |
|
from loguru import logger |
|
from transformers import AutoTokenizer,AutoModel |
|
from duckduckgo_search import ddg_suggestions |
|
from duckduckgo_search import ddg_translate, ddg, ddg_news |
|
|
|
from langchain.document_loaders import UnstructuredFileLoader |
|
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter |
|
from langchain.llms import OpenAI |
|
from langchain.schema import Document |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain import OpenAI,VectorDBQA |
|
|
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
|
|
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float() |
|
model = model.eval() |
|
return tokenizer,model |
|
|
|
|
|
def chat_glm(input, history=None): |
|
if history is None: |
|
history = [] |
|
|
|
tokenizer,model = load_model() |
|
response, history = model.chat(tokenizer, input, history) |
|
logger.debug("chatglm:", input,response) |
|
return history, history |
|
|
|
def search_web(query): |
|
logger.debug("searchweb:", query) |
|
results = ddg(query) |
|
web_content = '' |
|
if results: |
|
for result in results: |
|
web_content += result['body'] |
|
return web_content |
|
|
|
def search_vec(query): |
|
logger.debug("searchvec:", query) |
|
embedding_model_name = 'GanymedeNil/text2vec-large-chinese' |
|
vec_path = 'cache' |
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) |
|
vector_store = FAISS.load_local(vec_path,embeddings) |
|
|
|
qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True) |
|
result = qa({"query": query}) |
|
return result['result'] |
|
|
|
def chat_gpt(input, use_web, use_vec, history=None): |
|
if history is None: |
|
history = [] |
|
|
|
|
|
context = "无" |
|
if use_vec: |
|
context = search_vec(input) |
|
prompt_template = f"""基于以下已知信息,请简洁并专业地回答用户的问题。 |
|
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。 |
|
已知内容: |
|
{context}"""+""" |
|
问题: |
|
{question}""" |
|
|
|
prompt = PromptTemplate(template=prompt_template,input_variables=["question"]) |
|
|
|
llm = OpenAI(temperature = 0.2) |
|
chain = LLMChain(llm=llm, prompt=prompt) |
|
result = chain.run(input) |
|
return result |
|
|
|
def predict(input, |
|
large_language_model, |
|
use_web, |
|
use_vec, |
|
openai_key, |
|
history=None): |
|
logger.debug("predict..",large_language_model,use_web) |
|
if openai_key is not None: |
|
os.environ['OPENAI_API_KEY'] = openai_key |
|
else: |
|
return '',"You forgot OpenAI API key","You forgot OpenAI API key" |
|
if history == None: |
|
history = [] |
|
|
|
if large_language_model == "GPT-3.5-turbo": |
|
resp = chat_gpt(input, use_web, use_vec, history) |
|
elif large_language_model == "ChatGLM-6B-int4": |
|
_,resp = chat_glm(input, history) |
|
resp = resp[-1][1] |
|
elif large_language_model == "Search Web": |
|
resp = search_web(input) |
|
elif large_language_model == "Search VectorStore": |
|
resp = search_vec(input) |
|
|
|
history.append((input, resp)) |
|
return '', history, history |
|
|
|
def clear_session(): |
|
return '', None |
|
|
|
block = gr.Blocks() |
|
with block as demo: |
|
gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1> |
|
<center><font size=3> |
|
本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br> |
|
</center></font> |
|
""") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_choose = gr.Accordion("模型选择") |
|
with model_choose: |
|
large_language_model = gr.Dropdown( |
|
["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"], |
|
label="large language model", |
|
value="ChatGLM-6B-int4") |
|
use_web = gr.Radio(["True", "False"], |
|
label="Web Search", |
|
value="False") |
|
use_vec = gr.Radio(["True", "False"], |
|
label="VectorStore Search", |
|
value="False") |
|
openai_key = gr.Textbox(label="请输入OpenAI API key", type="password") |
|
with gr.Column(scale=4): |
|
chatbot = gr.Chatbot(label='ChatLLM').style(height=600) |
|
message = gr.Textbox(label='请输入问题') |
|
state = gr.State() |
|
|
|
with gr.Row(): |
|
clear_history = gr.Button("🧹 清除历史对话") |
|
send = gr.Button("🚀 发送") |
|
|
|
send.click(predict, |
|
inputs=[ |
|
message, large_language_model, use_web, use_vec, openai_key, state |
|
], |
|
outputs=[message, chatbot, state]) |
|
clear_history.click(fn=clear_session, |
|
inputs=[], |
|
outputs=[chatbot, state], |
|
queue=False) |
|
|
|
message.submit(predict, |
|
inputs=[ |
|
message, large_language_model, use_web, use_vec, openai_key, state |
|
], |
|
outputs=[message, chatbot, state]) |
|
gr.Markdown("""提醒:<br> |
|
1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br> |
|
2. 使用chatgpt时需要输入您的api key. |
|
""") |
|
demo.queue().launch(server_name='0.0.0.0', share=False) |