sam2ai commited on
Commit
c91441b
·
1 Parent(s): 5382507

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. app.py +64 -31
app.py CHANGED
@@ -6,14 +6,12 @@ import os
6
  from io import BytesIO
7
 
8
  import uvicorn
9
- from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request
10
- from fastapi.responses import StreamingResponse
11
- from starlette.staticfiles import StaticFiles
12
- from starlette.templating import Jinja2Templates
13
- from sentence_transformers import SentenceTransformer
14
 
15
- # from utils.data_utils import remove_punctuation
16
- # from utils.utils import add_arguments, print_arguments
17
 
18
 
19
  def print_arguments(args):
@@ -32,28 +30,31 @@ def strtobool(val):
32
  else:
33
  raise ValueError("invalid truth value %r" % (val,))
34
 
35
-
36
  def str_none(val):
37
  if val == 'None':
38
  return None
39
  else:
40
  return val
41
 
42
-
43
  def add_arguments(argname, type, default, help, argparser, **kwargs):
44
  type = strtobool if type == bool else type
45
  type = str_none if type == str else type
46
- argparser.add_argument("--" + argname,
47
- default=default,
48
- type=type,
49
- help=help + ' Default: %(default)s.',
50
- **kwargs)
 
 
 
51
 
52
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
53
 
 
54
  parser = argparse.ArgumentParser(description=__doc__)
55
  add_arg = functools.partial(add_arguments, argparser=parser)
56
 
 
57
  add_arg("host", type=str, default="0.0.0.0", help="")
58
  add_arg("port", type=int, default=5000, help="")
59
  add_arg("model_path", type=str, default="BAAI/bge-small-en-v1.5", help="")
@@ -63,24 +64,45 @@ add_arg("beam_size", type=int, default=10, help="")
63
  add_arg("num_workers", type=int, default=2, help="")
64
  add_arg("vad_filter", type=bool, default=True, help="")
65
  add_arg("local_files_only", type=bool, default=True, help="")
 
 
66
  args = parser.parse_args()
67
  print_arguments(args)
68
 
69
- #
70
- # assert os.path.exists(args.model_path), f"{args.model_path}"
71
- #
 
 
 
 
 
 
72
  if args.use_gpu:
73
- model = SentenceTransformer(args.model_path, device="cuda", compute_type="float16", cache_folder=".")
 
 
 
 
 
 
 
 
74
  else:
75
- model = SentenceTransformer(args.model_path, device='cpu', cache_folder=".")
 
 
 
 
 
 
 
76
 
77
 
78
  app = FastAPI(title="embedding Inference")
79
- # app.mount('/static', StaticFiles(directory='static'), name='static')
80
- # templates = Jinja2Templates(directory="templates")
81
- # model_semaphore = None
82
 
83
- def similarity_score(textA, textB):
 
