Synced repo using 'sync_with_huggingface' Github Action
Browse files
app.py
CHANGED
@@ -6,14 +6,12 @@ import os
|
|
6 |
from io import BytesIO
|
7 |
|
8 |
import uvicorn
|
9 |
-
from fastapi import FastAPI,
|
10 |
-
from fastapi.responses import StreamingResponse
|
11 |
-
from starlette.staticfiles import StaticFiles
|
12 |
-
from starlette.templating import Jinja2Templates
|
13 |
-
from sentence_transformers import SentenceTransformer
|
14 |
|
15 |
-
# from utils.data_utils import remove_punctuation
|
16 |
-
# from utils.utils import add_arguments, print_arguments
|
17 |
|
18 |
|
19 |
def print_arguments(args):
|
@@ -32,28 +30,31 @@ def strtobool(val):
|
|
32 |
else:
|
33 |
raise ValueError("invalid truth value %r" % (val,))
|
34 |
|
35 |
-
|
36 |
def str_none(val):
|
37 |
if val == 'None':
|
38 |
return None
|
39 |
else:
|
40 |
return val
|
41 |
|
42 |
-
|
43 |
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
44 |
type = strtobool if type == bool else type
|
45 |
type = str_none if type == str else type
|
46 |
-
argparser.add_argument(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
51 |
|
52 |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
53 |
|
|
|
54 |
parser = argparse.ArgumentParser(description=__doc__)
|
55 |
add_arg = functools.partial(add_arguments, argparser=parser)
|
56 |
|
|
|
57 |
add_arg("host", type=str, default="0.0.0.0", help="")
|
58 |
add_arg("port", type=int, default=5000, help="")
|
59 |
add_arg("model_path", type=str, default="BAAI/bge-small-en-v1.5", help="")
|
@@ -63,24 +64,45 @@ add_arg("beam_size", type=int, default=10, help="")
|
|
63 |
add_arg("num_workers", type=int, default=2, help="")
|
64 |
add_arg("vad_filter", type=bool, default=True, help="")
|
65 |
add_arg("local_files_only", type=bool, default=True, help="")
|
|
|
|
|
66 |
args = parser.parse_args()
|
67 |
print_arguments(args)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
if args.use_gpu:
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
app = FastAPI(title="embedding Inference")
|
79 |
-
# app.mount('/static', StaticFiles(directory='static'), name='static')
|
80 |
-
# templates = Jinja2Templates(directory="templates")
|
81 |
-
# model_semaphore = None
|
82 |
|
83 |
-
|
|
|
84 |
em_test = model.encode(
|
85 |
[textA, textB],
|
86 |
normalize_embeddings=True
|
@@ -88,13 +110,26 @@ def similarity_score(textA, textB):
|
|
88 |
return em_test[0] @ em_test[1].T
|
89 |
|
90 |
|
91 |
-
@app.post("/
|
92 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
text1: str = Body("text1", description="", embed=True),
|
94 |
text2: str = Body("text2", description="", embed=True),
|
95 |
):
|
96 |
|
97 |
-
scores = similarity_score(text1, text2)
|
98 |
print(scores)
|
99 |
scores = scores.tolist()
|
100 |
|
@@ -102,11 +137,9 @@ async def api_embed(
|
|
102 |
return ret
|
103 |
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
# "index.html", {"request": request, "id": id}
|
109 |
-
# )
|
110 |
|
111 |
|
112 |
if __name__ == '__main__':
|
|
|
6 |
from io import BytesIO
|
7 |
|
8 |
import uvicorn
|
9 |
+
from fastapi import FastAPI, Body, Request
|
10 |
+
# from fastapi.responses import StreamingResponse
|
11 |
+
# from starlette.staticfiles import StaticFiles
|
12 |
+
# from starlette.templating import Jinja2Templates
|
13 |
+
from sentence_transformers import SentenceTransformer, models
|
14 |
|
|
|
|
|
15 |
|
16 |
|
17 |
def print_arguments(args):
|
|
|
30 |
else:
|
31 |
raise ValueError("invalid truth value %r" % (val,))
|
32 |
|
|
|
33 |
def str_none(val):
|
34 |
if val == 'None':
|
35 |
return None
|
36 |
else:
|
37 |
return val
|
38 |
|
|
|
39 |
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
40 |
type = strtobool if type == bool else type
|
41 |
type = str_none if type == str else type
|
42 |
+
argparser.add_argument(
|
43 |
+
"--" + argname,
|
44 |
+
default=default,
|
45 |
+
type=type,
|
46 |
+
help=help + ' Default: %(default)s.',
|
47 |
+
**kwargs
|
48 |
+
)
|
49 |
+
|
50 |
|
51 |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
52 |
|
53 |
+
|
54 |
parser = argparse.ArgumentParser(description=__doc__)
|
55 |
add_arg = functools.partial(add_arguments, argparser=parser)
|
56 |
|
57 |
+
|
58 |
add_arg("host", type=str, default="0.0.0.0", help="")
|
59 |
add_arg("port", type=int, default=5000, help="")
|
60 |
add_arg("model_path", type=str, default="BAAI/bge-small-en-v1.5", help="")
|
|
|
64 |
add_arg("num_workers", type=int, default=2, help="")
|
65 |
add_arg("vad_filter", type=bool, default=True, help="")
|
66 |
add_arg("local_files_only", type=bool, default=True, help="")
|
67 |
+
|
68 |
+
|
69 |
args = parser.parse_args()
|
70 |
print_arguments(args)
|
71 |
|
72 |
+
|
73 |
+
|
74 |
+
if args.use_gpu:
|
75 |
+
bge_model = SentenceTransformer(args.model_path, device="cuda", compute_type="float16", cache_folder=".")
|
76 |
+
else:
|
77 |
+
bge_model = SentenceTransformer(args.model_path, device='cpu', cache_folder=".")
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
if args.use_gpu:
|
82 |
+
model_name = 'sam2ai/sbert-tsdae'
|
83 |
+
word_embedding_model = models.Transformer(model_name)
|
84 |
+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
|
85 |
+
tsdae_model = SentenceTransformer(
|
86 |
+
modules=[word_embedding_model, pooling_model],
|
87 |
+
device="cuda",
|
88 |
+
compute_type="float16",
|
89 |
+
cache_folder="."
|
90 |
+
)
|
91 |
else:
|
92 |
+
model_name = 'sam2ai/sbert-tsdae'
|
93 |
+
word_embedding_model = models.Transformer(model_name)
|
94 |
+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
|
95 |
+
tsdae_model = SentenceTransformer(
|
96 |
+
modules=[word_embedding_model, pooling_model],
|
97 |
+
device='cpu',
|
98 |
+
cache_folder="."
|
99 |
+
)
|
100 |
|
101 |
|
102 |
app = FastAPI(title="embedding Inference")
|
|
|
|
|
|
|
103 |
|
104 |
+
|
105 |
+
def similarity_score(model, textA, textB):
|
106 |
em_test = model.encode(
|
107 |
[textA, textB],
|
108 |
normalize_embeddings=True
|
|
|
110 |
return em_test[0] @ em_test[1].T
|
111 |
|
112 |
|
113 |
+
@app.post("/bge_embed")
|
114 |
+
async def api_bge_embed(
|
115 |
+
text1: str = Body("text1", description="", embed=True),
|
116 |
+
text2: str = Body("text2", description="", embed=True),
|
117 |
+
):
|
118 |
+
|
119 |
+
scores = similarity_score(bge_model, text1, text2)
|
120 |
+
print(scores)
|
121 |
+
scores = scores.tolist()
|
122 |
+
|
123 |
+
ret = {"similarity score": scores, "status_code": 200}
|
124 |
+
return ret
|
125 |
+
|
126 |
+
@app.post("/tsdae_embed")
|
127 |
+
async def api_tsdae_embed(
|
128 |
text1: str = Body("text1", description="", embed=True),
|
129 |
text2: str = Body("text2", description="", embed=True),
|
130 |
):
|
131 |
|
132 |
+
scores = similarity_score(tsdae_model, text1, text2)
|
133 |
print(scores)
|
134 |
scores = scores.tolist()
|
135 |
|
|
|
137 |
return ret
|
138 |
|
139 |
|
140 |
+
@app.get("/")
|
141 |
+
async def index(request: Request):
|
142 |
+
return {"detail": "API is Active !!"}
|
|
|
|
|
143 |
|
144 |
|
145 |
if __name__ == '__main__':
|