tmp.ptr = malloc(tmp.rows * (tmp.cols * 2) * sizeof(float));
#ifdef GPU_ENABLED
- if (USE_GPU && (tmp.rows >= 1024 || tmp.cols >= 1024))
+ if (USE_GPU && (tmp.rows >= 1024 && tmp.cols >= 1024))
{
cpx_mtx_knk_metal(tmp.ptr, filter.ptr, gate.ptr, filter.rows, filter.cols, gate.rows, gate.cols);
}
cpx_mtx_dot(tmp.ptr, stateVector->ptr, filter.ptr, stateVector->rows, stateVector->cols, filter.rows, filter.cols);
}
#else
- //cpx_ncpx_mmul_mt
- //(
- // tmp.ptr, stateVector->ptr, filter.ptr,
- // stateVector->rows * 2, filter.cols * 2, stateVector->cols * 2
- //);
cpx_mtx_dot
(
tmp.ptr, stateVector->ptr, filter.ptr,