Split headers and logic

This commit is contained in:
2025-11-17 16:03:32 +04:00
parent bbd9c67c96
commit d7d93999a4
18 changed files with 589 additions and 394 deletions

View File

@@ -1,392 +1,78 @@
#pragma once
#include <array>
#include <random>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <cstddef>
#include <string>
template <typename T, int Dim> class Tensor;
template <typename T, int Dim> class TensorInfo {
template <typename T, int Dim> class ITensor {
protected:
std::array<size_t, Dim> shape_;
std::array<int, Dim> axes_;
template <typename... Indices> size_t computeIndex(Indices... indices) const {
static_assert(sizeof...(Indices) == Dim, "Invalid number of indices");
std::array<size_t, Dim> indicesArray = {static_cast<size_t>(indices)...};
std::array<size_t, Dim> axesIndices;
for (int i = 0; i < Dim; ++i)
axesIndices[axes_[i]] = indicesArray[i];
size_t index = 0;
size_t stride = 1;
for (int i = Dim - 1; i >= 0; --i) {
index += axesIndices[i] * stride;
stride *= shape_[i];
}
return index;
}
template <typename... Indices> size_t computeIndex(Indices... indices) const;
void checkItHasSameShape(const TensorInfo &other) {
if (getShape() != other.getShape())
throw std::invalid_argument("Tensor shapes must match");
}
void checkAxisInDim(int axis) {
if (axis < 0 || axis >= Dim)
throw std::invalid_argument("Invalid axis index");
}
void checkItHasSameShape(const ITensor &other) const;
void checkAxisInDim(int axis) const;
public:
typedef class Tensor<T, Dim> Ten;
typedef class Tensor<T, Dim> Tensor;
TensorInfo() = delete;
ITensor() = delete;
ITensor(const std::array<size_t, Dim> &shape);
ITensor(const ITensor &other);
ITensor &operator=(const ITensor &other);
ITensor(ITensor &&other) noexcept;
ITensor &operator=(ITensor &&other) noexcept;
~ITensor() = default;
TensorInfo(const std::array<size_t, Dim> &shape) {
for (size_t d : shape)
if (d == 0)
throw std::invalid_argument("Invalid shape");
shape_ = shape;
for (int i = 0; i < Dim; ++i)
axes_[i] = i;
}
const std::array<int, Dim> &getAxes() const;
const std::array<size_t, Dim> getShape() const;
size_t getSize() const;
TensorInfo(const TensorInfo &other)
: shape_(other.shape_), axes_(other.axes_) {}
TensorInfo &operator=(const TensorInfo &other) {
shape_ = other.shape_;
axes_ = other.axes_;
return *this;
}
TensorInfo(TensorInfo &&other) noexcept
: shape_(std::move(other.shape_)), axes_(std::move(other.axes_)) {}
TensorInfo &operator=(TensorInfo &&other) noexcept {
shape_ = std::move(other.shape_);
axes_ = std::move(other.axes_);
return *this;
}
~TensorInfo() = default;
Tensor &transpose(const std::array<int, Dim> &new_axes);
Tensor &transpose(int axis_a, int axis_b);
Tensor &t();
const std::array<int, Dim> &getAxes() const { return axes_; }
const std::array<size_t, Dim> getShape() const {
std::array<size_t, Dim> result;
for (int i = 0; i < Dim; ++i)
result[i] = shape_[axes_[i]];
return result;
}
size_t getSize() const {
size_t size = 1;
for (size_t i = 0; i < shape_.size(); ++i)
size *= shape_[i];
return size;
};
// === Operators ===
virtual Tensor operator+() const = 0;
virtual Tensor operator-() const = 0;
Ten &transpose(const std::array<int, Dim> &new_axes) {
std::array<bool, Dim> used{};
for (int axis : new_axes) {
checkAxisInDim(axis);
if (used[axis])
throw std::invalid_argument("Duplicate axis index");
used[axis] = true;
}
axes_ = new_axes;
return static_cast<Ten &>(*this);
}
Ten &transpose(int axis_a, int axis_b) {
checkAxisInDim(axis_a);
checkAxisInDim(axis_b);
if (axis_a == axis_b)
throw std::invalid_argument("Duplicate axis index");
std::swap(axes_[axis_a], axes_[axis_b]);
return static_cast<Ten &>(*this);
}
Ten &t() {
static_assert(Dim >= 2, "Can't change the only axis");
std::swap(axes_[Dim - 1], axes_[Dim - 2]);
return static_cast<Ten &>(*this);
}
virtual Tensor &operator+=(const T &scalar) = 0;
virtual Tensor &operator*=(const T &scalar) = 0;
virtual Ten operator+() const = 0;
virtual Ten operator-() const = 0;
virtual Tensor &operator+=(const Tensor &other) = 0;
virtual Tensor &operator*=(const Tensor &other) = 0;
virtual Ten &operator+=(const T &scalar) = 0;
virtual Ten &operator*=(const T &scalar) = 0;
Ten operator+(const T &scalar) const {
Ten result = static_cast<const Ten &>(*this);
result += scalar;
return result;
}
friend Ten operator+(const T &scalar, const Ten &tensor) {
Tensor operator+(const T &scalar) const;
friend Tensor operator+(const T &scalar, const Tensor &tensor) {
return tensor + scalar;
}
Ten &operator-=(const T &scalar) {
*this += -scalar;
return static_cast<Ten &>(*this);
}
Ten operator-(const T &scalar) const {
Ten result = static_cast<const Ten &>(*this);
result -= scalar;
return result;
}
friend Ten operator-(const T &scalar, const Ten &tensor) {
Tensor &operator-=(const T &scalar);
Tensor operator-(const T &scalar) const;
friend Tensor operator-(const T &scalar, const Tensor &tensor) {
return tensor + (-scalar);
}
Ten operator*(const T &scalar) const {
Ten result = static_cast<const Ten &>(*this);
result *= scalar;
return result;
}
friend Ten operator*(const T &scalar, const Ten &tensor) {
Tensor operator*(const T &scalar) const;
friend Tensor operator*(const T &scalar, const Tensor &tensor) {
return tensor * scalar;
}
Ten &operator/=(const T &scalar) {
*this *= T(1) / scalar;
return static_cast<Ten &>(*this);
}
Ten operator/(const T &scalar) const {
Ten result = static_cast<const Ten &>(*this);
result /= scalar;
return result;
}
Tensor &operator/=(const T &scalar);
Tensor operator/(const T &scalar) const;
virtual Ten &operator+=(const Ten &other) = 0;
virtual Ten &operator*=(const Ten &other) = 0;
Tensor operator+(const Tensor &other) const;
Ten operator+(const Ten &other) const {
Ten result = static_cast<const Ten &>(*this);
result += other;
return result;
}
Tensor &operator-=(const Tensor &other);
Tensor operator-(const Tensor &other) const;
Ten &operator-=(const Ten &other) {
checkItHasSameShape(other);
*this += -other;
return static_cast<Ten &>(*this);
}
Ten operator-(const Ten &other) const {
Ten result = static_cast<const Ten &>(*this);
result -= other;
return result;
}
Ten operator*(const Ten &other) const {
Ten result = static_cast<const Ten &>(*this);
result *= other;
return result;
}
Tensor operator*(const Tensor &other) const;
// === Utils ===
virtual std::string toString() const = 0;
};
template <typename T, int Dim> class Tensor : public TensorInfo<T, Dim> {
private:
std::vector<T> data_;
public:
typedef class TensorInfo<T, Dim> TensorInfo;
using TensorInfo::axes_;
using TensorInfo::checkAxisInDim;
using TensorInfo::checkItHasSameShape;
using TensorInfo::computeIndex;
using TensorInfo::getSize;
using TensorInfo::shape_;
Tensor() = delete;
Tensor(const std::array<size_t, Dim> &shape) : TensorInfo(shape) {
size_t size = 1;
for (size_t dim : shape)
size *= dim;
data_.resize(size);
}
Tensor(const std::array<size_t, Dim> &shape, T value) : Tensor(shape) {
std::fill(data_.begin(), data_.end(), value);
}
Tensor(const std::array<size_t, Dim> &shape, const std::vector<T> &data)
: Tensor(shape) {
if (data.size() != data_.size())
throw std::invalid_argument("Invalid fill data size");
data_ = data;
}
Tensor(const std::array<size_t, Dim> &shape, T min, T max) : Tensor(shape) {
static std::random_device rd;
static std::mt19937 gen(rd());
if constexpr (std::is_integral_v<T>) {
std::uniform_int_distribution<T> dis(min, max);
for (T &e : data_)
e = dis(gen);
} else if constexpr (std::is_floating_point_v<T>) {
std::uniform_real_distribution<T> dis(min, max);
for (T &e : data_)
e = dis(gen);
} else
throw std::invalid_argument("Invalid randomized type");
}
Tensor(const Tensor &other) : TensorInfo(other), data_(other.data_) {}
Tensor &operator=(const Tensor &other) {
TensorInfo::operator=(other);
data_ = other.data_;
return *this;
}
Tensor(Tensor &&other) noexcept
: TensorInfo(std::move(other)), data_(std::move(other.data_)) {}
Tensor &operator=(Tensor &&other) noexcept {
TensorInfo::operator=(std::move(other));
data_ = std::move(other.data_);
return *this;
}
~Tensor() = default;
T &operator[](size_t i) { return data_[i]; }
const T &operator[](size_t i) const { return data_[i]; }
template <typename... Indices> T &operator()(Indices... indices) {
return data_[computeIndex(indices...)];
}
template <typename... Indices> const T &operator()(Indices... indices) const {
return data_[computeIndex(indices...)];
}
using TensorInfo::operator+;
using TensorInfo::operator-;
Tensor operator+() const override {
Tensor result = *this;
for (T &e : result.data_)
e = +e;
return result;
}
Tensor operator-() const override {
Tensor result = *this;
for (T &e : result.data_)
e = -e;
return result;
}
Tensor &operator+=(const T &scalar) override {
for (T &e : data_)
e += scalar;
return *this;
}
Tensor &operator*=(const T &scalar) override {
for (T &e : data_)
e *= scalar;
return *this;
}
Tensor &operator+=(const Tensor &other) override {
checkItHasSameShape(other);
for (size_t i = 0; i < data_.size(); ++i)
data_[i] += other.data_[i];
return *this;
}
Tensor &operator*=(const Tensor &other) override {
checkItHasSameShape(other);
for (size_t i = 0; i < data_.size(); ++i)
data_[i] *= other.data_[i];
return *this;
}
Tensor<T, Dim == 1 ? 0 : 2> operator%(const Tensor &other) const {
static_assert(Dim == 1 || Dim == 2,
"Inner product is only defined for vectors and matrices");
if constexpr (Dim == 1) {
if (data_.size() != other.data_.size())
throw std::invalid_argument(
"Vector sizes must match for inner product");
T result_val = T(0);
for (size_t i = 0; i < data_.size(); ++i)
result_val += data_[i] * other.data_[i];
return Tensor<T, 0>({}, {result_val});
} else if constexpr (Dim == 2) {
if (shape_[axes_[1]] != other.shape_[other.axes_[0]])
throw std::invalid_argument(
"Matrix dimensions must match for multiplication");
size_t m = shape_[axes_[0]];
size_t n = shape_[axes_[1]];
size_t p = other.shape_[other.axes_[1]];
Tensor<T, 2> result({m, p}, T(0));
for (size_t i = 0; i < m; ++i) {
for (size_t j = 0; j < p; ++j) {
T sum = T(0);
for (size_t k = 0; k < n; ++k)
sum += (*this)(i, k) * other(k, j);
result(i, j) = sum;
}
}
return result;
}
}
std::string toString() const override {
std::ostringstream oss;
if constexpr (Dim == 0) {
oss << "Scalar<" << typeid(T).name() << ">: " << data_[0];
} else if constexpr (Dim == 1) {
oss << "Vector<" << typeid(T).name() << ">(" << shape_[0] << "): [";
for (size_t i = 0; i < data_.size(); ++i) {
oss << data_[i];
if (i < data_.size() - 1)
oss << ", ";
}
oss << "]";
} else if constexpr (Dim == 2) {
oss << "Matrix<" << typeid(T).name() << ">(" << shape_[axes_[0]] << "x"
<< shape_[axes_[1]] << "):";
for (size_t i = 0; i < shape_[axes_[0]]; ++i) {
oss << "\n [";
for (size_t j = 0; j < shape_[axes_[1]]; ++j) {
oss << (*this)(i, j);
if (j < shape_[axes_[1]] - 1)
oss << ", ";
}
oss << "]";
}
} else {
oss << "Tensor" << Dim << "D<" << typeid(T).name() << ">" << "[";
for (size_t i = 0; i < Dim; ++i) {
oss << shape_[axes_[i]];
if (i < Dim - 1)
oss << "x";
}
oss << "]: [";
size_t show = std::min(data_.size(), size_t(10));
for (size_t i = 0; i < show; ++i) {
oss << data_[i];
if (i < show - 1)
oss << ", ";
}
if (data_.size() > 10)
oss << ", ...";
oss << "]";
}
return oss.str();
}
};
template <typename T> using Scalar = Tensor<T, 0>;
template <typename T> using Vector = Tensor<T, 1>;
template <typename T> using Matrix = Tensor<T, 2>;
class Tensors {
Tensors() = delete;
public:
template <typename T, typename... Args> static auto empty(Args... args) {
return Tensor<T, sizeof...(Args)>({static_cast<size_t>(args)...});
}
template <typename T, typename... Args> static auto zero(Args... args) {
return Tensor<T, sizeof...(Args)>({static_cast<size_t>(args)...}, T(0));
}
template <typename T, typename... Args> static auto rand(Args... args) {
return Tensor<T, sizeof...(Args)>({static_cast<size_t>(args)...}, T(0),
T(1));
}
};
#include "tensor.tpp"