Mercurial > hg > octave-jordi
diff liboctave/CMatrix.cc @ 9665:1dba57e9d08d
use blas_trans_type for xgemm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sat, 26 Sep 2009 10:41:07 +0200 |
parents | 7e5b4de5fbfe |
children | f80c566bc751 |
line wrap: on
line diff
--- a/liboctave/CMatrix.cc +++ b/liboctave/CMatrix.cc @@ -3784,20 +3784,19 @@ // the general GEMM operation ComplexMatrix -xgemm (bool transa, bool conja, const ComplexMatrix& a, - bool transb, bool conjb, const ComplexMatrix& b) +xgemm (const ComplexMatrix& a, const ComplexMatrix& b, + blas_trans_type transa, blas_trans_type transb) { ComplexMatrix retval; - // conjugacy is ignored if no transpose - conja = conja && transa; - conjb = conjb && transb; - - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + bool cja = transa == blas_conj_trans, cjb = transb == blas_conj_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3805,18 +3804,18 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = ComplexMatrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = ComplexMatrix (a_nr, b_nc); Complex *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa, conja); - if (conja || conjb) + const char *ctra = get_blas_trans_arg (tra, cja); + if (cja || cjb) { F77_XFCN (zherk, ZHERK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3828,7 +3827,7 @@ else { F77_XFCN (zsyrk, ZSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3850,38 +3849,38 @@ if (b_nc == 1 && a_nr == 1) { - if (conja == conjb) + if (cja == cjb) { F77_FUNC (xzdotu, XZDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); - if (conja) *c = std::conj (*c); + if (cja) *c = std::conj (*c); } - else if (conja) + else if (cja) F77_FUNC (xzdotc, XZDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); else F77_FUNC (xzdotc, XZDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); } - else if (b_nc == 1 && ! conjb) + else if (b_nc == 1 && ! cjb) { - const char *ctransa = get_blas_trans_arg (transa, conja); - F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } - else if (a_nr == 1 && ! conja && ! conjb) + else if (a_nr == 1 && ! cja && ! cjb) { - const char *crevtransb = get_blas_trans_arg (! transb, conjb); - F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb, cjb); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa, conja); - const char *ctransb = get_blas_trans_arg (transb, conjb); - F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + const char *ctrb = get_blas_trans_arg (trb, cjb); + F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3896,7 +3895,7 @@ ComplexMatrix operator * (const ComplexMatrix& a, const ComplexMatrix& b) { - return xgemm (false, false, a, false, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max