Forward with new tensors math

This commit is contained in:
2025-11-01 14:22:09 +04:00
parent f1dfe1b335
commit c548c3089b
8 changed files with 192 additions and 68 deletions

View File

@@ -69,12 +69,41 @@ public:
if (fill)
createBuf(getShapeSize(shape), 0.0f, queue);
}
Tensor(const Tensor &) = delete;
Tensor &operator=(const Tensor &) = delete;
Tensor(Tensor &&other) : ITensor(other.shape), buffer(other.buffer) {
Tensor(const Tensor &other, const cl::CommandQueue *queue = nullptr)
: ITensor(other) {
cl::CommandQueue q = queue == nullptr ? openCL.getDefaultQueue() : *queue;
createBuf(other.getSize(), &q);
q.enqueueCopyBuffer(*other.buffer, *buffer, 0, 0,
other.getSize() * sizeof(float));
};
Tensor &operator=(const Tensor &other) {
if (buffer != nullptr)
delete buffer;
ITensor::operator=(other);
createBuf(other.getSize(), &openCL.getDefaultQueue());
openCL.getDefaultQueue().enqueueCopyBuffer(*other.buffer, *buffer, 0, 0,
other.getSize() * sizeof(float));
return *this;
};
Tensor(Tensor &&other) : ITensor(other), buffer(other.buffer) {
other.buffer = nullptr;
};
Tensor &operator=(Tensor &&other) = delete;
Tensor &operator=(Tensor &&other) {
if (this != &other) {
if (buffer != nullptr)
delete buffer;
ITensor::operator=(std::move(other));
buffer = other.buffer;
other.buffer = nullptr;
}
return *this;
};
~Tensor() {
if (buffer != nullptr)
delete buffer;
}
std::vector<float> toVector(const cl::CommandQueue *queue = nullptr) {
size_t size = getShapeSize(shape);
@@ -144,17 +173,25 @@ public:
if (shape.size() != 0)
throw std::invalid_argument("Tensor0 dimension must be 0");
}
Tensor0(const cl::CommandQueue *queue = nullptr) : Tensor({}, queue) {
Tensor0(const cl::CommandQueue *queue = nullptr)
: Tensor(std::vector<int>{}, queue) {
createBuf(1, queue);
}
Tensor0(float value, const cl::CommandQueue *queue = nullptr)
: Tensor({}, queue) {
: Tensor(std::vector<int>{}, queue) {
createBuf(1, value, queue);
}
Tensor0(const Tensor0 &) = delete;
Tensor0 &operator=(const Tensor0 &) = delete;
Tensor0(const Tensor0 &other, const cl::CommandQueue *queue = nullptr)
: Tensor(other, queue) {};
Tensor0 &operator=(const Tensor0 &other) {
Tensor::operator=(other);
return *this;
};
Tensor0(Tensor0 &&other) : Tensor(std::move(other)) {};
Tensor0 &operator=(Tensor0 &&other) = delete;
Tensor0 &operator=(Tensor0 &&other) {
Tensor::operator=(std::move(other));
return *this;
};
};
class Tensor1 : public Tensor, public ITensor1 {
@@ -180,10 +217,17 @@ public:
: Tensor({(int)values.size()}, false, queue) {
fillBuf(values, queue);
}
Tensor1(const Tensor1 &) = delete;
Tensor1 &operator=(const Tensor1 &) = delete;
Tensor1(Tensor1 &&other) : Tensor(std::move(other)) {}
Tensor1 &operator=(Tensor1 &&other) = delete;
Tensor1(const Tensor1 &other, const cl::CommandQueue *queue = nullptr)
: Tensor(other, queue) {};
Tensor1 &operator=(const Tensor1 &other) {
Tensor::operator=(other);
return *this;
};
Tensor1(Tensor1 &&other) : Tensor(std::move(other)) {};
Tensor1 &operator=(Tensor1 &&other) {
Tensor::operator=(std::move(other));
return *this;
};
int getSize() const override { return shape[0]; }
};
@@ -223,10 +267,17 @@ public:
fillBuf(v, queue);
}
Tensor2(const Tensor2 &) = delete;
Tensor2 &operator=(const Tensor2 &) = delete;
Tensor2(Tensor2 &&other) : Tensor(std::move(other)) {}
Tensor2 &operator=(Tensor2 &&other) = delete;
Tensor2(const Tensor2 &other, const cl::CommandQueue *queue = nullptr)
: Tensor(other, queue) {};
Tensor2 &operator=(const Tensor2 &other) {
Tensor::operator=(other);
return *this;
};
Tensor2(Tensor2 &&other) : Tensor(std::move(other)) {};
Tensor2 &operator=(Tensor2 &&other) {
Tensor::operator=(std::move(other));
return *this;
};
int getRows() const override { return shape[0]; }
int getCols() const override { return shape[1]; }
@@ -269,10 +320,17 @@ public:
}
fillBuf(v, queue);
}
Tensor3(const Tensor3 &) = delete;
Tensor3 &operator=(const Tensor3 &) = delete;
Tensor3(Tensor3 &&other) : Tensor(std::move(other)) {}
Tensor3 &operator=(Tensor3 &&other) = delete;
Tensor3(const Tensor3 &other, const cl::CommandQueue *queue = nullptr)
: Tensor(other, queue) {};
Tensor3 &operator=(const Tensor3 &other) {
Tensor::operator=(other);
return *this;
};
Tensor3(Tensor3 &&other) : Tensor(std::move(other)) {};
Tensor3 &operator=(Tensor3 &&other) {
Tensor::operator=(std::move(other));
return *this;
};
};
typedef Tensor0 Scalar;