This commit is contained in:
2025-11-25 18:50:52 +04:00
parent 8d5a57a8c0
commit a001582431
8 changed files with 195 additions and 273 deletions

View File

@@ -2,6 +2,8 @@
#include "opencl.hpp"
#include "kernels.hpp"
#include "../tensor.hpp"
#include <random>
@@ -45,6 +47,7 @@ private:
public:
typedef class ITensor<T, Dim> ITensor;
typedef class Kernels<T, Dim> Kernels;
using ITensor::axes_;
using ITensor::checkAxisInDim;
@@ -117,7 +120,7 @@ public:
using ITensor::operator-;
Tensor operator+() const override {
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::POSITIVE);
cl::Kernel kernel = Kernels::create(Kernels::Method::POSITIVE);
kernel.setArg(0, *data_);
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange(getSize()),
@@ -126,7 +129,7 @@ public:
}
Tensor operator-() const override {
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::NEGATIVE);
cl::Kernel kernel = Kernels::create(Kernels::Method::NEGATIVE);
kernel.setArg(0, *data_);
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange(getSize()),
@@ -135,7 +138,7 @@ public:
}
Tensor &operator+=(const T scalar) override {
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::S_ADD);
cl::Kernel kernel = Kernels::create(Kernels::Method::S_ADD);
kernel.setArg(0, *data_);
kernel.setArg(1, scalar);
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
@@ -145,7 +148,7 @@ public:
}
Tensor &operator*=(const T scalar) override {
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::S_MULT);
cl::Kernel kernel = Kernels::create(Kernels::Method::S_MULT);
kernel.setArg(0, *data_);
kernel.setArg(1, scalar);
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
@@ -155,7 +158,7 @@ public:
}
Tensor &operator+=(const Tensor &other) override {
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::T_ADD);
cl::Kernel kernel = Kernels::create(Kernels::Method::T_ADD);
kernel.setArg(0, *data_);
kernel.setArg(1, *other.getData());
openCL.getQueue().enqueueNDRangeKernel(
@@ -165,7 +168,7 @@ public:
}
Tensor &operator*=(const Tensor &other) override {
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::T_HADAMARD);
cl::Kernel kernel = Kernels::create(Kernels::Method::T_HADAMARD);
kernel.setArg(0, *data_);
kernel.setArg(1, *other.getData());
openCL.getQueue().enqueueNDRangeKernel(
@@ -189,7 +192,7 @@ public:
size_t k = shape_[axes_[1]];
size_t n = other.shape_[other.axes_[1]];
Tensor<T, 2> result({m, n});
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::T_MULT);
cl::Kernel kernel = Kernels::create(Kernels::Method::T_MULT);
kernel.setArg(0, *data_);
kernel.setArg(1, *other.getData());
kernel.setArg(2, *result.getData());