diff --git a/src/cltest.c b/src/cltest.c index 82ab4db..8747585 100644 --- a/src/cltest.c +++ b/src/cltest.c @@ -232,13 +232,11 @@ int main(int argc, const char *argv[]) { unsigned int imageCount; loadImages(&images, &imageCount); - imageCount = 600; - printf("%f\n", images[0].values[0]); srand(1); - /*unsigned int + unsigned int i = 784, h = 30, o = 10; @@ -246,47 +244,14 @@ int main(int argc, const char *argv[]) { clm_Linear layers[] = { clm_linearCreateRandom(i, h), clm_linearCreateRandom(h, o)}; - clm_NN nn = clm_nnCreate(sizeof(layers) / sizeof(clm_Linear), layers, 0.01, 10000);*/ - - float v_00[2] = {0, 0}; - float v_01[2] = {0, 1}; - float v_10[2] = {1, 0}; - float v_11[2] = {1, 1}; - - images = calloc(4, sizeof(clm_Vector)); - images[0] = (clm_Vector){.values = v_00, .length = 2}; - images[1] = (clm_Vector){.values = v_01, .length = 2}; - images[2] = (clm_Vector){.values = v_10, .length = 2}; - images[3] = (clm_Vector){.values = v_11, .length = 2}; - - labels = calloc(4, sizeof(clm_Vector)); - labels[0] = (clm_Vector){.values = v_10, .length = 2}; - labels[1] = (clm_Vector){.values = v_01, .length = 2}; - labels[2] = (clm_Vector){.values = v_01, .length = 2}; - labels[3] = (clm_Vector){.values = v_01, .length = 2}; - - imageCount = 4; - - unsigned int - i = 2, - o = 2; - - clm_Linear layers[] = { - clm_linearCreateRandom(i, o)}; - clm_NN nn = clm_nnCreate(sizeof(layers) / sizeof(clm_Linear), layers, 0.5, 4); + clm_NN nn = clm_nnCreate(sizeof(layers) / sizeof(clm_Linear), layers, 0.01, 10000); for(unsigned int i = 0; i < sizeof(layers) / sizeof(clm_Linear); i++) { clm_linearInit(&nn.layers[i]); } - for(unsigned int epoch = 0; epoch < 1000; epoch++) { + for(unsigned int epoch = 0; epoch < 10; epoch++) { printf("Epoch %u\n", epoch); - /*for(unsigned int idx = 0; idx < imageCount; idx++) { // Each train sample - if(idx % 1000 == 0) { - printf("\r%.2f%%", idx / (float) imageCount * 100); - fflush(stdout); - } - }*/ train(nn, imageCount, images, labels); printf("Score: %.2f\n", eval(nn, imageCount, images, labels));