sqlserver / app.py
rrg92's picture
fix
b45b188
raw
history blame
1.63 kB
import gradio as gr
from fastapi import FastAPI, Request
import uvicorn
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.quantization import quantize_embeddings
import spaces
app = FastAPI()
@spaces.GPU
def embed(text):
query_embedding = Embedder.encode(text)
return query_embedding.tolist();
@app.post("/v1/embeddings")
async def openai_embeddings(request: Request):
body = await request.json();
print(body);
model = body['model']
text = body['input'];
embeddings = embed(text)
return {
'object': "list"
,'data': [{
'object': "embeddings"
,'embedding': embeddings
,'index':0
}]
,'model':model
,'usage':{
'prompt_tokens': 0
,'total_tokens': 0
}
}
def fn(text):
embed(text);
with gr.Blocks(fill_height=True) as demo:
text = gr.Textbox();
embeddings = gr.Textbox()
text.submit(fn, [text], [embeddings]);
print("Loading embedding model");
Embedder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# demo.run_startup_events()
#demo.launch(
# share=False,
# debug=False,
# server_port=7860,
# server_name="0.0.0.0",
# allowed_paths=[]
#)
print("Demo run...");
(app,url,other) = demo.launch(prevent_thread_lock=True);
print("Mounting app...");
GradioApp = gr.mount_gradio_app(app, demo, path="/", ssr_mode=False);
if __name__ == '__main__':
print("Running uviconr...");
uvicorn.run(GradioApp)