mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-04 04:40:40 +04:00
First tensor python module
This commit is contained in:
@@ -20,24 +20,6 @@ void register_tensor(py::module &m, const std::string &name) {
|
||||
.def("get_size", &Tensor<T, Dim>::getSize)
|
||||
.def("get_axes", &Tensor<T, Dim>::getAxes)
|
||||
|
||||
.def("__getitem__",
|
||||
[](const Tensor<T, Dim> &t, size_t i) -> T {
|
||||
if (i >= t.getSize())
|
||||
throw py::index_error();
|
||||
return t[i];
|
||||
})
|
||||
.def("__setitem__",
|
||||
[](Tensor<T, Dim> &t, size_t i, T value) {
|
||||
if (i >= t.getSize())
|
||||
throw py::index_error();
|
||||
t[i] = value;
|
||||
})
|
||||
|
||||
// .def("__call__",
|
||||
// [](Tensor<T, Dim> &t, py::args args) -> T & {
|
||||
//
|
||||
// })
|
||||
|
||||
.def(py::self + py::self)
|
||||
.def(py::self - py::self)
|
||||
.def(py::self * py::self)
|
||||
@@ -49,19 +31,58 @@ void register_tensor(py::module &m, const std::string &name) {
|
||||
.def(py::self - T())
|
||||
.def(py::self * T())
|
||||
.def(py::self / T())
|
||||
.def(T() + py::self)
|
||||
.def(T() - py::self)
|
||||
.def(T() * py::self)
|
||||
|
||||
.def(py::self += T())
|
||||
.def(py::self -= T())
|
||||
.def(py::self *= T())
|
||||
.def(py::self /= T())
|
||||
.def(T() + py::self)
|
||||
.def(T() - py::self)
|
||||
.def(T() * py::self)
|
||||
|
||||
.def("__pos__", [](const Tensor<T, Dim> &t) { return +t; })
|
||||
.def("__neg__", [](const Tensor<T, Dim> &t) { return -t; })
|
||||
|
||||
.def("print", &Tensor<T, Dim>::print);
|
||||
.def("__repr__", &Tensor<T, Dim>::toString);
|
||||
|
||||
if constexpr (Dim != 0)
|
||||
tensor
|
||||
.def(
|
||||
"__getitem__",
|
||||
[](Tensor<T, Dim> &t, size_t index) -> T & {
|
||||
if (index >= t.getSize())
|
||||
throw py::value_error("Index out of range");
|
||||
return t[index];
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"__getitem__",
|
||||
[](Tensor<T, Dim> &t, const py::tuple &indices) -> T & {
|
||||
if (indices.size() != Dim)
|
||||
throw py::value_error("Expected " + std::to_string(Dim) +
|
||||
" indices, got " +
|
||||
std::to_string(indices.size()));
|
||||
return [&]<size_t... I>(std::index_sequence<I...>) -> T & {
|
||||
return t(py::cast<size_t>(indices[I])...);
|
||||
}(std::make_index_sequence<Dim>{});
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
|
||||
.def("__setitem__",
|
||||
[](Tensor<T, Dim> &t, size_t index, const T &value) {
|
||||
if (index >= t.getSize())
|
||||
throw py::value_error("Index out of range");
|
||||
t[index] = value;
|
||||
})
|
||||
.def("__setitem__",
|
||||
[](Tensor<T, Dim> &t, const py::tuple &indices, const T &value) {
|
||||
if (indices.size() != Dim)
|
||||
throw py::value_error("Expected " + std::to_string(Dim) +
|
||||
" indices, got " +
|
||||
std::to_string(indices.size()));
|
||||
[&]<size_t... I>(std::index_sequence<I...>) {
|
||||
t(py::cast<size_t>(indices[I])...) = value;
|
||||
}(std::make_index_sequence<Dim>{});
|
||||
});
|
||||
|
||||
if constexpr (Dim == 1 || Dim == 2)
|
||||
tensor.def("__matmul__", &Tensor<T, Dim>::operator%);
|
||||
@@ -83,20 +104,9 @@ PYBIND11_MODULE(tensor, m) {
|
||||
register_tensor<float, 1>(m, "Vector");
|
||||
register_tensor<float, 2>(m, "Matrix");
|
||||
register_tensor<float, 3>(m, "Tensor3");
|
||||
register_tensor<float, 4>(m, "Tensor4");
|
||||
register_tensor<float, 5>(m, "Tensor5");
|
||||
|
||||
register_tensor<double, 0>(m, "dScalar");
|
||||
register_tensor<double, 1>(m, "dVector");
|
||||
register_tensor<double, 2>(m, "dMatrix");
|
||||
register_tensor<double, 3>(m, "dTensor3");
|
||||
register_tensor<double, 4>(m, "dTensor4");
|
||||
register_tensor<double, 5>(m, "dTensor5");
|
||||
|
||||
register_tensor<int, 0>(m, "iScalar");
|
||||
register_tensor<int, 1>(m, "iVector");
|
||||
register_tensor<int, 2>(m, "iMatrix");
|
||||
register_tensor<int, 3>(m, "iTensor3");
|
||||
register_tensor<int, 4>(m, "iTensor4");
|
||||
register_tensor<int, 5>(m, "iTensor5");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user