Forward with new tensors math

This commit is contained in:
2025-11-01 14:22:09 +04:00
parent f1dfe1b335
commit c548c3089b
8 changed files with 192 additions and 68 deletions

View File

@@ -93,12 +93,11 @@ class Tensor0Math : public TensorMath<Tensor0>, public ITensor0Math<Tensor0> {};
class Tensor1Math : public TensorMath<Tensor1>, public ITensor1Math<Tensor1> {};
class Tensor2Math : public TensorMath<Tensor2>, public ITensor2Math<Tensor2> {
class Tensor2Math : public TensorMath<Tensor2>,
public ITensor2Math<Tensor2, Tensor1> {
private:
Tensor2 mult_tiled(const Tensor2 &a, const Tensor2 &b, bool transpose = false,
float bias = 0.0f, Activation type = Activation::LINEAR,
float alpha = 0.01f) {
validateMultDimensions(a, b, transpose);
Tensor2 mult_tiled(const Tensor2 &a, const Tensor2 &b, bool transpose,
const Vector &bias, Activation type, float alpha) {
Tensor2 result(a.getRows(), transpose ? b.getRows() : b.getCols(), false,
&queue);
@@ -111,7 +110,7 @@ private:
kernels[Method::MULT].setArg(0, *a.getBuffer());
kernels[Method::MULT].setArg(1, *b.getBuffer());
kernels[Method::MULT].setArg(2, *result.getBuffer());
kernels[Method::MULT].setArg(3, bias);
kernels[Method::MULT].setArg(3, *bias.getBuffer());
kernels[Method::MULT].setArg(4, static_cast<int>(type));
kernels[Method::MULT].setArg(5, alpha);
kernels[Method::MULT].setArg(6, result.getRows());
@@ -122,16 +121,14 @@ private:
global_size, local_size);
return result;
}
Tensor2 mult_small(const Tensor2 &a, const Tensor2 &b, bool transpose = false,
float bias = 0.0f, Activation type = Activation::LINEAR,
float alpha = 0.01f) {
validateMultDimensions(a, b, transpose);
Tensor2 mult_small(const Tensor2 &a, const Tensor2 &b, bool transpose,
const Vector &bias, Activation type, float alpha) {
Tensor2 result(a.getRows(), transpose ? b.getRows() : b.getCols(), false,
&queue);
kernels[Method::MULT_SMALL].setArg(0, *a.getBuffer());
kernels[Method::MULT_SMALL].setArg(1, *b.getBuffer());
kernels[Method::MULT_SMALL].setArg(2, *result.getBuffer());
kernels[Method::MULT_SMALL].setArg(3, bias);
kernels[Method::MULT_SMALL].setArg(3, *bias.getBuffer());
kernels[Method::MULT_SMALL].setArg(4, static_cast<int>(type));
kernels[Method::MULT_SMALL].setArg(5, alpha);
kernels[Method::MULT_SMALL].setArg(6, result.getRows());
@@ -145,13 +142,20 @@ private:
public:
Tensor2 mult(const Tensor2 &a, const Tensor2 &b, bool transpose = false,
float bias = 0.0f, Activation type = Activation::LINEAR,
const Vector *bias = nullptr,
Activation type = Activation::LINEAR,
float alpha = 0.01f) override {
validateMultDimensions(a, b, transpose);
const Vector defaultBias(a.getRows(), 0.0f, &queue);
if (bias != nullptr)
validateBiasDimensions(b, *bias, transpose);
if (a.getRows() > 64 || a.getCols() > 64 || b.getRows() > 64 ||
b.getCols() > 64)
return mult_tiled(a, b, transpose, bias, type, alpha);
return mult_tiled(a, b, transpose, bias == nullptr ? defaultBias : *bias,
type, alpha);
else
return mult_small(a, b, transpose, bias, type, alpha);
return mult_small(a, b, transpose, bias == nullptr ? defaultBias : *bias,
type, alpha);
}
};