HiCMA
Hierarchical Computations on Manycore Architectures
codelet_zgemm.c
Go to the documentation of this file.
1 
17 #include "morse.h"
18 #include "hicma.h"
19 #include "hicma_common.h"
20 #include "runtime/starpu/chameleon_starpu.h"
21 //#include "runtime/starpu/include/runtime_codelet_z.h"
22 
23 #include <sys/time.h>
24 
25 #include "runtime/starpu/runtime_codelets.h"
26 ZCODELETS_HEADER(gemm_hcore)
27 
28 //UPDATE this definition. I only copy-paste from runtime/starpu/codelets/codelet_zcallback.c
29 /*CHAMELEON_CL_CB(zgemm_hcore, starpu_matrix_get_nx(task->handles[2]), starpu_matrix_get_ny(task->handles[2]), starpu_matrix_get_ny(task->handles[0]), 2. *M*N*K) [> If A^t, computation is wrong <]*/
30 
31 #include "hcore_z.h"
32 
33 extern int global_always_fixed_rank;
34 extern int global_fixed_rank;
35 extern int print_index;
36 extern int print_index_end;
37 extern int print_mat;
38 extern void _printmat(double * A, int64_t m, int64_t n, int64_t ld);
45 void HICMA_TASK_zgemm(const MORSE_option_t *options,
46  MORSE_enum transA, int transB,
47  int m, int n,
48  double alpha,
49  const MORSE_desc_t *AUV,
50  const MORSE_desc_t *Ark,
51  int Am, int An, int lda,
52  const MORSE_desc_t *BUV,
53  const MORSE_desc_t *Brk,
54  int Bm, int Bn, int ldb,
55  double beta,
56  const MORSE_desc_t *CUV,
57  const MORSE_desc_t *Crk,
58  int Cm, int Cn, int ldc,
59  int rk,
60  int maxrk,
61  double acc
62  )
63 {
64  int nAUV = AUV->nb;
65  int nBUV = BUV->nb;
66  int nCUV = CUV->nb;
67  struct starpu_codelet *codelet = &cl_zgemm_hcore;
68  /*void (*callback)(void*) = options->profiling ? cl_zgemm_hcore_callback : NULL;*/
69  void (*callback)(void*) = NULL;
70  MORSE_starpu_ws_t *h_work = (MORSE_starpu_ws_t*)(options->ws_host);
71  /*printf("%s %d:\t%p %p\n", __FILE__, __LINE__, h_work, options->ws_host);*/
72 
73  int sizeA = lda*nAUV; //FIXME Think about scheduling of tasks according to sizes of the matrices
74  int sizeB = ldb*nBUV;
75  int sizeC = ldc*nCUV;
76  int execution_rank = CUV->get_rankof( CUV, Cm, Cn );
77  int rank_changed=0;
78  (void)execution_rank;
79 
80  /* force execution on the rank owning the largest data (tile) */
81  int threshold;
82  char* env = getenv("MORSE_COMM_FACTOR_THRESHOLD");
83 
84  int ifval = 0, elseifval = 0, initialval = execution_rank;
85  if (env != NULL)
86  threshold = (unsigned)atoi(env);
87  else
88  threshold = 10;
89  if ( sizeA > threshold*sizeC ){
90  execution_rank = AUV->get_rankof( AUV, Am, An );
91  ifval = execution_rank;
92  rank_changed = 1;
93  }else if( sizeB > threshold*sizeC ){
94  execution_rank = BUV->get_rankof( BUV, Bm, Bn );
95  elseifval = execution_rank;
96  rank_changed = 1;
97  }
98  //printf("m:%d n:%d k:%d nb:%d\n", m, n, k, nb); all of them are nb (1156)
99  //printf("initialval:\t%d if:%d\t else:\t%d rc:\t%d\n", initialval, ifval, elseifval, rank_changed);
100  MORSE_BEGIN_ACCESS_DECLARATION;
101  MORSE_ACCESS_R(AUV, Am, An);
102  MORSE_ACCESS_R(BUV, Bm, Bn);
103  MORSE_ACCESS_RW(CUV, Cm, Cn);
104 #if !defined(HICMA_ALWAYS_FIX_RANK)
105  MORSE_ACCESS_R(Ark, Am, An);
106  MORSE_ACCESS_R(Brk, Bm, Bn);
107  MORSE_ACCESS_RW(Crk, Cm, Cn);
108 #endif
109  if (rank_changed)
110  MORSE_RANK_CHANGED(execution_rank);
111  MORSE_END_ACCESS_DECLARATION;
112 
113  //printf("%s %d n:%d\n", __func__, __LINE__,n );
114  starpu_insert_task(
115  starpu_mpi_codelet(codelet),
116  STARPU_VALUE, &transA, sizeof(MORSE_enum),
117  STARPU_VALUE, &transB, sizeof(MORSE_enum),
118  STARPU_VALUE, &m, sizeof(int),
119  STARPU_VALUE, &n, sizeof(int),
120  STARPU_VALUE, &alpha, sizeof(double),
121  STARPU_R, RTBLKADDR(AUV, double, Am, An),
122  STARPU_VALUE, &lda, sizeof(int),
123  STARPU_R, RTBLKADDR(BUV, double, Bm, Bn),
124  STARPU_VALUE, &ldb, sizeof(int),
125  STARPU_VALUE, &beta, sizeof(double),
126  STARPU_RW, RTBLKADDR(CUV, double, Cm, Cn),
127 #if !defined(HICMA_ALWAYS_FIX_RANK)
128  STARPU_R, RTBLKADDR(Ark, double, Am, An),
129  STARPU_R, RTBLKADDR(Brk, double, Bm, Bn),
130  STARPU_RW, RTBLKADDR(Crk, double, Cm, Cn),
131 #endif
132  STARPU_VALUE, &ldc, sizeof(int),
133  STARPU_VALUE, &rk, sizeof(int),
134  STARPU_VALUE, &maxrk, sizeof(int),
135  STARPU_VALUE, &acc, sizeof(double),
136  STARPU_VALUE, &Am, sizeof(int),
137  STARPU_VALUE, &An, sizeof(int),
138  STARPU_VALUE, &Bm, sizeof(int),
139  STARPU_VALUE, &Bn, sizeof(int),
140  STARPU_VALUE, &Cm, sizeof(int),
141  STARPU_VALUE, &Cn, sizeof(int),
142  STARPU_VALUE, &nAUV, sizeof(int),
143  STARPU_VALUE, &nBUV, sizeof(int),
144  STARPU_VALUE, &nCUV, sizeof(int),
145  STARPU_SCRATCH, options->ws_worker,
146  STARPU_VALUE, &h_work, sizeof(MORSE_starpu_ws_t *),
147  STARPU_PRIORITY, options->priority,
148  STARPU_CALLBACK, callback,
149 #if defined(CHAMELEON_USE_MPI)
150  STARPU_EXECUTE_ON_NODE, execution_rank,
151 #endif
152 #if defined(CHAMELEON_CODELETS_HAVE_NAME)
153  STARPU_NAME, "hcore_zgemm",
154 #endif
155  0);
156 }
157 
158 #if !defined(CHAMELEON_SIMULATION)
159 static void cl_zgemm_hcore_cpu_func(void *descr[], void *cl_arg)
160 {
161 #ifdef HICMA_DISABLE_ALL_COMPUTATIONS
162  return;
163 #endif
164 #ifdef HICMA_DISABLE_HCORE_COMPUTATIONS
165  return;
166 #endif
167  struct timeval tvalBefore, tvalAfter; // removed comma
168  gettimeofday (&tvalBefore, NULL);
169  MORSE_enum transA;
170  MORSE_enum transB;
171  int m;
172  int n;
173  double alpha;
174  double *AUV = NULL;
175  double *AD = NULL;
176  double *Ark = NULL;
177  int lda;
178  double *BUV = NULL;
179  double *BD = NULL;
180  double *Brk = NULL;
181  int ldb;
182  double beta;
183  double *CUV = NULL;
184  double *CD = NULL;
185  double *Crk = NULL;
186  int ldc;
187  int rk;
188  int maxrk;
189  double acc ;
190  int nAUV;
191  int nBUV;
192  int nCUV;
193 
194 
195  int idescr = 0;
196  AUV = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
197  BUV = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
198  CUV = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
199 #if !defined(HICMA_ALWAYS_FIX_RANK)
200  Ark = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
201  Brk = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
202  Crk = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
203 #else
204  double _gemm_rank = global_fixed_rank;
205  Ark = &_gemm_rank;
206  Brk = &_gemm_rank;
207  Crk = &_gemm_rank;
208 #endif
209 
210  double* work = NULL;
211  work = (double *)STARPU_MATRIX_GET_PTR(descr[idescr++]);
212 
213  int Am, An, Bm, Bn, Cm, Cn;
214 
215  MORSE_starpu_ws_t *h_work;
216  starpu_codelet_unpack_args(cl_arg, &transA, &transB, &m, &n, &alpha, &lda, &ldb, &beta, &ldc, &rk, &maxrk, &acc, &Am, &An, &Bm, &Bn, &Cm, &Cn, &nAUV, &nBUV, &nCUV, &h_work);
217 
218  double *AU = AUV;
219  double *BU = BUV;
220  double *CU = CUV;
221 
222  int nAU = nAUV/2;
223  size_t nelm_AU = (size_t)lda * (size_t)nAU;
224  double *AV = &(AUV[nelm_AU]);
225 
226  int nBU = nBUV/2;
227  size_t nelm_BU = (size_t)ldb * (size_t)nBU;
228  double *BV = &(BUV[nelm_BU]);
229 
230  int nCU = nCUV/2;
231  size_t nelm_CU = (size_t)ldc * (size_t)nCU;
232  double *CV = &(CUV[nelm_CU]);
233 
234  double old_Crk = Crk[0];
235  char datebuf_start[128];
236  datebuf_start[0] = '\0';
237  if(print_index){
238  time_t timer;
239  struct tm* tm_info;
240  gettimeofday (&tvalAfter, NULL);
241  time(&timer); \
242  tm_info = localtime(&timer); \
243  strftime(datebuf_start, 26, "%Y-%m-%d %H:%M:%S",tm_info); \
244  printf("%d+GEMM\t|CUV(%d,%d)%g AUV(%d,%d)%g BUV(%d,%d)%g\t\t\t\t\tGEMM: %s\n",MORSE_My_Mpi_Rank(), Cm, Cn, old_Crk, Am, An, Ark[0], Bm, Bn, Brk[0], datebuf_start);
245  }
246 
247  int isTransA = transA == MorseTrans;
248  int isTransB = transB == MorseTrans;
249 
251  HCORE_zgemm_fast(transA, transB,
252  m, n,
253  alpha, (isTransA ? AV : AU), (isTransA ? AU : AV), Ark, lda,
254  (isTransB ? BU : BV), (isTransB ? BV : BU), Brk, ldb,
255  beta, CU, CV, Crk, ldc, rk, maxrk, acc, work);
256  }else{
257  HCORE_zgemm(transA, transB,
258  m, n,
259  alpha, (isTransA ? AV : AU), (isTransA ? AU : AV), Ark, lda,
260  (isTransB ? BU : BV), (isTransB ? BV : BU), Brk, ldb,
261  beta, CU, CV, Crk, ldc, rk, maxrk, acc, work);
262  }
264  char datebuf[128];
265  time_t timer;
266  struct tm* tm_info;
267  gettimeofday (&tvalAfter, NULL);
268  time(&timer); \
269  tm_info = localtime(&timer); \
270  strftime(datebuf, 26, "%Y-%m-%d %H:%M:%S",tm_info); \
271  printf("%d-GEMM\t|CUV(%d,%d)%g->%g AUV(%d,%d)%g BUV(%d,%d)%g acc:%e rk:%d maxrk:%d\t\t\tGEMM: %.4f\t%s---%s\n",MORSE_My_Mpi_Rank(),Cm, Cn, old_Crk, Crk[0], Am, An, Ark[0], Bm, Bn, Brk[0],
272  acc, rk, maxrk,
273  (tvalAfter.tv_sec - tvalBefore.tv_sec)
274  +(tvalAfter.tv_usec - tvalBefore.tv_usec)/1000000.0,
275  datebuf_start, datebuf
276  );
277  }
278 }
279 #endif /* !defined(MORSE_SIMULATION) */
280 
281 /*
282  * Codelet definition
283  */
284 #if defined(HICMA_ALWAYS_FIX_RANK)
285 CODELETS_CPU(zgemm_hcore, 4, cl_zgemm_hcore_cpu_func)
286 // CODELETS(zgemm_hcore, 4, cl_zgemm_hcore_cpu_func, cl_zgemm_hcore_cuda_func, STARPU_CUDA_ASYNC)
287 #else
288 CODELETS_CPU(zgemm_hcore, 7, cl_zgemm_hcore_cpu_func)
289 // CODELETS(zgemm_hcore, 7, cl_zgemm_hcore_cpu_func, cl_zgemm_hcore_cuda_func, STARPU_CUDA_ASYNC)
290 #endif
int global_always_fixed_rank
Definition: hcore_zgytlr.c:46
#define AUV(m, n)
Definition: pzgemm.c:60
int print_index
#define A(m, n)
Definition: pzgemm.c:56
int global_fixed_rank
Definition: hcore_zgytlr.c:47
time_t timer
struct tm * tm_info
#define BUV(m, n)
Definition: pzgemm.c:61
#define CUV(m, n)
Definition: pzgemm.c:62
int print_mat
void HCORE_zgemm_fast(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *d_work)
void HCORE_zgemm(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *work)
Definition: hcore_zgemm.c:568
void _printmat(double *A, int64_t m, int64_t n, int64_t ld)
char datebuf[128]
void HICMA_TASK_zgemm(const MORSE_option_t *options, MORSE_enum transA, int transB, int m, int n, double alpha, const MORSE_desc_t *AUV, const MORSE_desc_t *Ark, int Am, int An, int lda, const MORSE_desc_t *BUV, const MORSE_desc_t *Brk, int Bm, int Bn, int ldb, double beta, const MORSE_desc_t *CUV, const MORSE_desc_t *Crk, int Cm, int Cn, int ldc, int rk, int maxrk, double acc)
Definition: codelet_zgemm.c:45
int print_index_end
Definition: hcore_zgytlr.c:38
int HICMA_get_use_fast_hcore_zgemm()
Definition: hicma_init.c:26