Spaces:
Sleeping
Sleeping
ggml : support forward pass broadcasting in ggml_sub (ggml/914)
Browse files* ggml: support forward pass broadcasting in ggml_sub
Signed-off-by: Salvatore Mesoraca <[email protected]>
* Use assert instead of GGML_ASSERT in ggml_compute_forward_sub_f32
The check is already performed in ggml_sub_impl
Signed-off-by: Salvatore Mesoraca <[email protected]>
---------
Signed-off-by: Salvatore Mesoraca <[email protected]>
- ggml/src/ggml.c +46 -30
ggml/src/ggml.c
CHANGED
|
@@ -4661,11 +4661,13 @@ static struct ggml_tensor * ggml_sub_impl(
|
|
| 4661 |
struct ggml_tensor * a,
|
| 4662 |
struct ggml_tensor * b,
|
| 4663 |
bool inplace) {
|
| 4664 |
-
GGML_ASSERT(
|
| 4665 |
|
| 4666 |
bool is_node = false;
|
| 4667 |
|
| 4668 |
if (!inplace && (a->grad || b->grad)) {
|
|
|
|
|
|
|
| 4669 |
is_node = true;
|
| 4670 |
}
|
| 4671 |
|
|
@@ -10104,11 +10106,10 @@ static void ggml_compute_forward_sub_f32(
|
|
| 10104 |
const struct ggml_tensor * src0 = dst->src[0];
|
| 10105 |
const struct ggml_tensor * src1 = dst->src[1];
|
| 10106 |
|
| 10107 |
-
|
| 10108 |
-
return;
|
| 10109 |
-
}
|
| 10110 |
|
| 10111 |
-
|
|
|
|
| 10112 |
|
| 10113 |
const int nr = ggml_nrows(src0);
|
| 10114 |
|
|
@@ -10117,40 +10118,55 @@ static void ggml_compute_forward_sub_f32(
|
|
| 10117 |
GGML_ASSERT( nb0 == sizeof(float));
|
| 10118 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 10119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10120 |
if (nb10 == sizeof(float)) {
|
| 10121 |
-
for (int ir =
|
| 10122 |
-
// src0
|
| 10123 |
-
const
|
| 10124 |
-
const
|
| 10125 |
-
const
|
| 10126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10127 |
#ifdef GGML_USE_ACCELERATE
|
| 10128 |
-
|
| 10129 |
-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
|
| 10130 |
-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
|
| 10131 |
-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
|
| 10132 |
-
ne0);
|
| 10133 |
#else
|
| 10134 |
-
|
| 10135 |
-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
| 10136 |
-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
| 10137 |
-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
| 10138 |
#endif
|
| 10139 |
-
|
| 10140 |
-
// }
|
| 10141 |
}
|
| 10142 |
} else {
|
| 10143 |
// src1 is not contiguous
|
| 10144 |
-
for (int ir =
|
| 10145 |
-
// src0
|
| 10146 |
-
const
|
| 10147 |
-
const
|
| 10148 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10149 |
|
| 10150 |
-
|
| 10151 |
-
|
| 10152 |
-
|
| 10153 |
-
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
|
| 10154 |
|
| 10155 |
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
|
| 10156 |
}
|
|
|
|
| 4661 |
struct ggml_tensor * a,
|
| 4662 |
struct ggml_tensor * b,
|
| 4663 |
bool inplace) {
|
| 4664 |
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
| 4665 |
|
| 4666 |
bool is_node = false;
|
| 4667 |
|
| 4668 |
if (!inplace && (a->grad || b->grad)) {
|
| 4669 |
+
// TODO: support backward pass for broadcasting
|
| 4670 |
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
| 4671 |
is_node = true;
|
| 4672 |
}
|
| 4673 |
|
|
|
|
| 10106 |
const struct ggml_tensor * src0 = dst->src[0];
|
| 10107 |
const struct ggml_tensor * src1 = dst->src[1];
|
| 10108 |
|
| 10109 |
+
assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
|
|
|
|
|
|
| 10110 |
|
| 10111 |
+
const int ith = params->ith;
|
| 10112 |
+
const int nth = params->nth;
|
| 10113 |
|
| 10114 |
const int nr = ggml_nrows(src0);
|
| 10115 |
|
|
|
|
| 10118 |
GGML_ASSERT( nb0 == sizeof(float));
|
| 10119 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 10120 |
|
| 10121 |
+
// rows per thread
|
| 10122 |
+
const int dr = (nr + nth - 1)/nth;
|
| 10123 |
+
|
| 10124 |
+
// row range for this thread
|
| 10125 |
+
const int ir0 = dr*ith;
|
| 10126 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 10127 |
+
|
| 10128 |
if (nb10 == sizeof(float)) {
|
| 10129 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 10130 |
+
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
| 10131 |
+
const int64_t i03 = ir/(ne02*ne01);
|
| 10132 |
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
| 10133 |
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
| 10134 |
|
| 10135 |
+
const int64_t i13 = i03 % ne13;
|
| 10136 |
+
const int64_t i12 = i02 % ne12;
|
| 10137 |
+
const int64_t i11 = i01 % ne11;
|
| 10138 |
+
const int64_t nr0 = ne00 / ne10;
|
| 10139 |
+
|
| 10140 |
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
| 10141 |
+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
| 10142 |
+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
| 10143 |
+
|
| 10144 |
+
for (int64_t r = 0; r < nr0; ++r) {
|
| 10145 |
#ifdef GGML_USE_ACCELERATE
|
| 10146 |
+
vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10147 |
#else
|
| 10148 |
+
ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
|
|
|
|
|
|
|
|
|
| 10149 |
#endif
|
| 10150 |
+
}
|
|
|
|
| 10151 |
}
|
| 10152 |
} else {
|
| 10153 |
// src1 is not contiguous
|
| 10154 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 10155 |
+
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
| 10156 |
+
const int64_t i03 = ir/(ne02*ne01);
|
| 10157 |
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
| 10158 |
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
| 10159 |
+
|
| 10160 |
+
const int64_t i13 = i03 % ne13;
|
| 10161 |
+
const int64_t i12 = i02 % ne12;
|
| 10162 |
+
const int64_t i11 = i01 % ne11;
|
| 10163 |
+
|
| 10164 |
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
| 10165 |
+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
| 10166 |
|
| 10167 |
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
| 10168 |
+
const int64_t i10 = i0 % ne10;
|
| 10169 |
+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
|
|
|
|
| 10170 |
|
| 10171 |
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
|
| 10172 |
}
|