Single step joint decision
Browse files
JointDecision.mlmodelc/analytics/coremldata.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 243
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1183ba213bb94a918c8d2cad19ab045320618f97f6ca662245b3936d7b090f7
|
3 |
size 243
|
JointDecision.mlmodelc/coremldata.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2c6752f1c8cf2d3f6f26ec93195c9bfa759ad59edf9f806696a138154f96f11
|
3 |
+
size 534
|
JointDecision.mlmodelc/metadata.json
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
[
|
2 |
{
|
3 |
"metadataOutputVersion" : "3.0",
|
4 |
-
"shortDescription" : "Parakeet joint
|
5 |
"outputSchema" : [
|
6 |
{
|
7 |
"hasShapeFlexibility" : "0",
|
8 |
"isOptional" : "0",
|
9 |
"dataType" : "Int32",
|
10 |
-
"formattedType" : "MultiArray (Int32 1 ×
|
11 |
"shortDescription" : "",
|
12 |
-
"shape" : "[1,
|
13 |
"name" : "token_id",
|
14 |
"type" : "MultiArray"
|
15 |
},
|
@@ -17,9 +17,9 @@
|
|
17 |
"hasShapeFlexibility" : "0",
|
18 |
"isOptional" : "0",
|
19 |
"dataType" : "Float32",
|
20 |
-
"formattedType" : "MultiArray (Float32 1 ×
|
21 |
"shortDescription" : "",
|
22 |
-
"shape" : "[1,
|
23 |
"name" : "token_prob",
|
24 |
"type" : "MultiArray"
|
25 |
},
|
@@ -27,9 +27,9 @@
|
|
27 |
"hasShapeFlexibility" : "0",
|
28 |
"isOptional" : "0",
|
29 |
"dataType" : "Int32",
|
30 |
-
"formattedType" : "MultiArray (Int32 1 ×
|
31 |
"shortDescription" : "",
|
32 |
-
"shape" : "[1,
|
33 |
"name" : "duration",
|
34 |
"type" : "MultiArray"
|
35 |
}
|
@@ -74,10 +74,10 @@
|
|
74 |
"hasShapeFlexibility" : "0",
|
75 |
"isOptional" : "0",
|
76 |
"dataType" : "Float32",
|
77 |
-
"formattedType" : "MultiArray (Float32 1 × 1024 ×
|
78 |
"shortDescription" : "",
|
79 |
-
"shape" : "[1, 1024,
|
80 |
-
"name" : "
|
81 |
"type" : "MultiArray"
|
82 |
},
|
83 |
{
|
@@ -87,7 +87,7 @@
|
|
87 |
"formattedType" : "MultiArray (Float32 1 × 640 × 1)",
|
88 |
"shortDescription" : "",
|
89 |
"shape" : "[1, 640, 1]",
|
90 |
-
"name" : "
|
91 |
"type" : "MultiArray"
|
92 |
}
|
93 |
],
|
@@ -97,7 +97,7 @@
|
|
97 |
"com.github.apple.coremltools.version" : "9.0b1",
|
98 |
"com.github.apple.coremltools.source_dialect" : "TorchScript"
|
99 |
},
|
100 |
-
"generatedClassName" : "
|
101 |
"method" : "predict"
|
102 |
}
|
103 |
]
|
|
|
1 |
[
|
2 |
{
|
3 |
"metadataOutputVersion" : "3.0",
|
4 |
+
"shortDescription" : "Parakeet single-step joint decision (current frame)",
|
5 |
"outputSchema" : [
|
6 |
{
|
7 |
"hasShapeFlexibility" : "0",
|
8 |
"isOptional" : "0",
|
9 |
"dataType" : "Int32",
|
10 |
+
"formattedType" : "MultiArray (Int32 1 × 1 × 1)",
|
11 |
"shortDescription" : "",
|
12 |
+
"shape" : "[1, 1, 1]",
|
13 |
"name" : "token_id",
|
14 |
"type" : "MultiArray"
|
15 |
},
|
|
|
17 |
"hasShapeFlexibility" : "0",
|
18 |
"isOptional" : "0",
|
19 |
"dataType" : "Float32",
|
20 |
+
"formattedType" : "MultiArray (Float32 1 × 1 × 1)",
|
21 |
"shortDescription" : "",
|
22 |
+
"shape" : "[1, 1, 1]",
|
23 |
"name" : "token_prob",
|
24 |
"type" : "MultiArray"
|
25 |
},
|
|
|
27 |
"hasShapeFlexibility" : "0",
|
28 |
"isOptional" : "0",
|
29 |
"dataType" : "Int32",
|
30 |
+
"formattedType" : "MultiArray (Int32 1 × 1 × 1)",
|
31 |
"shortDescription" : "",
|
32 |
+
"shape" : "[1, 1, 1]",
|
33 |
"name" : "duration",
|
34 |
"type" : "MultiArray"
|
35 |
}
|
|
|
74 |
"hasShapeFlexibility" : "0",
|
75 |
"isOptional" : "0",
|
76 |
"dataType" : "Float32",
|
77 |
+
"formattedType" : "MultiArray (Float32 1 × 1024 × 1)",
|
78 |
"shortDescription" : "",
|
79 |
+
"shape" : "[1, 1024, 1]",
|
80 |
+
"name" : "encoder_step",
|
81 |
"type" : "MultiArray"
|
82 |
},
|
83 |
{
|
|
|
87 |
"formattedType" : "MultiArray (Float32 1 × 640 × 1)",
|
88 |
"shortDescription" : "",
|
89 |
"shape" : "[1, 640, 1]",
|
90 |
+
"name" : "decoder_step",
|
91 |
"type" : "MultiArray"
|
92 |
}
|
93 |
],
|
|
|
97 |
"com.github.apple.coremltools.version" : "9.0b1",
|
98 |
"com.github.apple.coremltools.source_dialect" : "TorchScript"
|
99 |
},
|
100 |
+
"generatedClassName" : "parakeet_joint_decision_single_step",
|
101 |
"method" : "predict"
|
102 |
}
|
103 |
]
|
JointDecision.mlmodelc/model.mil
CHANGED
@@ -1,58 +1,58 @@
|
|
1 |
program(1.0)
|
2 |
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3500.14.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.7.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0b1"}})]
|
3 |
{
|
4 |
-
func main<ios17>(tensor<fp32, [1, 640, 1]>
|
5 |
tensor<int32, [3]> input_1_perm_0 = const()[name = tensor<string, []>("input_1_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
|
6 |
-
tensor<string, []>
|
7 |
tensor<int32, [3]> input_3_perm_0 = const()[name = tensor<string, []>("input_3_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
|
8 |
-
tensor<string, []>
|
9 |
tensor<fp16, [640, 1024]> joint_module_enc_weight_to_fp16 = const()[name = tensor<string, []>("joint_module_enc_weight_to_fp16"), val = tensor<fp16, [640, 1024]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(64)))];
|
10 |
tensor<fp16, [640]> joint_module_enc_bias_to_fp16 = const()[name = tensor<string, []>("joint_module_enc_bias_to_fp16"), val = tensor<fp16, [640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(1310848)))];
|
11 |
-
tensor<fp16, [1, 1024,
|
12 |
-
tensor<fp16, [1,
|
13 |
-
tensor<fp16, [1,
|
14 |
tensor<fp16, [640, 640]> joint_module_pred_weight_to_fp16 = const()[name = tensor<string, []>("joint_module_pred_weight_to_fp16"), val = tensor<fp16, [640, 640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(1312192)))];
|
15 |
tensor<fp16, [640]> joint_module_pred_bias_to_fp16 = const()[name = tensor<string, []>("joint_module_pred_bias_to_fp16"), val = tensor<fp16, [640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(2131456)))];
|
16 |
-
tensor<fp16, [1, 640, 1]>
|
17 |
-
tensor<fp16, [1, 1, 640]> input_3_cast_fp16 = transpose(perm = input_3_perm_0, x =
|
18 |
tensor<fp16, [1, 1, 640]> linear_1_cast_fp16 = linear(bias = joint_module_pred_bias_to_fp16, weight = joint_module_pred_weight_to_fp16, x = input_3_cast_fp16)[name = tensor<string, []>("linear_1_cast_fp16")];
|
19 |
tensor<int32, [1]> var_23_axes_0 = const()[name = tensor<string, []>("op_23_axes_0"), val = tensor<int32, [1]>([2])];
|
20 |
-
tensor<fp16, [1,
|
21 |
tensor<int32, [1]> var_24_axes_0 = const()[name = tensor<string, []>("op_24_axes_0"), val = tensor<int32, [1]>([1])];
|
22 |
tensor<fp16, [1, 1, 1, 640]> var_24_cast_fp16 = expand_dims(axes = var_24_axes_0, x = linear_1_cast_fp16)[name = tensor<string, []>("op_24_cast_fp16")];
|
23 |
-
tensor<fp16, [1,
|
24 |
-
tensor<fp16, [1,
|
25 |
tensor<fp16, [1030, 640]> joint_module_joint_net_2_weight_to_fp16 = const()[name = tensor<string, []>("joint_module_joint_net_2_weight_to_fp16"), val = tensor<fp16, [1030, 640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(2132800)))];
|
26 |
tensor<fp16, [1030]> joint_module_joint_net_2_bias_to_fp16 = const()[name = tensor<string, []>("joint_module_joint_net_2_bias_to_fp16"), val = tensor<fp16, [1030]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(3451264)))];
|
27 |
-
tensor<fp16, [1,
|
28 |
tensor<int32, [4]> token_logits_begin_0 = const()[name = tensor<string, []>("token_logits_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
|
29 |
-
tensor<int32, [4]> token_logits_end_0 = const()[name = tensor<string, []>("token_logits_end_0"), val = tensor<int32, [4]>([1,
|
30 |
tensor<bool, [4]> token_logits_end_mask_0 = const()[name = tensor<string, []>("token_logits_end_mask_0"), val = tensor<bool, [4]>([true, true, true, false])];
|
31 |
-
tensor<fp16, [1,
|
32 |
tensor<int32, [4]> duration_logits_begin_0 = const()[name = tensor<string, []>("duration_logits_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 1025])];
|
33 |
-
tensor<int32, [4]> duration_logits_end_0 = const()[name = tensor<string, []>("duration_logits_end_0"), val = tensor<int32, [4]>([1,
|
34 |
tensor<bool, [4]> duration_logits_end_mask_0 = const()[name = tensor<string, []>("duration_logits_end_mask_0"), val = tensor<bool, [4]>([true, true, true, true])];
|
35 |
-
tensor<fp16, [1,
|
36 |
tensor<int32, []> var_43_axis_0 = const()[name = tensor<string, []>("op_43_axis_0"), val = tensor<int32, []>(-1)];
|
37 |
tensor<bool, []> var_43_keep_dims_0 = const()[name = tensor<string, []>("op_43_keep_dims_0"), val = tensor<bool, []>(false)];
|
38 |
tensor<string, []> var_43_output_dtype_0 = const()[name = tensor<string, []>("op_43_output_dtype_0"), val = tensor<string, []>("int32")];
|
39 |
-
tensor<int32, [1,
|
40 |
tensor<int32, []> var_49 = const()[name = tensor<string, []>("op_49"), val = tensor<int32, []>(-1)];
|
41 |
-
tensor<fp16, [1,
|
42 |
tensor<int32, [1]> var_58_axes_0 = const()[name = tensor<string, []>("op_58_axes_0"), val = tensor<int32, [1]>([-1])];
|
43 |
-
tensor<int32, [1,
|
44 |
tensor<int32, []> var_59 = const()[name = tensor<string, []>("op_59"), val = tensor<int32, []>(-1)];
|
45 |
tensor<bool, []> var_61_validate_indices_0 = const()[name = tensor<string, []>("op_61_validate_indices_0"), val = tensor<bool, []>(false)];
|
46 |
tensor<string, []> var_58_to_int16_dtype_0 = const()[name = tensor<string, []>("op_58_to_int16_dtype_0"), val = tensor<string, []>("int16")];
|
47 |
-
tensor<int16, [1,
|
48 |
-
tensor<fp16, [1,
|
49 |
tensor<int32, [1]> var_63_axes_0 = const()[name = tensor<string, []>("op_63_axes_0"), val = tensor<int32, [1]>([-1])];
|
50 |
-
tensor<fp16, [1,
|
51 |
tensor<string, []> var_63_cast_fp16_to_fp32_dtype_0 = const()[name = tensor<string, []>("op_63_cast_fp16_to_fp32_dtype_0"), val = tensor<string, []>("fp32")];
|
52 |
tensor<int32, []> var_66_axis_0 = const()[name = tensor<string, []>("op_66_axis_0"), val = tensor<int32, []>(-1)];
|
53 |
tensor<bool, []> var_66_keep_dims_0 = const()[name = tensor<string, []>("op_66_keep_dims_0"), val = tensor<bool, []>(false)];
|
54 |
tensor<string, []> var_66_output_dtype_0 = const()[name = tensor<string, []>("op_66_output_dtype_0"), val = tensor<string, []>("int32")];
|
55 |
-
tensor<int32, [1,
|
56 |
-
tensor<fp32, [1,
|
57 |
} -> (token_id, token_prob, duration);
|
58 |
}
|
|
|
1 |
program(1.0)
|
2 |
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3500.14.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.7.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0b1"}})]
|
3 |
{
|
4 |
+
func main<ios17>(tensor<fp32, [1, 640, 1]> decoder_step, tensor<fp32, [1, 1024, 1]> encoder_step) {
|
5 |
tensor<int32, [3]> input_1_perm_0 = const()[name = tensor<string, []>("input_1_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
|
6 |
+
tensor<string, []> encoder_step_to_fp16_dtype_0 = const()[name = tensor<string, []>("encoder_step_to_fp16_dtype_0"), val = tensor<string, []>("fp16")];
|
7 |
tensor<int32, [3]> input_3_perm_0 = const()[name = tensor<string, []>("input_3_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
|
8 |
+
tensor<string, []> decoder_step_to_fp16_dtype_0 = const()[name = tensor<string, []>("decoder_step_to_fp16_dtype_0"), val = tensor<string, []>("fp16")];
|
9 |
tensor<fp16, [640, 1024]> joint_module_enc_weight_to_fp16 = const()[name = tensor<string, []>("joint_module_enc_weight_to_fp16"), val = tensor<fp16, [640, 1024]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(64)))];
|
10 |
tensor<fp16, [640]> joint_module_enc_bias_to_fp16 = const()[name = tensor<string, []>("joint_module_enc_bias_to_fp16"), val = tensor<fp16, [640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(1310848)))];
|
11 |
+
tensor<fp16, [1, 1024, 1]> encoder_step_to_fp16 = cast(dtype = encoder_step_to_fp16_dtype_0, x = encoder_step)[name = tensor<string, []>("cast_3")];
|
12 |
+
tensor<fp16, [1, 1, 1024]> input_1_cast_fp16 = transpose(perm = input_1_perm_0, x = encoder_step_to_fp16)[name = tensor<string, []>("transpose_1")];
|
13 |
+
tensor<fp16, [1, 1, 640]> linear_0_cast_fp16 = linear(bias = joint_module_enc_bias_to_fp16, weight = joint_module_enc_weight_to_fp16, x = input_1_cast_fp16)[name = tensor<string, []>("linear_0_cast_fp16")];
|
14 |
tensor<fp16, [640, 640]> joint_module_pred_weight_to_fp16 = const()[name = tensor<string, []>("joint_module_pred_weight_to_fp16"), val = tensor<fp16, [640, 640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(1312192)))];
|
15 |
tensor<fp16, [640]> joint_module_pred_bias_to_fp16 = const()[name = tensor<string, []>("joint_module_pred_bias_to_fp16"), val = tensor<fp16, [640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(2131456)))];
|
16 |
+
tensor<fp16, [1, 640, 1]> decoder_step_to_fp16 = cast(dtype = decoder_step_to_fp16_dtype_0, x = decoder_step)[name = tensor<string, []>("cast_2")];
|
17 |
+
tensor<fp16, [1, 1, 640]> input_3_cast_fp16 = transpose(perm = input_3_perm_0, x = decoder_step_to_fp16)[name = tensor<string, []>("transpose_0")];
|
18 |
tensor<fp16, [1, 1, 640]> linear_1_cast_fp16 = linear(bias = joint_module_pred_bias_to_fp16, weight = joint_module_pred_weight_to_fp16, x = input_3_cast_fp16)[name = tensor<string, []>("linear_1_cast_fp16")];
|
19 |
tensor<int32, [1]> var_23_axes_0 = const()[name = tensor<string, []>("op_23_axes_0"), val = tensor<int32, [1]>([2])];
|
20 |
+
tensor<fp16, [1, 1, 1, 640]> var_23_cast_fp16 = expand_dims(axes = var_23_axes_0, x = linear_0_cast_fp16)[name = tensor<string, []>("op_23_cast_fp16")];
|
21 |
tensor<int32, [1]> var_24_axes_0 = const()[name = tensor<string, []>("op_24_axes_0"), val = tensor<int32, [1]>([1])];
|
22 |
tensor<fp16, [1, 1, 1, 640]> var_24_cast_fp16 = expand_dims(axes = var_24_axes_0, x = linear_1_cast_fp16)[name = tensor<string, []>("op_24_cast_fp16")];
|
23 |
+
tensor<fp16, [1, 1, 1, 640]> input_5_cast_fp16 = add(x = var_23_cast_fp16, y = var_24_cast_fp16)[name = tensor<string, []>("input_5_cast_fp16")];
|
24 |
+
tensor<fp16, [1, 1, 1, 640]> input_7_cast_fp16 = relu(x = input_5_cast_fp16)[name = tensor<string, []>("input_7_cast_fp16")];
|
25 |
tensor<fp16, [1030, 640]> joint_module_joint_net_2_weight_to_fp16 = const()[name = tensor<string, []>("joint_module_joint_net_2_weight_to_fp16"), val = tensor<fp16, [1030, 640]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(2132800)))];
|
26 |
tensor<fp16, [1030]> joint_module_joint_net_2_bias_to_fp16 = const()[name = tensor<string, []>("joint_module_joint_net_2_bias_to_fp16"), val = tensor<fp16, [1030]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(3451264)))];
|
27 |
+
tensor<fp16, [1, 1, 1, 1030]> linear_2_cast_fp16 = linear(bias = joint_module_joint_net_2_bias_to_fp16, weight = joint_module_joint_net_2_weight_to_fp16, x = input_7_cast_fp16)[name = tensor<string, []>("linear_2_cast_fp16")];
|
28 |
tensor<int32, [4]> token_logits_begin_0 = const()[name = tensor<string, []>("token_logits_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
|
29 |
+
tensor<int32, [4]> token_logits_end_0 = const()[name = tensor<string, []>("token_logits_end_0"), val = tensor<int32, [4]>([1, 1, 1, 1025])];
|
30 |
tensor<bool, [4]> token_logits_end_mask_0 = const()[name = tensor<string, []>("token_logits_end_mask_0"), val = tensor<bool, [4]>([true, true, true, false])];
|
31 |
+
tensor<fp16, [1, 1, 1, 1025]> token_logits_cast_fp16 = slice_by_index(begin = token_logits_begin_0, end = token_logits_end_0, end_mask = token_logits_end_mask_0, x = linear_2_cast_fp16)[name = tensor<string, []>("token_logits_cast_fp16")];
|
32 |
tensor<int32, [4]> duration_logits_begin_0 = const()[name = tensor<string, []>("duration_logits_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 1025])];
|
33 |
+
tensor<int32, [4]> duration_logits_end_0 = const()[name = tensor<string, []>("duration_logits_end_0"), val = tensor<int32, [4]>([1, 1, 1, 1030])];
|
34 |
tensor<bool, [4]> duration_logits_end_mask_0 = const()[name = tensor<string, []>("duration_logits_end_mask_0"), val = tensor<bool, [4]>([true, true, true, true])];
|
35 |
+
tensor<fp16, [1, 1, 1, 5]> duration_logits_cast_fp16 = slice_by_index(begin = duration_logits_begin_0, end = duration_logits_end_0, end_mask = duration_logits_end_mask_0, x = linear_2_cast_fp16)[name = tensor<string, []>("duration_logits_cast_fp16")];
|
36 |
tensor<int32, []> var_43_axis_0 = const()[name = tensor<string, []>("op_43_axis_0"), val = tensor<int32, []>(-1)];
|
37 |
tensor<bool, []> var_43_keep_dims_0 = const()[name = tensor<string, []>("op_43_keep_dims_0"), val = tensor<bool, []>(false)];
|
38 |
tensor<string, []> var_43_output_dtype_0 = const()[name = tensor<string, []>("op_43_output_dtype_0"), val = tensor<string, []>("int32")];
|
39 |
+
tensor<int32, [1, 1, 1]> token_id = reduce_argmax(axis = var_43_axis_0, keep_dims = var_43_keep_dims_0, output_dtype = var_43_output_dtype_0, x = token_logits_cast_fp16)[name = tensor<string, []>("op_43_cast_fp16")];
|
40 |
tensor<int32, []> var_49 = const()[name = tensor<string, []>("op_49"), val = tensor<int32, []>(-1)];
|
41 |
+
tensor<fp16, [1, 1, 1, 1025]> token_probs_all_cast_fp16 = softmax(axis = var_49, x = token_logits_cast_fp16)[name = tensor<string, []>("token_probs_all_cast_fp16")];
|
42 |
tensor<int32, [1]> var_58_axes_0 = const()[name = tensor<string, []>("op_58_axes_0"), val = tensor<int32, [1]>([-1])];
|
43 |
+
tensor<int32, [1, 1, 1, 1]> var_58 = expand_dims(axes = var_58_axes_0, x = token_id)[name = tensor<string, []>("op_58")];
|
44 |
tensor<int32, []> var_59 = const()[name = tensor<string, []>("op_59"), val = tensor<int32, []>(-1)];
|
45 |
tensor<bool, []> var_61_validate_indices_0 = const()[name = tensor<string, []>("op_61_validate_indices_0"), val = tensor<bool, []>(false)];
|
46 |
tensor<string, []> var_58_to_int16_dtype_0 = const()[name = tensor<string, []>("op_58_to_int16_dtype_0"), val = tensor<string, []>("int16")];
|
47 |
+
tensor<int16, [1, 1, 1, 1]> var_58_to_int16 = cast(dtype = var_58_to_int16_dtype_0, x = var_58)[name = tensor<string, []>("cast_1")];
|
48 |
+
tensor<fp16, [1, 1, 1, 1]> var_61_cast_fp16_cast_int16 = gather_along_axis(axis = var_59, indices = var_58_to_int16, validate_indices = var_61_validate_indices_0, x = token_probs_all_cast_fp16)[name = tensor<string, []>("op_61_cast_fp16_cast_int16")];
|
49 |
tensor<int32, [1]> var_63_axes_0 = const()[name = tensor<string, []>("op_63_axes_0"), val = tensor<int32, [1]>([-1])];
|
50 |
+
tensor<fp16, [1, 1, 1]> var_63_cast_fp16 = squeeze(axes = var_63_axes_0, x = var_61_cast_fp16_cast_int16)[name = tensor<string, []>("op_63_cast_fp16")];
|
51 |
tensor<string, []> var_63_cast_fp16_to_fp32_dtype_0 = const()[name = tensor<string, []>("op_63_cast_fp16_to_fp32_dtype_0"), val = tensor<string, []>("fp32")];
|
52 |
tensor<int32, []> var_66_axis_0 = const()[name = tensor<string, []>("op_66_axis_0"), val = tensor<int32, []>(-1)];
|
53 |
tensor<bool, []> var_66_keep_dims_0 = const()[name = tensor<string, []>("op_66_keep_dims_0"), val = tensor<bool, []>(false)];
|
54 |
tensor<string, []> var_66_output_dtype_0 = const()[name = tensor<string, []>("op_66_output_dtype_0"), val = tensor<string, []>("int32")];
|
55 |
+
tensor<int32, [1, 1, 1]> duration = reduce_argmax(axis = var_66_axis_0, keep_dims = var_66_keep_dims_0, output_dtype = var_66_output_dtype_0, x = duration_logits_cast_fp16)[name = tensor<string, []>("op_66_cast_fp16")];
|
56 |
+
tensor<fp32, [1, 1, 1]> token_prob = cast(dtype = var_63_cast_fp16_to_fp32_dtype_0, x = var_63_cast_fp16)[name = tensor<string, []>("cast_0")];
|
57 |
} -> (token_id, token_prob, duration);
|
58 |
}
|