rmatif commited on
Commit
5ff8785
·
1 Parent(s): b86860f

OpenCL: Add concat, tsembd, upscale, tanh, pad and repeat (llama/13840)

Browse files
ggml/src/ggml-opencl/CMakeLists.txt CHANGED
@@ -95,6 +95,12 @@ set(GGML_OPENCL_KERNELS
95
  sub
96
  sum_rows
97
  transpose
 
 
 
 
 
 
98
  )
99
 
100
  foreach (K ${GGML_OPENCL_KERNELS})
 
95
  sub
96
  sum_rows
97
  transpose
98
+ concat
99
+ tsembd
100
+ upscale
101
+ tanh
102
+ pad
103
+ repeat
104
  )
105
 
106
  foreach (K ${GGML_OPENCL_KERNELS})
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -315,6 +315,12 @@ struct ggml_backend_opencl_context {
315
  cl_program program_softmax_4_f16;
316
  cl_program program_argsort_f32_i32;
317
  cl_program program_sum_rows_f32;
 
 
 
 
 
 
318
 
319
  cl_kernel kernel_add, kernel_add_row;
320
  cl_kernel kernel_mul, kernel_mul_row;
@@ -351,6 +357,15 @@ struct ggml_backend_opencl_context {
351
  cl_kernel kernel_im2col_f32, kernel_im2col_f16;
352
  cl_kernel kernel_argsort_f32_i32;
353
  cl_kernel kernel_sum_rows_f32;
 
 
 
 
 
 
 
 
 
354
 
355
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
356
  // Transpose kernels
@@ -1097,6 +1112,150 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1097
  GGML_LOG_CONT(".");
1098
  }
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  // Adreno kernels
1101
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
1102
  // transpose
@@ -1976,9 +2135,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
1976
  case GGML_UNARY_OP_SILU:
1977
  case GGML_UNARY_OP_RELU:
1978
  case GGML_UNARY_OP_GELU_QUICK:
1979
- return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1980
  case GGML_UNARY_OP_SIGMOID:
1981
  return ggml_is_contiguous(op->src[0]);
 
 
 
1982
  default:
1983
  return false;
1984
  }
@@ -1988,6 +2150,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
1988
  case GGML_OP_NORM:
1989
  case GGML_OP_RMS_NORM:
1990
  return true;
 
 
 
 
 
 
 
 
 
 
 
1991
  case GGML_OP_GROUP_NORM:
1992
  return ggml_is_contiguous(op->src[0]);
1993
  case GGML_OP_MUL_MAT:
@@ -4108,6 +4281,536 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0,
4108
  #endif
4109
  }
4110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4111
  static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4112
  GGML_ASSERT(src0);
4113
  GGML_ASSERT(src0->extra);
@@ -5667,6 +6370,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
5667
  }
5668
  func = ggml_cl_sigmoid;
5669
  break;
 
 
 
 
 
 
5670
  default:
5671
  return false;
5672
  } break;
@@ -5694,6 +6403,36 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
5694
  }
5695
  func = ggml_cl_group_norm;
5696
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5697
  case GGML_OP_MUL_MAT:
5698
  if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
5699
  return false;
 
315
  cl_program program_softmax_4_f16;
316
  cl_program program_argsort_f32_i32;
317
  cl_program program_sum_rows_f32;
318
+ cl_program program_repeat;
319
+ cl_program program_pad;
320
+ cl_program program_tanh;
321
+ cl_program program_upscale;
322
+ cl_program program_concat;
323
+ cl_program program_tsembd;
324
 
325
  cl_kernel kernel_add, kernel_add_row;
326
  cl_kernel kernel_mul, kernel_mul_row;
 
357
  cl_kernel kernel_im2col_f32, kernel_im2col_f16;
358
  cl_kernel kernel_argsort_f32_i32;
359
  cl_kernel kernel_sum_rows_f32;
360
+ cl_kernel kernel_repeat;
361
+ cl_kernel kernel_pad;
362
+ cl_kernel kernel_tanh_f32_nd;
363
+ cl_kernel kernel_tanh_f16_nd;
364
+ cl_kernel kernel_upscale;
365
+ cl_kernel kernel_upscale_bilinear;
366
+ cl_kernel kernel_concat_f32_contiguous;
367
+ cl_kernel kernel_concat_f32_non_contiguous;
368
+ cl_kernel kernel_timestep_embedding;
369
 
370
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
371
  // Transpose kernels
 
1112
  GGML_LOG_CONT(".");
1113
  }
1114
 
