19 #include "hicma_common.h" 20 #include "runtime/starpu/chameleon_starpu.h" 25 #include "runtime/starpu/runtime_codelets.h" 26 ZCODELETS_HEADER(gemm_hcore)
38 extern void _printmat(
double *
A, int64_t m, int64_t n, int64_t ld);
46 MORSE_enum transA,
int transB,
49 const MORSE_desc_t *
AUV,
50 const MORSE_desc_t *Ark,
51 int Am,
int An,
int lda,
52 const MORSE_desc_t *
BUV,
53 const MORSE_desc_t *Brk,
54 int Bm,
int Bn,
int ldb,
56 const MORSE_desc_t *
CUV,
57 const MORSE_desc_t *Crk,
58 int Cm,
int Cn,
int ldc,
67 struct starpu_codelet *codelet = &cl_zgemm_hcore;
69 void (*callback)(
void*) = NULL;
70 MORSE_starpu_ws_t *h_work = (MORSE_starpu_ws_t*)(options->ws_host);
76 int execution_rank =
CUV->get_rankof(
CUV, Cm, Cn );
82 char* env = getenv(
"MORSE_COMM_FACTOR_THRESHOLD");
84 int ifval = 0, elseifval = 0, initialval = execution_rank;
86 threshold = (unsigned)atoi(env);
89 if ( sizeA > threshold*sizeC ){
90 execution_rank =
AUV->get_rankof(
AUV, Am, An );
91 ifval = execution_rank;
93 }
else if( sizeB > threshold*sizeC ){
94 execution_rank =
BUV->get_rankof(
BUV, Bm, Bn );
95 elseifval = execution_rank;
100 MORSE_BEGIN_ACCESS_DECLARATION;
101 MORSE_ACCESS_R(
AUV, Am, An);
102 MORSE_ACCESS_R(
BUV, Bm, Bn);
103 MORSE_ACCESS_RW(
CUV, Cm, Cn);
104 #if !defined(HICMA_ALWAYS_FIX_RANK) 105 MORSE_ACCESS_R(Ark, Am, An);
106 MORSE_ACCESS_R(Brk, Bm, Bn);
107 MORSE_ACCESS_RW(Crk, Cm, Cn);
110 MORSE_RANK_CHANGED(execution_rank);
111 MORSE_END_ACCESS_DECLARATION;
115 starpu_mpi_codelet(codelet),
116 STARPU_VALUE, &transA,
sizeof(MORSE_enum),
117 STARPU_VALUE, &transB,
sizeof(MORSE_enum),
118 STARPU_VALUE, &m,
sizeof(
int),
119 STARPU_VALUE, &n,
sizeof(
int),
120 STARPU_VALUE, &alpha,
sizeof(
double),
121 STARPU_R, RTBLKADDR(
AUV,
double, Am, An),
122 STARPU_VALUE, &lda,
sizeof(
int),
123 STARPU_R, RTBLKADDR(
BUV,
double, Bm, Bn),
124 STARPU_VALUE, &ldb,
sizeof(
int),
125 STARPU_VALUE, &beta,
sizeof(
double),
126 STARPU_RW, RTBLKADDR(
CUV,
double, Cm, Cn),
127 #
if !defined(HICMA_ALWAYS_FIX_RANK)
128 STARPU_R, RTBLKADDR(Ark,
double, Am, An),
129 STARPU_R, RTBLKADDR(Brk,
double, Bm, Bn),
130 STARPU_RW, RTBLKADDR(Crk,
double, Cm, Cn),
132 STARPU_VALUE, &ldc,
sizeof(
int),
133 STARPU_VALUE, &rk,
sizeof(
int),
134 STARPU_VALUE, &maxrk,
sizeof(
int),
135 STARPU_VALUE, &acc,
sizeof(
double),
136 STARPU_VALUE, &Am,
sizeof(
int),
137 STARPU_VALUE, &An,
sizeof(
int),
138 STARPU_VALUE, &Bm,
sizeof(
int),
139 STARPU_VALUE, &Bn,
sizeof(
int),
140 STARPU_VALUE, &Cm,
sizeof(
int),
141 STARPU_VALUE, &Cn,
sizeof(
int),
142 STARPU_VALUE, &nAUV,
sizeof(
int),
143 STARPU_VALUE, &nBUV,
sizeof(
int),
144 STARPU_VALUE, &nCUV,
sizeof(
int),
145 STARPU_SCRATCH, options->ws_worker,
146 STARPU_VALUE, &h_work,
sizeof(MORSE_starpu_ws_t *),
147 STARPU_PRIORITY, options->priority,
148 STARPU_CALLBACK, callback,
149 #
if defined(CHAMELEON_USE_MPI)
150 STARPU_EXECUTE_ON_NODE, execution_rank,
152 #
if defined(CHAMELEON_CODELETS_HAVE_NAME)
153 STARPU_NAME,
"hcore_zgemm",
158 #if !defined(CHAMELEON_SIMULATION) 159 static void cl_zgemm_hcore_cpu_func(
void *descr[],
void *cl_arg)
161 #ifdef HICMA_DISABLE_ALL_COMPUTATIONS 164 #ifdef HICMA_DISABLE_HCORE_COMPUTATIONS 167 struct timeval tvalBefore, tvalAfter;
168 gettimeofday (&tvalBefore, NULL);
196 AUV = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
197 BUV = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
198 CUV = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
199 #if !defined(HICMA_ALWAYS_FIX_RANK) 200 Ark = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
201 Brk = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
202 Crk = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
211 work = (
double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
213 int Am, An, Bm, Bn, Cm, Cn;
215 MORSE_starpu_ws_t *h_work;
216 starpu_codelet_unpack_args(cl_arg, &transA, &transB, &m, &n, &alpha, &lda, &ldb, &beta, &ldc, &rk, &maxrk, &acc, &Am, &An, &Bm, &Bn, &Cm, &Cn, &nAUV, &nBUV, &nCUV, &h_work);
223 size_t nelm_AU = (size_t)lda * (
size_t)nAU;
224 double *AV = &(
AUV[nelm_AU]);
227 size_t nelm_BU = (size_t)ldb * (
size_t)nBU;
228 double *BV = &(
BUV[nelm_BU]);
231 size_t nelm_CU = (size_t)ldc * (
size_t)nCU;
232 double *CV = &(
CUV[nelm_CU]);
234 double old_Crk = Crk[0];
235 char datebuf_start[128];
236 datebuf_start[0] =
'\0';
240 gettimeofday (&tvalAfter, NULL);
242 tm_info = localtime(&
timer); \
243 strftime(datebuf_start, 26,
"%Y-%m-%d %H:%M:%S",
tm_info); \
244 printf(
"%d+GEMM\t|CUV(%d,%d)%g AUV(%d,%d)%g BUV(%d,%d)%g\t\t\t\t\tGEMM: %s\n",MORSE_My_Mpi_Rank(), Cm, Cn, old_Crk, Am, An, Ark[0], Bm, Bn, Brk[0], datebuf_start);
247 int isTransA = transA == MorseTrans;
248 int isTransB = transB == MorseTrans;
253 alpha, (isTransA ? AV : AU), (isTransA ? AU : AV), Ark, lda,
254 (isTransB ? BU : BV), (isTransB ? BV : BU), Brk, ldb,
255 beta, CU, CV, Crk, ldc, rk, maxrk, acc, work);
259 alpha, (isTransA ? AV : AU), (isTransA ? AU : AV), Ark, lda,
260 (isTransB ? BU : BV), (isTransB ? BV : BU), Brk, ldb,
261 beta, CU, CV, Crk, ldc, rk, maxrk, acc, work);
267 gettimeofday (&tvalAfter, NULL);
269 tm_info = localtime(&
timer); \
271 printf(
"%d-GEMM\t|CUV(%d,%d)%g->%g AUV(%d,%d)%g BUV(%d,%d)%g acc:%e rk:%d maxrk:%d\t\t\tGEMM: %.4f\t%s---%s\n",MORSE_My_Mpi_Rank(),Cm, Cn, old_Crk, Crk[0], Am, An, Ark[0], Bm, Bn, Brk[0],
273 (tvalAfter.tv_sec - tvalBefore.tv_sec)
274 +(tvalAfter.tv_usec - tvalBefore.tv_usec)/1000000.0,
284 #if defined(HICMA_ALWAYS_FIX_RANK) 285 CODELETS_CPU(zgemm_hcore, 4, cl_zgemm_hcore_cpu_func)
288 CODELETS_CPU(zgemm_hcore, 7, cl_zgemm_hcore_cpu_func)
int global_always_fixed_rank
void HCORE_zgemm_fast(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *d_work)
void HCORE_zgemm(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *work)
void _printmat(double *A, int64_t m, int64_t n, int64_t ld)
void HICMA_TASK_zgemm(const MORSE_option_t *options, MORSE_enum transA, int transB, int m, int n, double alpha, const MORSE_desc_t *AUV, const MORSE_desc_t *Ark, int Am, int An, int lda, const MORSE_desc_t *BUV, const MORSE_desc_t *Brk, int Bm, int Bn, int ldb, double beta, const MORSE_desc_t *CUV, const MORSE_desc_t *Crk, int Cm, int Cn, int ldc, int rk, int maxrk, double acc)
int HICMA_get_use_fast_hcore_zgemm()