File size: 1,605 Bytes
3df4fd5
 
 
b63cd34
 
 
 
3df4fd5
 
fc3f0ed
1d06ec0
 
3df4fd5
 
b63cd34
288103a
3df4fd5
3af4a0e
 
 
 
 
 
 
b63cd34
 
 
 
 
 
 
 
288103a
 
b63cd34
288103a
3df4fd5
 
318b03c
8c155cc
b63cd34
318b03c
fc3f0ed
3af4a0e
318b03c
3df4fd5
318b03c
1d06ec0
 
318b03c
 
 
 
 
 
 
8c155cc
3df4fd5
8c155cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig


P = ParamSpec('P')


TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
    'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

    @spaces.GPU(duration=1500)
    def compile_transformer():

        with spaces.aoti_capture(pipeline.transformer) as call:
            pipeline(*args, **kwargs)

        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        pipeline.transformer.fuse_qkv_projections()

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        
        exported = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )

        return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)

    spaces.aoti_apply(compile_transformer(), pipeline.transformer)