mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-04 04:40:40 +04:00
Forward with new tensors math
This commit is contained in:
@@ -41,18 +41,24 @@ template <ITensor0Type T> class ITensor0Math {};
|
||||
|
||||
template <ITensor1Type T> class ITensor1Math {};
|
||||
|
||||
template <ITensor2Type T> class ITensor2Math {
|
||||
template <ITensor2Type M, ITensor1Type V> class ITensor2Math {
|
||||
public:
|
||||
virtual T mult(const T &a, const T &b, bool transpose, float bias,
|
||||
virtual M mult(const M &a, const M &b, bool transpose, const V *bias,
|
||||
Activation type, float alpha) = 0;
|
||||
|
||||
void validateMultDimensions(const T &a, const T &b, bool transpose) const {
|
||||
void validateMultDimensions(const M &a, const M &b, bool transpose) const {
|
||||
if ((!transpose && a.getCols() != b.getRows()) ||
|
||||
(transpose && a.getCols() != b.getCols())) {
|
||||
throw std::invalid_argument(
|
||||
"Invalid matrix dimensions for multiplication");
|
||||
}
|
||||
};
|
||||
void validateBiasDimensions(const M &a, const V &b, bool transpose) const {
|
||||
if ((!transpose && a.getCols() != b.getSize()) ||
|
||||
(transpose && a.getRows() != b.getSize())) {
|
||||
throw std::invalid_argument("Invalid matrix bias");
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <ITensor3Type T> class ITensor3Math {};
|
||||
Reference in New Issue
Block a user