Mercurial > hg > octave-thorsten
changeset 10461:81067c72361f
optimize kron
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Thu, 25 Mar 2010 14:19:29 +0100 |
parents | 4975d63bb2df |
children | 97a8ef453440 |
files | src/ChangeLog src/DLD-FUNCTIONS/kron.cc |
diffstat | 2 files changed, 168 insertions(+), 142 deletions(-) [+] |
line wrap: on
line diff
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,7 @@ +2010-03-25 Jaroslav Hajek <highegg@gmail.com> + + * kron.cc (Fkron): Completely rewrite. + 2010-03-24 John W. Eaton <jwe@octave.org> * version.h.in (OCTAVE_BUGS_STATEMENT): Point to
--- a/src/DLD-FUNCTIONS/kron.cc +++ b/src/DLD-FUNCTIONS/kron.cc @@ -27,83 +27,87 @@ #endif #include "dMatrix.h" +#include "fMatrix.h" #include "CMatrix.h" +#include "fCMatrix.h" + +#include "dSparse.h" +#include "CSparse.h" + +#include "dDiagMatrix.h" +#include "fDiagMatrix.h" +#include "CDiagMatrix.h" +#include "fCDiagMatrix.h" + +#include "PermMatrix.h" + +#include "mx-inlines.cc" #include "quit.h" #include "defun-dld.h" #include "error.h" #include "oct-obj.h" -#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) -extern void -kron (const Array<double>&, const Array<double>&, Array<double>&); +template <class R, class T> +static MArray<T> +kron (const MArray<R>& a, const MArray<T>& b) +{ + assert (a.ndims () == 2); + assert (b.ndims () == 2); + + octave_idx_type nra = a.rows (), nrb = b.rows (); + octave_idx_type nca = a.cols (), ncb = b.cols (); -extern void -kron (const Array<Complex>&, const Array<Complex>&, Array<Complex>&); + MArray<T> c (nra*nrb, nca*ncb); + T *cv = c.fortran_vec (); + + for (octave_idx_type ja = 0; ja < nca; ja++) + for (octave_idx_type jb = 0; jb < ncb; jb++) + for (octave_idx_type ia = 0; ia < nra; ia++) + { + octave_quit (); + mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb); + cv += nrb; + } -extern void -kron (const Array<float>&, const Array<float>&, Array<float>&); + return c; +} + +template <class R, class T> +static MArray<T> +kron (const MDiagArray2<R>& a, const MArray<T>& b) +{ + assert (b.ndims () == 2); + + octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length (); + octave_idx_type nca = a.cols (), ncb = b.cols (); -extern void -kron (const Array<FlaotComplex>&, const Array<FloatComplex>&, - Array<FloatComplex>&); -#endif + MArray<T> c (nra*nrb, nca*ncb, T()); + + for (octave_idx_type ja = 0; ja < dla; ja++) + for (octave_idx_type jb = 0; jb < ncb; jb++) + { + octave_quit (); + mx_inline_mul (nrb, &c.xelem(ja*nrb, ja*ncb + jb), a.dgelem (ja), b.data () + nrb*jb); + } + + return c; +} template <class T> -void -kron (const Array<T>& A, const Array<T>& B, Array<T>& C) -{ - C.resize (A.rows () * B.rows (), A.columns () * B.columns ()); - - octave_idx_type Ac, Ar, Cc, Cr; - - for (Ac = Cc = 0; Ac < A.columns (); Ac++, Cc += B.columns ()) - for (Ar = Cr = 0; Ar < A.rows (); Ar++, Cr += B.rows ()) - { - const T v = A (Ar, Ac); - for (octave_idx_type Bc = 0; Bc < B.columns (); Bc++) - for (octave_idx_type Br = 0; Br < B.rows (); Br++) - { - OCTAVE_QUIT; - C.xelem (Cr+Br, Cc+Bc) = v * B.elem (Br, Bc); - } - } -} - -template void -kron (const Array<double>&, const Array<double>&, Array<double>&); - -template void -kron (const Array<Complex>&, const Array<Complex>&, Array<Complex>&); - -template void -kron (const Array<float>&, const Array<float>&, Array<float>&); - -template void -kron (const Array<FloatComplex>&, const Array<FloatComplex>&, - Array<FloatComplex>&); - -#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) -extern void -kron (const Sparse<double>&, const Sparse<double>&, Sparse<double>&); - -extern void -kron (const Sparse<Complex>&, const Sparse<Complex>&, Sparse<Complex>&); -#endif - -template <class T> -void -kron (const Sparse<T>& A, const Sparse<T>& B, Sparse<T>& C) +static MSparse<T> +kron (const MSparse<T>& A, const MSparse<T>& B) { octave_idx_type idx = 0; - C = Sparse<T> (A.rows () * B.rows (), A.columns () * B.columns (), - A.nzmax () * B.nzmax ()); + MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (), + A.nzmax () * B.nzmax ()); C.cidx (0) = 0; for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++) for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++) { + octave_quit (); for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++) { octave_idx_type Ci = A.ridx(Ai) * B.rows (); @@ -111,23 +115,67 @@ for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++) { - OCTAVE_QUIT; C.data (idx) = v * B.data (Bi); C.ridx (idx++) = Ci + B.ridx (Bi); } } C.cidx (Aj * B.columns () + Bj + 1) = idx; } + + return C; } -template void -kron (const Sparse<double>&, const Sparse<double>&, Sparse<double>&); +static PermMatrix +kron (const PermMatrix& a, const PermMatrix& b) +{ + octave_idx_type na = a.rows (), nb = b.rows (); + const octave_idx_type *pa = a.data (), *pb = b.data (); + PermMatrix c(na*nb); // Row permutation. + octave_idx_type *pc = c.fortran_vec (); -template void -kron (const Sparse<Complex>&, const Sparse<Complex>&, Sparse<Complex>&); + bool cola = a.is_col_perm (), colb = b.is_col_perm (); + if (cola && colb) + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[pa[i]*nb+pb[j]] = i*nb+j; + } + else if (cola) + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[pa[i]*nb+j] = i*nb+pb[j]; + } + else if (colb) + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[i*nb+pb[j]] = pa[i]*nb+j; + } + else + { + for (octave_idx_type i = 0; i < na; i++) + for (octave_idx_type j = 0; j < nb; j++) + pc[i*nb+j] = pa[i]*nb+pb[j]; + } + + return c; +} -DEFUN_DLD (kron, args, nargout, "-*- texinfo -*-\n\ +template <class MTA, class MTB> +octave_value +do_kron (const octave_value& a, const octave_value& b) +{ + MTA am = octave_value_extract<MTA> (a); + MTB bm = octave_value_extract<MTB> (b); + return octave_value (kron (am, bm)); +} + +#define ALL_TYPES(AMT, BMT) \ + } while (0) \ + +DEFUN_DLD (kron, args, , "-*- texinfo -*-\n\ @deftypefn {Loadable Function} {} kron (@var{a}, @var{b})\n\ Form the kronecker product of two matrices, defined block by block as\n\ \n\ @@ -147,96 +195,70 @@ @end example\n\ @end deftypefn") { - octave_value_list retval; + octave_value retval; int nargin = args.length (); - if (nargin != 2 || nargout > 1) + if (nargin == 2) { - print_usage (); - } - else if (args(0).is_sparse_type () || args(1).is_sparse_type ()) - { - if (args(0).is_complex_type () || args(1).is_complex_type ()) + octave_value a = args(0), b = args(1); + if (a.is_perm_matrix () && b.is_perm_matrix ()) + retval = do_kron<PermMatrix, PermMatrix> (a, b); + else if (a.is_diag_matrix ()) { - SparseComplexMatrix a (args(0).sparse_complex_matrix_value()); - SparseComplexMatrix b (args(1).sparse_complex_matrix_value()); - - if (! error_state) + if (b.is_diag_matrix () && a.rows () == a.columns () + && b.rows () == b.columns ()) + { + octave_value_list tmp; + tmp(0) = a.diag (); + tmp(1) = b.diag (); + tmp = Fkron (tmp, 1); + if (tmp.length () == 1) + retval = tmp(0).diag (); + } + else if (a.is_single_type () || b.is_single_type ()) + { + if (a.is_complex_type ()) + return do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b); + else if (b.is_complex_type ()) + return do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b); + else + return do_kron<FloatDiagMatrix, FloatMatrix> (a, b); + } + else { - SparseComplexMatrix c; - kron (a, b, c); - retval(0) = c; + if (a.is_complex_type ()) + return do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b); + else if (b.is_complex_type ()) + return do_kron<DiagMatrix, ComplexMatrix> (a, b); + else + return do_kron<DiagMatrix, Matrix> (a, b); } } + else if (a.is_sparse_type () || b.is_sparse_type ()) + { + if (args(0).is_complex_type () || args(1).is_complex_type ()) + return do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b); + else + return do_kron<SparseMatrix, SparseMatrix> (a, b); + } + else if (a.is_single_type () || b.is_single_type ()) + { + if (a.is_complex_type ()) + return do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b); + else if (b.is_complex_type ()) + return do_kron<FloatMatrix, FloatComplexMatrix> (a, b); + else + return do_kron<FloatMatrix, FloatMatrix> (a, b); + } else { - SparseMatrix a (args(0).sparse_matrix_value ()); - SparseMatrix b (args(1).sparse_matrix_value ()); - - if (! error_state) - { - SparseMatrix c; - kron (a, b, c); - retval (0) = c; - } - } - } - else - { - if (args(0).is_single_type () || args(1).is_single_type ()) - { - if (args(0).is_complex_type () || args(1).is_complex_type ()) - { - FloatComplexMatrix a (args(0).float_complex_matrix_value()); - FloatComplexMatrix b (args(1).float_complex_matrix_value()); - - if (! error_state) - { - FloatComplexMatrix c; - kron (a, b, c); - retval(0) = c; - } - } + if (a.is_complex_type ()) + return do_kron<ComplexMatrix, ComplexMatrix> (a, b); + else if (b.is_complex_type ()) + return do_kron<Matrix, ComplexMatrix> (a, b); else - { - FloatMatrix a (args(0).float_matrix_value ()); - FloatMatrix b (args(1).float_matrix_value ()); - - if (! error_state) - { - FloatMatrix c; - kron (a, b, c); - retval (0) = c; - } - } - } - else - { - if (args(0).is_complex_type () || args(1).is_complex_type ()) - { - ComplexMatrix a (args(0).complex_matrix_value()); - ComplexMatrix b (args(1).complex_matrix_value()); - - if (! error_state) - { - ComplexMatrix c; - kron (a, b, c); - retval(0) = c; - } - } - else - { - Matrix a (args(0).matrix_value ()); - Matrix b (args(1).matrix_value ()); - - if (! error_state) - { - Matrix c; - kron (a, b, c); - retval (0) = c; - } - } + return do_kron<Matrix, Matrix> (a, b); } }