mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-03 20:30:39 +04:00
Add loss func and pretti readme
This commit is contained in:
@@ -96,8 +96,8 @@ class Tensor1Math : public TensorMath<Tensor1>, public ITensor1Math<Tensor1> {};
|
||||
class Tensor2Math : public TensorMath<Tensor2>,
|
||||
public ITensor2Math<Tensor2, Tensor1> {
|
||||
private:
|
||||
Tensor2 mult_tiled(const Tensor2 &a, const Tensor2 &b, bool transpose,
|
||||
const Vector &bias, Activation type, float alpha) {
|
||||
Tensor2 dot_tiled(const Tensor2 &a, const Tensor2 &b, bool transpose,
|
||||
const Vector &bias, Activation type, float alpha) {
|
||||
Tensor2 result(a.getRows(), transpose ? b.getRows() : b.getCols(), false,
|
||||
&queue);
|
||||
|
||||
@@ -121,8 +121,8 @@ private:
|
||||
global_size, local_size);
|
||||
return result;
|
||||
}
|
||||
Tensor2 mult_small(const Tensor2 &a, const Tensor2 &b, bool transpose,
|
||||
const Vector &bias, Activation type, float alpha) {
|
||||
Tensor2 dot_small(const Tensor2 &a, const Tensor2 &b, bool transpose,
|
||||
const Vector &bias, Activation type, float alpha) {
|
||||
Tensor2 result(a.getRows(), transpose ? b.getRows() : b.getCols(), false,
|
||||
&queue);
|
||||
kernels[Method::MULT_SMALL].setArg(0, *a.getBuffer());
|
||||
@@ -141,21 +141,21 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
Tensor2 mult(const Tensor2 &a, const Tensor2 &b, bool transpose = false,
|
||||
const Vector *bias = nullptr,
|
||||
Activation type = Activation::LINEAR,
|
||||
float alpha = 0.01f) override {
|
||||
Tensor2 dot(const Tensor2 &a, const Tensor2 &b, bool transpose = false,
|
||||
const Vector *bias = nullptr,
|
||||
Activation type = Activation::LINEAR,
|
||||
float alpha = 0.01f) override {
|
||||
validateMultDimensions(a, b, transpose);
|
||||
const Vector defaultBias(a.getRows(), 0.0f, &queue);
|
||||
if (bias != nullptr)
|
||||
validateBiasDimensions(b, *bias, transpose);
|
||||
if (a.getRows() > 64 || a.getCols() > 64 || b.getRows() > 64 ||
|
||||
b.getCols() > 64)
|
||||
return mult_tiled(a, b, transpose, bias == nullptr ? defaultBias : *bias,
|
||||
type, alpha);
|
||||
return dot_tiled(a, b, transpose, bias == nullptr ? defaultBias : *bias,
|
||||
type, alpha);
|
||||
else
|
||||
return mult_small(a, b, transpose, bias == nullptr ? defaultBias : *bias,
|
||||
type, alpha);
|
||||
return dot_small(a, b, transpose, bias == nullptr ? defaultBias : *bias,
|
||||
type, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user