84
  em_test = model.encode(
85
  [textA, textB],
86
  normalize_embeddings=True
@@ -88,13 +110,26 @@ def similarity_score(textA, textB):
88
  return em_test[0] @ em_test[1].T
89
 
90
 
91
- @app.post("/embed")
92
- async def api_embed(
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  text1: str = Body("text1", description="", embed=True),
94
  text2: str = Body("text2", description="", embed=True),
95
  ):
96
 
97
- scores = similarity_score(text1, text2)
98
  print(scores)
99
  scores = scores.tolist()
100
 
@@ -102,11 +137,9 @@ async def api_embed(
102
  return ret
103
 
104
 
105
- # @app.get("/")
106
- # async def index(request: Request):
107
- # return templates.TemplateResponse(
108
- # "index.html", {"request": request, "id": id}
109
- # )
110
 
111
 
112
  if __name__ == '__main__':
 
6
  from io import BytesIO
7
 
8
  import uvicorn
9
+ from fastapi import FastAPI, Body, Request
10
+ # from fastapi.responses import StreamingResponse
11
+ # from starlette.staticfiles import StaticFiles
12
+ # from starlette.templating import Jinja2Templates
13
+ from sentence_transformers import SentenceTransformer, models
14
 
 
 
15
 
16
 
17
  def print_arguments(args):
 
30
  else:
31
  raise ValueError("invalid truth value %r" % (val,))
32
 
 
33
  def str_none(val):
34
  if val == 'None':
35
  return None
36
  else:
37
  return val
38
 
 
39
  def add_arguments(argname, type, default, help, argparser, **kwargs):
40
  type = strtobool if type == bool else type
41
  type = str_none if type == str else type
42
+ argparser.add_argument(
43
+ "--" + argname,
44
+ default=default,
45
+ type=type,
46
+ help=help + ' Default: %(default)s.',
47
+ **kwargs
48
+ )
49
+
50
 
51
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
52
 
53
+
54
  parser = argparse.ArgumentParser(description=__doc__)
55
  add_arg = functools.partial(add_arguments, argparser=parser)
56
 
57
+
58
  add_arg("host", type=str, default="0.0.0.0", help="")
59
  add_arg("port", type=int, default=5000, help="")
60
  add_arg("model_path", type=str, default="BAAI/bge-small-en-v1.5", help="")
 
64
  add_arg("num_workers", type=int, default=2, help="")
65
  add_arg("vad_filter", type=bool, default=True, help="")
66
  add_arg("local_files_only", type=bool, default=True, help="")
67
+
68
+
69
  args = parser.parse_args()
70
  print_arguments(args)
71
 
72
+
73
+
74
+ if args.use_gpu:
75
+ bge_model = SentenceTransformer(args.model_path, device="cuda", compute_type="float16", cache_folder=".")
76
+ else:
77
+ bge_model = SentenceTransformer(args.model_path, device='cpu', cache_folder=".")
78
+
79
+
80
+
81
  if args.use_gpu:
82
+ model_name = 'sam2ai/sbert-tsdae'
83
+ word_embedding_model = models.Transformer(model_name)
84
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
85
+ tsdae_model = SentenceTransformer(
86
+ modules=[word_embedding_model, pooling_model],
87
+ device="cuda",
88
+ compute_type="float16",
89
+ cache_folder="."
90
+ )
91
  else:
92
+ model_name = 'sam2ai/sbert-tsdae'
93
+ word_embedding_model = models.Transformer(model_name)
94
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
95
+ tsdae_model = SentenceTransformer(
96
+ modules=[word_embedding_model, pooling_model],
97
+ device='cpu',
98
+ cache_folder="."
99
+ )
100
 
101
 
102
  app = FastAPI(title="embedding Inference")
 
 
 
103
 
104
+
105
+ def similarity_score(model, textA, textB):
106
  em_test = model.encode(
107
  [textA, textB],
108
  normalize_embeddings=True
 
110
  return em_test[0] @ em_test[1].T
111
 
112
 
113
+ @app.post("/bge_embed")
114
+ async def api_bge_embed(
115
+ text1: str = Body("text1", description="", embed=True),
116
+ text2: str = Body("text2", description="", embed=True),
117
+ ):
118
+
119
+ scores = similarity_score(bge_model, text1, text2)
120
+ print(scores)
121
+ scores = scores.tolist()
122
+
123
+ ret = {"similarity score": scores, "status_code": 200}
124
+ return ret
125
+
126
+ @app.post("/tsdae_embed")
127
+ async def api_tsdae_embed(
128
  text1: str = Body("text1", description="", embed=True),
129
  text2: str = Body("text2", description="", embed=True),
130
  ):
131
 
132
+ scores = similarity_score(tsdae_model, text1, text2)
133
  print(scores)
134
  scores = scores.tolist()
135
 
 
137
  return ret
138
 
139
 
140
+ @app.get("/")
141
+ async def index(request: Request):
142
+ return {"detail": "API is Active !!"}
 
 
143
 
144
 
145
  if __name__ == '__main__':