1115
+ // repeat
1116
+ {
1117
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1118
+ const std::string kernel_src {
1119
+ #include "repeat.cl.h"
1120
+ };
1121
+ #else
1122
+ const std::string kernel_src = read_file("repeat.cl");
1123
+ #endif
1124
+ if (!kernel_src.empty()) {
1125
+ backend_ctx->program_repeat =
1126
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1127
+ CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err));
1128
+ GGML_LOG_CONT(".");
1129
+ } else {
1130
+ GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n");
1131
+ backend_ctx->program_repeat = nullptr;
1132
+ backend_ctx->kernel_repeat = nullptr;
1133
+ }
1134
+ }
1135
+
1136
+ // pad
1137
+ {
1138
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1139
+ const std::string kernel_src {
1140
+ #include "pad.cl.h"
1141
+ };
1142
+ #else
1143
+ const std::string kernel_src = read_file("pad.cl");
1144
+ #endif
1145
+ if (!kernel_src.empty()) {
1146
+ backend_ctx->program_pad =
1147
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1148
+ CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, "kernel_pad", &err), err));
1149
+ GGML_LOG_CONT(".");
1150
+ } else {
1151
+ GGML_LOG_WARN("ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\n");
1152
+ backend_ctx->program_pad = nullptr;
1153
+ backend_ctx->kernel_pad = nullptr;
1154
+ }
1155
+ }
1156
+
1157
+ // tanh
1158
+ {
1159
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1160
+ const std::string kernel_src {
1161
+ #include "tanh.cl.h"
1162
+ };
1163
+ #else
1164
+ const std::string kernel_src = read_file("tanh.cl");
1165
+ #endif
1166
+ if (!kernel_src.empty()) {
1167
+ backend_ctx->program_tanh =
1168
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1169
+ CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err));
1170
+ CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err));
1171
+ GGML_LOG_CONT(".");
1172
+ } else {
1173
+ GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n");
1174
+ backend_ctx->program_tanh = nullptr;
1175
+ backend_ctx->kernel_tanh_f32_nd = nullptr;
1176
+ backend_ctx->kernel_tanh_f16_nd = nullptr;
1177
+ }
1178
+ }
1179
+
1180
+ // upscale
1181
+ {
1182
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1183
+ const std::string kernel_src {
1184
+ #include "upscale.cl.h"
1185
+ };
1186
+ #else
1187
+ const std::string kernel_src = read_file("upscale.cl");
1188
+ #endif
1189
+ if (!kernel_src.empty()) {
1190
+ backend_ctx->program_upscale =
1191
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1192
+ CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err));
1193
+ if (backend_ctx->program_upscale) {
1194
+ cl_int err_bilinear;
1195
+ backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear);
1196
+ if (err_bilinear != CL_SUCCESS) {
1197
+ GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear);
1198
+ backend_ctx->kernel_upscale_bilinear = nullptr;
1199
+ }
1200
+ } else {
1201
+ backend_ctx->kernel_upscale_bilinear = nullptr;
1202
+ }
1203
+ GGML_LOG_CONT(".");
1204
+ } else {
1205
+ GGML_LOG_WARN("ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\n");
1206
+ backend_ctx->program_upscale = nullptr;
1207
+ backend_ctx->kernel_upscale = nullptr;
1208
+ backend_ctx->kernel_upscale_bilinear = nullptr;
1209
+ }
1210
+ }
1211
+
1212
+ // concat
1213
+ {
1214
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1215
+ const std::string kernel_src {
1216
+ #include "concat.cl.h"
1217
+ };
1218
+ #else
1219
+
1220
+ const std::string kernel_src = read_file("concat.cl");
1221
+ #endif
1222
+ if (!kernel_src.empty()) {
1223
+ backend_ctx->program_concat =
1224
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1225
+
1226
+ CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err));
1227
+ CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err));
1228
+ GGML_LOG_CONT(".");
1229
+ } else {
1230
+ GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n");
1231
+ backend_ctx->program_concat = nullptr;
1232
+ backend_ctx->kernel_concat_f32_contiguous = nullptr;
1233
+ backend_ctx->kernel_concat_f32_non_contiguous = nullptr;
1234
+ }
1235
+ }
1236
+
1237
+ // timestep_embedding
1238
+ {
1239
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1240
+ const std::string kernel_src {
1241
+ #include "tsembd.cl.h"
1242
+ };
1243
+ #else
1244
+
1245
+ const std::string kernel_src = read_file("tsembd.cl");
1246
+ #endif
1247
+ if (!kernel_src.empty()) {
1248
+ backend_ctx->program_tsembd =
1249
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1250
+ CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, "kernel_timestep_embedding", &err), err));
1251
+ GGML_LOG_CONT(".");
1252
+ } else {
1253
+ GGML_LOG_WARN("ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\n");
1254
+ backend_ctx->program_tsembd = nullptr;
1255
+ backend_ctx->kernel_timestep_embedding = nullptr;
1256
+ }
1257
+ }
1258
+
1259
  // Adreno kernels
1260
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
1261
  // transpose
 
2135
  case GGML_UNARY_OP_SILU:
2136
  case GGML_UNARY_OP_RELU:
2137
  case GGML_UNARY_OP_GELU_QUICK:
2138
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
2139
  case GGML_UNARY_OP_SIGMOID:
2140
  return ggml_is_contiguous(op->src[0]);
2141
+ case GGML_UNARY_OP_TANH:
2142
+ return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2143
+ (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
2144
  default:
2145
  return false;
2146
  }
 
2150
  case GGML_OP_NORM:
2151
  case GGML_OP_RMS_NORM:
2152
  return true;
2153
+ case GGML_OP_REPEAT:
2154
+ return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
2155
+ case GGML_OP_PAD:
2156
+ return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
2157
+ op->src[0]->ne[3] == 1 && op->ne[3] == 1;
2158
+ case GGML_OP_UPSCALE:
2159
+ return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2160
+ case GGML_OP_CONCAT:
2161
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2162
+ case GGML_OP_TIMESTEP_EMBEDDING:
2163
+ return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2164
  case GGML_OP_GROUP_NORM:
2165
  return ggml_is_contiguous(op->src[0]);
2166
  case GGML_OP_MUL_MAT:
 
4281
  #endif
4282
  }
4283
 
