Finally it works

This commit is contained in:
2025-11-26 12:55:09 +04:00
parent 2db52adf0f
commit 153a13f443
12 changed files with 319 additions and 164 deletions

View File

@@ -11,6 +11,7 @@
template <typename T> class Kernels {
public:
enum class Vector {
type1 = 1,
type2 = 2,
type4 = 4,
type8 = 8,
@@ -24,6 +25,7 @@ public:
T_ADD,
T_HADAMARD,
T_MULT,
FUNC
};
private:
@@ -42,6 +44,7 @@ private:
pos += value.length();
}
}
// std::cout << result << std::endl;
return result;
}
@@ -50,6 +53,7 @@ private:
R"(
__kernel void {method}(__global type* A, int len) {
int gid = get_global_id(0);
#if WIDTH != 1
int base = gid * WIDTH;
if (base + WIDTH <= len) {
typeX data = vloadX(gid, A);
@@ -60,6 +64,9 @@ private:
if (idx < len) A[idx] = {operation}A[idx];
}
}
#else
A[gid] = {operation}A[gid];
#endif
})",
{{"method", name}, {"operation", operation}});
}
@@ -69,6 +76,7 @@ private:
R"(
__kernel void {method}(__global type* A, int len, type scalar) {
int gid = get_global_id(0);
#if WIDTH != 1
int base = gid * WIDTH;
if (base + WIDTH <= len) {
typeX data = vloadX(gid, A);
@@ -80,6 +88,9 @@ private:
if (idx < len) A[idx] = A[idx] {operation} scalar;
}
}
#else
A[gid] = A[gid] {operation} scalar;
#endif
})",
{{"method", name}, {"operation", operation}});
}
@@ -89,6 +100,7 @@ private:
R"(
__kernel void {method}(__global type* A, __global type* B, int len) {
int gid = get_global_id(0);
#if WIDTH != 1
int base = gid * WIDTH;
if (base + WIDTH <= len) {
typeX dataA = vloadX(gid, A);
@@ -100,48 +112,65 @@ private:
if (idx < len) A[idx] = A[idx] {operation} B[idx];
}
}
#else
A[gid] = A[gid] {operation} B[gid];
#endif
})",
{{"method", name}, {"operation", operation}});
}
std::string matrixMult(std::string name) {
return format(
R"(
#define TILE_SIZE WIDTH*4
__kernel void mult(const __global typeX* A,
const __global typeX* B,
__global typeX* C, const int M, const int N, const int K) {
const int row = get_local_id(0);
const int col = get_local_id(1);
const int globalRow = (TILE_SIZE/WIDTH)*get_group_id(0) + row;
const int globalCol = TILE_SIZE*get_group_id(1) + col;
__local typeX Asub[TILE_SIZE][TILE_SIZE/WIDTH];
__local typeX Bsub[TILE_SIZE][TILE_SIZE/WIDTH];
typeX acc = 0;
const int numTiles = K/TILE_SIZE;
for (int tile = 0; tile < numTiles; tile++) {
const int tiledRow = (TILE_SIZE/WIDTH)*tile + row;
const int tiledCol = TILE_SIZE*tile + col;
Asub[col][row] = A[tiledCol*(M/WIDTH) + globalRow];
Bsub[col][row] = B[globalCol*(K/WIDTH) + tiledRow];
barrier(CLK_LOCAL_MEM_FENCE);
typeX vecA, vecB;
type valB;
for (int k = 0; k < TILE_SIZE/WIDTH; k++) {
vecB = Bsub[col][k];
for (int w = 0; w < WIDTH; w++) {
vecA = Asub[WIDTH*k + w][row];
valB = vecB[w];
for (int i = 0; i < WIDTH; i++)
acc[i] += vecA[i] * valB;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
C[globalCol*(M/WIDTH) + globalRow] = acc;
}
)",
{{"method", name}});
std::string matrixMult() {
return R"(
__kernel void mult(const __global type* A,
const __global type* B,
__global type* C,
const int M, const int N, const int K) {
const int row = get_global_id(0);
const int col = get_global_id(1);
if (row < M && col < N) {
type sum = 0.0f;
for (int k = 0; k < K; k++)
sum += A[row * K + k] * B[k * N + col];
C[row * N + col] = sum;
}
})";
}
std::string func() {
return R"(
__kernel void func(__global type* A, const int f, const int derivative) {
int gid = get_global_id(0);
type x = A[gid];
switch (f) {
case 0: // SIGMOID
if (!derivative)
A[gid] = (type)1 / ((type)1 + exp(-x));
else {
type sigmoid = (type)1 / ((type)1 + exp(-x));
A[gid] = sigmoid * ((type)1 - sigmoid);
}
break;
case 1: // RELU
if (!derivative)
A[gid] = fmax((type)0, x);
else
A[gid] = (x > (type)0) ? (type)1 : (type)0;
break;
case 2: // MSE (здесь это скорее квадратная функция)
if (!derivative)
A[gid] = x * x;
else
A[gid] = (type)2 * x;
break;
case 3: // LINEAR
default:
if (!derivative)
A[gid] = x;
else
A[gid] = (type)1.0f;
break;
}
})";
}
std::unordered_map<Method, std::tuple<std::string, std::string>> programs = {
@@ -155,13 +184,18 @@ private:
{Method::T_HADAMARD,
{binaryOperation("hadamard_mult", "*"), "hadamard_mult"}},
{Method::T_MULT, {matrixMult("mult"), "mult"}},
{Method::T_MULT, {matrixMult(), "mult"}},
{Method::FUNC, {func(), "func"}},
};
std::unordered_map<Method, cl::Program> compiledPrograms;
public:
Kernels(Vector vec = Vector::type4) : vector(vec) {
Kernels(Vector vec) : vector(vec) {
std::cout << "Compile " << getTypeName()
<< " kernels with vector size = " << std::to_string((int)vector)
<< " ";
std::string extensions = openCL.getDevice().getInfo<CL_DEVICE_EXTENSIONS>();
if (extensions.find("cl_khr_fp16") != std::string::npos)
configuration = R"(
@@ -183,10 +217,12 @@ public:
configuration += format(
R"(
typedef {type} type;
typedef {type}{vector} typeX;
#define WIDTH {vector}
#define vloadX vload{vector}
#define vstoreX vstore{vector}
#if WIDTH != 1
typedef {type}{vector} typeX;
#define vloadX vload{vector}
#define vstoreX vstore{vector}
#endif
)",
{{"type", getTypeName()}, {"vector", std::to_string((int)vector)}});
@@ -209,6 +245,7 @@ public:
}
}
}
std::cout << "completed" << std::endl;
}
cl::Kernel create(Method method) {