This commit is contained in:
2025-11-23 01:15:51 +04:00
parent 0455d9bd5b
commit 8d5a57a8c0
4 changed files with 74 additions and 28 deletions

View File

@@ -175,6 +175,7 @@ public:
}
#define TILE_SIZE 16
#define VEC_SIZE 4
Tensor<T, Dim == 1 ? 0 : 2> operator%(const Tensor &other) const {
static_assert(Dim == 1 || Dim == 2,
"Inner product is only defined for vectors and matrices");
@@ -195,9 +196,10 @@ public:
kernel.setArg(3, (int)m);
kernel.setArg(4, (int)n);
kernel.setArg(5, (int)k);
cl::NDRange global_size(((m + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE,
((n + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE);
cl::NDRange local_size(TILE_SIZE, TILE_SIZE);
cl::NDRange global_size(
((m + TILE_SIZE * VEC_SIZE - 1) / (TILE_SIZE * VEC_SIZE)) * TILE_SIZE,
((n + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE);
cl::NDRange local_size(TILE_SIZE / VEC_SIZE, TILE_SIZE);
openCL.getQueue().enqueueNDRangeKernel(
kernel, cl::NullRange, global_size, local_size,
all(event_, other.event_), &result.event_);