From 3f457b80402dc9f8523d98c01a4c2715d1f7c0b1 Mon Sep 17 00:00:00 2001 From: miha-q <> Date: Mon, 4 Mar 2024 12:09:33 -0500 Subject: [PATCH] Mon Mar 4 12:09:33 PM EST 2024 --- src/QAnsel.c | 12 ++---------- src/complex.c | 4 ++-- src/kernel.cl | 1 - 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/QAnsel.c b/src/QAnsel.c index 9791e5e..33be213 100644 --- a/src/QAnsel.c +++ b/src/QAnsel.c @@ -237,19 +237,11 @@ void qansel_instruction(cpx_mtx_t* stateVector, unsigned char qubitCount, QInstr #ifdef GPU_ENABLED if (USE_GPU && (filter.cols >= 512 || stateVector->cols >= 512)) { - cpx_mtx_dot_metal - ( - tmp.ptr, stateVector->ptr, filter.ptr, - stateVector->rows, filter.cols, stateVector->cols - ); + cpx_mtx_dot_metal(tmp.ptr, stateVector->ptr, filter.ptr, stateVector->rows, stateVector->cols, filter.rows, filter.cols); } else { - cpx_mtx_dot - ( - tmp.ptr, stateVector->ptr, filter.ptr, - stateVector->rows, filter.cols, stateVector->cols - ); + cpx_mtx_dot(tmp.ptr, stateVector->ptr, filter.ptr, stateVector->rows, stateVector->cols, filter.rows, filter.cols); } #else //cpx_ncpx_mmul_mt diff --git a/src/complex.c b/src/complex.c index 3a111e7..f616e8d 100644 --- a/src/complex.c +++ b/src/complex.c @@ -135,7 +135,7 @@ int get_global_id(int id) } #include "kernel.cl" -void cpx_mtx_dot(float* ptrR, float* ptrA, float* ptrB, size_t rowsA, size_t colsB, size_t shared) +void cpx_mtx_dot(float* ptrR, float* ptrA, float* ptrB, int rowsA, int colsA, int rowsB, int colsB) { for (int i = 0; i < rowsA; i++) { @@ -143,7 +143,7 @@ void cpx_mtx_dot(float* ptrR, float* ptrA, float* ptrB, size_t rowsA, size_t col { GPU_GLOBAL_ID_0 = i; GPU_GLOBAL_ID_1 = j; - kernel_dot(ptrR, ptrA, ptrB, rowsA, colsB, shared); + kernel_dot(ptrR, ptrA, ptrB, rowsA, colsA, rowsB, colsB); } } } diff --git a/src/kernel.cl b/src/kernel.cl index 542ec74..f72d87e 100644 --- a/src/kernel.cl +++ b/src/kernel.cl @@ -54,7 +54,6 @@ __kernel void kernel_knk const int colsB ) { - const int rowsR = rowsA * rowsB; const int colsR = colsA * colsB; int rowR = get_global_id(0); -- 2.39.5