Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,605 Bytes
3df4fd5 b63cd34 3df4fd5 fc3f0ed 1d06ec0 3df4fd5 b63cd34 288103a 3df4fd5 3af4a0e b63cd34 288103a b63cd34 288103a 3df4fd5 318b03c 8c155cc b63cd34 318b03c fc3f0ed 3af4a0e 318b03c 3df4fd5 318b03c 1d06ec0 318b03c 8c155cc 3df4fd5 8c155cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
P = ParamSpec('P')
TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
pipeline.transformer.fuse_qkv_projections()
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
spaces.aoti_apply(compile_transformer(), pipeline.transformer)
|