mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-03 12:20:39 +04:00
84 lines
2.2 KiB
C++
84 lines
2.2 KiB
C++
#pragma once
|
|
|
|
#include <array>
|
|
#include <cstddef>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
template <typename T, int Dim> class Tensor;
|
|
enum class Function { SIGMOID, RELU, MSE, LINEAR };
|
|
|
|
template <typename T, int Dim> class ITensor {
|
|
protected:
|
|
std::array<size_t, Dim> shape_;
|
|
std::array<int, Dim> axes_;
|
|
|
|
template <typename... Indices> size_t computeIndex(Indices... indices) const;
|
|
|
|
void checkItHasSameShape(const ITensor &other) const;
|
|
void checkAxisInDim(int axis) const;
|
|
|
|
std::string format(std::vector<T> data) const;
|
|
|
|
public:
|
|
typedef class Tensor<T, Dim> Tensor;
|
|
|
|
ITensor() = delete;
|
|
ITensor(const std::array<size_t, Dim> &shape);
|
|
ITensor(const ITensor &other);
|
|
ITensor &operator=(const ITensor &other);
|
|
ITensor(ITensor &&other) noexcept;
|
|
ITensor &operator=(ITensor &&other) noexcept;
|
|
~ITensor() = default;
|
|
|
|
const std::array<int, Dim> &getAxes() const;
|
|
const std::array<size_t, Dim> getShape() const;
|
|
size_t getSize() const;
|
|
|
|
Tensor &transpose(const std::array<int, Dim> &new_axes);
|
|
Tensor &transpose(int axis_a, int axis_b);
|
|
Tensor &t();
|
|
|
|
virtual Tensor operator+() const = 0;
|
|
virtual Tensor operator-() const = 0;
|
|
|
|
virtual Tensor &operator+=(const T scalar) = 0;
|
|
virtual Tensor &operator*=(const T scalar) = 0;
|
|
|
|
virtual Tensor &operator+=(const Tensor &other) = 0;
|
|
virtual Tensor &operator*=(const Tensor &other) = 0;
|
|
|
|
Tensor operator+(const T scalar) const;
|
|
friend Tensor operator+(const T scalar, const Tensor &tensor) {
|
|
return tensor + scalar;
|
|
}
|
|
|
|
Tensor &operator-=(const T scalar);
|
|
Tensor operator-(const T scalar) const;
|
|
friend Tensor operator-(const T scalar, const Tensor &tensor) {
|
|
return tensor + (-scalar);
|
|
}
|
|
|
|
Tensor operator*(const T scalar) const;
|
|
friend Tensor operator*(const T scalar, const Tensor &tensor) {
|
|
return tensor * scalar;
|
|
}
|
|
|
|
Tensor &operator/=(const T scalar);
|
|
Tensor operator/(const T scalar) const;
|
|
|
|
Tensor operator+(const Tensor &other) const;
|
|
|
|
Tensor &operator-=(const Tensor &other);
|
|
Tensor operator-(const Tensor &other) const;
|
|
|
|
Tensor operator*(const Tensor &other) const;
|
|
|
|
virtual Tensor apply(Function f, bool derivative = false) const = 0;
|
|
|
|
// === Utils ===
|
|
virtual std::string toString() const = 0;
|
|
};
|
|
|
|
#include "tensor.tpp"
|