smeso commited on
Commit
0af2d37
·
unverified ·
1 Parent(s): f971b60

ggml : support forward pass broadcasting in ggml_sub (ggml/914)

Browse files

* ggml: support forward pass broadcasting in ggml_sub

Signed-off-by: Salvatore Mesoraca <[email protected]>

* Use assert instead of GGML_ASSERT in ggml_compute_forward_sub_f32

The check is already performed in ggml_sub_impl

Signed-off-by: Salvatore Mesoraca <[email protected]>

---------

Signed-off-by: Salvatore Mesoraca <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml.c +46 -30
ggml/src/ggml.c CHANGED
@@ -4661,11 +4661,13 @@ static struct ggml_tensor * ggml_sub_impl(
4661
  struct ggml_tensor * a,
4662
  struct ggml_tensor * b,
4663
  bool inplace) {
4664
- GGML_ASSERT(ggml_are_same_shape(a, b));
4665
 
4666
  bool is_node = false;
4667
 
4668
  if (!inplace && (a->grad || b->grad)) {
 
 
4669
  is_node = true;
4670
  }
4671
 
@@ -10104,11 +10106,10 @@ static void ggml_compute_forward_sub_f32(
10104
  const struct ggml_tensor * src0 = dst->src[0];
10105
  const struct ggml_tensor * src1 = dst->src[1];
10106
 
10107
- if (params->ith != 0) {
10108
- return;
10109
- }
10110
 
10111
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
10112
 
10113
  const int nr = ggml_nrows(src0);
10114
 
@@ -10117,40 +10118,55 @@ static void ggml_compute_forward_sub_f32(
10117
  GGML_ASSERT( nb0 == sizeof(float));
10118
  GGML_ASSERT(nb00 == sizeof(float));
10119
 
 
 
 
 
 
 
 
10120
  if (nb10 == sizeof(float)) {
10121
- for (int ir = 0; ir < nr; ++ir) {
10122
- // src0, src1 and dst are same shape => same indices
10123
- const int i3 = ir/(ne2*ne1);
10124
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10125
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10126
 
 
 
 
 
 
 
 
 
 
 
10127
  #ifdef GGML_USE_ACCELERATE
10128
- vDSP_vsub(
10129
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
10130
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
10131
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
10132
- ne0);
10133
  #else
10134
- ggml_vec_sub_f32(ne0,
10135
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
10136
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
10137
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
10138
  #endif
10139
- // }
10140
- // }
10141
  }
10142
  } else {
10143
  // src1 is not contiguous
10144
- for (int ir = 0; ir < nr; ++ir) {
10145
- // src0, src1 and dst are same shape => same indices
10146
- const int i3 = ir/(ne2*ne1);
10147
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10148
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
 
 
 
 
 
 
10149
 
10150
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
10151
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
10152
- for (int i0 = 0; i0 < ne0; i0++) {
10153
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
10154
 
10155
  dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
10156
  }
 
4661
  struct ggml_tensor * a,
4662
  struct ggml_tensor * b,
4663
  bool inplace) {
4664
+ GGML_ASSERT(ggml_can_repeat(b, a));
4665
 
4666
  bool is_node = false;
4667
 
4668
  if (!inplace && (a->grad || b->grad)) {
4669
+ // TODO: support backward pass for broadcasting
4670
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4671
  is_node = true;
4672
  }
4673
 
 
10106
  const struct ggml_tensor * src0 = dst->src[0];
10107
  const struct ggml_tensor * src1 = dst->src[1];
10108
 
10109
+ assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
 
10110
 
10111
+ const int ith = params->ith;
10112
+ const int nth = params->nth;
10113
 
10114
  const int nr = ggml_nrows(src0);
10115
 
 
10118
  GGML_ASSERT( nb0 == sizeof(float));
10119
  GGML_ASSERT(nb00 == sizeof(float));
10120
 
10121
+ // rows per thread
10122
+ const int dr = (nr + nth - 1)/nth;
10123
+
10124
+ // row range for this thread
10125
+ const int ir0 = dr*ith;
10126
+ const int ir1 = MIN(ir0 + dr, nr);
10127
+
10128
  if (nb10 == sizeof(float)) {
10129
+ for (int ir = ir0; ir < ir1; ++ir) {
10130
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10131
+ const int64_t i03 = ir/(ne02*ne01);
10132
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10133
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10134
 
10135
+ const int64_t i13 = i03 % ne13;
10136
+ const int64_t i12 = i02 % ne12;
10137
+ const int64_t i11 = i01 % ne11;
10138
+ const int64_t nr0 = ne00 / ne10;
10139
+
10140
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10141
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10142
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10143
+
10144
+ for (int64_t r = 0; r < nr0; ++r) {
10145
  #ifdef GGML_USE_ACCELERATE
10146
+ vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
 
 
 
 
10147
  #else
10148
+ ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 
 
 
10149
  #endif
10150
+ }
 
10151
  }
10152
  } else {
10153
  // src1 is not contiguous
10154
+ for (int ir = ir0; ir < ir1; ++ir) {
10155
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10156
+ const int64_t i03 = ir/(ne02*ne01);
10157
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10158
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10159
+
10160
+ const int64_t i13 = i03 % ne13;
10161
+ const int64_t i12 = i02 % ne12;
10162
+ const int64_t i11 = i01 % ne11;
10163
+
10164
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10165
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10166
 
10167
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
10168
+ const int64_t i10 = i0 % ne10;
10169
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
10170
 
10171
  dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
10172
  }