|
from typing import List |
|
import torch |
|
|
|
from diffusers.modular_pipelines import PipelineState, PipelineBlock, SequentialPipelineBlocks, AutoPipelineBlocks |
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
InputParam, |
|
ComponentSpec, |
|
OutputParam, |
|
) |
|
from diffusers.utils import load_image |
|
from diffusers.image_processor import PipelineImageInput |
|
from image_gen_aux import DepthPreprocessor |
|
|
|
|
|
class DepthProcessorBlock(PipelineBlock): |
|
@property |
|
def expected_components(self): |
|
return [ |
|
ComponentSpec( |
|
name="depth_processor", |
|
type_hint=DepthPreprocessor, |
|
subfolder="", |
|
repo="depth-anything/Depth-Anything-V2-Large-hf", |
|
) |
|
] |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"image", |
|
PipelineImageInput, |
|
description="Image(s) to use to extract depth maps", |
|
) |
|
] |
|
|
|
@property |
|
def intermediates_inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"image", |
|
PipelineImageInput, |
|
description="Image(s) to use to extract depth maps, can be output from LoadURL block", |
|
) |
|
] |
|
|
|
@property |
|
def intermediates_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"image", |
|
type_hint=torch.Tensor, |
|
description="Depth Map(s) of input Image(s)", |
|
), |
|
] |
|
|
|
@torch.no_grad() |
|
def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
device = pipeline._execution_device |
|
|
|
image = block_state.image |
|
depth_map = pipeline.depth_processor(image, return_type="pt") |
|
block_state.image = depth_map.to(device) |
|
|
|
self.add_block_state(state, block_state) |
|
return pipeline, state |
|
|
|
class LoadURL(PipelineBlock): |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"url", |
|
str, |
|
) |
|
] |
|
|
|
@property |
|
def intermediates_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"image", |
|
type_hint=PipelineImageInput, |
|
description="Image(s) to use to extract depth maps", |
|
), |
|
] |
|
|
|
def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
block_state.image = load_image(block_state.url) |
|
self.add_block_state(state, block_state) |
|
return pipeline, state |
|
|
|
class AutoLoadURL(AutoPipelineBlocks): |
|
block_classes = [LoadURL] |
|
block_names = ["url_to_image"] |
|
block_trigger_inputs = ["url"] |
|
|
|
@property |
|
def description(self): |
|
return "Run if `url` is provided." |
|
|
|
class DepthInput(SequentialPipelineBlocks): |
|
block_classes = [AutoLoadURL, DepthProcessorBlock] |
|
block_names = ["load_url", "depth_processor"] |
|
|
|
|
|
|