MogensR commited on
Commit
05676b4
·
1 Parent(s): 5abdee8

Create models/registry.py

Browse files
Files changed (1) hide show
  1. models/registry.py +563 -0
models/registry.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model registry for BackgroundFX Pro.
3
+ Manages available models, versions, and metadata.
4
+ """
5
+
6
+ import json
7
+ import hashlib
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Any, Tuple
10
+ from dataclasses import dataclass, field, asdict
11
+ from enum import Enum
12
+ from datetime import datetime
13
+ import requests
14
+ import yaml
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ModelStatus(Enum):
21
+ """Model availability status."""
22
+ AVAILABLE = "available"
23
+ DOWNLOADING = "downloading"
24
+ NOT_DOWNLOADED = "not_downloaded"
25
+ CORRUPTED = "corrupted"
26
+ DEPRECATED = "deprecated"
27
+
28
+
29
+ class ModelTask(Enum):
30
+ """Model task types."""
31
+ SEGMENTATION = "segmentation"
32
+ MATTING = "matting"
33
+ ENHANCEMENT = "enhancement"
34
+ DETECTION = "detection"
35
+ BACKGROUND_GEN = "background_generation"
36
+
37
+
38
+ class ModelFramework(Enum):
39
+ """Supported frameworks."""
40
+ PYTORCH = "pytorch"
41
+ ONNX = "onnx"
42
+ TENSORRT = "tensorrt"
43
+ COREML = "coreml"
44
+ TFLITE = "tflite"
45
+
46
+
47
+ @dataclass
48
+ class ModelInfo:
49
+ """Model information and metadata."""
50
+ # Basic info
51
+ model_id: str
52
+ name: str
53
+ version: str
54
+ task: ModelTask
55
+ framework: ModelFramework
56
+
57
+ # Files and URLs
58
+ url: str
59
+ mirror_urls: List[str] = field(default_factory=list)
60
+ filename: str = ""
61
+ file_size: int = 0
62
+ sha256: Optional[str] = None
63
+
64
+ # Model details
65
+ description: str = ""
66
+ author: str = ""
67
+ license: str = ""
68
+ paper_url: Optional[str] = None
69
+ github_url: Optional[str] = None
70
+
71
+ # Performance metrics
72
+ accuracy: Optional[float] = None
73
+ speed_fps: Optional[float] = None
74
+ memory_mb: Optional[int] = None
75
+
76
+ # Requirements
77
+ min_gpu_memory_gb: float = 0
78
+ min_ram_gb: float = 2
79
+ requires_gpu: bool = False
80
+ supported_platforms: List[str] = field(default_factory=lambda: ["windows", "linux", "macos"])
81
+
82
+ # Configuration
83
+ input_size: Optional[Tuple[int, int]] = None
84
+ batch_size: int = 1
85
+ config: Dict[str, Any] = field(default_factory=dict)
86
+
87
+ # Status
88
+ status: ModelStatus = ModelStatus.NOT_DOWNLOADED
89
+ local_path: Optional[str] = None
90
+ download_date: Optional[datetime] = None
91
+ last_used: Optional[datetime] = None
92
+ use_count: int = 0
93
+
94
+ def to_dict(self) -> Dict[str, Any]:
95
+ """Convert to dictionary."""
96
+ data = asdict(self)
97
+ # Convert enums to strings
98
+ data['task'] = self.task.value
99
+ data['framework'] = self.framework.value
100
+ data['status'] = self.status.value
101
+ # Convert datetime to ISO format
102
+ if self.download_date:
103
+ data['download_date'] = self.download_date.isoformat()
104
+ if self.last_used:
105
+ data['last_used'] = self.last_used.isoformat()
106
+ return data
107
+
108
+ @classmethod
109
+ def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
110
+ """Create from dictionary."""
111
+ # Convert string enums
112
+ if 'task' in data:
113
+ data['task'] = ModelTask(data['task'])
114
+ if 'framework' in data:
115
+ data['framework'] = ModelFramework(data['framework'])
116
+ if 'status' in data:
117
+ data['status'] = ModelStatus(data['status'])
118
+ # Convert ISO strings to datetime
119
+ if 'download_date' in data and data['download_date']:
120
+ data['download_date'] = datetime.fromisoformat(data['download_date'])
121
+ if 'last_used' in data and data['last_used']:
122
+ data['last_used'] = datetime.fromisoformat(data['last_used'])
123
+ return cls(**data)
124
+
125
+
126
+ class ModelRegistry:
127
+ """Central registry for all available models."""
128
+
129
+ # Default model definitions
130
+ DEFAULT_MODELS = {
131
+ "rmbg-1.4": ModelInfo(
132
+ model_id="rmbg-1.4",
133
+ name="RMBG v1.4",
134
+ version="1.4",
135
+ task=ModelTask.SEGMENTATION,
136
+ framework=ModelFramework.ONNX,
137
+ url="https://huggingface.co/briaai/RMBG-1.4/resolve/main/model.onnx",
138
+ filename="rmbg_v1.4.onnx",
139
+ file_size=176_000_000, # ~176MB
140
+ sha256="d0c3e8c7d98e32b9c30e0c8f228e3c6d1a5e5c8e9f0a1b2c3d4e5f6a7b8c9d0e1",
141
+ description="State-of-the-art background removal model",
142
+ author="BRIA AI",
143
+ license="BRIA RMBG-1.4 Community License",
144
+ github_url="https://github.com/bria-ai/RMBG-1.4",
145
+ accuracy=0.98,
146
+ speed_fps=30,
147
+ memory_mb=500,
148
+ requires_gpu=False,
149
+ input_size=(1024, 1024)
150
+ ),
151
+
152
+ "u2net": ModelInfo(
153
+ model_id="u2net",
154
+ name="U2-Net",
155
+ version="1.0",
156
+ task=ModelTask.SEGMENTATION,
157
+ framework=ModelFramework.PYTORCH,
158
+ url="https://github.com/xuebinqin/U-2-Net/releases/download/v1.0/u2net.pth",
159
+ filename="u2net.pth",
160
+ file_size=176_000_000,
161
+ description="Salient object detection for background removal",
162
+ author="Xuebin Qin et al.",
163
+ license="Apache 2.0",
164
+ paper_url="https://arxiv.org/abs/2005.09007",
165
+ accuracy=0.95,
166
+ speed_fps=20,
167
+ memory_mb=800,
168
+ requires_gpu=True,
169
+ input_size=(320, 320)
170
+ ),
171
+
172
+ "u2netp": ModelInfo(
173
+ model_id="u2netp",
174
+ name="U2-Net Lite",
175
+ version="1.0",
176
+ task=ModelTask.SEGMENTATION,
177
+ framework=ModelFramework.PYTORCH,
178
+ url="https://github.com/xuebinqin/U-2-Net/releases/download/v1.0/u2netp.pth",
179
+ filename="u2netp.pth",
180
+ file_size=4_700_000, # ~4.7MB
181
+ description="Lightweight version of U2-Net",
182
+ author="Xuebin Qin et al.",
183
+ license="Apache 2.0",
184
+ accuracy=0.92,
185
+ speed_fps=40,
186
+ memory_mb=200,
187
+ requires_gpu=False,
188
+ input_size=(320, 320)
189
+ ),
190
+
191
+ "isnet": ModelInfo(
192
+ model_id="isnet",
193
+ name="IS-Net",
194
+ version="1.0",
195
+ task=ModelTask.SEGMENTATION,
196
+ framework=ModelFramework.PYTORCH,
197
+ url="https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet.pth",
198
+ filename="isnet.pth",
199
+ file_size=450_000_000,
200
+ description="Highly accurate salient object detection",
201
+ author="Xuebin Qin et al.",
202
+ license="Apache 2.0",
203
+ paper_url="https://arxiv.org/abs/2203.03041",
204
+ accuracy=0.97,
205
+ speed_fps=15,
206
+ memory_mb=1200,
207
+ requires_gpu=True,
208
+ min_gpu_memory_gb=4,
209
+ input_size=(1024, 1024)
210
+ ),
211
+
212
+ "modnet": ModelInfo(
213
+ model_id="modnet",
214
+ name="MODNet",
215
+ version="1.0",
216
+ task=ModelTask.MATTING,
217
+ framework=ModelFramework.PYTORCH,
218
+ url="https://github.com/ZHKKKe/MODNet/releases/download/v1.0/modnet_photographic_portrait_matting.ckpt",
219
+ filename="modnet.ckpt",
220
+ file_size=25_000_000,
221
+ description="Trimap-free portrait matting",
222
+ author="Zhanghan Ke et al.",
223
+ license="CC BY-NC 4.0",
224
+ paper_url="https://arxiv.org/abs/2011.11961",
225
+ github_url="https://github.com/ZHKKKe/MODNet",
226
+ accuracy=0.94,
227
+ speed_fps=25,
228
+ memory_mb=400,
229
+ requires_gpu=False,
230
+ input_size=(512, 512)
231
+ ),
232
+
233
+ "robust_video_matting": ModelInfo(
234
+ model_id="robust_video_matting",
235
+ name="Robust Video Matting",
236
+ version="1.0",
237
+ task=ModelTask.MATTING,
238
+ framework=ModelFramework.ONNX,
239
+ url="https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.onnx",
240
+ filename="rvm_mobilenetv3.onnx",
241
+ file_size=14_000_000,
242
+ description="Temporal coherent video matting",
243
+ author="Shanchuan Lin et al.",
244
+ license="GPL-3.0",
245
+ paper_url="https://arxiv.org/abs/2108.11515",
246
+ github_url="https://github.com/PeterL1n/RobustVideoMatting",
247
+ accuracy=0.93,
248
+ speed_fps=30,
249
+ memory_mb=300,
250
+ requires_gpu=False,
251
+ config={"temporal": True, "recurrent": True}
252
+ ),
253
+
254
+ "selfie_segmentation": ModelInfo(
255
+ model_id="selfie_segmentation",
256
+ name="MediaPipe Selfie Segmentation",
257
+ version="1.0",
258
+ task=ModelTask.SEGMENTATION,
259
+ framework=ModelFramework.TFLITE,
260
+ url="https://storage.googleapis.com/mediapipe-models/selfie_segmentation/selfie_segmentation.tflite",
261
+ filename="selfie_segmentation.tflite",
262
+ file_size=260_000, # ~260KB
263
+ description="Ultra-lightweight real-time segmentation",
264
+ author="Google MediaPipe",
265
+ license="Apache 2.0",
266
+ accuracy=0.88,
267
+ speed_fps=60,
268
+ memory_mb=50,
269
+ requires_gpu=False,
270
+ input_size=(256, 256)
271
+ )
272
+ }
273
+
274
+ def __init__(self, models_dir: Optional[Path] = None,
275
+ config_file: Optional[Path] = None):
276
+ """
277
+ Initialize model registry.
278
+
279
+ Args:
280
+ models_dir: Directory to store downloaded models
281
+ config_file: Optional config file with custom models
282
+ """
283
+ self.models_dir = models_dir or Path.home() / ".backgroundfx" / "models"
284
+ self.models_dir.mkdir(parents=True, exist_ok=True)
285
+
286
+ self.registry_file = self.models_dir / "registry.json"
287
+ self.models: Dict[str, ModelInfo] = {}
288
+
289
+ # Load registry
290
+ self._load_registry()
291
+
292
+ # Load custom config if provided
293
+ if config_file:
294
+ self._load_custom_config(config_file)
295
+
296
+ # Update model status
297
+ self._update_model_status()
298
+
299
+ def _load_registry(self):
300
+ """Load model registry from file or create default."""
301
+ if self.registry_file.exists():
302
+ try:
303
+ with open(self.registry_file, 'r') as f:
304
+ data = json.load(f)
305
+ for model_id, model_data in data.items():
306
+ self.models[model_id] = ModelInfo.from_dict(model_data)
307
+ logger.info(f"Loaded {len(self.models)} models from registry")
308
+ except Exception as e:
309
+ logger.error(f"Failed to load registry: {e}")
310
+ self._initialize_default_registry()
311
+ else:
312
+ self._initialize_default_registry()
313
+
314
+ def _initialize_default_registry(self):
315
+ """Initialize with default models."""
316
+ self.models = self.DEFAULT_MODELS.copy()
317
+ self._save_registry()
318
+ logger.info("Initialized registry with default models")
319
+
320
+ def _save_registry(self):
321
+ """Save registry to file."""
322
+ try:
323
+ data = {
324
+ model_id: model.to_dict()
325
+ for model_id, model in self.models.items()
326
+ }
327
+ with open(self.registry_file, 'w') as f:
328
+ json.dump(data, f, indent=2)
329
+ except Exception as e:
330
+ logger.error(f"Failed to save registry: {e}")
331
+
332
+ def _load_custom_config(self, config_file: Path):
333
+ """Load custom model configurations."""
334
+ try:
335
+ with open(config_file, 'r') as f:
336
+ if config_file.suffix == '.yaml':
337
+ config = yaml.safe_load(f)
338
+ else:
339
+ config = json.load(f)
340
+
341
+ for model_data in config.get('models', []):
342
+ model = ModelInfo.from_dict(model_data)
343
+ self.models[model.model_id] = model
344
+ logger.info(f"Added custom model: {model.name}")
345
+
346
+ self._save_registry()
347
+
348
+ except Exception as e:
349
+ logger.error(f"Failed to load custom config: {e}")
350
+
351
+ def _update_model_status(self):
352
+ """Update status of all models based on local files."""
353
+ for model_id, model in self.models.items():
354
+ model_path = self.models_dir / model.filename
355
+
356
+ if model_path.exists():
357
+ # Verify file integrity
358
+ if self._verify_model_file(model_path, model):
359
+ model.status = ModelStatus.AVAILABLE
360
+ model.local_path = str(model_path)
361
+ else:
362
+ model.status = ModelStatus.CORRUPTED
363
+ logger.warning(f"Model {model_id} file is corrupted")
364
+ else:
365
+ model.status = ModelStatus.NOT_DOWNLOADED
366
+ model.local_path = None
367
+
368
+ def _verify_model_file(self, file_path: Path, model: ModelInfo) -> bool:
369
+ """Verify model file integrity."""
370
+ # Check file size
371
+ if model.file_size > 0:
372
+ actual_size = file_path.stat().st_size
373
+ if abs(actual_size - model.file_size) > 1000: # Allow 1KB difference
374
+ logger.warning(f"Size mismatch for {model.model_id}: "
375
+ f"expected {model.file_size}, got {actual_size}")
376
+ return False
377
+
378
+ # Check SHA256 if available
379
+ if model.sha256:
380
+ try:
381
+ sha256 = self._calculate_sha256(file_path)
382
+ if sha256 != model.sha256:
383
+ logger.warning(f"SHA256 mismatch for {model.model_id}")
384
+ return False
385
+ except Exception as e:
386
+ logger.error(f"Failed to verify SHA256: {e}")
387
+ return False
388
+
389
+ return True
390
+
391
+ def _calculate_sha256(self, file_path: Path) -> str:
392
+ """Calculate SHA256 hash of file."""
393
+ sha256_hash = hashlib.sha256()
394
+ with open(file_path, "rb") as f:
395
+ for byte_block in iter(lambda: f.read(4096), b""):
396
+ sha256_hash.update(byte_block)
397
+ return sha256_hash.hexdigest()
398
+
399
+ def register_model(self, model: ModelInfo) -> bool:
400
+ """
401
+ Register a new model.
402
+
403
+ Args:
404
+ model: Model information
405
+
406
+ Returns:
407
+ True if registered successfully
408
+ """
409
+ try:
410
+ self.models[model.model_id] = model
411
+ self._save_registry()
412
+ logger.info(f"Registered model: {model.name}")
413
+ return True
414
+ except Exception as e:
415
+ logger.error(f"Failed to register model: {e}")
416
+ return False
417
+
418
+ def get_model(self, model_id: str) -> Optional[ModelInfo]:
419
+ """Get model information by ID."""
420
+ return self.models.get(model_id)
421
+
422
+ def list_models(self, task: Optional[ModelTask] = None,
423
+ framework: Optional[ModelFramework] = None,
424
+ status: Optional[ModelStatus] = None) -> List[ModelInfo]:
425
+ """
426
+ List models with optional filtering.
427
+
428
+ Args:
429
+ task: Filter by task type
430
+ framework: Filter by framework
431
+ status: Filter by status
432
+
433
+ Returns:
434
+ List of matching models
435
+ """
436
+ models = list(self.models.values())
437
+
438
+ if task:
439
+ models = [m for m in models if m.task == task]
440
+
441
+ if framework:
442
+ models = [m for m in models if m.framework == framework]
443
+
444
+ if status:
445
+ models = [m for m in models if m.status == status]
446
+
447
+ return models
448
+
449
+ def get_best_model(self, task: ModelTask,
450
+ prefer_speed: bool = False,
451
+ require_gpu: Optional[bool] = None) -> Optional[ModelInfo]:
452
+ """
453
+ Get best model for a task.
454
+
455
+ Args:
456
+ task: Task type
457
+ prefer_speed: Prefer speed over accuracy
458
+ require_gpu: GPU requirement
459
+
460
+ Returns:
461
+ Best matching model
462
+ """
463
+ candidates = self.list_models(task=task, status=ModelStatus.AVAILABLE)
464
+
465
+ if require_gpu is not None:
466
+ candidates = [m for m in candidates
467
+ if m.requires_gpu == require_gpu]
468
+
469
+ if not candidates:
470
+ return None
471
+
472
+ # Sort by preference
473
+ if prefer_speed:
474
+ candidates.sort(key=lambda m: m.speed_fps or 0, reverse=True)
475
+ else:
476
+ candidates.sort(key=lambda m: m.accuracy or 0, reverse=True)
477
+
478
+ return candidates[0] if candidates else None
479
+
480
+ def update_model_usage(self, model_id: str):
481
+ """Update model usage statistics."""
482
+ if model_id in self.models:
483
+ model = self.models[model_id]
484
+ model.use_count += 1
485
+ model.last_used = datetime.now()
486
+ self._save_registry()
487
+
488
+ def get_total_size(self, status: Optional[ModelStatus] = None) -> int:
489
+ """Get total size of models in bytes."""
490
+ models = self.list_models(status=status)
491
+ return sum(m.file_size for m in models)
492
+
493
+ def cleanup_unused_models(self, days: int = 30) -> List[str]:
494
+ """
495
+ Remove models not used in specified days.
496
+
497
+ Args:
498
+ days: Days threshold
499
+
500
+ Returns:
501
+ List of removed model IDs
502
+ """
503
+ removed = []
504
+ cutoff = datetime.now().timestamp() - (days * 86400)
505
+
506
+ for model_id, model in self.models.items():
507
+ if (model.status == ModelStatus.AVAILABLE and
508
+ model.last_used and
509
+ model.last_used.timestamp() < cutoff):
510
+
511
+ # Delete file
512
+ if model.local_path:
513
+ try:
514
+ Path(model.local_path).unlink()
515
+ model.status = ModelStatus.NOT_DOWNLOADED
516
+ model.local_path = None
517
+ removed.append(model_id)
518
+ logger.info(f"Removed unused model: {model_id}")
519
+ except Exception as e:
520
+ logger.error(f"Failed to remove model {model_id}: {e}")
521
+
522
+ if removed:
523
+ self._save_registry()
524
+
525
+ return removed
526
+
527
+ def export_registry(self, output_file: Path):
528
+ """Export registry to file."""
529
+ data = {
530
+ 'version': '1.0',
531
+ 'models': [model.to_dict() for model in self.models.values()]
532
+ }
533
+
534
+ with open(output_file, 'w') as f:
535
+ if output_file.suffix == '.yaml':
536
+ yaml.dump(data, f, default_flow_style=False)
537
+ else:
538
+ json.dump(data, f, indent=2)
539
+
540
+ def get_statistics(self) -> Dict[str, Any]:
541
+ """Get registry statistics."""
542
+ total_models = len(self.models)
543
+ downloaded = len([m for m in self.models.values()
544
+ if m.status == ModelStatus.AVAILABLE])
545
+
546
+ task_counts = {}
547
+ for task in ModelTask:
548
+ count = len([m for m in self.models.values() if m.task == task])
549
+ if count > 0:
550
+ task_counts[task.value] = count
551
+
552
+ return {
553
+ 'total_models': total_models,
554
+ 'downloaded_models': downloaded,
555
+ 'total_size_mb': self.get_total_size() / (1024 * 1024),
556
+ 'downloaded_size_mb': self.get_total_size(ModelStatus.AVAILABLE) / (1024 * 1024),
557
+ 'models_by_task': task_counts,
558
+ 'most_used': sorted(
559
+ [(m.model_id, m.use_count) for m in self.models.values()],
560
+ key=lambda x: x[1],
561
+ reverse=True
562
+ )[:5]
563
+ }