mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-04 04:40:40 +04:00
Work
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
#include "opencl.hpp"
|
||||
|
||||
#include "kernels.hpp"
|
||||
|
||||
#include "../tensor.hpp"
|
||||
|
||||
#include <random>
|
||||
@@ -45,6 +47,7 @@ private:
|
||||
|
||||
public:
|
||||
typedef class ITensor<T, Dim> ITensor;
|
||||
typedef class Kernels<T, Dim> Kernels;
|
||||
|
||||
using ITensor::axes_;
|
||||
using ITensor::checkAxisInDim;
|
||||
@@ -117,7 +120,7 @@ public:
|
||||
using ITensor::operator-;
|
||||
|
||||
Tensor operator+() const override {
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::POSITIVE);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::POSITIVE);
|
||||
kernel.setArg(0, *data_);
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
cl::NDRange(getSize()),
|
||||
@@ -126,7 +129,7 @@ public:
|
||||
}
|
||||
|
||||
Tensor operator-() const override {
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::NEGATIVE);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::NEGATIVE);
|
||||
kernel.setArg(0, *data_);
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
cl::NDRange(getSize()),
|
||||
@@ -135,7 +138,7 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator+=(const T scalar) override {
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::S_ADD);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::S_ADD);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, scalar);
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
@@ -145,7 +148,7 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator*=(const T scalar) override {
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::S_MULT);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::S_MULT);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, scalar);
|
||||
openCL.getQueue().enqueueNDRangeKernel(kernel, cl::NullRange,
|
||||
@@ -155,7 +158,7 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator+=(const Tensor &other) override {
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::T_ADD);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::T_ADD);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, *other.getData());
|
||||
openCL.getQueue().enqueueNDRangeKernel(
|
||||
@@ -165,7 +168,7 @@ public:
|
||||
}
|
||||
|
||||
Tensor &operator*=(const Tensor &other) override {
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::T_HADAMARD);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::T_HADAMARD);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, *other.getData());
|
||||
openCL.getQueue().enqueueNDRangeKernel(
|
||||
@@ -189,7 +192,7 @@ public:
|
||||
size_t k = shape_[axes_[1]];
|
||||
size_t n = other.shape_[other.axes_[1]];
|
||||
Tensor<T, 2> result({m, n});
|
||||
cl::Kernel kernel = openCL.createKernel(OpenCL::Method::T_MULT);
|
||||
cl::Kernel kernel = Kernels::create(Kernels::Method::T_MULT);
|
||||
kernel.setArg(0, *data_);
|
||||
kernel.setArg(1, *other.getData());
|
||||
kernel.setArg(2, *result.getData());
|
||||
|
||||
Reference in New Issue
Block a user