#ifdef GPU_ENABLED
if (USE_GPU && (filter.cols >= 512 || stateVector->cols >= 512))
{
- cpx_mtx_dot_metal
- (
- tmp.ptr, stateVector->ptr, filter.ptr,
- stateVector->rows, filter.cols, stateVector->cols
- );
+ cpx_mtx_dot_metal(tmp.ptr, stateVector->ptr, filter.ptr, stateVector->rows, stateVector->cols, filter.rows, filter.cols);
}
else
{
- cpx_mtx_dot
- (
- tmp.ptr, stateVector->ptr, filter.ptr,
- stateVector->rows, filter.cols, stateVector->cols
- );
+ cpx_mtx_dot(tmp.ptr, stateVector->ptr, filter.ptr, stateVector->rows, stateVector->cols, filter.rows, filter.cols);
}
#else
//cpx_ncpx_mmul_mt
}
#include "kernel.cl"
-void cpx_mtx_dot(float* ptrR, float* ptrA, float* ptrB, size_t rowsA, size_t colsB, size_t shared)
+void cpx_mtx_dot(float* ptrR, float* ptrA, float* ptrB, int rowsA, int colsA, int rowsB, int colsB)
{
for (int i = 0; i < rowsA; i++)
{
{
GPU_GLOBAL_ID_0 = i;
GPU_GLOBAL_ID_1 = j;
- kernel_dot(ptrR, ptrA, ptrB, rowsA, colsB, shared);
+ kernel_dot(ptrR, ptrA, ptrB, rowsA, colsA, rowsB, colsB);
}
}
}