Stanislav commited on
Commit
86886a1
·
1 Parent(s): ce404c8

feat: in run_fastapi changed model-weights saving method

Browse files
Files changed (1) hide show
  1. run_fastapi.py +20 -3
run_fastapi.py CHANGED
@@ -24,13 +24,30 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
  templates = Jinja2Templates(directory="templates")
26
 
27
- # --- Download checkpoint from model-repo
28
- checkpoint_path = hf_hub_download(repo_id="stkrk/sam-vit-b-checkpoint", filename="sam_vit_b_01ec64.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # --- Model initialization (once)
31
  sam = SamWrapper(
32
  model_type="vit_b",
33
- checkpoint_path=checkpoint_path
34
  )
35
  dino = DinoWrapper()
36
 
 
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
  templates = Jinja2Templates(directory="templates")
26
 
27
+ # === Model checkpoint setup ===
28
+ FILENAME = "sam_vit_b_01ec64.pth"
29
+ REPO_ID = "stkrk/sam-vit-b-checkpoint"
30
+ MODEL_DIR = "weights"
31
+ MODEL_PATH = os.path.join(MODEL_DIR, FILENAME)
32
+
33
+ # Download if not exists locally
34
+ if not os.path.exists(MODEL_PATH):
35
+ print(f"Model not found locally. Downloading from {REPO_ID}...")
36
+ cached_path = hf_hub_download(
37
+ repo_id=REPO_ID,
38
+ filename=FILENAME,
39
+ local_dir=MODEL_DIR,
40
+ local_dir_use_symlinks=False
41
+ )
42
+ shutil.copy(cached_path, MODEL_PATH)
43
+ print(f"Model downloaded and copied to {MODEL_PATH}.")
44
+ else:
45
+ print(f"Model already exists at {MODEL_PATH}.")
46
 
47
  # --- Model initialization (once)
48
  sam = SamWrapper(
49
  model_type="vit_b",
50
+ checkpoint_path=MODEL_PATH
51
  )
52
  dino = DinoWrapper()
53