Back propogation. Not work(

This commit is contained in:
2025-11-02 20:19:36 +04:00
parent d795bb3019
commit df9a5e3017
6 changed files with 99 additions and 42 deletions

View File

@@ -34,6 +34,7 @@ public:
virtual T activate(const T &m, Activation type, float alpha) = 0;
virtual T d_activate(const T &m, Activation type, float alpha) = 0;
virtual T mult(const T &a, const T &b) = 0;
virtual T mult(const T &m, float x) = 0;
virtual T add(const T &a, const T &b, float x) = 0;
virtual T add(const T &m, float x) = 0;
@@ -47,24 +48,26 @@ template <ITensor1Type T> class ITensor1Math {};
template <ITensor2Type M, ITensor1Type V> class ITensor2Math {
public:
virtual M dot(const M &a, const M &b, bool transpose, const V *bias,
Activation type, float alpha) = 0;
virtual M dot(const M &a, const M &b, bool transpose_a, bool transpose_b,
const V *bias, Activation type, float alpha) = 0;
virtual M loss(const M &a, const M &b, Loss type) = 0;
virtual M d_loss(const M &a, const M &b, Loss type) = 0;
void validateMultDimensions(const M &a, const M &b, bool transpose) const {
if ((!transpose && a.getCols() != b.getRows()) ||
(transpose && a.getCols() != b.getCols())) {
virtual V axis_sum(const M &m) = 0;
void validateMultDimensions(const M &a, const M &b, bool transpose_a,
bool transpose_b) const {
int a_cols = transpose_a ? a.getRows() : a.getCols();
int b_rows = transpose_b ? b.getCols() : b.getRows();
if (a_cols != b_rows)
throw std::invalid_argument(
"Invalid matrix dimensions for multiplication");
}
};
void validateBiasDimensions(const M &a, const V &b, bool transpose) const {
if ((!transpose && a.getCols() != b.getSize()) ||
(transpose && a.getRows() != b.getSize())) {
(transpose && a.getRows() != b.getSize()))
throw std::invalid_argument("Invalid matrix bias");
}
};
};