4284
+ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4285
+ GGML_ASSERT(src0);
4286
+ GGML_ASSERT(src0->extra);
4287
+ GGML_ASSERT(dst);
4288
+ GGML_ASSERT(dst->extra);
4289
+
4290
+ UNUSED(src1);
4291
+
4292
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4293
+ cl_command_queue queue = backend_ctx->queue;
4294
+
4295
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4296
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4297
+
4298
+ cl_ulong offset0_abs = extra0->offset + src0->view_offs;
4299
+ cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
4300
+
4301
+ cl_kernel kernel;
4302
+ if (dst->type == GGML_TYPE_F32) {
4303
+ kernel = backend_ctx->kernel_tanh_f32_nd;
4304
+ } else if (dst->type == GGML_TYPE_F16) {
4305
+ kernel = backend_ctx->kernel_tanh_f16_nd;
4306
+ } else {
4307
+ GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh");
4308
+ }
4309
+ GGML_ASSERT(kernel != nullptr);
4310
+
4311
+ const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3];
4312
+ const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3];
4313
+
4314
+ const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3];
4315
+ const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3];
4316
+
4317
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4318
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
4319
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
4320
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
4321
+
4322
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
4323
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
4324
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
4325
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
4326
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
4327
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
4328
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
4329
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
4330
+
4331
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
4332
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
4333
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
4334
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
4335
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
4336
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
4337
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
4338
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
4339
+
4340
+ size_t global_work_size[3];
4341
+ if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
4342
+ return;
4343
+ }
4344
+ global_work_size[0] = (size_t)ne10;
4345
+ global_work_size[1] = (size_t)ne11;
4346
+ global_work_size[2] = (size_t)ne12;
4347
+
4348
+ size_t lws0 = 16, lws1 = 4, lws2 = 1;
4349
+ if (ne10 < 16) lws0 = ne10;
4350
+ if (ne11 < 4) lws1 = ne11;
4351
+ if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
4352
+
4353
+ while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
4354
+ while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
4355
+ while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
4356
+
4357
+
4358
+ size_t local_work_size[] = {lws0, lws1, lws2};
4359
+
4360
+ size_t* local_work_size_ptr = local_work_size;
4361
+ if (!backend_ctx->non_uniform_workgroups) {
4362
+ if (global_work_size[0] % local_work_size[0] != 0 ||
4363
+ global_work_size[1] % local_work_size[1] != 0 ||
4364
+ global_work_size[2] % local_work_size[2] != 0) {
4365
+ local_work_size_ptr = NULL;
4366
+ }
4367
+ }
4368
+ if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
4369
+
4370
+
4371
+ #ifdef GGML_OPENCL_PROFILING
4372
+ cl_event evt;
4373
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
4374
+
4375
+ g_profiling_info.emplace_back();
4376
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst);
4377
+ #else
4378
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
4379
+ #endif
4380
+ }
4381
+
4382
+ static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
4383
+ GGML_ASSERT(src0);
4384
+ GGML_ASSERT(src0->extra);
4385
+ GGML_ASSERT(dst);
4386
+ GGML_ASSERT(dst->extra);
4387
+ GGML_ASSERT(dst->type == src0->type);
4388
+
4389
+ UNUSED(src1_shape_def);
4390
+
4391
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4392
+ cl_command_queue queue = backend_ctx->queue;
4393
+
4394
+ if (backend_ctx->kernel_repeat == nullptr) {
4395
+ GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__);
4396
+ return;
4397
+ }
4398
+
4399
+ ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
4400
+ ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra;
4401
+
4402
+ cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
4403
+ cl_ulong off_dst = extra_dst->offset + dst->view_offs;
4404
+
4405
+ const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3];
4406
+ const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3];
4407
+
4408
+ const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3];
4409
+ const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3];
4410
+
4411
+ cl_kernel kernel = backend_ctx->kernel_repeat;
4412
+
4413
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4414
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device));
4415
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0));
4416
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
4417
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0));
4418
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1));
4419
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2));
4420
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3));
4421
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0));
4422
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1));
4423
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2));
4424
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3));
4425
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0));
4426
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1));
4427
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2));
4428
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3));
4429
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0));
4430
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1));
4431
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2));
4432
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3));
4433
+
4434
+ size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1;
4435
+ size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1;
4436
+ size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1;
4437
+
4438
+ size_t global_work_size[] = { gws0, gws1, gws2 };
4439
+
4440
+ #ifdef GGML_OPENCL_PROFILING
4441
+ cl_event evt;
4442
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, &evt));
4443
+
4444
+ g_profiling_info.emplace_back();
4445
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, (size_t[3]){0,0,0}, dst);
4446
+ #else
4447
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL));
4448
+ #endif
4449
+ }
4450
+
4451
+ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
4452
+ GGML_ASSERT(src0);
4453
+ GGML_ASSERT(src0->extra);
4454
+ GGML_ASSERT(dst);
4455
+ GGML_ASSERT(dst->extra);
4456
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
4457
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
4458
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
4459
+
4460
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4461
+ cl_command_queue queue = backend_ctx->queue;
4462
+
4463
+ if (backend_ctx->kernel_pad == nullptr) {
4464
+ GGML_LOG_WARN("%s: pad kernel not available, skipping OpenCL execution.\n", __func__);
4465
+ return;
4466
+ }
4467
+
4468
+ ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
4469
+ ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra;
4470
+
4471
+ cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
4472
+ cl_ulong off_dst = extra_dst->offset + dst->view_offs;
4473
+
4474
+ const int s_ne0 = src0->ne[0];
4475
+ const int s_ne1 = src0->ne[1];
4476
+ const int s_ne2 = src0->ne[2];
4477
+
4478
+ const int d_ne0 = dst->ne[0];
4479
+ const int d_ne1 = dst->ne[1];
4480
+ const int d_ne2 = dst->ne[2];
4481
+
4482
+ cl_kernel kernel = backend_ctx->kernel_pad;
4483
+
4484
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4485
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
4486
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
4487
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
4488
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
4489
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
4490
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
4491
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0));
4492
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1));
4493
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2));
4494
+
4495
+ size_t lws0 = 64;
4496
+ size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
4497
+
4498
+ size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
4499
+ size_t local_work_size[] = { lws0, 1, 1 };
4500
+
4501
+ size_t * local_work_size_ptr = local_work_size;
4502
+ if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) {
4503
+ local_work_size_ptr = nullptr;
4504
+ }
4505
+
4506
+ #ifdef GGML_OPENCL_PROFILING
4507
+ cl_event evt;
4508
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
4509
+
4510
+ g_profiling_info.emplace_back();
4511
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst);
4512
+ #else
4513
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
4514
+ #endif
4515
+ }
4516
+
4517
+ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
4518
+ GGML_ASSERT(src0);
4519
+ GGML_ASSERT(src0->extra);
4520
+ GGML_ASSERT(dst);
4521
+ GGML_ASSERT(dst->extra);
4522
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
4523
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
4524
+
4525
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4526
+ cl_command_queue queue = backend_ctx->queue;
4527
+
4528
+ const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
4529
+ cl_kernel kernel = nullptr;
4530
+
4531
+ if (mode == GGML_SCALE_MODE_NEAREST) {
4532
+ kernel = backend_ctx->kernel_upscale;
4533
+ if (kernel == nullptr) {
4534
+ GGML_LOG_WARN("%s: nearest upscale kernel not available, skipping OpenCL execution.\n", __func__);
4535
+ return;
4536
+ }
4537
+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
4538
+ kernel = backend_ctx->kernel_upscale_bilinear;
4539
+ if (kernel == nullptr) {
4540
+ GGML_LOG_WARN("%s: bilinear upscale kernel not available, skipping OpenCL execution.\n", __func__);
4541
+ return;
4542
+ }
4543
+ } else {
4544
+ GGML_LOG_WARN("%s: unsupported upscale mode %d, skipping OpenCL execution.\n", __func__, mode);
4545
+ return;
4546
+ }
4547
+
4548
+ ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
4549
+ ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra;
4550
+
4551
+ cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
4552
+ cl_ulong off_dst = extra_dst->offset + dst->view_offs;
4553
+
4554
+ const cl_ulong nb00 = src0->nb[0];
4555
+ const cl_ulong nb01 = src0->nb[1];
4556
+ const cl_ulong nb02 = src0->nb[2];
4557
+ const cl_ulong nb03 = src0->nb[3];
4558
+
4559
+ const int ne00_src = src0->ne[0];
4560
+ const int ne01_src = src0->ne[1];
4561
+
4562
+ const int ne10_dst = dst->ne[0];
4563
+ const int ne11_dst = dst->ne[1];
4564
+ const int ne12_dst = dst->ne[2];
4565
+ const int ne13_dst = dst->ne[3];
4566
+
4567
+ const float sf0 = (float)dst->ne[0] / src0->ne[0];
4568
+ const float sf1 = (float)dst->ne[1] / src0->ne[1];
4569
+ const float sf2 = (float)dst->ne[2] / src0->ne[2];
4570
+ const float sf3 = (float)dst->ne[3] / src0->ne[3];
4571
+
4572
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4573
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
4574
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
4575
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
4576
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb00));
4577
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
4578
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb02));
4579
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03));
4580
+
4581
+ if (mode == GGML_SCALE_MODE_NEAREST) {
4582
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst));
4583
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst));
4584
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst));
4585
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst));
4586
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0));
4587
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1));
4588
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2));
4589
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
4590
+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
4591
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src));
4592
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src));
4593
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst));
4594
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst));
4595
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst));
4596
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst));
4597
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0));
4598
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1));
4599
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2));
4600
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3));
4601
+ }
4602
+
4603
+
4604
+ size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst;
4605
+ if (dst_total_elements == 0) {
4606
+ return;
4607
+ }
4608
+ size_t global_work_size[] = { dst_total_elements, 1, 1 };
4609
+ size_t local_work_size_pref = 256;
4610
+ size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1};
4611
+
4612
+ size_t * local_work_size_ptr = local_work_size;
4613
+ if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) {
4614
+ local_work_size_ptr = nullptr;
4615
+ }
4616
+
4617
+ #ifdef GGML_OPENCL_PROFILING
4618
+ cl_event evt;
4619
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
4620
+
4621
+ g_profiling_info.emplace_back();
4622
+ size_t profiling_gws[3] = {global_work_size[0], 1, 1};
4623
+ size_t profiling_lws[3] = {local_work_size_ptr ? local_work_size[0] : 0, 1, 1};
4624
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst);
4625
+ #else
4626
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
4627
+ #endif
4628
+ }
4629
+
4630
+ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4631
+ GGML_ASSERT(src0);
4632
+ GGML_ASSERT(src0->extra);
4633
+ GGML_ASSERT(src1);
4634
+ GGML_ASSERT(src1->extra);
4635
+ GGML_ASSERT(dst);
4636
+ GGML_ASSERT(dst->extra);
4637
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
4638
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
4639
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
4640
+
4641
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4642
+ cl_command_queue queue = backend_ctx->queue;
4643
+
4644
+ if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) {
4645
+ GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__);
4646
+ return;
4647
+ }
4648
+
4649
+ ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra;
4650
+ ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra;
4651
+ ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra;
4652
+
4653
+ cl_ulong off_src0 = extra0_cl->offset + src0->view_offs;
4654
+ cl_ulong off_src1 = extra1_cl->offset + src1->view_offs;
4655
+ cl_ulong off_dst = extrad_cl->offset + dst->view_offs;
4656
+
4657
+ const int32_t dim = ((const int32_t *) dst->op_params)[0];
4658
+ GGML_ASSERT(dim >= 0 && dim <= 3);
4659
+
4660
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
4661
+ if (dim == 3) {
4662
+
4663
+ size_t nbytes_src0 = ggml_nbytes(src0);
4664
+ size_t nbytes_src1 = ggml_nbytes(src1);
4665
+
4666
+ CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device,
4667
+ off_src0, off_dst, nbytes_src0, 0, NULL, NULL));
4668
+ CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device,
4669
+ off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL));
4670
+ } else {
4671
+
4672
+ cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous;
4673
+ size_t global_work_size[3];
4674
+
4675
+ for (int i3 = 0; i3 < dst->ne[3]; ++i3) {
4676
+ cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]);
4677
+ cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]);
4678
+ cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]);
4679
+
4680
+ int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2];
4681
+ int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2];
4682
+ int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2];
4683
+
4684
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device));
4685
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &current_off_src0));
4686
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device));
4687
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &current_off_src1));
4688
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
4689
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &current_off_dst));
4690
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00));
4691
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01));
4692
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02));
4693
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10));
4694
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11));
4695
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12));
4696
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
4697
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
4698
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
4699
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim));
4700
+
4701
+ global_work_size[0] = d_ne0;
4702
+ global_work_size[1] = d_ne1;
4703
+ global_work_size[2] = d_ne2;
4704
+
4705
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL));
4706
+ }
4707
+ }
4708
+ } else {
4709
+ cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous;
4710
+
4711
+ long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
4712
+ cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
4713
+
4714
+ cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
4715
+
4716
+ long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3];
4717
+ cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3];
4718
+
4719
+
4720
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device));
4721
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
4722
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device));
4723
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1));
4724
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
4725
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst));
4726
+
4727
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(long), &ne00));
4728
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(long), &ne01));
4729
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(long), &ne02));
4730
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(long), &ne03));
4731
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
4732
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
4733
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
4734
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
4735
+
4736
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));
4737
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
4738
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
4739
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
4740
+
4741
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(long), &d_ne0));
4742
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(long), &d_ne1));
4743
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(long), &d_ne2));
4744
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(long), &d_ne3));
4745
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0));
4746
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1));
4747
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2));
4748
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3));
4749
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim));
4750
+
4751
+ size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1,
4752
+ d_ne2 > 0 ? (size_t)d_ne2 : 1,
4753
+ d_ne3 > 0 ? (size_t)d_ne3 : 1 };
4754
+
4755
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size_nc, NULL, 0, NULL, NULL));
4756
+ }
4757
+ }
4758
+
4759
+ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
4760
+ GGML_ASSERT(src0);
4761
+ GGML_ASSERT(src0->extra);
4762
+ GGML_ASSERT(dst);
4763
+ GGML_ASSERT(dst->extra);
4764
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
4765
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
4766
+
4767
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4768
+ cl_command_queue queue = backend_ctx->queue;
4769
+
4770
+ if (backend_ctx->kernel_timestep_embedding == nullptr) {
4771
+ GGML_LOG_WARN("%s: timestep_embedding kernel not available, skipping OpenCL execution.\n", __func__);
4772
+ return;
4773
+ }
4774
+
4775
+ ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
4776
+ ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra;
4777
+
4778
+ cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
4779
+ cl_ulong off_dst = extra_dst->offset + dst->view_offs;
4780
+
4781
+ const int logical_dim = dst->op_params[0];
4782
+ const int max_period = dst->op_params[1];
4783
+ const int dst_nb1_bytes = dst->nb[1];
4784
+
4785
+ cl_kernel kernel = backend_ctx->kernel_timestep_embedding;
4786
+
4787
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4788
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
4789
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
4790
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
4791
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes));
4792
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim));
4793
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period));
4794
+
4795
+ size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1);
4796
+
4797
+ size_t gws1 = (size_t)src0->ne[0];
4798
+
4799
+ size_t global_work_size[] = {gws0, gws1, 1};
4800
+
4801
+ #ifdef GGML_OPENCL_PROFILING
4802
+ cl_event evt;
4803
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, &evt)); // Pass 2 for 2D problem
4804
+
4805
+ g_profiling_info.emplace_back();
4806
+ size_t profiling_gws[3] = {global_work_size[0], global_work_size[1], 1};
4807
+ size_t profiling_lws[3] = {0,0,0}; // Reflects NULL LWS
4808
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst);
4809
+ #else
4810
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, NULL)); // Pass 2 for 2D problem
4811
+ #endif
4812
+ }
4813
+
4814
  static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4815
  GGML_ASSERT(src0);
