Stanislav commited on
Commit
24ec3e7
·
1 Parent(s): 7d6bd1b

feat: changed tmp dir for all

Browse files
Files changed (1) hide show
  1. run_fastapi.py +23 -12
run_fastapi.py CHANGED
@@ -15,31 +15,42 @@ from models.dino import DinoWrapper
15
 
16
  from huggingface_hub import hf_hub_download
17
 
 
18
  print("WRITE to ./weights:", os.access("weights", os.W_OK))
19
  print("WRITE to /tmp:", os.access("/tmp", os.W_OK))
20
 
21
- # --- Init app and database
 
 
 
 
 
 
 
 
 
 
 
 
22
  app = FastAPI()
23
  init_db()
24
 
25
- # --- Static and templates
26
- app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
27
  app.mount("/static", StaticFiles(directory="static"), name="static")
28
  templates = Jinja2Templates(directory="templates")
29
 
30
- # === Model checkpoint setup ===
31
  FILENAME = "sam_vit_b_01ec64.pth"
32
  REPO_ID = "stkrk/sam-vit-b-checkpoint"
33
- MODEL_DIR = "weights"
34
- MODEL_PATH = os.path.join(MODEL_DIR, FILENAME)
35
 
36
- # Download if not exists locally
37
  if not os.path.exists(MODEL_PATH):
38
  print(f"Model not found locally. Downloading from {REPO_ID}...")
39
  cached_path = hf_hub_download(
40
  repo_id=REPO_ID,
41
  filename=FILENAME,
42
- cache_dir="./weights/.hf_cache",
43
  local_dir_use_symlinks=False
44
  )
45
  shutil.copy(cached_path, MODEL_PATH)
@@ -47,11 +58,11 @@ if not os.path.exists(MODEL_PATH):
47
  else:
48
  print(f"Model already exists at {MODEL_PATH}.")
49
 
50
- # --- Model initialization (once)
51
  sam = SamWrapper(
52
- model_type="vit_b",
53
- checkpoint_path=MODEL_PATH
54
- )
55
  dino = DinoWrapper()
56
 
57
 
 
15
 
16
  from huggingface_hub import hf_hub_download
17
 
18
+ # Check write permissions
19
  print("WRITE to ./weights:", os.access("weights", os.W_OK))
20
  print("WRITE to /tmp:", os.access("/tmp", os.W_OK))
21
 
22
+ # --- Set base directory depending on environment
23
+ BASE_DIR = "/tmp" if os.access("/tmp", os.W_OK) else "."
24
+
25
+ WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
26
+ UPLOADS_DIR = os.path.join(BASE_DIR, "uploads")
27
+ OUTPUTS_DIR = os.path.join(BASE_DIR, "outputs")
28
+
29
+ # Create directories if not exist
30
+ os.makedirs(WEIGHTS_DIR, exist_ok=True)
31
+ os.makedirs(UPLOADS_DIR, exist_ok=True)
32
+ os.makedirs(OUTPUTS_DIR, exist_ok=True)
33
+
34
+ # --- Initialize FastAPI and database
35
  app = FastAPI()
36
  init_db()
37
 
38
+ # --- Mount static files and templates
39
+ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") # still serve static from project root
40
  app.mount("/static", StaticFiles(directory="static"), name="static")
41
  templates = Jinja2Templates(directory="templates")
42
 
43
+ # === Download and load model checkpoint ===
44
  FILENAME = "sam_vit_b_01ec64.pth"
45
  REPO_ID = "stkrk/sam-vit-b-checkpoint"
46
+ MODEL_PATH = os.path.join(WEIGHTS_DIR, FILENAME)
 
47
 
 
48
  if not os.path.exists(MODEL_PATH):
49
  print(f"Model not found locally. Downloading from {REPO_ID}...")
50
  cached_path = hf_hub_download(
51
  repo_id=REPO_ID,
52
  filename=FILENAME,
53
+ cache_dir=WEIGHTS_DIR,
54
  local_dir_use_symlinks=False
55
  )
56
  shutil.copy(cached_path, MODEL_PATH)
 
58
  else:
59
  print(f"Model already exists at {MODEL_PATH}.")
60
 
61
+ # --- Initialize models
62
  sam = SamWrapper(
63
+ model_type="vit_b",
64
+ checkpoint_path=MODEL_PATH
65
+ )
66
  dino = DinoWrapper()
67
 
68