mirror of
https://github.com/StepanovPlaton/NeuralNetwork.git
synced 2026-04-03 20:30:39 +04:00
Refactor
This commit is contained in:
@@ -3,6 +3,11 @@
|
||||
|
||||
#include <CL/cl.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "opencl.hpp"
|
||||
|
||||
class CalcEngine {
|
||||
|
||||
49
src/main.cpp
49
src/main.cpp
@@ -1,24 +1,21 @@
|
||||
#include <CL/cl.h>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "matrix.hpp"
|
||||
|
||||
class MatrixCalculator {
|
||||
class MutableMatrix : public Matrix {
|
||||
private:
|
||||
CalcEngine *calcEngine;
|
||||
cl_command_queue queue;
|
||||
cl_kernel kernel;
|
||||
|
||||
public:
|
||||
MatrixCalculator(CalcEngine &calcEngine) {
|
||||
MutableMatrix(CalcEngine &calcEngine, size_t rows, size_t cols, float *matrix)
|
||||
: Matrix(calcEngine, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, rows, cols,
|
||||
matrix) {
|
||||
this->calcEngine = &calcEngine;
|
||||
kernel = calcEngine.loadKernel("matrix_mult.cl");
|
||||
queue = clCreateCommandQueue(calcEngine.getContext(),
|
||||
@@ -28,29 +25,31 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
~MatrixCalculator() {
|
||||
~MutableMatrix() {
|
||||
if (queue)
|
||||
clReleaseCommandQueue(queue);
|
||||
}
|
||||
|
||||
std::vector<float> multiply(Matrix &a, Matrix &b, int M, int N, int K) {
|
||||
if (a.getRows() != M || a.getCols() != K || b.getRows() != K ||
|
||||
b.getCols() != N) {
|
||||
void mult_by(Matrix &m) {
|
||||
if (cols != m.getRows()) {
|
||||
throw std::invalid_argument("Invalid matrix dimensions");
|
||||
}
|
||||
|
||||
cl_mem bufC = calcEngine->createBuffer(CL_MEM_WRITE_ONLY,
|
||||
M * N * sizeof(float), nullptr);
|
||||
cl_mem b =
|
||||
calcEngine->createBuffer(CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
|
||||
rows * m.getCols() * sizeof(float), nullptr);
|
||||
|
||||
calcEngine->setKernelArgs(kernel, a.getBuf(), b.getBuf(), bufC, M, N, K);
|
||||
calcEngine->setKernelArgs(kernel, buf, m.getBuf(), b, rows, m.getCols(),
|
||||
cols);
|
||||
calcEngine->runKernel(queue, kernel, rows, m.getCols());
|
||||
|
||||
calcEngine->runKernel(queue, kernel, M, N);
|
||||
|
||||
std::vector<float> C(M * N);
|
||||
calcEngine->readResult(queue, bufC, C);
|
||||
|
||||
clReleaseMemObject(bufC);
|
||||
clReleaseMemObject(buf);
|
||||
buf = b;
|
||||
}
|
||||
|
||||
std::vector<float> exportMatrix() {
|
||||
std::vector<float> C(rows, cols);
|
||||
calcEngine->readResult(queue, buf, C);
|
||||
return C;
|
||||
}
|
||||
};
|
||||
@@ -59,15 +58,15 @@ int main() {
|
||||
CalcEngine calcEngine;
|
||||
calcEngine.printDeviceInfo();
|
||||
|
||||
MatrixCalculator matrixCalculator(calcEngine);
|
||||
|
||||
float matrixA[2 * 3] = {1, 2, 3, 4, 5, 6};
|
||||
Matrix a(calcEngine, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, 2, 3, matrixA);
|
||||
MutableMatrix a(calcEngine, 2, 3, matrixA);
|
||||
|
||||
float matrixB[3 * 2] = {1, 2, 3, 4, 5, 6};
|
||||
Matrix b(calcEngine, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, 3, 2, matrixB);
|
||||
|
||||
std::vector<float> v = matrixCalculator.multiply(a, b, 2, 2, 3);
|
||||
a.mult_by(b);
|
||||
|
||||
std::vector<float> v = a.exportMatrix();
|
||||
for (const auto &element : v) {
|
||||
std::cout << element << " ";
|
||||
}
|
||||
|
||||
BIN
src/main.exe
BIN
src/main.exe
Binary file not shown.
@@ -1,14 +1,12 @@
|
||||
#ifndef MATRIX_H
|
||||
#define MATRIX_H
|
||||
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "opencl.hpp"
|
||||
#include "device.hpp"
|
||||
|
||||
class Matrix {
|
||||
private:
|
||||
protected:
|
||||
cl_mem buf;
|
||||
size_t rows;
|
||||
size_t cols;
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#define OPENCL_H
|
||||
|
||||
#include <CL/cl.h>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
|
||||
class OpenCLException : public std::runtime_error {
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user