#pragma once #include "tensor.hpp" enum class Activation { LINEAR, SIGMOID, TANH, RELU, LEAKY_RELU, ELU }; enum class Loss { MSE }; template concept ITensorType = std::is_base_of_v; template concept ITensor0Type = std::is_base_of_v; template concept ITensor1Type = std::is_base_of_v; template concept ITensor2Type = std::is_base_of_v; template concept ITensor3Type = std::is_base_of_v; template class ITensorMath { protected: void validateSameDimensions(const T &a, const T &b) const { if (a.getDim() != b.getDim()) throw std::invalid_argument("Tensors must have the same dimension"); if (a.getSize() != b.getSize()) throw std::invalid_argument("Tensors must have the same size"); for (int i = 0; i < a.getDim(); ++i) { if (a.getShape()[i] != b.getShape()[i]) throw std::invalid_argument("Tensors must have the same shape"); } }; public: virtual T activate(const T &m, Activation type, float alpha) = 0; virtual T d_activate(const T &m, Activation type, float alpha) = 0; virtual T mult(const T &a, const T &b) = 0; virtual T mult(const T &m, float x) = 0; virtual T add(const T &a, const T &b, float x) = 0; virtual T add(const T &m, float x) = 0; virtual void await() const = 0; }; template class ITensor0Math {}; template class ITensor1Math {}; template class ITensor2Math { public: virtual M dot(const M &a, const M &b, bool transpose_a, bool transpose_b, const V *bias, Activation type, float alpha) = 0; virtual M loss(const M &a, const M &b, Loss type) = 0; virtual M d_loss(const M &a, const M &b, Loss type) = 0; virtual V axis_sum(const M &m) = 0; void validateMultDimensions(const M &a, const M &b, bool transpose_a, bool transpose_b) const { int a_cols = transpose_a ? a.getRows() : a.getCols(); int b_rows = transpose_b ? b.getCols() : b.getRows(); if (a_cols != b_rows) throw std::invalid_argument( "Invalid matrix dimensions for multiplication"); }; void validateBiasDimensions(const M &m, const V &v, bool transpose) const { if ((transpose && (size_t)m.getCols() != v.getSize()) || (!transpose && (size_t)m.getRows() != v.getSize())) throw std::invalid_argument("Invalid matrix bias"); }; }; template class ITensor3Math {};