Mercurial > hg > octave-lyh
diff liboctave/Sparse.cc @ 10479:ded9beac7582
optimize sparse matrix assembly
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 31 Mar 2010 10:03:55 +0200 (2010-03-31) |
parents | 197b096001b7 |
children | 19e1e4470e01 |
line wrap: on
line diff
--- a/liboctave/Sparse.cc +++ b/liboctave/Sparse.cc @@ -35,6 +35,7 @@ #include <vector> #include "Array.h" +#include "MArray.h" #include "Array-util.h" #include "Range.h" #include "idx-vector.h" @@ -551,6 +552,339 @@ } template <class T> +Sparse<T>::Sparse (const Array<T>& a, const idx_vector& r, + const idx_vector& c, octave_idx_type nr, + octave_idx_type nc, bool sum_terms) + : rep (nil_rep ()), dimensions (), idx (0), idx_count (0) +{ + if (nr < 0) + nr = r.extent (0); + else if (r.extent (nr) > nr) + (*current_liboctave_error_handler) ("sparse: row index %d out of bound %d", + r.extent (nr), nr); + + if (nc < 0) + nc = c.extent (0); + else if (c.extent (nc) > nc) + (*current_liboctave_error_handler) ("sparse: column index %d out of bound %d", + r.extent (nc), nc); + + if (--rep->count == 0) + delete rep; + rep = new SparseRep (nr, nc); + + dimensions = dim_vector (nr, nc); + + + octave_idx_type n = a.numel (), rl = r.length (nr), cl = c.length (nc); + bool a_scalar = n == 1; + if (a_scalar) + { + if (rl != 1) + n = rl; + else if (cl != 1) + n = cl; + } + + if ((rl != 1 && rl != n) || (cl != 1 && cl != n)) + (*current_liboctave_error_handler) ("sparse: dimension mismatch"); + + if (rl <= 1 && cl <= 1) + { + if (n == 1 && a(0) != T ()) + { + change_capacity (1); + xridx(0) = r(0); + xdata(0) = a(0); + for (octave_idx_type j = 0; j < nc; j++) + xcidx(j+1) = j >= c(0); + } + } + else if (a_scalar) + { + // This is completely specialized, because the sorts can be simplified. + T a0 = a(0); + if (cl == 1) + { + // Sparse column vector. Sort row indices. + idx_vector rs = r.sorted (); + + octave_quit (); + + const octave_idx_type *rd = rs.raw (); + // Count unique indices. + octave_idx_type new_nz = 1; + for (octave_idx_type i = 1; i < n; i++) + new_nz += rd[i-1] != rd[i]; + // Allocate result. + change_capacity (new_nz); + xcidx (1) = new_nz; + octave_idx_type *rri = ridx (); + T *rrd = data (); + + octave_quit (); + + octave_idx_type k = -1, l = -1; + + if (sum_terms) + { + // Sum repeated indices. + for (octave_idx_type i = 0; i < n; i++) + { + if (rd[i] != l) + { + l = rd[i]; + rri[++k] = rd[i]; + rrd[k] = a0; + } + else + rrd[k] += a0; + } + } + else + { + // Pick the last one. + for (octave_idx_type i = 1; i < n; i++) + { + if (rd[i] != l) + { + l = rd[i]; + rrd[++k] = a0; + rri[k] = rd[i]; + } + } + } + + } + else + { + idx_vector rr = r, cc = c; + const octave_idx_type *rd = rr.raw (), *cd = cc.raw (); + OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, ci, nc+1, 0); + ci[0] = 0; + // Bin counts of column indices. + for (octave_idx_type i = 0; i < n; i++) + ci[cd[i]+1]++; + // Make them cumulative, shifted one to right. + for (octave_idx_type i = 1, s = 0; i <= nc; i++) + { + octave_idx_type s1 = s + ci[i]; + ci[i] = s; + s = s1; + } + + octave_quit (); + + // Bucket sort. + OCTAVE_LOCAL_BUFFER (octave_idx_type, sidx, n); + for (octave_idx_type i = 0; i < n; i++) + sidx[ci[cd[i]+1]++] = rd[i]; + + // Subsorts. We don't need a stable sort, all values are equal. + xcidx(0) = 0; + for (octave_idx_type j = 0; j < nc; j++) + { + std::sort (sidx + ci[j], sidx + ci[j+1]); + octave_idx_type l = -1, nzj = 0; + // Count. + for (octave_idx_type i = ci[j]; i < ci[j+1]; i++) + { + octave_idx_type k = sidx[i]; + if (k != l) + { + l = k; + nzj++; + } + } + // Set column pointer. + xcidx(j+1) = xcidx(j) + nzj; + } + + change_capacity (xcidx (nc)); + octave_idx_type *rri = ridx (); + T *rrd = data (); + + // Fill-in data. + for (octave_idx_type j = 0, jj = -1; j < nc; j++) + { + octave_quit (); + octave_idx_type l = -1; + if (sum_terms) + { + // Sum adjacent terms. + for (octave_idx_type i = ci[j]; i < ci[j+1]; i++) + { + octave_idx_type k = sidx[i]; + if (k != l) + { + l = k; + rrd[++jj] = a0; + rri[jj] = k; + } + else + rrd[jj] += a0; + } + } + else + { + // Use the last one. + for (octave_idx_type i = ci[j]; i < ci[j+1]; i++) + { + octave_idx_type k = sidx[i]; + if (k != l) + { + l = k; + rrd[++jj] = a0; + rri[jj] = k; + } + } + } + } + } + } + else if (cl == 1) + { + // Sparse column vector. Sort row indices. + Array<octave_idx_type> rsi; + idx_vector rs = r.sorted (rsi); + + octave_quit (); + + const octave_idx_type *rd = rs.raw (), *rdi = rsi.data (); + // Count unique indices. + octave_idx_type new_nz = 1; + for (octave_idx_type i = 1; i < n; i++) + new_nz += rd[i-1] != rd[i]; + // Allocate result. + change_capacity (new_nz); + xcidx(1) = new_nz; + octave_idx_type *rri = ridx (); + T *rrd = data (); + + octave_quit (); + + octave_idx_type k = 0; + rri[k] = rd[0]; + rrd[k] = a(rdi[0]); + + if (sum_terms) + { + // Sum repeated indices. + for (octave_idx_type i = 1; i < n; i++) + { + if (rd[i] != rd[i-1]) + { + rri[++k] = rd[i]; + rrd[k] = a(rdi[i]); + } + else + rrd[k] += a(rdi[i]); + } + } + else + { + // Pick the last one. + for (octave_idx_type i = 1; i < n; i++) + { + if (rd[i] != rd[i-1]) + rri[++k] = rd[i]; + rrd[k] = a(rdi[i]); + } + } + } + else + { + idx_vector rr = r, cc = c; + const octave_idx_type *rd = rr.raw (), *cd = cc.raw (); + OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, ci, nc+1, 0); + ci[0] = 0; + // Bin counts of column indices. + for (octave_idx_type i = 0; i < n; i++) + ci[cd[i]+1]++; + // Make them cumulative, shifted one to right. + for (octave_idx_type i = 1, s = 0; i <= nc; i++) + { + octave_idx_type s1 = s + ci[i]; + ci[i] = s; + s = s1; + } + + octave_quit (); + + typedef std::pair<octave_idx_type, octave_idx_type> idx_pair; + // Bucket sort. + OCTAVE_LOCAL_BUFFER (idx_pair, spairs, n); + for (octave_idx_type i = 0; i < n; i++) + { + idx_pair& p = spairs[ci[cd[i]+1]++]; + p.first = rd[i]; + p.second = i; + } + + // Subsorts. We don't need a stable sort, the second index stabilizes it. + xcidx(0) = 0; + for (octave_idx_type j = 0; j < nc; j++) + { + std::sort (spairs + ci[j], spairs + ci[j+1]); + octave_idx_type l = -1, nzj = 0; + // Count. + for (octave_idx_type i = ci[j]; i < ci[j+1]; i++) + { + octave_idx_type k = spairs[i].first; + if (k != l) + { + l = k; + nzj++; + } + } + // Set column pointer. + xcidx(j+1) = xcidx(j) + nzj; + } + + change_capacity (xcidx (nc)); + octave_idx_type *rri = ridx (); + T *rrd = data (); + + // Fill-in data. + for (octave_idx_type j = 0, jj = -1; j < nc; j++) + { + octave_quit (); + octave_idx_type l = -1; + if (sum_terms) + { + // Sum adjacent terms. + for (octave_idx_type i = ci[j]; i < ci[j+1]; i++) + { + octave_idx_type k = spairs[i].first; + if (k != l) + { + l = k; + rrd[++jj] = a(spairs[i].second); + rri[jj] = k; + } + else + rrd[jj] += a(spairs[i].second); + } + } + else + { + // Use the last one. + for (octave_idx_type i = ci[j]; i < ci[j+1]; i++) + { + octave_idx_type k = spairs[i].first; + if (k != l) + { + l = k; + rri[++jj] = k; + } + rrd[jj] = a(spairs[i].second); + } + } + } + } +} + +template <class T> Sparse<T>::Sparse (const Array<T>& a) : dimensions (a.dims ()), idx (0), idx_count (0) {