46 #include "hicma_common.h" 48 #include "control/common.h" 50 #include "coreblas/lapacke.h" 56 #define A(m, n) AUV, m, n 57 #define B(m, n) BUV, m, n 58 #define C(m, n) CUV, m, n 60 #define AUV(m, n) AUV, Ark, m, n 61 #define BUV(m, n) BUV, Brk, m, n 62 #define CUV(m, n) CUV, Crk, m, n 70 double alpha, MORSE_desc_t *
AUV, MORSE_desc_t *Ark,
72 MORSE_desc_t *
BUV, MORSE_desc_t *Brk,
73 double beta, MORSE_desc_t *
CUV, MORSE_desc_t *Crk,
75 MORSE_sequence_t *sequence, MORSE_request_t *request,
76 int rk,
int maxrk,
double acc)
78 MORSE_context_t *morse;
79 MORSE_option_t options;
82 int ldam, ldak, ldbn, ldbk, ldcm;
83 int tempmm, tempnn, tempkn, tempkm;
89 double zone = (double)1.0;
93 morse = morse_context_self();
94 if (sequence->status != MORSE_SUCCESS)
96 RUNTIME_options_init(&options, morse, sequence, request);
100 2 *
CUV->mb * 2 * maxrk
103 + 2 *
CUV->mb * maxrk
104 + (2*maxrk) * (2*maxrk)
108 #ifdef HCORE_GEMM_USE_ORGQR
115 int info = LAPACKE_dgesvd_work( LAPACK_COL_MAJOR,
'A',
'A',
120 NULL, 2*maxrk, &work_query, lwork );
121 lwork = (int)work_query;
124 ws_worker += (2*maxrk);
127 ws_worker *=
sizeof(double);
128 RUNTIME_options_ws_alloc( &options, ws_worker, ws_host );
130 for (m = 0; m <
CUV->mt; m++) {
132 ldcm = BLKLDD(
CUV, m);
133 for (n = 0; n <
CUV->nt; n++) {
138 if (transA == MorseNoTrans) {
139 ldam = BLKLDD(
AUV, m);
140 if (transB == MorseNoTrans) {
141 for (k = 0; k <
AUV->nt; k++) {
143 ldbk = BLKLDD(
BUV, k);
144 zbeta = k == 0 ? beta : zone;
149 alpha,
AUV(m, k), ldam,
151 zbeta,
CUV(m, n), ldcm,
154 RUNTIME_barrier(morse);
162 ldbn = BLKLDD(
BUV, n);
163 for (k = 0; k <
AUV->nt; k++) {
165 zbeta = k == 0 ? beta : zone;
170 alpha,
AUV(m, k), ldam,
172 zbeta,
CUV(m, n), ldcm,
175 RUNTIME_barrier(morse);
184 if (transB == MorseNoTrans) {
185 for (k = 0; k <
AUV->mt; k++) {
187 ldak = BLKLDD(
AUV, k);
188 ldbk = BLKLDD(
BUV, k);
189 zbeta = k == 0 ? beta : zone;
194 alpha,
AUV(k, m), ldak,
196 zbeta,
CUV(m, n), ldcm,
199 RUNTIME_barrier(morse);
207 ldbn = BLKLDD(
BUV, n);
208 for (k = 0; k <
AUV->mt; k++) {
210 ldak = BLKLDD(
AUV, k);
211 zbeta = k == 0 ? beta : zone;
216 alpha,
AUV(k, m), ldak,
218 zbeta,
CUV(m, n), ldcm,
221 RUNTIME_barrier(morse);
226 RUNTIME_data_flush( sequence,
C(m, n) );
228 if (transA == MorseNoTrans) {
229 for (k = 0; k <
AUV->nt; k++) {
231 RUNTIME_data_flush( sequence,
A(m, k) );
234 for (k = 0; k <
AUV->mt; k++) {
236 RUNTIME_data_flush( sequence,
A(k, m) );
243 RUNTIME_options_ws_free(&options);
244 RUNTIME_options_finalize(&options, morse);
void hicma_pzgemm(MORSE_enum transA, MORSE_enum transB, double alpha, MORSE_desc_t *AUV, MORSE_desc_t *Ark, MORSE_desc_t *BUV, MORSE_desc_t *Brk, double beta, MORSE_desc_t *CUV, MORSE_desc_t *Crk, MORSE_sequence_t *sequence, MORSE_request_t *request, int rk, int maxrk, double acc)
void HICMA_TASK_zgemm(const MORSE_option_t *options, MORSE_enum transA, int transB, int m, int n, double alpha, const MORSE_desc_t *AUV, const MORSE_desc_t *Ark, int Am, int An, int lda, const MORSE_desc_t *BUV, const MORSE_desc_t *Brk, int Bm, int Bn, int ldb, double beta, const MORSE_desc_t *CUV, const MORSE_desc_t *Crk, int Cm, int Cn, int ldc, int rk, int maxrk, double acc)
int HICMA_get_use_fast_hcore_zgemm()