4816
  GGML_ASSERT(src0->extra);
 
6370
  }
6371
  func = ggml_cl_sigmoid;
6372
  break;
6373
+ case GGML_UNARY_OP_TANH:
6374
+ if (!any_on_device) {
6375
+ return false;
6376
+ }
6377
+ func = ggml_cl_tanh;
6378
+ break;
6379
  default:
6380
  return false;
6381
  } break;
 
6403
  }
6404
  func = ggml_cl_group_norm;
6405
  break;
6406
+ case GGML_OP_REPEAT:
6407
+ if (!any_on_device) {
6408
+ return false;
6409
+ }
6410
+ func = ggml_cl_repeat;
6411
+ break;
6412
+ case GGML_OP_PAD:
6413
+ if (!any_on_device) {
6414
+ return false;
6415
+ }
6416
+ ggml_cl_pad(backend, tensor->src[0], tensor);
6417
+ return true;
6418
+ case GGML_OP_UPSCALE:
6419
+ if (!any_on_device) {
6420
+ return false;
6421
+ }
6422
+ ggml_cl_upscale(backend, tensor->src[0], tensor);
6423
+ return true;
6424
+ case GGML_OP_CONCAT:
6425
+ if (!any_on_device) {
6426
+ return false;
6427
+ }
6428
+ func = ggml_cl_concat;
6429
+ break;
6430
+ case GGML_OP_TIMESTEP_EMBEDDING:
6431
+ if (!any_on_device) {
6432
+ return false;
6433
+ }
6434
+ ggml_cl_timestep_embedding(backend, tensor->src[0], tensor);
6435
+ return true;
6436
  case GGML_OP_MUL_MAT:
