nn-testing/src/mat.cl

34 lines
1.0 KiB
Common Lisp
Raw Normal View History

2023-10-29 21:34:23 +01:00
typedef struct __attribute__((packed)) {
uint rows;
uint cols;
char transposed;
} cl_GPUMat;
2023-10-29 21:34:23 +01:00
__kernel void mat_multiply(cl_GPUMat matA, __global float *matA_values, cl_GPUMat matB, __global float *matB_values, cl_GPUMat matOut, __global float *matOut_values) {
/*if(a.cols != b.rows) {
printf("Cannot multiply matrices (got %dx%d and %dx%d)\n", a.rows, a.cols, b.rows, b.cols);
return INVALID_MATRIX;
}
if(out.rows != a.rows || out.cols != b.cols) {
printf("Cannot multiply matrices: output invalid shape (expected %dx%d, got %dx%d)\n", a.rows, b.cols, out.rows, out.cols);
return INVALID_MATRIX;
}*/
uint idx = get_global_id(0);
if(idx >= matOut.rows * matOut.cols) return;
uint i = idx / matOut.cols;
uint j = idx % matOut.cols;
2023-10-29 21:34:23 +01:00
// for(unsigned int i = 0; i < out.rows; i++) {
// for(unsigned int j = 0; j < out.cols; j++) {
float sum = 0;
for(unsigned int k = 0; k < matA.cols; k++) {
sum += matA_values[i * matA.cols + k] * matB_values[k * matB.cols + j];
}
matOut_values[i * matOut.cols + j] = sum;
//}
//}
}