Tensor math OpenCL lib

This commit is contained in:
2025-11-19 18:05:11 +04:00
parent c1874212ae
commit bd8b26c35a
12 changed files with 361 additions and 355 deletions

View File

@@ -2,10 +2,17 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "tensor.hpp"
#ifdef USE_OPENCL
#include "opencl/tensor.hpp"
OpenCL openCL;
#elif USE_CPU
#include "cpu/tensor.hpp"
#endif
namespace py = pybind11;
enum class TENSOR_PLATFORM { CPU, OPENCL };
template <typename T, int Dim>
void register_tensor(py::module &m, const std::string &name) {
auto tensor = py::class_<Tensor<T, Dim>>(m, name.c_str())
@@ -15,9 +22,9 @@ void register_tensor(py::module &m, const std::string &name) {
const std::vector<T> &>())
.def(py::init<const std::array<size_t, Dim> &, T, T>())
.def("get_shape", &TensorInfo<T, Dim>::getShape)
.def("get_axes", &TensorInfo<T, Dim>::getAxes)
.def("get_size", &TensorInfo<T, Dim>::getSize)
.def("get_shape", &Tensor<T, Dim>::getShape)
.def("get_axes", &Tensor<T, Dim>::getAxes)
.def("get_size", &Tensor<T, Dim>::getSize)
.def(py::self + py::self)
.def(py::self - py::self)
@@ -52,6 +59,7 @@ void register_tensor(py::module &m, const std::string &name) {
.def("t", &Tensor<T, Dim>::t);
}
#ifndef USE_OPENCL
if constexpr (Dim != 0)
tensor
.def(
@@ -91,21 +99,47 @@ void register_tensor(py::module &m, const std::string &name) {
t(py::cast<size_t>(indices[I])...) = value;
}(std::make_index_sequence<Dim>{});
});
#endif
if constexpr (Dim == 1 || Dim == 2)
// if constexpr (Dim == 1 || Dim == 2)
if constexpr (Dim == 2)
tensor.def("__matmul__", &Tensor<T, Dim>::operator%);
}
PYBIND11_MODULE(tensor, m) {
m.doc() = "Tensor math library";
py::enum_<TENSOR_PLATFORM>(m, "PLATFORM")
.value("CPU", TENSOR_PLATFORM::CPU)
.value("OPENCL", TENSOR_PLATFORM::OPENCL)
.export_values();
#ifdef USE_OPENCL
m.attr("MODE") = TENSOR_PLATFORM::OPENCL;
#elif USE_CPU
m.attr("MODE") = TENSOR_PLATFORM::CPU;
#endif
register_tensor<float, 0>(m, "Scalar");
register_tensor<float, 1>(m, "Vector");
register_tensor<float, 2>(m, "Matrix");
// register_tensor<float, 3>(m, "Tensor3");
//
// register_tensor<int, 0>(m, "iScalar");
// register_tensor<int, 1>(m, "iVector");
// register_tensor<int, 2>(m, "iMatrix");
// register_tensor<int, 3>(m, "iTensor3");
register_tensor<float, 3>(m, "Tensor3");
#ifdef USE_OPENCL
m.def("init", [](const std::string &programsBasePath) {
openCL.init(programsBasePath);
});
#endif
#ifndef USE_OPENCL
register_tensor<double, 0>(m, "dScalar");
register_tensor<double, 1>(m, "dVector");
register_tensor<double, 2>(m, "dMatrix");
register_tensor<double, 3>(m, "dTensor3");
register_tensor<int, 0>(m, "iScalar");
register_tensor<int, 1>(m, "iVector");
register_tensor<int, 2>(m, "iMatrix");
register_tensor<int, 3>(m, "iTensor3");
#endif
}