Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| from modules.control.util import HWC3, resize_image | |
| from modules import devices | |
| from modules.shared import opts | |
| from .marigold_pipeline import MarigoldPipeline | |
| class MarigoldDetector: | |
| def __init__(self, model): | |
| self.model: MarigoldPipeline = model | |
| def from_pretrained(cls, pretrained_model_or_path, cache_dir=None, **load_config): | |
| model = MarigoldPipeline.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, **load_config) | |
| return cls(model) | |
| def to(self, device): | |
| self.model.to(device) | |
| return self | |
| def __call__( | |
| self, | |
| input_image: Image, | |
| denoising_steps: int = 10, | |
| ensemble_size: int = 10, | |
| processing_res: int = 768, | |
| match_input_res: bool = True, | |
| color_map: str = "Spectral", | |
| output_type=None, | |
| ): | |
| self.model.to(device=devices.device, dtype=devices.dtype) | |
| res = self.model( | |
| input_image, | |
| denoising_steps=denoising_steps, | |
| ensemble_size=ensemble_size, | |
| processing_res=processing_res, | |
| match_input_res=match_input_res, | |
| color_map=color_map if color_map != 'None' else 'Spectral', | |
| batch_size=1, | |
| show_progress_bar=True, | |
| ) | |
| depth_map = res.depth_colored if color_map != 'None' else res.depth_np | |
| if opts.control_move_processor: | |
| self.model.to('cpu') | |
| if output_type == "pil": | |
| return Image.fromarray(depth_map) | |
| else: | |
| return depth_map | |