HiCMA
Hierarchical Computations on Manycore Architectures
hcore_zgemm.c
Go to the documentation of this file.
1 
16 #include "coreblas/coreblas.h"
17 #include "coreblas/lapacke.h"
18 #include <assert.h>
19 #ifdef LAPACKE_UTILS
20 #include <lapacke_utils.h>
21 #endif
22 
23 //FIXME PREVIOUS DECLARION OF CBLAS_SADDR ~/hicma/chameleon/build/include/chameleon/coreblas/include/coreblas.h
24 #undef CBLAS_SADDR
25 #define CBLAS_SADDR(_val) (_val)
26 
27 int use_trmm = 1;
28 extern int use_scratch;
31 int hc_nelm_limit = 512;
32 void hc_printmat(double * A, int m, int n, int ld){
33  printf("M:%d N:%d LD:%d\n[", m, n, ld);
34  int i, j, nelm = 0;
35  for(i=0;i<m;i++){
36  printf("[");
37  for(j=0;j<n;j++){
38  printf("%+.2e", A[j*ld+i]);
39  if(j!=n-1){
40  //printf(", ");
41  }
42  //printf("%g ", A[j*tld(descZ)+i]);
43  //printf("%g\t", A[j*descZ->n+i]);
44  //printf("(%d,%d,%d) %g\t", i,j,descZ->mb,A[j*descZ->mb+i]);
45  //printf("(%d,%d,%d) %g\t", i,j,descZ->n,A[j*descZ->n+i]);
46  //printf("(%d,%d,%d) %g [%d %d %d]\t", i,j,descZ->n,A[j*descZ->n+i], descZ->m, descZ->lm, descZ->ln);
47  nelm++;
48  if(nelm >= hc_nelm_limit){
49  printf("\n");
50  return;
51  }
52  }
53  printf("]");
54  if(i!=m-1){
55  printf(",");
56  printf("\n");
57  }
58  }
59  printf("]\n");
60 }
61 // void __qra(ka_matrix* _CU, ka_matrix* _AU, double alpha, double* qrtauA)
62 void __qra(int _M,
63  int maxrank,
64  double* _CU, int ld_CU, int _Crk,
65  int* pnew_CU_ncols,
66  double* _AU, int ld_AU, int _Ark,
67  double alpha, double beta, double* qrtauA){
68  int info;
69  int AU_nrows = _M;// assume that _M == nrows of A
70  int AU_ncols = _Ark;
71  int CU_ncols = _Crk;
72  if ((_Ark + _Crk) > 2*maxrank){
73  fprintf(stderr, "%s %s %d: Sum of ranks (%d) is too big! _Ark:%d _Crk:%d maxrank:%d (x2: %d)\n",
74  __FILE__, __func__, __LINE__, (_Ark + _Crk), _Ark, _Crk, maxrank, 2*maxrank);
75  exit(-1);
76  }
77  int nelm_AU = AU_nrows * AU_ncols;
78  int incOne = 1;
79  //hc_printmat(_CU, _M, _M, ld_CU);
80  //hc_printmat(_CU, _M, _M, ld_CU);
81  //ERRONEOUS. here is no ld!!!
82  if(gemm_print_index){
83  printf(" QRA\t|%d\t|nelm_AU:%d alpha:%g CU_ncols:%d ld_CU:%d CU_ncols*ld_CU:%d\n", __LINE__, nelm_AU, alpha, CU_ncols, ld_CU, CU_ncols*ld_CU);
84  }
85  cblas_dcopy(nelm_AU, _AU, incOne, &_CU[CU_ncols*ld_CU], incOne);
86  //hc_printmat(_CU, _M, _M, ld_CU);
87  double d_one = (double)1.0;
88  if(alpha != d_one){
89  cblas_dscal(nelm_AU, CBLAS_SADDR(alpha) , &_CU[CU_ncols*ld_CU], incOne);
90  }
91  if(beta != d_one){
92  cblas_dscal(_M * _Crk, CBLAS_SADDR(beta) , _CU, incOne);
93  }
94  //hc_printmat(_CU, _M, _M, ld_CU);
95  int CU_nrows = _M;
96  /*printf("__qra: CU_ncols:%d new_CU_ncols:%d\n", CU_ncols, (CU_ncols+_Ark));*/
97  CU_ncols += _Ark;
98  *pnew_CU_ncols = CU_ncols; //CHANGED VALUE, RETURN VALUE
99  if(gemm_print_index){
100  printf(" QRA\t|%d\t|CU_nrows:%d CU_ncols:%d ld_CU:%d QRA:%p\n", __LINE__, CU_nrows, CU_ncols, ld_CU, _CU);
101  }
102  info = LAPACKE_dgeqrf(
103  LAPACK_COL_MAJOR, CU_nrows, CU_ncols, _CU, ld_CU, qrtauA);
104  //printf("%d %d %d _M:%d\n", CU_nrows, CU_ncols, ld_CU, _M);
105  //hc_printmat(_CU, _M, _M, ld_CU);
106  /*hc_printmat(qrtauA, 1, _M, 1); */
107  if(info != 0){
108  fprintf(stderr,
109  "%s %d ERROR in LAPACKE_dgeqrf(1:CU_nrows:%d 2:CU_ncols:%d 3:_CU:%p 4:ld_CU:%d 5:qrtauA:%p) info=%d maxrank:%d\n",
110  __FILE__, __LINE__, CU_nrows, CU_ncols, _CU, ld_CU, qrtauA, info, maxrank);
111  exit(-1);
112  }
113  //hc_printmat(_AU, _M, _M, ld_AU);
114 }
115 /*
116  * CV, AV and BV are stored as transposed.
117  * Performs QR([CV|(AV^T*BU*BV^T)])
118  * Step 1: F=AV^T*BU gemm
119  * Step 2: G=P*BV^T gemm
120  * Step 3: CV=CV|G done in the gemm at Step 2
121  * Step 4: QR(CV) potrf
122  */
123 void __qrb(
124  int _M,
125  int maxrank,
126  double* _CV, int ld_CV, int _Crk,
127  int* pnew_CV_ncols,
128  double* _AV, int ld_AV, int _Ark,
129  double* _BU, int ld_BU,
130  double* _BV, int ld_BV, int _Brk,
131  double* qrtauB, double* AcolBcolT){
132  int info;
134  assert(AcolBcolT != NULL);
135  int ld_AcolBcolT = maxrank; //ASSUMPTION
136 
137  double alpha = 1.0;
138  double beta = 0.0;
139  int AV_ncols = _Ark; //ASSUMPTION
140  int BV_nrows = _M; //ASSUMPTION
141  int BV_ncols = _Brk; //ASSUMPTION
142  // AV*BV^T
143  if(gemm_print_index){
144  printf(" QRB\t|%d\t| (AV*BV^T) M,N,K:%d,%d,%d LDA,LDB,LDC:%d,%d,%d alpha:%g beta:%g\n", __LINE__,AV_ncols, BV_ncols, BV_nrows,ld_AV,ld_BV,ld_AcolBcolT, alpha, beta);
145  }
146  /* Step 1: F=AV^T*BU gemm */
147  cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans,
148  AV_ncols, BV_ncols, BV_nrows, CBLAS_SADDR(alpha),
149  _AV, ld_AV, _BV, ld_BV, CBLAS_SADDR(beta),
150  AcolBcolT, ld_AcolBcolT);
151  int AcolBcolT_nrows = AV_ncols;
152  int AcolBcolT_ncols = BV_ncols;
153 
154  int BU_nrows = _M; //ASSUMPTION
155  int BU_ncols = _Brk; //ASSUMPTION
156  int CV_ncols = _Crk;
157  if ((AcolBcolT_nrows + _Crk) > 2*maxrank){
158  fprintf(stderr, "%s %s %d: Sum of two ranks (%d) is too big! \
159  AcolBcolT:%d _Crk:%d maxrank:%d (x2: %d)\n",
160  __FILE__, __func__, __LINE__, (AcolBcolT_nrows + _Crk), AcolBcolT, _Crk, maxrank, 2*maxrank);
161  exit(-1);
162  }
163  // (AV*BV^T) * BU^T
164  if(gemm_print_index){
165  printf(" QRB\t|%d\t| (AV*BV^T) * BU^T M,N,K:%d,%d,%d LDA,LDB,LDC:%d,%d,%d alpha:%g beta:%g CV_ncols:%d AcolBcolT_nrows:%d ldcB:%d\n",
166  __LINE__, BU_nrows,AcolBcolT_nrows, BU_ncols,ld_BU,ld_AcolBcolT,ld_CV, alpha, beta, CV_ncols, AcolBcolT_nrows, ld_CV);
167  }
168  /* Step 2: G=P*BV^T gemm */
169  /* Step 3: CV=CV|G done in the gemm at Step 2*/
170  cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans,
171  BU_nrows, AcolBcolT_nrows, BU_ncols,
172  CBLAS_SADDR(alpha),
173  _BU, ld_BU,
174  AcolBcolT, ld_AcolBcolT,
175  CBLAS_SADDR(beta),
176  &_CV[CV_ncols*ld_CV], ld_CV);
177  CV_ncols += AcolBcolT_nrows;
178  *pnew_CV_ncols = CV_ncols; //CHANGED VALUE, RETURN VALUE
179  int CV_nrows = _M; //ASSUMPTION
180 
181  if(gemm_print_index){
182  printf(" QRB\t|%d\t|CV_nrows:%d CV_ncols:%d ld_CV:%d QRB:%p\n",
183  __LINE__, CV_nrows, CV_ncols, ld_CV, _CV);
184  }
185  /* Step 4: QR(CV) potrf */
186  info = LAPACKE_dgeqrf(LAPACK_COL_MAJOR, CV_nrows, CV_ncols,
187  _CV, ld_CV, qrtauB);
188  if(info != 0){
189  fprintf(stderr,
190  "%s %d ERROR in LAPACKE_dgeqrf(1:CV_nrows:%d 2:CV_ncols:%d 3:_CV:%p 4:ld_CV:%d 5:qrtauB:%p) info=%d maxrank:%d\n",
191  __FILE__, __LINE__, CV_nrows, CV_ncols, _CV, ld_CV, qrtauB,info, maxrank);
192  exit(-1);
193  }
194 }
195 
196 //void __svd(ka_matrix* _CU, ka_matrix* _CV, int rank, double acc, ka_matrix* _U, ka_matrix* _V)
197 // Uses CV for TRMM
198 void __svd(
199  int _M,
200  int maxrank,
201  double* _CU, int ld_CU,
202  double* _CV, int ld_CV, int _Crk,
203  double* _U, int ld_U,
204  double* _V, int ld_V, int* pnew_UVrk,
205  int rank,
206  double acc,
207  double* _rA, double* _rB, double* _T, double* sigma, double* svdsuperb
208  ) {
209  int info;
210  int nb = _M; //ASSUMPTION
211  int CU_nrows = _M; //ASSUMPTION
212  int CU_ncols = _Crk; //ASSUMPTION
214  int rA_nrows = chameleon_min(CU_nrows, CU_ncols);
215  int rA_ncols = CU_ncols; //ASSUMPTION
216  int maxncolsR = 2*_Crk; //nb; //ASSUMPTION
217  int ld_rA = rA_nrows;
218 
219  if(rA_nrows != rA_ncols){
220  printf("TRMM cannot be used because R coming from QR factorization of A is not square nrows: %d ncols:%d \n", rA_nrows, rA_ncols);
221  exit(1);
222  }
223  assert(_rA != NULL);
224 
225  if(gemm_print_index){
226  printf(" SVD\t|%d\t| copy rA rA_nrows:%d rA_ncols:%d ld_CU:%d ld_rA:%d CU:%p rA:%p\n",
227  __LINE__, rA_nrows, rA_ncols, ld_CU, ld_rA, _CU, _rA);
228  }
229  double zero = 0.0;
230  char chlow = 'L';
231  LAPACK_dlaset(&chlow,
232  &rA_nrows, &rA_ncols, &zero, &zero, _rA, &ld_rA);
233  char chup = 'U';
234  LAPACK_dlacpy(&chup,
235  &rA_nrows, &rA_ncols,
236  _CU, &ld_CU,
237  _rA, &ld_rA);
238  if(gemm_print_mat){
239  printf("%d\t|_CU and _rA\n", __LINE__);
240  hc_printmat(_CU, _M, _Crk, ld_CU);
241  hc_printmat(_rA, rA_nrows, rA_ncols, ld_rA);
242  }
245  int CV_nrows = _M; //ASSUMPTION
246  int CV_ncols = _Crk; //ASSUMPTION
247  int rB_nrows = chameleon_min(CV_nrows, CV_ncols);
248  int rB_ncols = CV_ncols;
249  int ld_rB = rB_nrows; //ASSUMPTION
250  assert(rA_ncols == rB_ncols);
251  // trmm does not access lower part but gemm reads.
252  // However lower part is never written.
253  if(gemm_print_index){
254  printf(" SVD\t|%d\t| copy rB rB_nrows:%d rB_ncols:%d ld_rB:%d ld_CV:%d rB:%p CV:%p\n",
255  __LINE__, rB_nrows, rB_ncols, ld_rB, ld_CV, _rB, _CV);
256  }
257  if(use_trmm == 0){
258  assert(_rB != NULL);
259  LAPACK_dlaset(&chlow,
260  &rB_nrows, &rB_ncols, &zero, &zero, _rB, &ld_rB);
261  LAPACK_dlacpy(&chup,
262  &rB_nrows, &rB_ncols, _CV, &ld_CV, _rB, &ld_rB);
263  } else {
264  _rB = _CV;
265  ld_rB = ld_CV;
266  }
267  if(gemm_print_mat){
268  printf("%d\t|_CV and _rB\n", __LINE__);
269  hc_printmat(_CV, _M, _Crk, ld_CV);
270  hc_printmat(_rB, rB_nrows, rB_ncols, ld_rB);
271  }
274  int T_nrows = rA_ncols;
275  int T_ncols = rA_ncols;
276  int ld_T = T_nrows;
277  double alpha = 1.0;
278  double beta = 0.0;
279  if(use_trmm == 1){
280  cblas_dtrmm(CblasColMajor, CblasRight, CblasUpper, CblasTrans, CblasNonUnit, rA_nrows, rA_ncols, alpha, _rB, ld_rB, _rA, ld_rA); //FIXME Correctness of rA_ncols is not checked
281  _T = _rA;
282  } else {
283  assert(_T != NULL);
284  cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans,
285  rA_nrows, rB_nrows, rB_ncols, CBLAS_SADDR(alpha), _rA, ld_rA,
286  _rB, ld_rB, CBLAS_SADDR(beta), _T, ld_T);
287  }
288  if(gemm_print_index){
289  printf(" SVD\t|%d\t| T=rA*rB^T rA_nrows:%d rB_nrows:%d rB_ncols:%d ld_rA:%d ld_rB:%d ld_T:%d alpha:%g beta:%g\n",
290  __LINE__, rA_nrows, rB_nrows, rB_ncols, ld_rA, ld_rB, ld_T, alpha, beta);
291  }
292  if(gemm_print_mat){
293  hc_printmat(_T, T_nrows, T_ncols, ld_T);
294  }
295  // Singular values are ALWAYS floating point numbers.
296  assert(sigma != NULL);
297  int size_sigma = T_nrows;
298  assert(svdsuperb != NULL);
299  if(gemm_print_index){
300  printf(" SVD\t|%d\t| svd(T) (3.m)T_nrows:%d (4.n)T_ncols:%d ld_T:%d ld_U:%d (11.ldvt)ld_V:%d _T:%p (zero based parameter indices)\n",
301  __LINE__, T_nrows, T_ncols, ld_T, ld_U, ld_V, _T);
302  }
303  info = LAPACKE_dgesvd(LAPACK_COL_MAJOR, 'A', 'A',
304  T_nrows, T_ncols, _T, ld_T, sigma,
305  _U, ld_U, _V, ld_V,
306  svdsuperb);
307  if(info != 0){
308  fprintf(stderr,
309  "%s %d ERROR in LAPACKE_dgesvd() info=%d"
310  "1:T_nrows=%d, 2:T_ncols=%d, 3:_T=%p, 4:ld_T=:%d, 5:sigma=%p,"
311  "6:_U=%p, 7:ld_U=%d, 8:_V=%p, 9:ld_V:%d,"
312  "10:svdsuperb:%p"
313  "\n",
314  __FILE__, __LINE__, info,
315  T_nrows, T_ncols, _T, ld_T, sigma,
316  _U, ld_U, _V, ld_V,
317  svdsuperb);
318  exit(-1);
319  }
320  int U_nrows, U_ncols, V_nrows, V_ncols;
321  U_nrows = U_ncols = V_nrows = V_ncols = T_nrows;
322 
323  if(gemm_print_mat){
324  hc_printmat(_U, U_nrows, U_ncols, ld_U);
325  hc_printmat(sigma, 1, size_sigma, 1);
326  hc_printmat(_V, V_nrows, V_ncols, ld_V);
327  printf("%d %e\n", rank, acc);
328  }
329  int finalrank = -1;
330  //double relacc = (acc*sigma[0]*acc);
331  double relacc = (acc);
332  //double relacc = (acc*acc);
333  if(rank != 0) {
334  finalrank = rank;
335  if(rank > size_sigma)
336  finalrank = size_sigma;
337  }
338  else{
339  int newrank = size_sigma;
340  int i;
341  for(i=2;i<size_sigma;i++){
342  if(sigma[i] < relacc)
343  {
344  newrank=i;
345  break;
346  }
347  }
348  finalrank = newrank;
349  }
350  if(finalrank > maxrank){
351  fprintf(stderr, "%s %s %d: Rank after truncation is too big! finalrank:%d maxrank:%d\n", __FILE__, __func__, __LINE__, finalrank, maxrank);
352  exit(-1);
353  }
354 
355  if(gemm_print_index){int i;
356  printf("rank:%d acc:%.2e relac:%.2e size_sigma:%d final_rank:%d: ",
357  rank, acc, relacc, size_sigma, finalrank);
358  for(i=0;i<size_sigma;i++){
359  printf("%d:%.2e ", i,sigma[i]);
360  }
361  printf("\n");
362  }
363  if(gemm_print_index){
364  printf("size_sigma:%d finalrank:%d %.2e\n", size_sigma, finalrank, acc);
365  }
366  U_ncols = finalrank;
367  V_nrows = finalrank;
369  int rank_V = finalrank;
370  int k;
371  for(k = 0; k < rank_V; k++){
372  double diagval = sigma[k];
373  cblas_dscal(V_ncols, CBLAS_SADDR(diagval), &_V[k], ld_V);
374  }
375  if(gemm_print_index){
376  printf(" SVD\t|%d\t| S*V V_ncols:%d ld_V:%d _V:%p\n",
377  __LINE__, V_ncols, ld_V, _V);
378  }
379  if(gemm_print_mat){
380  hc_printmat(_V, V_nrows, V_ncols, ld_V);
381  }
382  *pnew_UVrk = finalrank;
383 }
384 //void __newu(ka_matrix* _CU, double* qrtauA, ka_matrix* _U)
385 void __newu(
386  int _M,
387  int ncols_qA,
388  double* _CU, int ld_CU, int _Crk,
389  double* _U, int ld_U, int _Urk,
390  double* qrtauA
391  ) {
392  int info = 0;
393  int CU_nrows = _M;
394  int CU_ncols = ncols_qA;
395  int U_nrows = ncols_qA; //ASSUMPTION
396  int U_ncols = _Urk; //ASSUMPTION
397  int nrows = CU_nrows - U_nrows;
398  double zero = 0.0;
399  if(gemm_print_index){
400  printf(" NEWU\t|%d\t| zero nrows:%d U_ncols:%d ld_U:%d _Crk:%d CU_ncols:%d _Urk:%d CU_nrows:%d U_nrows:%d diff:%d\n",
401  __LINE__, nrows, U_ncols, ld_U, _Crk, CU_ncols, _Urk, CU_nrows, U_nrows, nrows);
402  }
403  //info = LAPACKE_zlaset(LAPACK_COL_MAJOR, 'A', nrows, U_ncols,
404  // zero, zero, &_U[U_nrows], ld_U);
405 
406  char uplo = 'A';
407  LAPACK_dlaset( &uplo, &nrows, &U_ncols, &zero, &zero, &_U[U_nrows], &ld_U );
408  if(info != 0){
409  fprintf(stderr,
410  "%s %d ERROR in LAPACKE_dlaset() info=%d\n",
411  __FILE__, __LINE__, info);
412  }
413  if(gemm_print_mat){
414  hc_printmat(_U, _M, _M, ld_U);
415  }
416 
417  //zunmqr
418  info = LAPACKE_dormqr(LAPACK_COL_MAJOR, 'L', 'N',
419  CU_nrows, U_ncols, ncols_qA, _CU, ld_CU, qrtauA, _U, ld_U);
420  if(gemm_print_index){
421  printf(" NEWU\t|%d\t| ormqr CU_nrows (new U_nrows):%d U_ncols:%d ncols_qA:%d ld_CU:%d ld_U:%d\n",
422  __LINE__, CU_nrows, U_ncols, ncols_qA, ld_CU, ld_U);
423  }
424  if(gemm_print_mat){
425  hc_printmat(_U, _M, _M, ld_U);
426  }
427  if(info != 0){
428  fprintf(stderr,
429  "%s %d ERROR in LAPACKE_dormqr() info=%d\n",
430  __FILE__, __LINE__, info);
431  printf(" NEWU\t|%d\t| ormqr CU_nrows (new U_nrows):%d U_ncols:%d ncols_qA:%d ld_CU:%d ld_U:%d U_nrows:%d\n",
432  __LINE__, CU_nrows, U_ncols, ncols_qA, ld_CU, ld_U, U_nrows);
433  if(0){
434  int i, j, ssend, ldarr;
435  double *arr;
436  arr = _CU; ldarr = ld_CU; ssend = ncols_qA;
437  for(i=0;i<50;i++){
438  for(j=0;j<3;j++){
439  printf("%.3e ", arr[j*ldarr+i]);
440  }
441  printf("...");
442  for(j=ssend-4;j<ssend;j++){
443  printf("%.3e ", arr[j*ldarr+i]);
444  }
445  printf("\n");
446  }
447  printf("\n");
448  arr = _U; ldarr = ld_U; ssend = U_ncols;
449  for(i=0;i<50;i++){
450  for(j=0;j<3;j++){
451  printf("%.3e ", arr[j*ldarr+i]);
452  }
453  printf("...");
454  for(j=ssend-4;j<ssend;j++){
455  printf("%.3e ", arr[j*ldarr+i]);
456  }
457  printf("\n");
458  }
459  for(j=0;j<U_ncols;j++){
460  for(i=0; i < CU_nrows; i++){
461  double val = _U[j*ld_U+i];
462  if(val != val){
463  printf("%d,%d is nan (%g) CU_nrows:%d U_ncols:%d ld_U:%d\n", i, j, val, CU_nrows, U_ncols, ld_U);
464  }
465  }
466  }
467  }
468  exit(-1);
469  }
470  U_nrows = CU_nrows;
471  LAPACKE_dlacpy(LAPACK_COL_MAJOR, 'A', U_nrows, U_ncols,
472  _U, ld_U, _CU, ld_CU);
473  if(gemm_print_index){
474  printf(" NEWU\t|%d\t| copy U_nrows:%d U_ncols:%d ld_CU:%d ld_U:%d\n",
475  __LINE__, U_nrows, U_ncols, ld_CU, ld_U);
476  if(0){
477  int i, j;
478  for(i=0;i<100;i++){
479  for(j=0;j<3;j++){
480  printf("%.3e ", _CU[j*ld_CU+i]);
481  }
482  printf("...");
483  for(j=U_ncols-4;j<U_ncols;j++){
484  printf("%.3e ", _CU[j*ld_CU+i]);
485  }
486  printf("\n");
487  }
488  getc(stdin);
489  }
490  }
491 }
492 //void __newv(ka_matrix* _CV, double* qrtauB, ka_matrix* _V)
493 void __newv(
494  int _M,
495  int ncols_qB,
496  double* _CV, int ld_CV, int _Crk,
497  double* _V, int ld_V, int _Vrk,
498  double* qrtauB
499  ) {
500  int info;
501  int CV_nrows = _M;
502  int CV_ncols = ncols_qB;
503  int V_nrows = _Vrk; //ASSUMPTION
504  int V_ncols = ncols_qB; //ASSUMPTION
505  int ncols = CV_nrows - V_ncols;
506  double zero = 0.0;
507  if(gemm_print_index){
508  printf(" NEWV\t|%d\t| zero V_nrows:%d ncols:%d ld_V:%d _Crk:%d CV_ncols:%d _Vrk:%d CV_nrows:%d V_ncols:%d diff:%d\n",
509  __LINE__, V_nrows, ncols, ld_V, _Crk, CV_ncols, _Vrk, CV_nrows, V_ncols, ncols);
510  }
511  //LAPACKE_zlaset(LAPACK_COL_MAJOR, 'A', V_nrows, ncols,
512  // zero, zero, &(_V[V_ncols*ld_V]), ld_V);
513  char uplo = 'A';
514  size_t iv = V_ncols*ld_V;
515  LAPACK_dlaset( &uplo, &V_nrows, &ncols, &zero, &zero, &(_V[iv]), &ld_V );
516  if(gemm_print_mat){
517  hc_printmat(_V, _M, _M, ld_V);
518  }
519  if(gemm_print_index){
520  printf(" NEWV\t|%d\t| ormqr V_nrows:%d CV_nrows:%d ncols_qB:%d ld_CV:%d ld_V:%d\n",
521  __LINE__, V_nrows, CV_nrows, ncols_qB, ld_CV, ld_V);
522  }
523  //zunmqr
524  info = LAPACKE_dormqr(LAPACK_COL_MAJOR, 'R', 'T',
525  V_nrows, CV_nrows, ncols_qB, _CV, ld_CV, qrtauB, _V, ld_V);
526  if(gemm_print_mat){
527  hc_printmat(_V, _M, _M, ld_V);
528  }
529  if(info != 0){
530  fprintf(stderr,
531  "%s %d ERROR in LAPACKE_dormqr() info=%d\n",
532  __FILE__, __LINE__, info);
533  exit(-1);
534  }
535  V_ncols = CV_nrows;
536  CV_ncols = V_nrows;
537  if(gemm_print_index){
538  printf(" NEWV\t|%d\t| trans V_nrows:%d V_ncols:%d ld_V:%d ld_CV%d\n",
539  __LINE__, V_nrows, V_ncols, ld_V, ld_CV);
540  }
541  LAPACKE_dge_trans(LAPACK_COL_MAJOR, V_nrows, V_ncols,
542  _V, ld_V, _CV, ld_CV);
543  if(gemm_print_index){
544  printf(" NEWV\t|%d\t| copy V_nrows:%d V_ncols:%d ld_CV:%d ld_V:%d\n",
545  __LINE__, V_nrows, V_ncols, ld_CV, ld_V);
546  if(0){
547  int i, j;
548  for(i=0;i<100;i++){
549  for(j=0;j<3;j++){
550  printf("%.3e ", _CV[j*ld_CV+i]);
551  }
552  printf("...");
553  for(j=V_ncols-4;j<V_ncols;j++){
554  printf("%.3e ", _CV[j*ld_CV+i]);
555  }
556  printf("\n");
557  }
558  getc(stdin);
559  }
560  }
561  if(gemm_print_mat){
562  hc_printmat(_CV, _M, _M, ld_CV);
563  }
564 }
568 void HCORE_zgemm(MORSE_enum transA, int transB,
569  int M, int N,
570  double alpha,
571  double *AU,
572  double *AV,
573  double *Ark,
574  int LDA,
575  double *BU,
576  double *BV,
577  double *Brk,
578  int LDB,
579  double beta,
580  double *CU,
581  double *CV,
582  double *Crk,
583  int LDC,
584  int rk,
585  int maxrk,
586  double acc,
587  double* work
588 )
589 {
590  if(gemm_print_index){
591  printf("%d:%s work:%p ", __LINE__, __func__, work);
592  printf("M:%d N:%d LDA:%d LDB:%d LDC:%d rk:%d maxrk:%d acc:%e a:%e b:%e\n",
593  M, N, LDA, LDB, LDC, rk, maxrk, acc, alpha, beta);
594  }
595 
596  int new_Crk = 0;
597  /*
598  * NOTES:
599  * assumptions on matrix dimensions are marked as //ASSUMPTION
600  */
601  /*printf("%d %d|%g->%d. %g %g, %g %g\n", */
602  /*__LINE__, __COUNTER__,*/
603  /*Crk[0], new_Crk, CU[0], CU[1], CV[0], CV[1]);*/
604  //hcore_dgemm(Aij, ik, Ajk, -1, rank, acc);
605  int _Ark = (int)(Ark[0]);
606  int _Brk = (int)(Brk[0]);
607  int _Crk = (int)(Crk[0]);
608  if(_Ark == 0 || _Brk == 0 || _Crk == 0){
609  fprintf(stderr, "%s %d: _Ark=%d _Brk=%d _Crk=%d. These rank values should not be zero.\n", __FILE__, __LINE__, _Ark, _Brk, _Crk);
610  exit(-1);
611  //return MORSE_ERR_ILLEGAL_VALUE;
612  }
613 
614  int _M = M; int _N = N;
615  double* _CU = CU; int ld_CU = LDC;
616  double* _CV = CV; int ld_CV = LDC; int* pnew_Crk = &new_Crk;
617  double* _AU = AU; int ld_AU = LDA;
618  double* _AV = AV; int ld_AV = LDA;
619  double* _BU = BU; int ld_BU = LDB;
620  double* _BV = BV; int ld_BV = LDB;
621  int rank = rk;
622  { // FIXME remove these extra braces
623 
624 
625  //printf("%s %s %d maxrank=%d\n", __FILE__, __func__, __LINE__, maxrk);
626  char chall = 'A';
627  double* CUclone = NULL;
628  size_t CUclone_nelm = _M * 2 * maxrk;
629  int ld_CUclone = _M;
630  int use_CUV_clone = 1;
631  if(use_CUV_clone == 1) {
632  if(use_scratch){
633  CUclone = work;
634  work += CUclone_nelm;
635  } else {
636  CUclone = malloc(CUclone_nelm * sizeof(double));
637  }
638  LAPACK_dlacpy(&chall,
639  &_M, &_Crk,
640  _CU, &ld_CU,
641  CUclone, &ld_CUclone);
642  }
643 
644  double* CVclone = NULL;
645  size_t CVclone_nelm = _M * 2 * maxrk;
646  int ld_CVclone = _M;
647  if(use_CUV_clone == 1) {
648  if(use_scratch){
649  CVclone = work;
650  work += CVclone_nelm;
651  } else {
652  CVclone = malloc(CVclone_nelm * sizeof(double));
653  }
654  LAPACK_dlacpy(&chall,
655  &_M, &_Crk,
656  _CV, &ld_CV,
657  CVclone, &ld_CVclone);
658  }
659  double* _CU_save = _CU;
660  double* _CV_save = _CV;
661 
662  if(use_CUV_clone == 1) {
663  _CU = CUclone;
664  _CV = CVclone;
665  }
666  int nb = _M;//ASSUMPTION
667  double* qrtauA = NULL;
668  size_t qrtauA_nelm = nb;
669  if(use_scratch){
670  qrtauA = work;
671  } else {
672  qrtauA = malloc(qrtauA_nelm * sizeof(double));
673  }
674  assert(qrtauA != NULL);
675  double* qrtauB = NULL;
676  size_t qrtauB_nelm = nb;
677  if(use_scratch){
678  qrtauB = work + qrtauA_nelm;
679  } else {
680  qrtauB = malloc(qrtauB_nelm * sizeof(double));
681  }
682  assert(qrtauB != NULL);
683 
684 
685 
686  int CU_ncols = 0;
687  //Warning:In this function, there are assumptions
688  //on leading dimensions and number of rows/cols
689  __qra(_M, maxrk, _CU, ld_CU, _Crk, &CU_ncols, _AU, ld_AU, _Ark, alpha, beta, qrtauA);
690  //output: _CU contains [_CU _AU]. use _Crk+_Ark as number of cols
691  assert(CU_ncols == (_Crk + _Ark));
692 
693  int CV_ncols = 0;
694  //Warning:In this function, there are assumptions
695  //on leading dimensions and number of rows/cols
696  double* qrb_aubut = NULL;
697  size_t qrb_aubut_nelm = maxrk * maxrk;
698  if(use_scratch){
699  qrb_aubut = work + qrtauA_nelm + qrtauB_nelm;
700  } else {
701  qrb_aubut = malloc(qrb_aubut_nelm * sizeof(double));
702  }
703  __qrb(_M, maxrk, _CV, ld_CV, _Crk, &CV_ncols,
704  _AV, ld_AV, _Ark, _BU, ld_BU, _BV, ld_BV, _Brk, qrtauB, qrb_aubut);
705  if(CU_ncols == 0 || CV_ncols == 0){
706  fprintf(stderr, "%s %d: CU_ncols=%d CV_ncols=%d. These values should not be zero.\n", __FILE__, __LINE__, CU_ncols, CV_ncols);
707  exit(-1);
708  //return MORSE_ERR_ILLEGAL_VALUE;
709  }
710  if(use_scratch == 0) {
711  free(qrb_aubut);
712  }
713  assert(CU_ncols == CV_ncols);
714 
715 
716  int ld_newU = nb; //ASSUMPTION
717  int ld_newV = maxrk; //ASSUMPTION
718  int new_UVrk;
719  double* newU = NULL;
720  size_t newU_nelm = nb * maxrk;
721  if(use_scratch){
722  newU = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm;
723  } else {
724  newU = malloc(newU_nelm * sizeof(double));
725  }
726  double* newV = NULL;
727  size_t newV_nelm = nb * maxrk;
728 
729  if(use_scratch){
730  newV = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm + newU_nelm;
731  } else {
732  newV = malloc(newV_nelm * sizeof(double));
733  }
734  assert(newU != NULL);
735  assert(newV != NULL);
736  double *svd_rA = NULL;
737  int svd_rA_nrows = chameleon_min(_M, CU_ncols);
738  int svd_rA_ncols = CU_ncols; //ASSUMPTION
739  size_t svd_rA_nelm;
740  if(use_trmm == 1){ // allocate more because rA will be output of trmm(rA*rB^T)
741  svd_rA_nelm = svd_rA_nrows * svd_rA_nrows;
742  } else {
743  svd_rA_nelm = svd_rA_nrows * svd_rA_ncols;
744  }
745  if(use_scratch == 1) {
746  svd_rA = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm + newU_nelm + newV_nelm;
747  } else {
748  svd_rA = malloc(svd_rA_nelm * sizeof(double));
749  }
750  double *svd_rB = NULL;
751  int svd_rB_nrows = chameleon_min(_M, CV_ncols);
752  int svd_rB_ncols = CV_ncols;
753  size_t svd_rB_nelm;
754  if(use_trmm == 0) {
755  svd_rB_nelm = svd_rB_nrows * svd_rB_ncols;
756  if(use_scratch == 1) {
757  svd_rB = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm + newU_nelm + newV_nelm + svd_rA_nelm;
758  } else {
759  svd_rB = malloc(svd_rB_nelm * sizeof(double));
760  }
761  } else {
762  svd_rB_nelm = 0; //This variable will be used for calculating offset for work array
763  }
764  double *svd_T = NULL;
765  int svd_T_nrows = svd_rA_ncols;
766  int svd_T_ncols = svd_rA_ncols;
767  size_t svd_T_nelm;
768  if(use_trmm == 0) {
769  svd_T_nelm = svd_T_nrows * svd_T_ncols;
770  if(use_scratch == 1) {
771  svd_T = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm + newU_nelm + newV_nelm + svd_rA_nelm + svd_rB_nelm;
772  } else {
773  svd_T = malloc(svd_T_nelm * sizeof(double));
774  }
775  } else {
776  svd_T_nelm = 0; //This variable will be used for calculating offset for work array
777  }
778 
779  double *svd_sigma = NULL;
780  double *svd_superb = NULL;
781  size_t svd_sigma_nelm = svd_T_nrows;
782  size_t svd_superb_nelm = svd_T_nrows;
783  if(use_scratch == 1){
784  svd_sigma = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm + newU_nelm + newV_nelm + svd_rA_nelm + svd_rB_nelm + svd_T_nelm;
785  svd_superb = work + qrtauA_nelm + qrtauB_nelm + qrb_aubut_nelm + newU_nelm + newV_nelm + svd_rA_nelm + svd_rB_nelm + svd_T_nelm + svd_sigma_nelm;
786  } else {
787  svd_sigma = malloc(svd_sigma_nelm * sizeof(double));
788  svd_superb = malloc(svd_superb_nelm * sizeof(double));
789  }
790  if(ld_newV < CU_ncols){
791  fprintf(stderr, "%s %d: Increase maxrank. %d is not enough ld_newV:%d CU_ncols:%d\n", __FILE__, __LINE__, maxrk, ld_newV, CU_ncols);
792  exit(-1);
793  //return MORSE_ERR_ILLEGAL_VALUE;
794  }
795  __svd(
796  _M,
797  maxrk,
798  _CU, ld_CU,
799  _CV, ld_CV, CU_ncols,
800  newU, ld_newU,
801  newV, ld_newV, &new_UVrk ,
802  rank, acc,
803  svd_rA, svd_rB, svd_T, svd_sigma, svd_superb
804  );
805  if(use_scratch == 0) {
806  free(svd_rA);
807  }
808  if(use_scratch == 0 && use_trmm == 0){
809  free(svd_rB);
810  // free(svd_T);
811  }
812  int ncols_qA = CU_ncols;
813  //__newu(_CU, qrtauA, newU);
814  __newu(_M,
815  ncols_qA,
816  _CU, ld_CU, _Crk,
817  newU, ld_newU, new_UVrk,
818  qrtauA
819  );
820  //Warning: number of columns of _CU is new_UVrk now!
821 
822  int ncols_qB = CV_ncols;
823  //__newv(_CV, qrtauB, newV);
824  __newv(_M,
825  ncols_qB,
826  _CV, ld_CV, _Crk,
827  newV, ld_newV, new_UVrk,
828  qrtauB
829  );
830  *pnew_Crk = new_UVrk;
831  //Warning: number of columns of _CV is new_UVrk now!
832 
833  //printf("%d->%d\n", _Crk, *pnew_Crk);
834 
835  if(use_CUV_clone == 1) {
836  LAPACK_dlacpy(&chall,
837  &_M, &new_UVrk,
838  CUclone, &ld_CUclone,
839  _CU_save, &ld_CU
840  );
841 
842  LAPACK_dlacpy(&chall,
843  &_M, &new_UVrk,
844  CVclone, &ld_CVclone,
845  _CV_save, &ld_CV
846  );
847  if(use_scratch == 0) {
848  free(CUclone);
849  free(CVclone);
850  }
851  }
852 
853 
854  if(use_scratch == 0) {
855  free(qrtauA);
856  free(qrtauB);
857  free(newU);
858  free(newV);
859  free(svd_sigma);
860  free(svd_superb);
861  }
862 
863  } // FIXME remove these extra braces
864 
865  int old_Crk = Crk[0];
866  if(gemm_print_index){
867  printf("Ark:%d Brk:%d Crk[0]:%d %g RANK CHANGE: %d->%d\n",
868  _Ark, _Brk, Crk[0], Crk[0], old_Crk, *pnew_Crk);
869  }
870  Crk[0] = new_Crk;
871  if(gemm_print_index){
872  int casted_Crk = (int)(Crk[0]);
873  printf("casted_Crk:%d Ark:%d Brk:%d Crk[0]:%d %g RANK CHANGE: %d->%d\n",
874  casted_Crk, _Ark, _Brk, Crk[0], Crk[0], old_Crk, new_Crk);
875  }
876 }
877 
878 /***************************************************************************/
879 // For debugging precision conversion script
880  //dormqr
881  //zormqr
882  //ormqr
883  //LAPACKE_dormqr
884  //LAPACKE_zormqr //
885  //LAPACKE_ormqr //
886  //dunmqr
887  //zunmqr
888  //unmqr
889  //LAPACKE_dunmqr
890  //LAPACKE_zunmqr //
891  //LAPACKE_unmqr //
int hc_nelm_limit
Definition: hcore_zgemm.c:31
int gemm_print_index
Definition: hcore_zgemm.c:29
void __newu(int _M, int ncols_qA, double *_CU, int ld_CU, int _Crk, double *_U, int ld_U, int _Urk, double *qrtauA)
Definition: hcore_zgemm.c:385
#define A(m, n)
Definition: pzgemm.c:56
void __newv(int _M, int ncols_qB, double *_CV, int ld_CV, int _Crk, double *_V, int ld_V, int _Vrk, double *qrtauB)
Definition: hcore_zgemm.c:493
void __qrb(int _M, int maxrank, double *_CV, int ld_CV, int _Crk, int *pnew_CV_ncols, double *_AV, int ld_AV, int _Ark, double *_BU, int ld_BU, double *_BV, int ld_BV, int _Brk, double *qrtauB, double *AcolBcolT)
Definition: hcore_zgemm.c:123
void __qra(int _M, int maxrank, double *_CU, int ld_CU, int _Crk, int *pnew_CU_ncols, double *_AU, int ld_AU, int _Ark, double alpha, double beta, double *qrtauA)
Definition: hcore_zgemm.c:62
int gemm_print_mat
Definition: hcore_zgemm.c:30
void HCORE_zgemm(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *work)
Definition: hcore_zgemm.c:568
#define CBLAS_SADDR(_val)
Definition: hcore_zgemm.c:25
void __svd(int _M, int maxrank, double *_CU, int ld_CU, double *_CV, int ld_CV, int _Crk, double *_U, int ld_U, double *_V, int ld_V, int *pnew_UVrk, int rank, double acc, double *_rA, double *_rB, double *_T, double *sigma, double *svdsuperb)
Definition: hcore_zgemm.c:198
int uplo[2]
int use_trmm
Definition: hcore_zgemm.c:27
void hc_printmat(double *A, int m, int n, int ld)
Definition: hcore_zgemm.c:32
int use_scratch