mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-04 04:40:40 +04:00
Check
This commit is contained in:
@@ -45,9 +45,13 @@ private:
|
||||
all(other.getEvent()), &event_);
|
||||
}
|
||||
|
||||
static cl::Kernel createKernel(Kernels<T>::Method method) {
|
||||
static Kernels<T> kernels(Kernels<T>::Vector::type4);
|
||||
return kernels.create(method);
|
||||
}
|
||||
|
||||
public:
|
||||
typedef class ITensor<T, Dim> ITensor;
|
||||
typedef class Kernels<T, Dim> Kernels;
|
||||
|
||||
using ITensor::axes_;
|
||||
using ITensor::checkAxisInDim;
|
||||
@@ -105,7 +109,7 @@ public:
|
||||
ITensor::operator=(std::move(other));
|
||||
data_ = other.data_;
|
||||
event_ = other.event_;
|
||||
other.data = nullptr;
|
||||
other.data_ = nullptr;
|
||||
return *this;
|
||||
}
|
||||
~Tensor() {
|
||||
@@ -120,8 +124,9 @@ public:
|
||||
using ITensor::operator-;
|
||||
|
||||
Tensor operator+() const override {
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::POSITIVE);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::POSITIVE);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, (int)getSize());
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
cl::NDRange(getSize()),
|
||||
cl::NullRange, all(event_), &event_);
|
||||
@@ -129,8 +134,9 @@ public:
|
||||
}
|
||||
|
||||
Tensor operator-() const override {
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::NEGATIVE);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::NEGATIVE);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, (int)getSize());
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
cl::NDRange(getSize()),
|
||||
cl::NullRange, all(event_), &event_);
|
||||
@@ -138,9 +144,10 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator+=(const T scalar) override {
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::S_ADD);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::S_ADD);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, scalar);
|
||||
kernel.setArg(2, (int)getSize());
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
cl::NDRange(getSize()),
|
||||
cl::NullRange, all(event_), &event_);
|
||||
@@ -148,9 +155,10 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator*=(const T scalar) override {
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::S_MULT);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::S_MULT);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, scalar);
|
||||
kernel.setArg(2, (int)getSize());
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
cl::NDRange(getSize()),
|
||||
cl::NullRange, all(event_), &event_);
|
||||
@@ -158,9 +166,10 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator+=(const Tensor &other) override {
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::T_ADD);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::T_ADD);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, *other.getData());
|
||||
kernel.setArg(2, (int)getSize());
|
||||
openCL.getQueue().enqueueNDRangeKernel(
|
||||
kernel, cl::NullRange, cl::NDRange(getSize()), cl::NullRange,
|
||||
all(event_, other.event_), &event_);
|
||||
@@ -168,9 +177,10 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator*=(const Tensor &other) override {
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::T_HADAMARD);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::T_HADAMARD);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, *other.getData());
|
||||
kernel.setArg(2, getSize());
|
||||
openCL.getQueue().enqueueNDRangeKernel(
|
||||
kernel, cl::NullRange, cl::NDRange(getSize()), cl::NullRange,
|
||||
all(event_, other.event_), &event_);
|
||||
@@ -192,16 +202,14 @@ public:
|
||||
size_t k = shape_[axes_[1]];
|
||||
size_t n = other.shape_[other.axes_[1]];
|
||||
Tensor<T, 2> result({m, n});
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::T_MULT);
|
||||
cl::Kernel kernel = createKernel(Kernels<T>::Method::T_MULT);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, *other.getData());
|
||||
kernel.setArg(2, *result.getData());
|
||||
kernel.setArg(3, (int)m);
|
||||
kernel.setArg(4, (int)n);
|
||||
kernel.setArg(5, (int)k);
|
||||
cl::NDRange global_size(
|
||||
((m + TILE_SIZE * VEC_SIZE - 1) / (TILE_SIZE * VEC_SIZE)) * TILE_SIZE,
|
||||
((n + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE);
|
||||
cl::NDRange global_size(m / VEC_SIZE, n);
|
||||
cl::NDRange local_size(TILE_SIZE / VEC_SIZE, TILE_SIZE);
|
||||
openCL.getQueue().enqueueNDRangeKernel(
|
||||
kernel, cl::NullRange, global_size, local_size,
|
||||
|
||||
Reference in New Issue
Block a user