sam2ai commited on
Commit
6eb0c45
·
1 Parent(s): 7d4ef19

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. app.py +1 -63
  3. download.py +44 -0
Dockerfile CHANGED
@@ -62,6 +62,6 @@ EXPOSE 7860
62
 
63
  # Start the FastAPI app using Uvicorn web server
64
  # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]
65
- # RUN python3 download.py
66
 
67
- CMD ["python3", "app.py", "--host=0.0.0.0", "--port=7860", "--model_path=BAAI/bge-small-en-v1.5", "--num_workers=2"]
 
62
 
63
  # Start the FastAPI app using Uvicorn web server
64
  # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]
65
+ RUN python3 download.py
66
 
67
+ CMD ["python3", "app.py", "--host=0.0.0.0", "--port=7860", "--model_path=models/BAAI/bge-small-en-v1.5", "--num_workers=2"]
app.py CHANGED
@@ -8,24 +8,14 @@ from io import BytesIO
8
  import uvicorn
9
  from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request
10
  from fastapi.responses import StreamingResponse
11
- # from faster_whisper import WhisperModel
12
  from starlette.staticfiles import StaticFiles
13
  from starlette.templating import Jinja2Templates
14
  from sentence_transformers import SentenceTransformer
15
- # from zhconv import convert
16
 
17
  # from utils.data_utils import remove_punctuation
18
  # from utils.utils import add_arguments, print_arguments
19
 
20
 
21
- import hashlib
22
- import os
23
- import tarfile
24
- import urllib.request
25
-
26
- # from tqdm import tqdm
27
-
28
-
29
  def print_arguments(args):
30
  print("----------- Configuration Arguments -----------")
31
  for arg, value in vars(args).items():
@@ -77,7 +67,7 @@ args = parser.parse_args()
77
  print_arguments(args)
78
 
79
  #
80
- # assert os.path.exists(args.model_path), f"{args.model_path}"
81
  #
82
  if args.use_gpu:
83
  model = SentenceTransformer(args.model_path, device="cuda", compute_type="float16", cache_folder=".")
@@ -85,64 +75,12 @@ else:
85
  model = SentenceTransformer(args.model_path, device='cpu', cache_folder=".")
86
 
87
 
88
- #
89
- # _, _ = model.transcribe("dataset/test.wav", beam_size=5)
90
-
91
  app = FastAPI(title="embedding Inference")
92
  # app.mount('/static', StaticFiles(directory='static'), name='static')
93
  # templates = Jinja2Templates(directory="templates")
94
  # model_semaphore = None
95
 
96
 
97
- # def release_model_semaphore():
98
- # model_semaphore.release()
99
-
100
-
101
- # def recognition(file: File, to_simple: int,
102
- # remove_pun: int, language: str = "bn",
103
- # task: str = "transcribe"
104
- # ):
105
-
106
- # segments, info = model.transcribe(file, beam_size=10, task=task, language=language, vad_filter=args.vad_filter)
107
- # for segment in segments:
108
- # text = segment.text
109
- # if to_simple == 1:
110
- # # text = convert(text, '')
111
- # pass
112
- # if remove_pun == 1:
113
- # # text = remove_punctuation(text)
114
- # pass
115
- # ret = {"result": text, "start": round(segment.start, 2), "end": round(segment.end, 2)}
116
- # #
117
- # yield json.dumps(ret).encode() + b"\0"
118
-
119
-
120
- # @app.post("/recognition_stream")
121
- # async def api_recognition_stream(
122
- # to_simple: int = Body(1, description="", embed=True),
123
- # remove_pun: int = Body(0, description="", embed=True),
124
- # language: str = Body("bn", description="", embed=True),
125
- # task: str = Body("transcribe", description="", embed=True),
126
- # audio: UploadFile = File(..., description="")
127
- # ):
128
-
129
- # global model_semaphore
130
- # if language == "None": language = None
131
- # if model_semaphore is None:
132
- # model_semaphore = asyncio.Semaphore(5)
133
- # await model_semaphore.acquire()
134
- # contents = await audio.read()
135
- # data = BytesIO(contents)
136
- # generator = recognition(
137
- # file=data, to_simple=to_simple,
138
- # remove_pun=remove_pun, language=language,
139
- # task=task
140
- # )
141
- # background_tasks = BackgroundTasks()
142
- # background_tasks.add_task(release_model_semaphore)
143
- # return StreamingResponse(generator, background=background_tasks)
144
-
145
-
146
  @app.post("/embed")
147
  async def api_embed(
148
  textA: str = Body("text1", description="", embed=True),
 
8
  import uvicorn
9
  from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request
10
  from fastapi.responses import StreamingResponse
 
11
  from starlette.staticfiles import StaticFiles
12
  from starlette.templating import Jinja2Templates
13
  from sentence_transformers import SentenceTransformer
 
14
 
15
  # from utils.data_utils import remove_punctuation
16
  # from utils.utils import add_arguments, print_arguments
17
 
18
 
 
 
 
 
 
 
 
 
19
  def print_arguments(args):
20
  print("----------- Configuration Arguments -----------")
21
  for arg, value in vars(args).items():
 
67
  print_arguments(args)
68
 
69
  #
70
+ assert os.path.exists(args.model_path), f"{args.model_path}"
71
  #
72
  if args.use_gpu:
73
  model = SentenceTransformer(args.model_path, device="cuda", compute_type="float16", cache_folder=".")
 
75
  model = SentenceTransformer(args.model_path, device='cpu', cache_folder=".")
76
 
77
 
 
 
 
78
  app = FastAPI(title="embedding Inference")
79
  # app.mount('/static', StaticFiles(directory='static'), name='static')
80
  # templates = Jinja2Templates(directory="templates")
81
  # model_semaphore = None
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.post("/embed")
85
  async def api_embed(
86
  textA: str = Body("text1", description="", embed=True),
download.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import os
4
+ from tqdm import tqdm
5
+
6
+ def download_file(url, path):
7
+ response = requests.get(url, stream=True)
8
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
9
+ block_size = 1024 #1 Kbyte
10
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
11
+
12
+ with open(path, 'wb') as file:
13
+ for data in response.iter_content(block_size):
14
+ progress_bar.update(len(data))
15
+ file.write(data)
16
+
17
+ progress_bar.close()
18
+
19
+ def download_model(model_name, destination_folder="models"):
20
+ # Define the base URL and headers for the Hugging Face API
21
+ base_url = f"https://huggingface.co/{model_name}/resolve/main"
22
+ headers = {"User-Agent": "Hugging Face Python"}
23
+
24
+ # Send a GET request to the Hugging Face API to get a list of all files
25
+ response = requests.get(f"https://huggingface.co/api/models/{model_name}", headers=headers)
26
+ response.raise_for_status()
27
+
28
+ # Extract the list of files from the response JSON
29
+ files_to_download = [file["rfilename"] for file in response.json()["siblings"]]
30
+
31
+ # Ensure the directory exists
32
+ os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True)
33
+
34
+ # Download each file
35
+ for file in files_to_download:
36
+ print(f"Downloading {file}...")
37
+ download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}")
38
+
39
+ if __name__ == "__main__":
40
+ # parser = argparse.ArgumentParser()
41
+ # parser.add_argument("model_name", type=str, default="sam2ai/whisper-odia-small-finetune-int8-ct2", help="Name of the model to download.")
42
+ # args = parser.parse_args()
43
+
44
+ download_model("BAAI/bge-small-en-v1.5")