us1 = get_time();
cpx_mtx_knk_threads(tmp.ptr, filter.ptr, gate.ptr, filter.rows, filter.cols, gate.rows, gate.cols);
us2 = get_time();
+ printf("\tThreads2x2: %lu\n", us2 - us1);
+ us1 = get_time();
+ cpx_mtx_knk_threads_2x2(tmp.ptr, filter.ptr, gate.ptr, filter.rows, filter.cols, gate.rows, gate.cols);
+ us2 = get_time();
printf("\tThreads: %lu\n", us2 - us1);
us1 = get_time();
cpx_mtx_knk(tmp.ptr, filter.ptr, gate.ptr, filter.rows, filter.cols, gate.rows, gate.cols);
}
}
+void* cpx_mtx_knk_threads_2x2_run(void *context)
+{
+ cpx_thread_context* ctx = (cpx_thread_context*)context;
+ int rowsR = (ctx->rowsA) * (ctx->rowsB);
+ int colsR = (ctx->colsA) * (ctx->colsB);
+ for (int i = 0; i < (ctx->delimeterCount); i++)
+ {
+ kernel_knk_2x2(ctx->ptrR, ctx->ptrA, ctx->rowsA, ctx->colsA, ctx->ptrB[0], ctx->ptrB[1], ctx->ptrB[2], ctx->ptrB[3], ctx->ptrB[4], ctx->ptrB[5], ctx->ptrB[6], ctx->ptrB[7], i + (ctx->delimeterStart));
+ }
+}
+
+void cpx_mtx_knk_threads_2x2(float* ptrR, float* ptrA, float* ptrB, int rowsA, int colsA, int rowsB, int colsB)
+{
+ int delimeter = (rowsA * rowsB) / 2;
+ int cores = get_core_count();
+ int threadCount = cores;
+ if (threadCount > delimeter) threadCount = delimeter;
+ int delimetersPerThread = delimeter / threadCount;
+ int leftOvers = delimeter % threadCount;
+
+ cpx_thread_context ctx = {ptrR, ptrA, ptrB, rowsA, colsA, rowsB, colsB, 0, 0};
+ cpx_thread_context ctxs[threadCount];
+ pthread_t threads[threadCount];
+ for (int i = 0; i < threadCount; i++)
+ {
+ ctxs[i].ptrR = ctx.ptrR;
+ ctxs[i].ptrA = ctx.ptrA;
+ ctxs[i].ptrB = ctx.ptrB;
+ ctxs[i].rowsA = ctx.rowsA;
+ ctxs[i].colsA = ctx.colsA;
+ ctxs[i].rowsB = ctx.rowsB;
+ ctxs[i].colsB = ctx.colsB;
+ ctxs[i].delimeterStart = i * delimetersPerThread;
+ ctxs[i].delimeterCount = delimetersPerThread + ((i == threadCount - 1) ? leftOvers : 0);
+
+ if (pthread_create(&(threads[i]), NULL, &cpx_mtx_knk_threads_2x2_run, (void*)&(ctxs[i])))
+ {
+ fprintf(stderr, "QAnsel: Thread error. (1)\n");
+ exit(1);
+ }
+ }
+ for (uint32_t i = 0; i < threadCount; i++)
+ {
+ if (pthread_join(threads[i], NULL))
+ {
+ fprintf(stderr, "QAnsel: Thread error. (2)\n");
+ }
+ }
+}
+
void* cpx_mtx_dot_threads_run(void *context)
{
cpx_thread_context* ctx = (cpx_thread_context*)context;