sqlserver / app.py
rrg92's picture
fix gradio
415b5df
raw
history blame
1.55 kB
import os
# If you want Gradio to run on a particular host/port, you can do this:
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_SERVER_PORT"] = "7860"
os.environ["GRADIO_ROOT_PATH"] = "/_app/immutable"
import gradio as gr
from fastapi import FastAPI, Request
import uvicorn
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.quantization import quantize_embeddings
import spaces
app = FastAPI()
@spaces.GPU
def embed(text):
query_embedding = Embedder.encode(text)
return query_embedding.tolist();
#@app.post("/v1/embeddings")
#async def openai_embeddings(request: Request):
# body = await request.json();
# print(body);
#
# model = body['model']
# text = body['input'];
# embeddings = embed(text)
# return {
# 'object': "list"
# ,'data': [{
# 'object': "embeddings"
# ,'embedding': embeddings
# ,'index':0
# }]
# ,'model':model
# ,'usage':{
# 'prompt_tokens': 0
# ,'total_tokens': 0
# }
# }
with gr.Blocks(fill_height=True) as demo:
text = gr.Textbox();
embeddings = gr.Textbox()
text.submit(embed, [text], [embeddings]);
print("Loading embedding model");
Embedder = None #SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
GradioApp = gr.mount_gradio_app(app, demo, path="/", ssr_mode=False);
if __name__ == "__main__":
uvicorn.run(GradioApp, port=7860, host="0.0.0.0")