Tensor math OpenCL lib

This commit is contained in:
2025-11-19 18:05:11 +04:00
parent c1874212ae
commit bd8b26c35a
12 changed files with 361 additions and 355 deletions

View File

@@ -38,8 +38,8 @@ public:
using ITensor::operator+;
using ITensor::operator-;
Tensor operator+() override;
Tensor operator-() override;
Tensor operator+() const override;
Tensor operator-() const override;
Tensor &operator+=(const T scalar) override;

View File

@@ -2,6 +2,7 @@
#include "tensor.hpp"
#include <iostream>
#include <random>
#include <sstream>
@@ -79,13 +80,15 @@ const T &Tensor<T, Dim>::operator()(Indices... indices) const {
}
// ===== OPERATORS =====
template <typename T, int Dim> Tensor<T, Dim> Tensor<T, Dim>::operator+() {
template <typename T, int Dim>
Tensor<T, Dim> Tensor<T, Dim>::operator+() const {
Tensor result = *this;
for (T &e : result.data_)
e = +e;
return result;
}
template <typename T, int Dim> Tensor<T, Dim> Tensor<T, Dim>::operator-() {
template <typename T, int Dim>
Tensor<T, Dim> Tensor<T, Dim>::operator-() const {
Tensor result = *this;
for (T &e : result.data_)
e = -e;
@@ -156,46 +159,5 @@ Tensor<T, Dim>::operator%(const Tensor &other) const {
// ===== UTILS =====
template <typename T, int Dim> std::string Tensor<T, Dim>::toString() const {
std::ostringstream oss;
if constexpr (Dim == 0) {
oss << "Scalar<" << typeid(T).name() << ">: " << data_[0];
} else if constexpr (Dim == 1) {
oss << "Vector<" << typeid(T).name() << ">(" << shape_[0] << "): [";
for (size_t i = 0; i < getSize(); ++i) {
oss << data_[i];
if (i < getSize() - 1)
oss << ", ";
}
oss << "]";
} else if constexpr (Dim == 2) {
oss << "Matrix<" << typeid(T).name() << ">(" << shape_[axes_[0]] << "x"
<< shape_[axes_[1]] << "):";
for (size_t i = 0; i < shape_[axes_[0]]; ++i) {
oss << "\n [";
for (size_t j = 0; j < shape_[axes_[1]]; ++j) {
oss << (*this)(i, j);
if (j < shape_[axes_[1]] - 1)
oss << ", ";
}
oss << "]";
}
} else {
oss << "Tensor" << Dim << "D<" << typeid(T).name() << ">" << "[";
for (size_t i = 0; i < Dim; ++i) {
oss << shape_[axes_[i]];
if (i < Dim - 1)
oss << "x";
}
oss << "]: [";
size_t show = std::min(getSize(), size_t(10));
for (size_t i = 0; i < show; ++i) {
oss << data_[i];
if (i < show - 1)
oss << ", ";
}
if (getSize() > 10)
oss << ", ...";
oss << "]";
}
return oss.str();
return ITensor::format(data_);
}