|
import os
|
|
|
|
|
|
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
|
|
os.environ["GRADIO_SERVER_PORT"] = "7860"
|
|
os.environ["GRADIO_ROOT_PATH"] = "/_app/immutable"
|
|
|
|
import gradio as gr
|
|
from fastapi import FastAPI, Request
|
|
import uvicorn
|
|
from sentence_transformers import SentenceTransformer
|
|
from sentence_transformers.util import cos_sim
|
|
from sentence_transformers.quantization import quantize_embeddings
|
|
|
|
|
|
import spaces
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@spaces.GPU
|
|
def embed(text):
|
|
|
|
query_embedding = Embedder.encode(text)
|
|
return query_embedding.tolist();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(fill_height=True) as demo:
|
|
text = gr.Textbox();
|
|
embeddings = gr.Textbox()
|
|
|
|
text.submit(embed, [text], [embeddings]);
|
|
|
|
|
|
print("Loading embedding model");
|
|
Embedder = None
|
|
|
|
|
|
GradioApp = gr.mount_gradio_app(app, demo, path="/", ssr_mode=False);
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(GradioApp, port=7860, host="0.0.0.0")
|
|
|