Add loss func and pretti readme

This commit is contained in:
2025-11-02 16:50:30 +04:00
parent 5a25027e1c
commit d795bb3019
12 changed files with 219 additions and 91 deletions

View File

@@ -2,7 +2,8 @@
#include "tensor.hpp"
enum class Activation { LINEAR, SIGMOID, TANH, RELU, LEAKY_RELU, ELU, GELU };
enum class Activation { LINEAR, SIGMOID, TANH, RELU, LEAKY_RELU, ELU };
enum class Loss { MSE };
template <typename T>
concept ITensorType = std::is_base_of_v<ITensor, T>;
@@ -31,6 +32,7 @@ protected:
public:
virtual T activate(const T &m, Activation type, float alpha) = 0;
virtual T d_activate(const T &m, Activation type, float alpha) = 0;
virtual T mult(const T &m, float x) = 0;
virtual T add(const T &a, const T &b, float x) = 0;
@@ -45,11 +47,13 @@ template <ITensor1Type T> class ITensor1Math {};
template <ITensor2Type M, ITensor1Type V> class ITensor2Math {
public:
virtual M mult(const M &a, const M &b, bool transpose, const V *bias,
Activation type, float alpha) = 0;
virtual M dot(const M &a, const M &b, bool transpose, const V *bias,
Activation type, float alpha) = 0;
virtual M loss(const M &a, const M &b, Loss type) = 0;
virtual M d_loss(const M &a, const M &b, Loss type) = 0;
void validateMultDimensions(const M &a, const M &b, bool transpose) const {
printf("%dx%d %dx%d\n", a.getRows(), a.getCols(), b.getRows(), b.getCols());
if ((!transpose && a.getCols() != b.getRows()) ||
(transpose && a.getCols() != b.getCols())) {
throw std::invalid_argument(
@@ -64,4 +68,4 @@ public:
};
};
template <ITensor3Type T> class ITensor3Math {};
template <ITensor3Type T> class ITensor3Math {};