Stanislav commited on
Commit
08a4a7f
·
1 Parent(s): 3d9f0a5

feat: IMPORTANT changes to write DINO MODEL

Browse files
Files changed (2) hide show
  1. models/dino.py +12 -3
  2. run_fastapi.py +32 -2
models/dino.py CHANGED
@@ -16,7 +16,7 @@ class DinoWrapper:
16
  Wrapper for Grounding DINO model for text-prompt-based object detection.
17
  """
18
 
19
- def __init__(self, model_name="IDEA-Research/grounding-dino-base", device=None):
20
  """
21
  Initialize the Grounding DINO model.
22
 
@@ -27,8 +27,17 @@ class DinoWrapper:
27
  device = "cpu"
28
 
29
  self.device = device
30
- self.model = GroundingDinoForObjectDetection.from_pretrained(model_name).to(self.device)
31
- self.processor = GroundingDinoProcessor.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def predict_boxes(self, image, prompt, box_threshold=0.15, text_threshold=0.18):
 
16
  Wrapper for Grounding DINO model for text-prompt-based object detection.
17
  """
18
 
19
+ def __init__(self, model_dir, device=None):
20
  """
21
  Initialize the Grounding DINO model.
22
 
 
27
  device = "cpu"
28
 
29
  self.device = device
30
+
31
+ self.model = GroundingDinoForObjectDetection.from_pretrained(
32
+ pretrained_model_name_or_path=model_dir,
33
+ local_files_only=True,
34
+ use_safetensors=True
35
+ ).to(self.device)
36
+
37
+ self.processor = GroundingDinoProcessor.from_pretrained(
38
+ pretrained_model_name_or_path=model_dir,
39
+ local_files_only=True
40
+ )
41
 
42
 
43
  def predict_boxes(self, image, prompt, box_threshold=0.15, text_threshold=0.18):
run_fastapi.py CHANGED
@@ -41,7 +41,7 @@ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") # still
41
  app.mount("/static", StaticFiles(directory="static"), name="static")
42
  templates = Jinja2Templates(directory="templates")
43
 
44
- # === Download and load model checkpoint ===
45
  FILENAME = "sam_vit_b_01ec64.pth"
46
  REPO_ID = "stkrk/sam-vit-b-checkpoint"
47
  MODEL_PATH = os.path.join(WEIGHTS_DIR, FILENAME)
@@ -59,12 +59,42 @@ if not os.path.exists(MODEL_PATH):
59
  else:
60
  print(f"Model already exists at {MODEL_PATH}.")
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # --- Initialize models
63
  sam = SamWrapper(
64
  model_type="vit_b",
65
  checkpoint_path=MODEL_PATH
66
  )
67
- dino = DinoWrapper()
68
 
69
 
70
  @app.get("/")
 
41
  app.mount("/static", StaticFiles(directory="static"), name="static")
42
  templates = Jinja2Templates(directory="templates")
43
 
44
+ # === Download and load model SAM checkpoint ===
45
  FILENAME = "sam_vit_b_01ec64.pth"
46
  REPO_ID = "stkrk/sam-vit-b-checkpoint"
47
  MODEL_PATH = os.path.join(WEIGHTS_DIR, FILENAME)
 
59
  else:
60
  print(f"Model already exists at {MODEL_PATH}.")
61
 
62
+ # === Download and prepare Grounding DINO checkpoint ===
63
+ DINO_REPO_ID = "stkrk/dino_base"
64
+ DINO_DIR = os.path.join(WEIGHTS_DIR, "grounding_dino_base")
65
+ os.makedirs(DINO_DIR, exist_ok=True)
66
+
67
+ DINO_FILES = [
68
+ "config.json",
69
+ "model.safetensors",
70
+ "preprocessor_config.json",
71
+ "special_tokens_map.json",
72
+ "tokenizer_config.json",
73
+ "tokenizer.json",
74
+ "vocab.txt"
75
+ ]
76
+
77
+ for filename in DINO_FILES:
78
+ target_path = os.path.join(DINO_DIR, filename)
79
+ if not os.path.exists(target_path):
80
+ print(f"Downloading {filename} from {DINO_REPO_ID}...")
81
+ hf_hub_download(
82
+ repo_id=DINO_REPO_ID,
83
+ filename=filename,
84
+ cache_dir=DINO_DIR,
85
+ local_dir=DINO_DIR,
86
+ local_dir_use_symlinks=False
87
+ )
88
+ else:
89
+ print(f"{filename} already exists in {DINO_DIR}.")
90
+
91
+
92
  # --- Initialize models
93
  sam = SamWrapper(
94
  model_type="vit_b",
95
  checkpoint_path=MODEL_PATH
96
  )
97
+ dino = DinoWrapper(model_dir=DINO_DIR)
98
 
99
 
100
  @app.get("/")