WIP commit

This commit is contained in:
MrLetsplay 2024-02-23 16:17:05 +01:00
parent 9547d86251
commit e7ec45ba0a
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg

View File

@ -157,6 +157,7 @@ static cl_mem writeNativeMatrixArray(clm_NativeMatrixArray array) {
clm_Matrix mat = array.matrixes[0];
cl_mem mem = array.native->mem;
// TODO: don't do blocking writes, instead wait once at the end
size_t matLength = sizeof(float) * mat.rows * mat.cols;
for(unsigned int i = 0; i < array.length; i++) {
cl_int err = clEnqueueWriteBuffer(queue, mem, CL_TRUE, i * matLength, matLength, array.matrixes[i].values, 0, NULL, NULL);
@ -192,7 +193,15 @@ static void readNativeMatrix(clm_NativeMatrix matrix) {
static void readNativeMatrixArray(clm_NativeMatrixArray array) {
clm_Matrix mat = array.matrixes[0];
size_t matLength = sizeof(float) * mat.rows * mat.cols;
// TODO: don't do blocking reads, instead wait once at the end
for(unsigned int i = 0; i < array.length; i++) {
cl_int err = clEnqueueReadBuffer(queue, array.native->mem, CL_TRUE, i * matLength, matLength, array.matrixes[i].values, 0, NULL, NULL);
if(err != CL_SUCCESS) {
printf("Failed to enqueue read: %s\n", clm_clErrorToString(err));
return;
}
}
}
@ -366,8 +375,12 @@ void clm_linearBackprop(clm_Linear *linear, float learnRate, unsigned int batchS
clFlush(queue);
clFinish(queue);
clm_matrixPrint(linear->weightsError.matrixes[0]);
readNativeMatrixArray(linear->weightsError);
readNativeMatrixArray(linear->gradient);
clm_matrixPrint(linear->weightsError.matrixes[0]);
if(updateErrors) readGPUMats(matOutputErrors, batchSize, outputErrors, linear->nativeOutputErrors);
}