YiYiXu's picture
Update block.py
f580baa verified
from typing import List
import torch
from diffusers.modular_pipelines import PipelineState, PipelineBlock, SequentialPipelineBlocks, AutoPipelineBlocks
from diffusers.modular_pipelines.modular_pipeline_utils import (
InputParam,
ComponentSpec,
OutputParam,
)
from diffusers.utils import load_image
from diffusers.image_processor import PipelineImageInput
from image_gen_aux import DepthPreprocessor
class DepthProcessorBlock(PipelineBlock):
@property
def expected_components(self):
return [
ComponentSpec(
name="depth_processor",
type_hint=DepthPreprocessor,
subfolder="",
repo="depth-anything/Depth-Anything-V2-Large-hf",
)
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
PipelineImageInput,
description="Image(s) to use to extract depth maps",
)
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
PipelineImageInput,
description="Image(s) to use to extract depth maps, can be output from LoadURL block",
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image",
type_hint=torch.Tensor,
description="Depth Map(s) of input Image(s)",
),
]
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = pipeline._execution_device
image = block_state.image
depth_map = pipeline.depth_processor(image, return_type="pt")
block_state.image = depth_map.to(device)
self.add_block_state(state, block_state)
return pipeline, state
class LoadURL(PipelineBlock):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"url",
str,
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image",
type_hint=PipelineImageInput,
description="Image(s) to use to extract depth maps",
),
]
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.image = load_image(block_state.url)
self.add_block_state(state, block_state)
return pipeline, state
class AutoLoadURL(AutoPipelineBlocks):
block_classes = [LoadURL]
block_names = ["url_to_image"]
block_trigger_inputs = ["url"]
@property
def description(self):
return "Run if `url` is provided."
class DepthInput(SequentialPipelineBlocks):
block_classes = [AutoLoadURL, DepthProcessorBlock]
block_names = ["load_url", "depth_processor"]