Spaces:
Running
on
Zero
Running
on
Zero
""" | |
""" | |
from typing import Any | |
from typing import Callable | |
from typing import ParamSpec | |
import spaces | |
import torch | |
from torch.utils._pytree import tree_map_only | |
from torchao.quantization import quantize_ | |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig | |
from optimization_utils import capture_component_call | |
from optimization_utils import aoti_compile | |
P = ParamSpec('P') | |
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21) | |
TRANSFORMER_DYNAMIC_SHAPES = { | |
'hidden_states': { | |
2: TRANSFORMER_NUM_FRAMES_DIM, | |
}, | |
} | |
INDUCTOR_CONFIGS = { | |
'conv_1x1_as_mm': True, | |
'epilogue_fusion': False, | |
'coordinate_descent_tuning': True, | |
'coordinate_descent_check_all_directions': True, | |
'max_autotune': True, | |
'triton.cudagraphs': True, | |
} | |
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): | |
def compile_transformer(): | |
with capture_component_call(pipeline, 'transformer') as call: | |
pipeline(*args, **kwargs) | |
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs) | |
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES | |
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig()) | |
hidden_states: torch.Tensor = call.kwargs['hidden_states'] | |
hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous() | |
if hidden_states.shape[-1] > hidden_states.shape[-2]: | |
hidden_states_landscape = hidden_states | |
hidden_states_portrait = hidden_states_transposed | |
else: | |
hidden_states_landscape = hidden_states_transposed | |
hidden_states_portrait = hidden_states | |
exported_landscape = torch.export.export( | |
mod=pipeline.transformer, | |
args=call.args, | |
kwargs=call.kwargs | {'hidden_states': hidden_states_landscape}, | |
dynamic_shapes=dynamic_shapes, | |
) | |
exported_portrait = torch.export.export( | |
mod=pipeline.transformer, | |
args=call.args, | |
kwargs=call.kwargs | {'hidden_states': hidden_states_portrait}, | |
dynamic_shapes=dynamic_shapes, | |
) | |
compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS) | |
compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS) | |
compiled_portrait.weights = compiled_landscape.weights # Avoid weights duplication when serializing back to main process | |
return compiled_landscape, compiled_portrait | |
compiled_landscape, compiled_portrait = compile_transformer() | |
def combined_transformer(*args, **kwargs): | |
hidden_states: torch.Tensor = kwargs['hidden_states'] | |
if hidden_states.shape[-1] > hidden_states.shape[-2]: | |
return compiled_landscape(*args, **kwargs) | |
else: | |
return compiled_portrait(*args, **kwargs) | |
transformer_config = pipeline.transformer.config | |
transformer_dtype = pipeline.transformer.dtype | |
pipeline.transformer = combined_transformer | |
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue] | |
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue] | |