HiCMA
Hierarchical Computations on Manycore Architectures
hcore_zgemm_fast.c
Go to the documentation of this file.
1 
16 #include "coreblas/include/coreblas.h"
17 #include "coreblas/lapacke.h"
18 #include <assert.h>
19 #ifdef LAPACKE_UTILS
20 #include <lapacke_utils.h>
21 #endif
22 
23 #include "control/hicma_config.h"
24 
25 //FIXME PREVIOUS DECLARION OF CBLAS_SADDR ~/hicma-dev/chameleon/build/include/chameleon/coreblas/include/coreblas.h
26 #undef CBLAS_SADDR
27 #define CBLAS_SADDR(_val) (_val)
28 
29 // #define DBG_MSG
30 
31 #ifdef DBG_MSG
32 #define ECHO_I(_val) printf("%s(%d) ", #_val, (_val));
33 #define ECHO_f(_val) printf("%s(%e) ", #_val, (_val));
34 #define ECHO_LN printf("\n");
35 #else
36 #define ECHO_I(_val)
37 #define ECHO_f(_val)
38 #define ECHO_LN
39 #endif
40 
41 #ifdef HCORE_GEMM_USE_KBLAS_ACA
42 extern int kblas_ACAf( int m, int n,
43  double* A, int lda,
44  double* U, int ldu,
45  double* V, int ldv,
46  double* S,
47  double maxacc, int maxrk,
48  double* acc, int* rk);
49 #endif
50 
51 extern int use_trmm;
52 extern int use_scratch;
53 extern int gemm_print_index;
54 extern int gemm_print_mat;
55 extern void hc_printmat(double * A, int m, int n, int ld);
56 /***************************************************************************//*
57  *
58  * @ingroup CORE_double
59  *
60  **/
61 
62 void HCORE_zgemm_fast(MORSE_enum transA, int transB,
63  int M, int N,
64  double alpha,
65  double *AU,
66  double *AV,
67  double *Ark,
68  int LDA,
69  double *BU,
70  double *BV,
71  double *Brk,
72  int LDB,
73  double beta,
74  double *CU,
75  double *CV,
76  double *Crk,
77  int LDC,
78  int rk,
79  int maxrk,
80  double acc,
81  double* d_work
82  )
83 {do{
84  // printf("%d GEMM %p\n", MORSE_My_Mpi_Rank(), d_work);
85  assert(use_trmm == 1);
86  assert(use_scratch == 1);
87 
88  int ws_needed = 0;
89 
90  // cudaStat = cudaStreamSynchronize( cuda_stream );
91  // assert(cudaSuccess == cudaStat);
92 
93  int _Ark = (int)(Ark[0]); ECHO_I(_Ark);
94  int _Brk = (int)(Brk[0]); ECHO_I(_Brk);
95  int _Crk = (int)(Crk[0]); ECHO_I(_Crk);
96  int old_Crk = _Crk;
97  // if(gemm_print_index) printf("%d, _Ark %d, _Brk %d, _Crk %d\n", __LINE__, _Ark, _Brk, _Crk);
98 
99  int _M = M; int _N = N; ECHO_I(_M);
100  double* _CU = CU; int ld_CU = LDC; ECHO_I(ld_CU);
101  double* _CV = CV; int ld_CV = LDC; ECHO_I(ld_CV);
102  double* _AU = AU; int ld_AU = LDA; ECHO_I(ld_AU);
103  double* _AV = AV; int ld_AV = LDA; ECHO_I(ld_AV);
104  double* _BU = BU; int ld_BU = LDB; ECHO_I(ld_BU);
105  double* _BV = BV; int ld_BV = LDB; ECHO_I(ld_BV);
106  int rank = rk; ECHO_I(rank); ECHO_I(maxrk)
107 
108  char chall = 'A';
109 
110  int use_CUV_clone = 0;
111  double* _CU_save = _CU;
112  double* _CV_save = _CV;
113  int ld_CU_save = ld_CU;
114  int ld_CV_save = ld_CV;
115 
116  int CUV_ncols = _Crk + _Ark; ECHO_I(CUV_ncols);
117 
118  if((CUV_ncols > maxrk)){
119  double* CUclone = NULL;
120  int ld_CUclone = _M;
121  double* CVclone = NULL;
122  int ld_CVclone = _M;
123  size_t CUclone_nelm = _M * 2 * maxrk;
124  size_t CVclone_nelm = _M * 2 * maxrk;
125 
126  use_CUV_clone = 1;
127  CUclone = d_work;
128  d_work += CUclone_nelm;
129  ws_needed += CUclone_nelm;
130  CVclone = d_work;
131  d_work += CVclone_nelm;
132  ws_needed += CVclone_nelm;
133  LAPACK_dlacpy(&chall,
134  &_M, &_Crk,
135  _CU, &ld_CU,
136  CUclone, &ld_CUclone);
137  LAPACK_dlacpy(&chall,
138  &_M, &_Crk,
139  _CV, &ld_CV,
140  CVclone, &ld_CVclone);
141  _CU = CUclone;
142  _CV = CVclone;
143  ld_CU = ld_CUclone;
144  ld_CV = ld_CVclone;
145  }
146  int incOne = 1;
147  double d_one = 1.0;
148  double d_zero = 0.0;
149 
150  // TODO remove the abundant assumptions on matrices sizes and leading dimensions
151 
152  //=======================================================================================================
153  // QR A
154 
155  //concat CU+AU
156  int nelm_AU = _M * _Ark; ECHO_I(nelm_AU); ECHO_I(_Crk*ld_CU);
157  cblas_dcopy(nelm_AU, _AU, incOne, &_CU[_Crk*ld_CU], incOne);
158 
159  if(alpha != d_one){
160  cblas_dscal(nelm_AU, CBLAS_SADDR(alpha), &_CU[_Crk*ld_CU], incOne);
161  }
162  if(beta != d_one){
163  ECHO_I(_M * _Crk);
164  cblas_dscal(_M * _Crk, CBLAS_SADDR(beta), _CU, incOne);
165  }
166 
167  double *qrtauA = d_work;
168  size_t qrtauA_nelm = _M;
169  d_work += qrtauA_nelm;
170  ws_needed += qrtauA_nelm;
171  assert(qrtauA != NULL);
172 
173  int info = LAPACKE_dgeqrf( LAPACK_COL_MAJOR, _M, CUV_ncols, _CU, ld_CU, qrtauA);
174 
175  //=======================================================================================================
176  // QR B
177  double* qrb_avtbv = d_work;
178  size_t qrb_avtbv_nelm = maxrk * maxrk; ECHO_I(qrb_avtbv_nelm);
179  d_work += qrb_avtbv_nelm;
180  ws_needed += qrb_avtbv_nelm;
181 
182  //P = AV^T * BV
183  cblas_dgemm(CblasColMajor,
184  CblasTrans, CblasNoTrans,
185  _Ark, _Brk, _M,
186  CBLAS_SADDR(d_one), _AV, ld_AV,
187  _BV, ld_BV,
188  CBLAS_SADDR(d_zero), qrb_avtbv, maxrk);
189 
190  //G = P * BU^T <=> G^T = BU * P^T
191  //CV = CV | G^T
192  cblas_dgemm(CblasColMajor,
193  CblasNoTrans, CblasTrans,
194  _M, _Ark, _Brk,
195  CBLAS_SADDR(d_one), _BU, ld_BU,
196  qrb_avtbv, maxrk,
197  CBLAS_SADDR(d_zero), &_CV[_Crk*ld_CV], ld_CV);
198 
199  // double* qrtauB = qrtauA + qrtauA_nelm;
200  double* qrtauB = d_work;
201  size_t qrtauB_nelm = _M;
202  d_work += qrtauB_nelm;
203  ws_needed += qrtauB_nelm;
204  assert(qrtauB != NULL);
205 
206  info = LAPACKE_dgeqrf(LAPACK_COL_MAJOR, _M, CUV_ncols, _CV, ld_CV, qrtauB);
207 
208  //=======================================================================================================
209  //SVD
210  double* rA = d_work;
211  size_t rA_nelm = CUV_ncols * CUV_ncols; ECHO_I(rA_nelm);
212  int ld_rA = CUV_ncols;
213  d_work += rA_nelm;
214  ws_needed += rA_nelm;
215 
216  double* rB = _CV;
217  int ld_rB = ld_CV;
218  // size_t rB_nelm = CV_ncols * CV_ncols;
219 
220  //copy rA from CU
221  char chlow = 'L';
222  LAPACK_dlaset(&chlow, &CUV_ncols, &CUV_ncols, &d_zero, &d_zero, rA, &ld_rA);
223  char chup = 'U';
224  LAPACK_dlacpy(&chup, &CUV_ncols, &CUV_ncols,
225  _CU, &ld_CU,
226  rA, &ld_rA);
227 
228  // rA = rA * rB^T
229  cblas_dtrmm(CblasColMajor, CblasRight, CblasUpper, CblasTrans, CblasNonUnit,
230  CUV_ncols, CUV_ncols,
231  d_one, rB, ld_rB,
232  rA, ld_rA);
233 
234  int finalrank = -1, size_sigma = CUV_ncols; ECHO_I(size_sigma)
235  double relacc = (acc);
236 
237  double* _T = rA;
238  int ld_T = ld_rA;
239  double* TU = d_work;
240  #ifdef HCORE_GEMM_USE_ORGQR
241  size_t TU_nelm = CUV_ncols * CUV_ncols; ECHO_I(TU_nelm)
242  int ld_TU = CUV_ncols;
243  #else
244  size_t TU_nelm = _M * CUV_ncols; ECHO_I(TU_nelm)
245  int ld_TU = _M;
246  #endif
247  d_work += TU_nelm;
248  ws_needed += TU_nelm;
249 
250  double* TV = d_work;
251  #ifdef HCORE_GEMM_USE_ORGQR
252  size_t TV_nelm = CUV_ncols * CUV_ncols;
253  int ld_TV = CUV_ncols;
254  #else
255  size_t TV_nelm = _M * CUV_ncols;
256  #ifdef HCORE_GEMM_USE_KBLAS_ACA
257  int ld_TV = _M;
258  #else
259  int ld_TV = CUV_ncols;
260  #endif
261  #endif
262  ECHO_I(TV_nelm)
263  d_work += TV_nelm;
264  ws_needed += TV_nelm;
265 
266  double *d_sigma = d_work;
267  size_t d_sigma_nelm = CUV_ncols; ECHO_I(d_sigma_nelm)
268  d_work += d_sigma_nelm;
269  ws_needed += d_sigma_nelm;
270 
271  #if defined HCORE_GEMM_USE_KBLAS_ACA
272 
273  double finalacc;
274  kblas_ACAf( CUV_ncols, CUV_ncols,
275  _T, ld_T,
276  TU, ld_TU,
277  TV, ld_TV,
278  d_sigma,
279  relacc, rank,
280  &finalacc, &finalrank);
281  ECHO_I(finalrank)
282  #else
283  double* svdsuperb = d_work;
284  double work_query;
285  int lwork = -1;
286  info = LAPACKE_dgesvd_work( LAPACK_COL_MAJOR, 'A', 'A',
287  CUV_ncols, CUV_ncols,
288  NULL, CUV_ncols,
289  NULL,
290  NULL, CUV_ncols,
291  NULL, CUV_ncols, &work_query, lwork );
292  lwork = (int)work_query;
293  size_t svdsuperb_nelm = lwork;
294  d_work += svdsuperb_nelm;
295  ws_needed += svdsuperb_nelm;
296 
297  info = LAPACKE_dgesvd_work( LAPACK_COL_MAJOR,
298  'A', 'A',
299  CUV_ncols, CUV_ncols,
300  _T, ld_T,
301  d_sigma,
302  TU, ld_TU,
303  TV, ld_TV,
304  svdsuperb,
305  svdsuperb_nelm);
306 
307  double *h_sigma = d_sigma;
308 
309  if(rank != 0) {
310  finalrank = rank;
311  if(rank > size_sigma)
312  finalrank = size_sigma;
313  }
314  else{
315  int newrank = size_sigma;
316  int i;
317  for(i=2;i<size_sigma;i++){
318  // ECHO_f(h_sigma[i] )
319  if(h_sigma[i] < relacc)
320  {
321  newrank=i;
322  break;
323  }
324  }
325  finalrank = newrank; ECHO_I(finalrank)
326  }
327 
328  //since we store SV we need to scale V by S
329  int k;
330  for(k = 0; k < finalrank; k++){
331  double diagval = h_sigma[k];
332 
333  cblas_dscal(CUV_ncols, CBLAS_SADDR(diagval), &TV[k], ld_TV);
334  }
335  #endif
336  Crk[0] = (double)finalrank; ECHO_f(Crk[0])
337 
338  //=======================================================================================================
339  // construct final U
340  #if defined HCORE_GEMM_USE_ORGQR
341  double* newUV = d_work;
342  size_t newUV_nelm = _M * finalrank;
343  d_work += newUV_nelm;
344 
345  info = LAPACKE_dorgqr( LAPACK_COL_MAJOR,
346  _M, CUV_ncols, CUV_ncols,
347  _CU, ld_CU,
348  qrtauA);
349  cblas_dgemm(CblasColMajor,
350  CblasNoTrans, CblasNoTrans,
351  _M, finalrank, CUV_ncols,
352  CBLAS_SADDR(d_one), _CU, ld_CU,
353  TU, ld_TU,
354  CBLAS_SADDR(d_zero), use_CUV_clone ? _CU_save : newUV, use_CUV_clone ? ld_CU_save : ld_CU);
355 
356  if(!use_CUV_clone)
357  LAPACKE_dlacpy(LAPACK_COL_MAJOR, 'A', _M, finalrank, newUV, ld_CU, _CU_save, ld_CU_save);
358  #else
359  char uplo = 'A';
360  int nrows = _M - CUV_ncols;
361  int ncols = finalrank;
362  LAPACK_dlaset( &uplo, &nrows, &ncols, &d_zero, &d_zero, &(TU[CUV_ncols]), &ld_TU );
363 
364  info = LAPACKE_dormqr( LAPACK_COL_MAJOR,
365  'L', 'N',
366  _M, finalrank, CUV_ncols,
367  _CU, ld_CU,
368  qrtauA,
369  TU, ld_TU);
370 
371  LAPACKE_dlacpy(LAPACK_COL_MAJOR, 'A', _M, finalrank, TU, ld_TU, _CU_save, ld_CU_save);
372  #endif
373 
374  //=======================================================================================================
375  // construct final V
376  #ifdef HCORE_GEMM_USE_ORGQR
377  info = LAPACKE_dorgqr( LAPACK_COL_MAJOR,
378  _M, CUV_ncols, CUV_ncols,
379  _CV, ld_CV,
380  qrtauB);
381 
382  cblas_dgemm(CblasColMajor,
383  CblasNoTrans,
384  #ifdef HCORE_GEMM_USE_KBLAS_ACA
385  CblasNoTrans,
386  #else
387  CblasTrans,
388  #endif
389  _M, finalrank, CUV_ncols,
390  CBLAS_SADDR(d_one), _CV, ld_CV,
391  TV, ld_TV,
392  CBLAS_SADDR(d_zero), use_CUV_clone ? _CV_save : newUV, use_CUV_clone ? ld_CV_save : ld_CV);
393 
394  if(!use_CUV_clone)
395  LAPACKE_dlacpy(LAPACK_COL_MAJOR, 'A', _M, finalrank, newUV, ld_CV, _CV_save, ld_CV_save);
396  #else
397  #ifdef HCORE_GEMM_USE_KBLAS_ACA
398  int TV_pad = CUV_ncols;
399  nrows = _M - CUV_ncols;
400  ncols = finalrank;
401  #else
402  int TV_pad = CUV_ncols * ld_TV;
403  nrows = finalrank;
404  ncols = _M - CUV_ncols;
405  #endif
406  LAPACK_dlaset( &uplo, &nrows, &ncols, &d_zero, &d_zero, &(TV[TV_pad]), &ld_TV );
407 
408  info = LAPACKE_dormqr( LAPACK_COL_MAJOR,
409  #ifdef HCORE_GEMM_USE_KBLAS_ACA
410  'L', 'N',
411  _M, finalrank, CUV_ncols,
412  #else
413  'R', 'T',
414  finalrank, _M, CUV_ncols,
415  #endif
416  _CV, ld_CV,
417  qrtauB,
418  TV, ld_TV);
419 
420  #ifdef HCORE_GEMM_USE_KBLAS_ACA
421  LAPACKE_dlacpy(LAPACK_COL_MAJOR, 'A', _M, finalrank, TV, ld_TV, _CV_save, ld_CV_save);
422  #else
423  LAPACKE_dge_trans(LAPACK_COL_MAJOR, finalrank, _M, TV, ld_TV, _CV_save, ld_CV_save);
424  #endif
425  #endif
426  ECHO_I(ws_needed);
427  ECHO_LN
428  }while(0);
429 }
430 
431 /***************************************************************************/
432 // For debugging precision conversion script
433  //dormqr
434  //zormqr
435  //ormqr
436  //LAPACKE_dormqr
437  //LAPACKE_zormqr //
438  //LAPACKE_ormqr //
439  //dunmqr
440  //zunmqr
441  //unmqr
442  //LAPACKE_dunmqr
443  //LAPACKE_zunmqr //
444  //LAPACKE_unmqr //
int gemm_print_index
Definition: hcore_zgemm.c:29
#define A(m, n)
Definition: pzgemm.c:56
#define ECHO_I(_val)
int gemm_print_mat
Definition: hcore_zgemm.c:30
#define ECHO_f(_val)
void hc_printmat(double *A, int m, int n, int ld)
Definition: hcore_zgemm.c:32
void HCORE_zgemm_fast(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *d_work)
#define ECHO_LN
#define CBLAS_SADDR(_val)
int use_trmm
Definition: hcore_zgemm.c:27
int uplo[2]
int use_scratch