16 #include "coreblas/include/coreblas.h" 17 #include "coreblas/lapacke.h" 20 #include <lapacke_utils.h> 27 #define CBLAS_SADDR(_val) (_val) 32 #define ECHO_I(_val) printf("%s(%d) ", #_val, (_val)); 33 #define ECHO_f(_val) printf("%s(%e) ", #_val, (_val)); 34 #define ECHO_LN printf("\n"); 41 #ifdef HCORE_GEMM_USE_KBLAS_ACA 42 extern int kblas_ACAf(
int m,
int n,
47 double maxacc,
int maxrk,
48 double* acc,
int* rk);
55 extern void hc_printmat(
double *
A,
int m,
int n,
int ld);
93 int _Ark = (int)(Ark[0]);
ECHO_I(_Ark);
94 int _Brk = (int)(Brk[0]);
ECHO_I(_Brk);
95 int _Crk = (int)(Crk[0]);
ECHO_I(_Crk);
99 int _M = M;
int _N = N;
ECHO_I(_M);
100 double* _CU = CU;
int ld_CU = LDC;
ECHO_I(ld_CU);
101 double* _CV = CV;
int ld_CV = LDC;
ECHO_I(ld_CV);
102 double* _AU = AU;
int ld_AU = LDA;
ECHO_I(ld_AU);
103 double* _AV = AV;
int ld_AV = LDA;
ECHO_I(ld_AV);
104 double* _BU = BU;
int ld_BU = LDB;
ECHO_I(ld_BU);
105 double* _BV = BV;
int ld_BV = LDB;
ECHO_I(ld_BV);
110 int use_CUV_clone = 0;
111 double* _CU_save = _CU;
112 double* _CV_save = _CV;
113 int ld_CU_save = ld_CU;
114 int ld_CV_save = ld_CV;
116 int CUV_ncols = _Crk + _Ark;
ECHO_I(CUV_ncols);
118 if((CUV_ncols > maxrk)){
119 double* CUclone = NULL;
121 double* CVclone = NULL;
123 size_t CUclone_nelm = _M * 2 * maxrk;
124 size_t CVclone_nelm = _M * 2 * maxrk;
128 d_work += CUclone_nelm;
129 ws_needed += CUclone_nelm;
131 d_work += CVclone_nelm;
132 ws_needed += CVclone_nelm;
133 LAPACK_dlacpy(&chall,
136 CUclone, &ld_CUclone);
137 LAPACK_dlacpy(&chall,
140 CVclone, &ld_CVclone);
156 int nelm_AU = _M * _Ark;
ECHO_I(nelm_AU);
ECHO_I(_Crk*ld_CU);
157 cblas_dcopy(nelm_AU, _AU, incOne, &_CU[_Crk*ld_CU], incOne);
160 cblas_dscal(nelm_AU,
CBLAS_SADDR(alpha), &_CU[_Crk*ld_CU], incOne);
164 cblas_dscal(_M * _Crk,
CBLAS_SADDR(beta), _CU, incOne);
167 double *qrtauA = d_work;
168 size_t qrtauA_nelm = _M;
169 d_work += qrtauA_nelm;
170 ws_needed += qrtauA_nelm;
171 assert(qrtauA != NULL);
173 int info = LAPACKE_dgeqrf( LAPACK_COL_MAJOR, _M, CUV_ncols, _CU, ld_CU, qrtauA);
177 double* qrb_avtbv = d_work;
178 size_t qrb_avtbv_nelm = maxrk * maxrk;
ECHO_I(qrb_avtbv_nelm);
179 d_work += qrb_avtbv_nelm;
180 ws_needed += qrb_avtbv_nelm;
183 cblas_dgemm(CblasColMajor,
184 CblasTrans, CblasNoTrans,
192 cblas_dgemm(CblasColMajor,
193 CblasNoTrans, CblasTrans,
200 double* qrtauB = d_work;
201 size_t qrtauB_nelm = _M;
202 d_work += qrtauB_nelm;
203 ws_needed += qrtauB_nelm;
204 assert(qrtauB != NULL);
206 info = LAPACKE_dgeqrf(LAPACK_COL_MAJOR, _M, CUV_ncols, _CV, ld_CV, qrtauB);
211 size_t rA_nelm = CUV_ncols * CUV_ncols;
ECHO_I(rA_nelm);
212 int ld_rA = CUV_ncols;
214 ws_needed += rA_nelm;
222 LAPACK_dlaset(&chlow, &CUV_ncols, &CUV_ncols, &d_zero, &d_zero, rA, &ld_rA);
224 LAPACK_dlacpy(&chup, &CUV_ncols, &CUV_ncols,
229 cblas_dtrmm(CblasColMajor, CblasRight, CblasUpper, CblasTrans, CblasNonUnit,
230 CUV_ncols, CUV_ncols,
234 int finalrank = -1, size_sigma = CUV_ncols;
ECHO_I(size_sigma)
235 double relacc = (acc);
240 #ifdef HCORE_GEMM_USE_ORGQR 241 size_t TU_nelm = CUV_ncols * CUV_ncols;
ECHO_I(TU_nelm)
242 int ld_TU = CUV_ncols;
244 size_t TU_nelm = _M * CUV_ncols;
ECHO_I(TU_nelm)
248 ws_needed += TU_nelm;
251 #ifdef HCORE_GEMM_USE_ORGQR 252 size_t TV_nelm = CUV_ncols * CUV_ncols;
253 int ld_TV = CUV_ncols;
255 size_t TV_nelm = _M * CUV_ncols;
256 #ifdef HCORE_GEMM_USE_KBLAS_ACA 259 int ld_TV = CUV_ncols;
264 ws_needed += TV_nelm;
266 double *d_sigma = d_work;
267 size_t d_sigma_nelm = CUV_ncols;
ECHO_I(d_sigma_nelm)
268 d_work += d_sigma_nelm;
269 ws_needed += d_sigma_nelm;
271 #if defined HCORE_GEMM_USE_KBLAS_ACA 274 kblas_ACAf( CUV_ncols, CUV_ncols,
280 &finalacc, &finalrank);
283 double* svdsuperb = d_work;
286 info = LAPACKE_dgesvd_work( LAPACK_COL_MAJOR,
'A',
'A',
287 CUV_ncols, CUV_ncols,
291 NULL, CUV_ncols, &work_query, lwork );
292 lwork = (int)work_query;
293 size_t svdsuperb_nelm = lwork;
294 d_work += svdsuperb_nelm;
295 ws_needed += svdsuperb_nelm;
297 info = LAPACKE_dgesvd_work( LAPACK_COL_MAJOR,
299 CUV_ncols, CUV_ncols,
307 double *h_sigma = d_sigma;
311 if(rank > size_sigma)
312 finalrank = size_sigma;
315 int newrank = size_sigma;
317 for(i=2;i<size_sigma;i++){
319 if(h_sigma[i] < relacc)
325 finalrank = newrank;
ECHO_I(finalrank)
330 for(k = 0; k < finalrank; k++){
331 double diagval = h_sigma[k];
333 cblas_dscal(CUV_ncols,
CBLAS_SADDR(diagval), &TV[k], ld_TV);
336 Crk[0] = (double)finalrank;
ECHO_f(Crk[0])
340 #if defined HCORE_GEMM_USE_ORGQR 341 double* newUV = d_work;
342 size_t newUV_nelm = _M * finalrank;
343 d_work += newUV_nelm;
345 info = LAPACKE_dorgqr( LAPACK_COL_MAJOR,
346 _M, CUV_ncols, CUV_ncols,
349 cblas_dgemm(CblasColMajor,
350 CblasNoTrans, CblasNoTrans,
351 _M, finalrank, CUV_ncols,
354 CBLAS_SADDR(d_zero), use_CUV_clone ? _CU_save : newUV, use_CUV_clone ? ld_CU_save : ld_CU);
357 LAPACKE_dlacpy(LAPACK_COL_MAJOR,
'A', _M, finalrank, newUV, ld_CU, _CU_save, ld_CU_save);
360 int nrows = _M - CUV_ncols;
361 int ncols = finalrank;
362 LAPACK_dlaset( &
uplo, &nrows, &ncols, &d_zero, &d_zero, &(TU[CUV_ncols]), &ld_TU );
364 info = LAPACKE_dormqr( LAPACK_COL_MAJOR,
366 _M, finalrank, CUV_ncols,
371 LAPACKE_dlacpy(LAPACK_COL_MAJOR,
'A', _M, finalrank, TU, ld_TU, _CU_save, ld_CU_save);
376 #ifdef HCORE_GEMM_USE_ORGQR 377 info = LAPACKE_dorgqr( LAPACK_COL_MAJOR,
378 _M, CUV_ncols, CUV_ncols,
382 cblas_dgemm(CblasColMajor,
384 #ifdef HCORE_GEMM_USE_KBLAS_ACA
389 _M, finalrank, CUV_ncols,
392 CBLAS_SADDR(d_zero), use_CUV_clone ? _CV_save : newUV, use_CUV_clone ? ld_CV_save : ld_CV);
395 LAPACKE_dlacpy(LAPACK_COL_MAJOR,
'A', _M, finalrank, newUV, ld_CV, _CV_save, ld_CV_save);
397 #ifdef HCORE_GEMM_USE_KBLAS_ACA 398 int TV_pad = CUV_ncols;
399 nrows = _M - CUV_ncols;
402 int TV_pad = CUV_ncols * ld_TV;
404 ncols = _M - CUV_ncols;
406 LAPACK_dlaset( &
uplo, &nrows, &ncols, &d_zero, &d_zero, &(TV[TV_pad]), &ld_TV );
408 info = LAPACKE_dormqr( LAPACK_COL_MAJOR,
409 #ifdef HCORE_GEMM_USE_KBLAS_ACA
411 _M, finalrank, CUV_ncols,
414 finalrank, _M, CUV_ncols,
420 #ifdef HCORE_GEMM_USE_KBLAS_ACA 421 LAPACKE_dlacpy(LAPACK_COL_MAJOR,
'A', _M, finalrank, TV, ld_TV, _CV_save, ld_CV_save);
423 LAPACKE_dge_trans(LAPACK_COL_MAJOR, finalrank, _M, TV, ld_TV, _CV_save, ld_CV_save);
void hc_printmat(double *A, int m, int n, int ld)
void HCORE_zgemm_fast(MORSE_enum transA, int transB, int M, int N, double alpha, double *AU, double *AV, double *Ark, int LDA, double *BU, double *BV, double *Brk, int LDB, double beta, double *CU, double *CV, double *Crk, int LDC, int rk, int maxrk, double acc, double *d_work)
#define CBLAS_SADDR(_val)