# HG changeset patch # User Jaroslav Hajek # Date 1276863144 -7200 # Node ID 53253f796351b198e075431f16b6106964c042be # Parent 600bdfb0854022a879151ea6f334d4ba5cf1bdc5 make [] (hopefully) more Matlab compatible diff --git a/liboctave/ChangeLog b/liboctave/ChangeLog --- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,9 @@ +2010-06-17 Jaroslav Hajek + + * dim-vector.cc (dim_vector::hvcat): New method. + * dim-vector.h (dim_vector::hvcat, dim_vector::cat): Update decls. + (dim_vector::empty_2d): New method. + 2010-06-17 Jaroslav Hajek * MatrixType.cc (matrix_real_probe): Use OCTAVE_LOCAL_BUFFER for diff --git a/liboctave/dim-vector.cc b/liboctave/dim-vector.cc --- a/liboctave/dim-vector.cc +++ b/liboctave/dim-vector.cc @@ -153,6 +153,16 @@ return new_dims; } +// This is the rule for cat(). cat(dim, A, B) works if one +// of the following holds, in this order: +// +// 1. size(A, k) == size(B, k) for all k != dim. +// In this case, size (C, dim) = size (A, dim) + size (B, dim) and +// other sizes remain intact. +// +// 2. A is 0x0, in which case B is the result +// 3. B is 0x0, in which case A is the result + bool dim_vector::concat (const dim_vector& dvb, int dim) { @@ -205,6 +215,42 @@ return match; } +// Rules for horzcat/vertcat are yet looser. +// two arrays A, B can be concatenated +// horizontally (dim = 2) or vertically (dim = 1) if one of the +// following holds, in this order: +// +// 1. cat(dim, A, B) works +// +// 2. A, B are 2D and one of them is an empty vector, in which +// case the result is the other one except if both of them +// are empty vectors, in which case the result is 0x0. + +bool +dim_vector::hvcat (const dim_vector& dvb, int dim) +{ + if (concat (dvb, dim)) + return true; + else if (length () == 2 && dvb.length () == 2) + { + bool e2dv = rep[0] + rep[1] == 1; + bool e2dvb = dvb(0) + dvb(1) == 1; + if (e2dvb) + { + if (e2dv) + *this = dim_vector (); + return true; + } + else if (e2dv) + { + *this = dvb; + return true; + } + } + + return false; +} + dim_vector dim_vector::redim (int n) const { diff --git a/liboctave/dim-vector.h b/liboctave/dim-vector.h --- a/liboctave/dim-vector.h +++ b/liboctave/dim-vector.h @@ -287,6 +287,12 @@ return retval; } + bool empty_2d (void) const + { + return length () == 2 && (elem (0) == 0 || elem (1) == 0); + } + + bool zero_by_zero (void) const { return length () == 2 && elem (0) == 0 && elem (1) == 0; @@ -355,7 +361,12 @@ dim_vector squeeze (void) const; - bool concat (const dim_vector& dvb, int dim = 0); + // This corresponds to cat(). + bool concat (const dim_vector& dvb, int dim); + + // This corresponds to [,] (horzcat, dim = 0) and [;] (vertcat, dim = 1). + // The rules are more relaxed here. + bool hvcat (const dim_vector& dvb, int dim); // Force certain dimensionality, preserving numel (). Missing // dimensions are set to 1, redundant are folded into the trailing diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,12 @@ +2010-06-18 Jaroslav Hajek + + * pt-mat.cc (tm_row_const::eval_error): Make a static func. + (tm_row_const::do_init_element): Simplify using dim_vector::hvcat. + (tm_const::init): Ditto. + (single_type_concat): Special-case empty results. Skip or use 0x0 for + empty arrays otherwise. + (tree_matrix::rvalue1): Skip empty arrays in the fallback branch. + 2010-06-16 Rik * DLD-FUNCTIONS/cellfun.cc, DLD-FUNCTIONS/dot.cc, diff --git a/src/pt-mat.cc b/src/pt-mat.cc --- a/src/pt-mat.cc +++ b/src/pt-mat.cc @@ -110,9 +110,6 @@ tm_row_const_rep& operator = (const tm_row_const_rep&); - void eval_error (const char *msg, int l, int c, - int x = -1, int y = -1) const; - void eval_warning (const char *msg, int l, int c) const; }; @@ -257,79 +254,47 @@ return retval; } +static void +eval_error (const char *msg, int l, int c, + const dim_vector& x, const dim_vector& y) +{ + if (l == -1 && c == -1) + { + ::error ("%s (%s vs %s)", msg, x.str ().c_str (), y.str ().c_str ()); + } + else + { + ::error ("%s (%s vs %s) near line %d, column %d", msg, + x.str ().c_str (), y.str ().c_str (), l, c); + } +} + bool tm_row_const::tm_row_const_rep::do_init_element (tree_expression *elt, const octave_value& val, bool& first_elem) { - octave_idx_type this_elt_nr = val.rows (); - octave_idx_type this_elt_nc = val.columns (); - std::string this_elt_class_nm = val.class_name (); dim_vector this_elt_dv = val.dims (); class_nm = get_concat_class (class_nm, this_elt_class_nm); - - if (! this_elt_dv.all_zero ()) + if (! this_elt_dv.zero_by_zero ()) { all_mt = false; if (first_elem) { first_elem = false; - - dv.resize (this_elt_dv.length ()); - for (int i = 2; i < dv.length (); i++) - dv.elem (i) = this_elt_dv.elem (i); - - dv.elem (0) = this_elt_nr; - - dv.elem (1) = 0; + dv = this_elt_dv; } - else + else if (! dv.hvcat (this_elt_dv, 1)) { - int len = (this_elt_dv.length () < dv.length () - ? this_elt_dv.length () : dv.length ()); - - if (this_elt_nr != dv (0)) - { - eval_error ("number of rows must match", - elt->line (), elt->column (), this_elt_nr, dv (0)); - return false; - } - for (int i = 2; i < len; i++) - { - if (this_elt_dv (i) != dv (i)) - { - eval_error ("dimensions mismatch", elt->line (), elt->column (), this_elt_dv (i), dv (i)); - return false; - } - } - - if (this_elt_dv.length () > len) - for (int i = len; i < this_elt_dv.length (); i++) - if (this_elt_dv (i) != 1) - { - eval_error ("dimensions mismatch", elt->line (), elt->column (), this_elt_dv (i), 1); - return false; - } - - if (dv.length () > len) - for (int i = len; i < dv.length (); i++) - if (dv (i) != 1) - { - eval_error ("dimensions mismatch", elt->line (), elt->column (), 1, dv (i)); - return false; - } + eval_error ("horizontal dimensions mismatch", elt->line (), elt->column (), dv, this_elt_dv); + return false; } - dv.elem (1) = dv.elem (1) + this_elt_nc; - } - else - eval_warning ("empty matrix found in matrix list", - elt->line (), elt->column ()); append (val); @@ -413,26 +378,6 @@ } void -tm_row_const::tm_row_const_rep::eval_error (const char *msg, int l, - int c, int x, int y) const -{ - if (l == -1 && c == -1) - { - if (x == -1 || y == -1) - ::error ("%s", msg); - else - ::error ("%s (%d != %d)", msg, x, y); - } - else - { - if (x == -1 || y == -1) - ::error ("%s near line %d, column %d", msg, l, c); - else - ::error ("%s (%d != %d) near line %d, column %d", msg, x, y, l, c); - } -} - -void tm_row_const::tm_row_const_rep::eval_warning (const char *msg, int l, int c) const { @@ -576,85 +521,33 @@ octave_idx_type this_elt_nc = elt.cols (); std::string this_elt_class_nm = elt.class_name (); + class_nm = get_concat_class (class_nm, this_elt_class_nm); dim_vector this_elt_dv = elt.dims (); - if (!this_elt_dv.all_zero ()) - { - all_mt = false; - - if (first_elem) - { - first_elem = false; - - class_nm = this_elt_class_nm; - - dv.resize (this_elt_dv.length ()); - for (int i = 2; i < dv.length (); i++) - dv.elem (i) = this_elt_dv.elem (i); - - dv.elem (0) = 0; + all_mt = false; - dv.elem (1) = this_elt_nc; - } - else if (all_str) - { - class_nm = get_concat_class (class_nm, this_elt_class_nm); - - if (this_elt_nc > cols ()) - dv.elem (1) = this_elt_nc; - } - else - { - class_nm = get_concat_class (class_nm, this_elt_class_nm); - - bool get_out = false; - int len = (this_elt_dv.length () < dv.length () - ? this_elt_dv.length () : dv.length ()); + if (first_elem) + { + first_elem = false; - for (int i = 1; i < len; i++) - { - if (i == 1 && this_elt_nc != dv (1)) - { - ::error ("number of columns must match (%d != %d)", - this_elt_nc, dv (1)); - get_out = true; - break; - } - else if (this_elt_dv (i) != dv (i)) - { - ::error ("dimensions mismatch (dim = %i, %d != %d)", i+1, this_elt_dv (i), dv (i)); - get_out = true; - break; - } - } - - if (this_elt_dv.length () > len) - for (int i = len; i < this_elt_dv.length (); i++) - if (this_elt_dv (i) != 1) - { - ::error ("dimensions mismatch (dim = %i, %d != %d)", i+1, this_elt_dv (i), 1); - get_out = true; - break; - } - - if (dv.length () > len) - for (int i = len; i < dv.length (); i++) - if (dv (i) != 1) - { - ::error ("dimensions mismatch (dim = %i, %d != %d)", i+1, 1, dv(i)); - get_out = true; - break; - } - - if (get_out) - break; - } - dv.elem (0) = dv.elem (0) + this_elt_nr; + dv = this_elt_dv; } - else - warning_with_id ("Octave:empty-list-elements", - "empty matrix found in matrix list"); + else if (all_str && dv.length () == 2 + && this_elt_dv.length () == 2) + { + // FIXME: this is Octave's specialty. Character matrices allow + // rows of unequal length. + if (this_elt_nc > cols ()) + dv(1) = this_elt_nc; + dv(0) += this_elt_nr; + } + else if (! dv.hvcat (this_elt_dv, 0)) + { + eval_error ("vertical dimensions mismatch", -1, -1, + dv, this_elt_dv); + return; + } } } @@ -734,6 +627,9 @@ for (tm_const::iterator p = tmp.begin (); p != tmp.end (); p++) { tm_row_const row = *p; + // Skip empty arrays to allow looser rules. + if (row.dims ().any_zero ()) + continue; for (tm_row_const::iterator q = row.begin (); q != row.end (); @@ -743,14 +639,18 @@ TYPE ra = octave_value_extract (*q); + // Skip empty arrays to allow looser rules. if (! error_state) { - result.insert (ra, r, c); + if (! ra.is_empty ()) + { + result.insert (ra, r, c); - if (! error_state) - c += ra.columns (); - else - return; + if (! error_state) + c += ra.columns (); + else + return; + } } else return; @@ -767,6 +667,12 @@ const dim_vector& dv, tm_const& tmp) { + if (dv.any_zero ()) + { + result = Array (dv); + return; + } + if (tmp.length () == 1) { // If possible, forward the operation to liboctave. @@ -781,7 +687,10 @@ { octave_quit (); - array_list[i++] = octave_value_extract (*q); + // Use 0x0 in place of all empty arrays to allow looser rules. + if (! q->is_empty ()) + array_list[i] = octave_value_extract (*q); + i++; } if (! error_state) @@ -797,9 +706,15 @@ template static void single_type_concat (Sparse& result, - const dim_vector&, + const dim_vector& dv, tm_const& tmp) { + if (dv.any_zero ()) + { + result = Sparse (dv); + return; + } + // Sparse matrices require preallocation for efficient indexing; besides, // only horizontal concatenation can be efficiently handled by indexing. // So we just cat all rows through liboctave, then cat the final column. @@ -817,10 +732,17 @@ { octave_quit (); - sparse_list[i++] = octave_value_extract (*q); + // Use 0x0 in place of all empty arrays to allow looser rules. + if (! q->is_empty ()) + sparse_list[i] = octave_value_extract (*q); + i++; } - sparse_row_list[j++] = Sparse::cat (1, ncols, sparse_list); + Sparse stmp = Sparse::cat (1, ncols, sparse_list); + // Use 0x0 in place of all empty arrays to allow looser rules. + if (! stmp.is_empty ()) + sparse_row_list[j] = stmp; + j++; } result = Sparse::cat (0, nrows, sparse_row_list); @@ -1089,6 +1011,9 @@ octave_value elt = *q; + if (elt.is_empty ()) + continue; + ctmp = do_cat_op (ctmp, elt, ra_idx); if (error_state) diff --git a/test/test_unwind.m b/test/test_unwind.m --- a/test/test_unwind.m +++ b/test/test_unwind.m @@ -52,5 +52,5 @@ %! end_unwind_protect %!test %! global g = -1; -%! fail("y = f (3);","number of columns must match"); +%! fail("y = f (3);","mismatch");