diff src/DLD-FUNCTIONS/qr.cc @ 8547:d66c9b6e506a

imported patch qrupdate.diff
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 20 Jan 2009 21:16:42 +0100 (2009-01-20)
parents 81d6ab3ac93c
children a6edd5c23cb5
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/qr.cc
+++ b/src/DLD-FUNCTIONS/qr.cc
@@ -1,6 +1,8 @@
 /*
 
 Copyright (C) 1996, 1997, 1999, 2000, 2005, 2006, 2007 John W. Eaton
+Copyright (C) 2008, 2009 Jaroslav Hajek
+Copyright (C) 2008, 2009 VZLU Prague
 
 This file is part of Octave.
 
@@ -20,10 +22,6 @@
 
 */
 
-// The qrupdate, qrinsert, qrdelete and qrshift functions were written by
-// Jaroslav Hajek <highegg@gmail.com>, Copyright (C) 2008  VZLU
-// Prague, a.s., Czech Republic.
-
 #ifdef HAVE_CONFIG_H
 #include <config.h>
 #endif
@@ -741,6 +739,24 @@
 
 */
 
+#ifdef HAVE_QRUPDATE
+
+static
+bool check_qr_dims (const octave_value& q, const octave_value& r,
+                    bool allow_ecf = false)
+{
+  octave_idx_type m = q.rows (), k = r.rows (), n = r.columns ();
+  return ((q.ndims () == 2 || r.ndims () == 2 && k == q.columns ())
+            && (m == k || (allow_ecf && k == n && k < m)));
+}
+
+static 
+bool check_index (const octave_value& i, bool vector_allowed = false)
+{
+  return ((i.is_real_type () || i.is_integer_type ()) 
+          && (i.is_scalar_type () || vector_allowed));
+}
+
 DEFUN_DLD (qrupdate, args, ,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {[@var{Q1}, @var{R1}] =} qrupdate (@var{Q}, @var{R}, @var{u}, @var{v})\n\
@@ -748,10 +764,14 @@
 @w{@var{A} = @var{Q}*@var{R}}, @var{Q}@tie{}unitary and\n\
 @var{R}@tie{}upper trapezoidal, return the QR@tie{}factorization\n\
 of @w{@var{A} + @var{u}*@var{v}'}, where @var{u} and @var{v} are\n\
-column vectors (rank-1 update).\n\
+column vectors (rank-1 update) or matrices with equal number of columns\n\
+(rank-k update). Notice that the latter case is done as a sequence of rank-1 updates;\n\
+thus, for k large enough, it will be both faster and more accurate to recompute\n\
+the factorization from scratch.\n\
 \n\
-If the matrix @var{Q} is not square, the matrix @var{A} is updated by\n\
-Q*Q'*u*v' instead of u*v'.\n\
+The QR factorization supplied may be either full\n\
+(Q is square) or economized (R is square).\n\
+\n\
 @seealso{qr, qrinsert, qrdelete}\n\
 @end deftypefn")
 {
@@ -772,18 +792,12 @@
   if (argq.is_numeric_type () && argr.is_numeric_type () 
       && argu.is_numeric_type () && argv.is_numeric_type ())
     {
-      octave_idx_type m = argq.rows ();
-      octave_idx_type n = argr.columns ();
-      octave_idx_type k = argq.columns ();
-
-      if (argr.rows () == k
-          && argu.rows () == m && argu.columns () == 1
-          && argv.rows () == n && argv.columns () == 1)
+      if (check_qr_dims (argq, argr, true))
         {
-          if (argq.is_real_matrix () 
-	      && argr.is_real_matrix () 
-	      && argu.is_real_matrix () 
-	      && argv.is_real_matrix ())
+          if (argq.is_real_type () 
+	      && argr.is_real_type () 
+	      && argu.is_real_type () 
+	      && argv.is_real_type ())
             {
 	      // all real case
 	      if (argq.is_single_type () 
@@ -935,12 +949,19 @@
 @code{\"row\"}).\n\
 \n\
 The default value of @var{orient} is @code{\"col\"}.\n\
+If @var{orient} is @code{\"col\"},\n\
+@var{u} may be a matrix and @var{j} an index vector\n\
+resulting in the QR@tie{}factorization of a matrix @var{B} such that\n\
+@w{B(:,@var{j})} gives @var{u} and @w{B(:,@var{j}) = []} gives @var{A}.\n\
+Notice that the latter case is done as a sequence of k insertions;\n\
+thus, for k large enough, it will be both faster and more accurate to recompute\n\
+the factorization from scratch.\n\
 \n\
-If @var{orient} is @code{\"col\"} and the matrix @var{Q} is not square,\n\
-then what gets inserted is the projection of @var{u} onto the space\n\
-spanned by columns of @var{Q}, i.e. Q*Q'*u.\n\
+If @var{orient} is @code{\"col\"},\n\
+the QR factorization supplied may be either full\n\
+(Q is square) or economized (R is square).\n\
 \n\
-If @var{orient} is @code{\"row\"}, @var{Q} must be square.\n\
+If @var{orient} is @code{\"row\"}, full factorization is needed.\n\
 @seealso{qr, qrupdate, qrdelete}\n\
 @end deftypefn")
 {
@@ -959,30 +980,24 @@
   octave_value argx = args(3);
       
   if (argq.is_numeric_type () && argr.is_numeric_type ()
-      && argj.is_scalar_type () && argx.is_numeric_type ()
+      && argx.is_numeric_type ()
       && (nargin < 5 || args(4).is_string ()))
     {
-      octave_idx_type m = argq.rows ();
-      octave_idx_type n = argr.columns ();
-      octave_idx_type k = argq.columns ();
-
       std::string orient = (nargin < 5) ? "col" : args(4).string_value ();
 
-      bool row = orient == "row";
+      bool col = orient == "col";
 
-      if (row || orient == "col")
-        if (argr.rows () == k 
-            && (! row || m == k)
-            && argx.rows () == (row ? 1 : m)
-            && argx.columns () == (row ? n : 1))
+      if (col || orient == "row")
+        if (check_qr_dims (argq, argr, col) 
+            && (col || argx.rows () == 1))
           {
-            octave_idx_type j = argj.idx_type_value ();
-
-            if (j >= 1 && j <= (row ? n : m)+1)
+            if (check_index (argj, col))
               {
-                if (argq.is_real_matrix () 
-		    && argr.is_real_matrix () 
-		    && argx.is_real_matrix ())
+                MArray<octave_idx_type> j = argj.int_vector_value ();
+
+                if (argq.is_real_type () 
+		    && argr.is_real_type () 
+		    && argx.is_real_type ())
                   {
                     // real case
 		    if (argq.is_single_type () 
@@ -995,10 +1010,10 @@
 
 			FloatQR fact (Q, R);
 
-			if (row) 
-			  fact.insert_row (x, j-1);
+			if (col) 
+			  fact.insert_col (x, j-1);
 			else 
-			  fact.insert_col (x, j-1);
+			  fact.insert_row (x.row (0), j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1012,10 +1027,10 @@
 
 			QR fact (Q, R);
 
-			if (row) 
-			  fact.insert_row (x, j-1);
+			if (col) 
+			  fact.insert_col (x, j-1);
 			else 
-			  fact.insert_col (x, j-1);
+			  fact.insert_row (x.row (0), j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1035,10 +1050,10 @@
 
 			FloatComplexQR fact (Q, R);
 
-			if (row) 
-			  fact.insert_row (x, j-1);
+			if (col) 
+			  fact.insert_col (x, j-1);
 			else 
-			  fact.insert_col (x, j-1);
+			  fact.insert_row (x.row (0), j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1051,10 +1066,10 @@
 
 			ComplexQR fact (Q, R);
 
-			if (row) 
-			  fact.insert_row (x, j-1);
+			if (col) 
+			  fact.insert_col (x, j-1);
 			else 
-			  fact.insert_col (x, j-1);
+			  fact.insert_row (x.row (0), j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1063,7 +1078,7 @@
 
               }
             else
-              error ("qrinsert: index j out of range");
+              error ("qrinsert: invalid index");
           }
         else
           error ("qrinsert: dimension mismatch");
@@ -1150,18 +1165,19 @@
 \n\
 The default value of @var{orient} is \"col\".\n\
 \n\
-If @var{orient} is \"col\", the matrix @var{Q} is not required to\n\
-be square.\n\
+If @var{orient} is @code{\"col\"},\n\
+@var{j} may be an index vector\n\
+resulting in the QR@tie{}factorization of a matrix @var{B} such that\n\
+@w{A(:,@var{j}) = []} gives @var{B}.\n\
+Notice that the latter case is done as a sequence of k deletions;\n\
+thus, for k large enough, it will be both faster and more accurate to recompute\n\
+the factorization from scratch.\n\
 \n\
-For @sc{Matlab} compatibility, if @var{Q} is nonsquare on input, the\n\
-updated factorization is always stripped to the economical form, i.e.\n\
-@code{columns (Q) == rows (R) <= columns (R)}.\n\
+If @var{orient} is @code{\"col\"},\n\
+the QR factorization supplied may be either full\n\
+(Q is square) or economized (R is square).\n\
 \n\
-To get the less intelligent but more natural behaviour when @var{Q}\n\
-retains it shape and @var{R} loses one column, set @var{orient} to\n\
-\"col+\" instead.\n\
-\n\
-If @var{orient} is \"row\", @var{Q} must be square.\n\
+If @var{orient} is @code{\"row\"}, full factorization is needed.\n\
 @seealso{qr, qrinsert, qrupdate}\n\
 @end deftypefn")
 {
@@ -1179,27 +1195,21 @@
   octave_value argj = args(2);
 
   if (argq.is_numeric_type () && argr.is_numeric_type ()
-      && argj.is_scalar_type ()
       && (nargin < 4 || args(3).is_string ()))
     {
-      octave_idx_type m = argq.rows ();
-      octave_idx_type k = argq.columns ();
-      octave_idx_type n = argr.columns ();
-
       std::string orient = (nargin < 4) ? "col" : args(3).string_value ();
 
-      bool row = orient == "row";
-      bool colp = orient == "col+";
+      bool col = orient == "col";
 
-      if (row || colp || orient == "col")
-        if (argr.rows () == k
-            && (! row || m == k))
+      if (col || orient == "row")
+        if (check_qr_dims (argq, argr, col))
           {
-            octave_idx_type j = argj.scalar_value ();
-            if (j >= 1 && j <= (row ? n : m))
+            if (check_index (argj, col))
               {
-                if (argq.is_real_matrix ()
-		    && argr.is_real_matrix ())
+                MArray<octave_idx_type> j = argj.int_vector_value ();
+
+                if (argq.is_real_type ()
+		    && argr.is_real_type ())
                   {
                     // real case
 		    if (argq.is_single_type ()
@@ -1210,15 +1220,10 @@
 
 			FloatQR fact (Q, R);
 
-			if (row) 
-			  fact.delete_row (j-1);
+			if (col) 
+                          fact.delete_col (j-1);
 			else 
-			  {
-			    fact.delete_col (j-1);
-
-			    if (! colp && k < m)
-			      fact.economize ();
-			  }
+			  fact.delete_row (j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1230,15 +1235,10 @@
 
 			QR fact (Q, R);
 
-			if (row) 
-			  fact.delete_row (j-1);
+			if (col) 
+                          fact.delete_col (j-1);
 			else 
-			  {
-			    fact.delete_col (j-1);
-
-			    if (! colp && k < m)
-			      fact.economize ();
-			  }
+			  fact.delete_row (j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1255,15 +1255,10 @@
 
 			FloatComplexQR fact (Q, R);
 
-			if (row) 
-			  fact.delete_row (j-1);
+			if (col) 
+                          fact.delete_col (j-1);
 			else 
-			  {
-			    fact.delete_col (j-1);
-
-			    if (! colp && k < m)
-			      fact.economize ();
-			  }
+			  fact.delete_row (j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1275,15 +1270,10 @@
 
 			ComplexQR fact (Q, R);
 
-			if (row) 
-			  fact.delete_row (j-1);
+			if (col) 
+                          fact.delete_col (j-1);
 			else 
-			  {
-			    fact.delete_col (j-1);
-
-			    if (! colp && k < m)
-			      fact.economize ();
-			  }
+			  fact.delete_row (j(0)-1);
 
 			retval(1) = fact.R ();
 			retval(0) = fact.Q ();
@@ -1291,7 +1281,7 @@
                   }
               }
             else
-              error ("qrdelete: index j out of range");
+              error ("qrdelete: invalid index");
           }
         else
           error ("qrdelete: dimension mismatch");
@@ -1439,20 +1429,17 @@
   octave_value argi = args(2);
   octave_value argj = args(3);
 
-  if (argq.is_numeric_type () && argr.is_numeric_type () 
-      && argi.is_real_scalar () && argj.is_real_scalar ())
+  if (argq.is_numeric_type () && argr.is_numeric_type ())
     {
-      octave_idx_type n = argr.columns ();
-      octave_idx_type k = argq.columns ();
+      if (check_qr_dims (argq, argr, true))
+        {
+          if (check_index (argi) && check_index (argj))
+            {
+              octave_idx_type i = argi.int_value ();
+              octave_idx_type j = argj.int_value ();
 
-      if (argr.rows () == k)
-        {
-          octave_idx_type i = argi.scalar_value ();
-          octave_idx_type j = argj.scalar_value ();
-          if (i > 1 && i <= n && j > 1 && j <= n)
-            {
-              if (argq.is_real_matrix () 
-                  && argr.is_real_matrix ())
+              if (argq.is_real_type () 
+                  && argr.is_real_type ())
                 {
                   // all real case
 		  if (argq.is_single_type () 
@@ -1508,7 +1495,7 @@
                 }
             }
           else
-            error ("qrshift: index out of range");
+            error ("qrshift: invalid index");
         }
       else
 	error ("qrshift: dimensions mismatch");
@@ -1593,6 +1580,8 @@
 %! assert(norm(vec(Q*R - AA(:,p)),Inf) < norm(AA)*1e1*eps('single'))
 */
 
+#endif
+
 /*
 ;;; Local Variables: ***
 ;;; mode: C++ ***