6437
  if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
6438
  return false;
ggml/src/ggml-opencl/kernels/concat.cl ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kernel void kernel_concat_f32_contiguous(
2
+ global const char * p_src0, ulong off_src0,
3
+ global const char * p_src1, ulong off_src1,
4
+ global char * p_dst, ulong off_dst,
5
+ int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice
6
+ int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
7
+ int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice
8
+ int dim
9
+ ) {
10
+ global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);
11
+ global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);
12
+ global float * dst = (global float*)((global char*)p_dst + off_dst);
13
+
14
+ int i0 = get_global_id(0); // Index along dst's 0th dimension
15
+ int i1 = get_global_id(1); // Index along dst's 1st dimension
16
+ int i2 = get_global_id(2); // Index along dst's 2nd dimension
17
+
18
+ if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {
19
+ return;
20
+ }
21
+
22
+ ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;
23
+ ulong src_idx;
24
+
25
+ if (dim == 0) {
26
+ if (i0 < d_ne00) { // Data from src0
27
+ src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
28
+ dst[dst_idx] = src0[src_idx];
29
+ } else { // Data from src1
30
+ src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
31
+ dst[dst_idx] = src1[src_idx];
32
+ }
33
+ } else if (dim == 1) {
34
+ if (i1 < d_ne01) { // Data from src0
35
+ src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
36
+ dst[dst_idx] = src0[src_idx];
37
+ } else { // Data from src1
38
+ src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
39
+ dst[dst_idx] = src1[src_idx];
40
+ }
41
+ } else if (dim == 2) {
42
+ if (i2 < d_ne02) { // Data from src0
43
+ src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
44
+ dst[dst_idx] = src0[src_idx];
45
+ } else { // Data from src1
46
+
47
+ src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
48
+ dst[dst_idx] = src1[src_idx];
49
+ }
50
+ }
51
+ }
52
+
53
+ kernel void kernel_concat_f32_non_contiguous(
54
+ global const char * p_src0, ulong off_src0,
55
+ global const char * p_src1, ulong off_src1,
56
+ global char * p_dst, ulong off_dst,
57
+
58
+ long ne00, long ne01, long ne02, long ne03,
59
+ ulong nb00, ulong nb01, ulong nb02, ulong nb03,
60
+
61
+ ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
62
+
63
+ long d_ne0, long d_ne1, long d_ne2, long d_ne3,
64
+ ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
65
+ int dim
66
+ ) {
67
+ global const char * src0_base = p_src0 + off_src0;
68
+ global const char * src1_base = p_src1 + off_src1;
69
+ global char * dst_base = p_dst + off_dst;
70
+
71
+ long current_i1 = get_global_id(0); // Index for dst_dim_1
72
+ long current_i2 = get_global_id(1); // Index for dst_dim_2
73
+ long current_i3 = get_global_id(2); // Index for dst_dim_3
74
+
75
+ if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
76
+ return;
77
+ }
78
+
79
+ global const float * x_val_ptr;
80
+ global float * y_val_ptr;
81
+
82
+ for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
83
+ bool use_src0;
84
+ long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
85
+
86
+ if (dim == 0) {
87
+ use_src0 = (current_i0 < ne00);
88
+ if (!use_src0) { s_i0 = current_i0 - ne00; }
89
+ } else if (dim == 1) {
90
+ use_src0 = (current_i1 < ne01);
91
+ if (!use_src0) { s_i1 = current_i1 - ne01; }
92
+ } else if (dim == 2) {
93
+ use_src0 = (current_i2 < ne02);
94
+ if (!use_src0) { s_i2 = current_i2 - ne02; }
95
+ } else { // dim == 3
96
+ use_src0 = (current_i3 < ne03);
97
+ if (!use_src0) { s_i3 = current_i3 - ne03; }
98
+ }
99
+
100
+ if (use_src0) {
101
+ x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
102
+ } else {
103
+ x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);
104
+ }
105
+
106
+ y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);
107
+ *y_val_ptr = *x_val_ptr;
108
+ }
109
+ }
ggml/src/ggml-opencl/kernels/pad.cl ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kernel void kernel_pad(
2
+ global const void * src0_ptr,
3
+ ulong src0_offset,
4
+ global void * dst_ptr,
5
+ ulong dst_offset,
6
+ int s_ne0, int s_ne1, int s_ne2,
7
+ int d_ne0, int d_ne1, int d_ne2
8
+ ) {
9
+ global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
10
+ global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
11
+
12
+ int nidx = get_global_id(0);
13
+ int idx_d1 = get_group_id(1);
14
+ int idx_d2 = get_group_id(2);
15
+
16
+ if (nidx >= d_ne0) {
17
+ return;
18
+ }
19
+
20
+ int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
21
+
22
+ bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
23
+
24
+ if (in_src_bounds) {
25
+ int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
26
+ dst[dst_el_offset] = src0[src_el_offset];
27
+ } else {
28
+ dst[dst_el_offset] = 0.0f;
29
+ }
30
+ }
ggml/src/ggml-opencl/kernels/repeat.cl ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kernel void kernel_repeat(
2
+ global const char * src0_data_in,
3
+ global char * dst_data_in,
4
+ ulong src0_offset,
5
+ ulong dst_offset,
6
+ int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3,
7
+ ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3,
8
+ int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3,
9
+ ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3
10
+ ) {
11
+ global const char * src0_data = src0_data_in + src0_offset;
12
+ global char * dst_data = dst_data_in + dst_offset;
13
+
14
+ const int d3 = get_global_id(2);
15
+ const int d2 = get_global_id(1);
16
+ const int d1 = get_global_id(0);
17
+
18
+ if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) {
19
+ return;
20
+ }
21
+
22
+ const int s3 = d3 % src0_ne3;
23
+ const int s2 = d2 % src0_ne2;
24
+ const int s1 = d1 % src0_ne1;
25
+
26
+ const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1;
27
+ global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1;
28
+
29
+ for (int d0 = 0; d0 < dst_ne0; ++d0) {
30
+ // Determine source index for dimension 0 based on tiling/broadcasting.
31
+ const int s0 = d0 % src0_ne0;
32
+
33
+ const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0;
34
+ global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0;
35
+ for (int k = 0; k < src0_nb0; ++k) {
36
+ current_dst_el_ptr[k] = current_src_el_ptr[k];
37
+ }
38
+ }
39
+ }
ggml/src/ggml-opencl/kernels/tanh.cl ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+
3
+ #ifdef cl_intel_required_subgroup_size
4
+ #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
5
+ #define INTEL_GPU 1
6
+ #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
7
+ #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
8
+ #elif defined(cl_qcom_reqd_sub_group_size)
9
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
10
+ #define ADRENO_GPU 1
11
+ #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
12
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
13
+ #endif
14
+
15
+ kernel void kernel_tanh_f32_nd(
16
+ global void * p_src0_base, ulong off_src0_abs,
17
+ global void * p_dst_base, ulong off_dst_abs,
18
+ int ne00, int ne01, int ne02, int ne03,
19
+ ulong nb00, ulong nb01, ulong nb02, ulong nb03,
20
+ int ne10, int ne11, int ne12, int ne13,
21
+ ulong nb10, ulong nb11, ulong nb12, ulong nb13
22
+ ) {
23
+ int i0 = get_global_id(0);
24
+ int i1 = get_global_id(1);
25
+ int i2 = get_global_id(2);
26
+
27
+ if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
28
+ for (int i3 = 0; i3 < ne13; ++i3) {
29
+ ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
30
+ global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
31
+
32
+ ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
33
+ global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
34
+
35
+ *dst_val_ptr = tanh(*src_val_ptr);
36
+ }
37
+ }
38
+ }
39
+
40
+ kernel void kernel_tanh_f16_nd(
41
+ global void * p_src0_base, ulong off_src0_abs,
42
+ global void * p_dst_base, ulong off_dst_abs,
43
+ int ne00, int ne01, int ne02, int ne03,
44
+ ulong nb00, ulong nb01, ulong nb02, ulong nb03,
45
+ int ne10, int ne11, int ne12, int ne13,
46
+ ulong nb10, ulong nb11, ulong nb12, ulong nb13
47
+ ) {
48
+ int i0 = get_global_id(0);
49
+ int i1 = get_global_id(1);
50
+ int i2 = get_global_id(2);
51
+
52
+ if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
53
+ for (int i3 = 0; i3 < ne13; ++i3) {
54
+ ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
55
+ global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
56
+
57
+ ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
58
+ global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
59
+
60
+ *dst_val_ptr = tanh(*src_val_ptr);
61
+ }
62
+ }
63
+ }
ggml/src/ggml-opencl/kernels/tsembd.cl ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kernel void kernel_timestep_embedding(
2
+ global const void * p_timesteps,
3
+ ulong off_timesteps,
4
+ global void * p_dst,
5
+ ulong off_dst,
6
+ int dst_nb1_bytes,
7
+ int logical_dim,
8
+ int max_period
9
+ ) {
10
+ int local_i;
11
+ int local_j;
12
+ int local_half_dim;
13
+ float local_timestep_val;
14
+ float local_freq;
15
+ float local_arg;
16
+ global float * local_embed_data_ptr;
17
+ global const float * local_timesteps_input_ptr;
18
+ global float * local_dst_output_base_ptr;
19
+
20
+ local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps);
21
+ local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst);
22
+
23
+ local_i = get_global_id(1);
24
+ local_j = get_global_id(0);
25
+
26
+ local_half_dim = logical_dim / 2;
27
+ local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
28
+
29
+ if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) {
30
+ local_embed_data_ptr[logical_dim] = 0.0f;
31
+ }
32
+
33
+ if (local_j >= local_half_dim) {
34
+ return;
35
+ }
36
+
37
+ local_timestep_val = local_timesteps_input_ptr[local_i];
38
+
39
+ if (local_half_dim == 0) {
40
+ local_freq = 1.0f;
41
+ } else {
42
+ local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim);
43
+ }
44
+
45
+ local_arg = local_timestep_val * local_freq;
46
+ local_embed_data_ptr[local_j] = cos(local_arg);
47
+ local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg);
48
+ }
ggml/src/ggml-opencl/kernels/upscale.cl ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kernel void kernel_upscale(
2
+ global const void * p_src0,
3
+ ulong off_src0,
4
+ global void * p_dst,
5
+ ulong off_dst,
6
+ ulong nb00,
7
+ ulong nb01,
8
+ ulong nb02,
9
+ ulong nb03,
10
+ int ne10,
11
+ int ne11,
12
+ int ne12,
13
+ int ne13,
14
+ float sf0,
15
+ float sf1,
16
+ float sf2,
17
+ float sf3
18
+ ) {
19
+ global const char * src_base = (global const char *)p_src0 + off_src0;
20
+ global float * dst_base = (global float *)((global char *)p_dst + off_dst);
21
+
22
+ int index = get_global_id(0);
23
+ int dst_total_elements = ne10 * ne11 * ne12 * ne13;
24
+
25
+ if (index >= dst_total_elements) {
26
+ return;
27
+ }
28
+
29
+ int i10 = index % ne10;
30
+ int i11 = (index / ne10) % ne11;
31
+ int i12 = (index / (ne10 * ne11)) % ne12;
32
+ int i13 = index / (ne10 * ne11 * ne12);
33
+
34
+ int i00 = (int)(i10 / sf0);
35
+ int i01 = (int)(i11 / sf1);
36
+ int i02 = (int)(i12 / sf2);
37
+ int i03 = (int)(i13 / sf3);
38
+
39
+ ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;
40
+ global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);
41
+
42
+ dst_base[index] = *src_element_ptr;
43
+ }
44
+
45
+ kernel void kernel_upscale_bilinear(
46
+ global const void * p_src0,
47
+ ulong off_src0,
48
+ global void * p_dst,
49
+ ulong off_dst,
50
+ ulong nb00,
51
+ ulong nb01,
52
+ ulong nb02,
53
+ ulong nb03,
54
+ int ne00_src,
55
+ int ne01_src,
56
+ int ne10_dst,
57
+ int ne11_dst,
58
+ int ne12_dst,
59
+ int ne13_dst,
60
+ float sf0,
61
+ float sf1,
62
+ float sf2,
63
+ float sf3
64
+ ) {
65
+ global const char * src_base = (global const char *)p_src0 + off_src0;
66
+ global float * dst_base = (global float *)((global char *)p_dst + off_dst);
67
+
68
+ int index = get_global_id(0);
69
+ int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
70
+
71
+ if (index >= dst_total_elements) {
72
+ return;
73
+ }
74
+
75
+ int i10_dst = index % ne10_dst;
76
+ int i11_dst = (index / ne10_dst) % ne11_dst;
77
+ int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
78
+ int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
79
+
80
+ int i02_src = (int)(i12_dst / sf2);
81
+ int i03_src = (int)(i13_dst / sf3);
82
+
83
+ const float pixel_offset = 0.5f;
84
+
85
+ float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
86
+ long y0_src = (long)floor(y_src_f);
87
+ long y1_src = y0_src + 1;
88
+
89
+ y0_src = max(0L, min(y0_src, (long)ne01_src - 1));
90
+ y1_src = max(0L, min(y1_src, (long)ne01_src - 1));
91
+
92
+ float dy = y_src_f - (float)y0_src;
93
+ dy = max(0.0f, min(dy, 1.0f));
94
+
95
+ float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
96
+ long x0_src = (long)floor(x_src_f);
97
+ long x1_src = x0_src + 1;
98
+
99
+ x0_src = max(0L, min(x0_src, (long)ne00_src - 1));
100
+ x1_src = max(0L, min(x1_src, (long)ne00_src - 1));
101
+
102
+ float dx = x_src_f - (float)x0_src;
103
+ dx = max(0.0f, min(dx, 1.0f));
104
+
105
+ global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
106
+ global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
107
+ global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
108
+ global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
109
+
110
+ const float val_a = *p_a;
111
+ const float val_b = *p_b;
112
+ const float val_c = *p_c;
113
+ const float val_d = *p_d;
114
+
115
+ float result = val_a * (1.0f - dx) * (1.0f - dy) +
116
+ val_b * dx * (1.0f - dy) +
117
+ val_c * (1.0f - dx) * dy +
118
+ val_d * dx * dy;
119
+
120
+ dst_base[index] = result;
121
+ }