TensorInfo

This commit is contained in:
2025-11-17 12:44:50 +04:00
parent 41f5634ce9
commit bbd9c67c96
6 changed files with 395 additions and 609 deletions

View File

@@ -15,10 +15,9 @@ void register_tensor(py::module &m, const std::string &name) {
const std::vector<T> &>())
.def(py::init<const std::array<size_t, Dim> &, T, T>())
.def("get_shape", &Tensor<T, Dim>::getShape)
.def("get_data", &Tensor<T, Dim>::getData)
.def("get_size", &Tensor<T, Dim>::getSize)
.def("get_axes", &Tensor<T, Dim>::getAxes)
.def("get_shape", &TensorInfo<T, Dim>::getShape)
.def("get_axes", &TensorInfo<T, Dim>::getAxes)
.def("get_size", &TensorInfo<T, Dim>::getSize)
.def(py::self + py::self)
.def(py::self - py::self)
@@ -44,6 +43,15 @@ void register_tensor(py::module &m, const std::string &name) {
.def("__repr__", &Tensor<T, Dim>::toString);
if constexpr (Dim >= 2) {
tensor
.def("transpose", py::overload_cast<const std::array<int, Dim> &>(
&Tensor<T, Dim>::transpose))
.def("transpose",
py::overload_cast<int, int>(&Tensor<T, Dim>::transpose))
.def("t", &Tensor<T, Dim>::t);
}
if constexpr (Dim != 0)
tensor
.def(
@@ -86,15 +94,6 @@ void register_tensor(py::module &m, const std::string &name) {
if constexpr (Dim == 1 || Dim == 2)
tensor.def("__matmul__", &Tensor<T, Dim>::operator%);
if constexpr (Dim >= 2) {
tensor
.def("transpose", py::overload_cast<const std::array<int, Dim> &>(
&Tensor<T, Dim>::transpose))
.def("transpose",
py::overload_cast<int, int>(&Tensor<T, Dim>::transpose))
.def("t", &Tensor<T, Dim>::t);
}
}
PYBIND11_MODULE(tensor, m) {
@@ -103,10 +102,10 @@ PYBIND11_MODULE(tensor, m) {
register_tensor<float, 0>(m, "Scalar");
register_tensor<float, 1>(m, "Vector");
register_tensor<float, 2>(m, "Matrix");
register_tensor<float, 3>(m, "Tensor3");
register_tensor<int, 0>(m, "iScalar");
register_tensor<int, 1>(m, "iVector");
register_tensor<int, 2>(m, "iMatrix");
register_tensor<int, 3>(m, "iTensor3");
// register_tensor<float, 3>(m, "Tensor3");
//
// register_tensor<int, 0>(m, "iScalar");
// register_tensor<int, 1>(m, "iVector");
// register_tensor<int, 2>(m, "iMatrix");
// register_tensor<int, 3>(m, "iTensor3");
}