diff liboctave/SparsedbleLU.cc @ 7515:f3c00dc0912b

Eliminate the rest of the dispatched sparse functions
author David Bateman <dbateman@free.fr>
date Fri, 22 Feb 2008 15:50:51 +0100 (2008-02-22)
parents a1dbe9d80eee
children b166043585a8
line wrap: on
line diff
--- a/liboctave/SparsedbleLU.cc
+++ b/liboctave/SparsedbleLU.cc
@@ -41,7 +41,7 @@
 
 #include "oct-sparse.h"
 
-SparseLU::SparseLU (const SparseMatrix& a, double piv_thres)
+SparseLU::SparseLU (const SparseMatrix& a, const Matrix& piv_thres, bool scale)
 {
 #ifdef HAVE_UMFPACK
   octave_idx_type nr = a.rows ();
@@ -56,20 +56,24 @@
   if (!xisnan (tmp))
     Control (UMFPACK_PRL) = tmp;
 
-  if (piv_thres >= 0.)
+  if (piv_thres.nelem() != 2)
     {
-      piv_thres = (piv_thres > 1. ? 1. : piv_thres);
-      Control (UMFPACK_SYM_PIVOT_TOLERANCE) = piv_thres;
-      Control (UMFPACK_PIVOT_TOLERANCE) = piv_thres;
+      tmp = (piv_thres (0) > 1. ? 1. : piv_thres (0));
+      if (!xisnan (tmp))
+	Control (UMFPACK_PIVOT_TOLERANCE) = tmp;
+      tmp = (piv_thres (1) > 1. ? 1. : piv_thres (1));
+      if (!xisnan (tmp))
+	Control (UMFPACK_SYM_PIVOT_TOLERANCE) = tmp;
     }
   else
     {
       tmp = octave_sparse_params::get_key ("piv_tol");
       if (!xisnan (tmp))
-	{
+	Control (UMFPACK_PIVOT_TOLERANCE) = tmp;
+
+      tmp = octave_sparse_params::get_key ("sym_tol");
+      if (!xisnan (tmp))
 	  Control (UMFPACK_SYM_PIVOT_TOLERANCE) = tmp;
-	  Control (UMFPACK_PIVOT_TOLERANCE) = tmp;
-	}
     }
 
   // Set whether we are allowed to modify Q or not
@@ -77,8 +81,10 @@
   if (!xisnan (tmp))
     Control (UMFPACK_FIXQ) = tmp;
 
-  // Turn-off UMFPACK scaling for LU 
-  Control (UMFPACK_SCALE) = UMFPACK_SCALE_NONE;
+  if (scale)
+    Control (UMFPACK_SCALE) = UMFPACK_SCALE_SUM;
+  else
+    Control (UMFPACK_SCALE) = UMFPACK_SCALE_NONE;
 
   UMFPACK_DNAME (report_control) (control);
 
@@ -167,6 +173,15 @@
 	      octave_idx_type *Uj = Ufact.ridx ();
 	      double *Ux = Ufact.data ();
 
+	      Rfact = SparseMatrix (nr, nr, nr);
+	      for (octave_idx_type i = 0; i < nr; i++)
+		{
+		  Rfact.xridx (i) = i;
+		  Rfact.xcidx (i) = i;
+		}
+	      Rfact.xcidx (nr) = nr;
+	      double *Rx = Rfact.data ();
+
 	      P.resize (nr);
 	      octave_idx_type *p = P.fortran_vec ();
 
@@ -176,12 +191,12 @@
 	      octave_idx_type do_recip;
 	      status = UMFPACK_DNAME (get_numeric) (Ltp, Ltj, Ltx,
 					       Up, Uj, Ux, p, q, NULL,
-					       &do_recip, NULL, 
+					       &do_recip, Rx, 
 					       Numeric) ;
 
 	      UMFPACK_DNAME (free_numeric) (&Numeric) ;
 
-	      if (status < 0 || do_recip)
+	      if (status < 0)
 		{
 		  (*current_liboctave_error_handler) 
 		    ("SparseLU::SparseLU extracting LU factors failed");
@@ -192,6 +207,10 @@
 		{
 		  Lfact = Lfact.transpose ();
 
+		  if (do_recip)
+		    for (octave_idx_type i = 0; i < nr; i++)
+		      Rx[i] = 1.0 / Rx[i];
+
 		  UMFPACK_DNAME (report_matrix) (nr, n_inner, 
 					    Lfact.cidx (), Lfact.ridx (),
 					    Lfact.data (), 1, control);
@@ -212,8 +231,8 @@
 }
 
 SparseLU::SparseLU (const SparseMatrix& a, const ColumnVector& Qinit,
-		    double piv_thres, bool FixedQ, double droptol,
-		    bool milu, bool udiag)
+		    const Matrix& piv_thres, bool scale, bool FixedQ,
+		    double droptol, bool milu, bool udiag)
 {
 #ifdef HAVE_UMFPACK
   if (milu)
@@ -232,20 +251,25 @@
       double tmp = octave_sparse_params::get_key ("spumoni");
       if (!xisnan (tmp))
 	Control (UMFPACK_PRL) = tmp;
-      if (piv_thres >= 0.)
+
+      if (piv_thres.nelem() != 2)
 	{
-	  piv_thres = (piv_thres > 1. ? 1. : piv_thres);
-	  Control (UMFPACK_SYM_PIVOT_TOLERANCE) = piv_thres;
-	  Control (UMFPACK_PIVOT_TOLERANCE) = piv_thres;
+	  tmp = (piv_thres (0) > 1. ? 1. : piv_thres (0));
+	  if (!xisnan (tmp))
+	    Control (UMFPACK_PIVOT_TOLERANCE) = tmp;
+	  tmp = (piv_thres (1) > 1. ? 1. : piv_thres (1));
+	  if (!xisnan (tmp))
+	    Control (UMFPACK_SYM_PIVOT_TOLERANCE) = tmp;
 	}
       else
 	{
 	  tmp = octave_sparse_params::get_key ("piv_tol");
 	  if (!xisnan (tmp))
-	    {
-	      Control (UMFPACK_SYM_PIVOT_TOLERANCE) = tmp;
-	      Control (UMFPACK_PIVOT_TOLERANCE) = tmp;
-	    }
+	    Control (UMFPACK_PIVOT_TOLERANCE) = tmp;
+
+	  tmp = octave_sparse_params::get_key ("sym_tol");
+	  if (!xisnan (tmp))
+	    Control (UMFPACK_SYM_PIVOT_TOLERANCE) = tmp;
 	}
 
       if (droptol >= 0.)
@@ -262,8 +286,10 @@
 	    Control (UMFPACK_FIXQ) = tmp;
 	}
 
-      // Turn-off UMFPACK scaling for LU 
-      Control (UMFPACK_SCALE) = UMFPACK_SCALE_NONE;
+      if (scale)
+	Control (UMFPACK_SCALE) = UMFPACK_SCALE_SUM;
+      else
+	Control (UMFPACK_SCALE) = UMFPACK_SCALE_NONE;
 
       UMFPACK_DNAME (report_control) (control);
 
@@ -363,6 +389,15 @@
 		  octave_idx_type *Uj = Ufact.ridx ();
 		  double *Ux = Ufact.data ();
 
+		  Rfact = SparseMatrix (nr, nr, nr);
+		  for (octave_idx_type i = 0; i < nr; i++)
+		    {
+		      Rfact.xridx (i) = i;
+		      Rfact.xcidx (i) = i;
+		    }
+		  Rfact.xcidx (nr) = nr;
+		  double *Rx = Rfact.data ();
+
 		  P.resize (nr);
 		  octave_idx_type *p = P.fortran_vec ();
 
@@ -373,11 +408,11 @@
 		  status = UMFPACK_DNAME (get_numeric) (Ltp, Ltj,
 						   Ltx, Up, Uj, Ux, p, q, 
 						   NULL, &do_recip, 
-						   NULL, Numeric) ;
+						   Rx, Numeric) ;
 
 		  UMFPACK_DNAME (free_numeric) (&Numeric) ;
 
-		  if (status < 0 || do_recip)
+		  if (status < 0)
 		    {
 		      (*current_liboctave_error_handler) 
 			("SparseLU::SparseLU extracting LU factors failed");
@@ -387,6 +422,11 @@
 		  else
 		    {
 		      Lfact = Lfact.transpose ();
+
+		      if (do_recip)
+			for (octave_idx_type i = 0; i < nr; i++)
+			  Rx[i] = 1.0 / Rx[i];
+
 		      UMFPACK_DNAME (report_matrix) (nr, n_inner, 
 						Lfact.cidx (), 
 						Lfact.ridx (),