TooN 2.1
|
00001 // -*- c++ -*- 00002 00003 // Copyright (C) 2005,2009 Tom Drummond (twd20@cam.ac.uk), 00004 // Ed Rosten (er258@cam.ac.uk) 00005 // 00006 // This file is part of the TooN Library. This library is free 00007 // software; you can redistribute it and/or modify it under the 00008 // terms of the GNU General Public License as published by the 00009 // Free Software Foundation; either version 2, or (at your option) 00010 // any later version. 00011 00012 // This library is distributed in the hope that it will be useful, 00013 // but WITHOUT ANY WARRANTY; without even the implied warranty of 00014 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00015 // GNU General Public License for more details. 00016 00017 // You should have received a copy of the GNU General Public License along 00018 // with this library; see the file COPYING. If not, write to the Free 00019 // Software Foundation, 59 Temple Place - Suite 330, Boston, MA 02111-1307, 00020 // USA. 00021 00022 // As a special exception, you may use this file as part of a free software 00023 // library without restriction. Specifically, if other files instantiate 00024 // templates or use macros or inline functions from this file, or you compile 00025 // this file and link it with other files to produce an executable, this 00026 // file does not by itself cause the resulting executable to be covered by 00027 // the GNU General Public License. This exception does not however 00028 // invalidate any other reasons why the executable file might be covered by 00029 // the GNU General Public License. 00030 00031 #ifndef TOON_INCLUDE_LU_H 00032 #define TOON_INCLUDE_LU_H 00033 00034 #include <iostream> 00035 00036 #include <TooN/lapack.h> 00037 00038 #include <TooN/TooN.h> 00039 00040 namespace TooN { 00041 /** 00042 Performs %LU decomposition and back substitutes to solve equations. 00043 The %LU decomposition is the fastest way of solving the equation 00044 \f$M\underline{x} = \underline{c}\f$m, but it becomes unstable when 00045 \f$M\f$ is (nearly) singular (in which cases the SymEigen or SVD decompositions 00046 are better). It decomposes a matrix \f$M\f$ into 00047 \f[M = L \times U\f] 00048 where \f$L\f$ is a lower-diagonal matrix with unit diagonal and \f$U\f$ is an 00049 upper-diagonal matrix. The library only supports the decomposition of square matrices. 00050 It can be used as follows to solve the \f$M\underline{x} = \underline{c}\f$ problem as follows: 00051 @code 00052 // construct M 00053 Matrix<3> M; 00054 M[0] = makeVector(1,2,3); 00055 M[1] = makeVector(3,2,1); 00056 M[2] = makeVector(1,0,1); 00057 // construct c 00058 Vector<3> c = makeVector(2,3,4); 00059 // create the LU decomposition of M 00060 LU<3> luM(M); 00061 // compute x = M^-1 * c 00062 Vector<3> x = luM.backsub(c); 00063 @endcode 00064 The convention LU<> (=LU<-1>) is used to create an LU decomposition whose size is 00065 determined at runtime. 00066 @ingroup gDecomps 00067 **/ 00068 template <int Size=-1, class Precision=double> 00069 class LU { 00070 public: 00071 00072 /// Construct the %LU decomposition of a matrix. This initialises the class, and 00073 /// performs the decomposition immediately. 00074 template<int S1, int S2, class Base> 00075 LU(const Matrix<S1,S2,Precision, Base>& m) 00076 :my_lu(m.num_rows(),m.num_cols()),my_IPIV(m.num_rows()){ 00077 compute(m); 00078 } 00079 00080 /// Perform the %LU decompsition of another matrix. 00081 template<int S1, int S2, class Base> 00082 void compute(const Matrix<S1,S2,Precision,Base>& m){ 00083 //check for consistency with Size 00084 SizeMismatch<Size, S1>::test(my_lu.num_rows(),m.num_rows()); 00085 SizeMismatch<Size, S2>::test(my_lu.num_rows(),m.num_cols()); 00086 00087 //Make a local copy. This is guaranteed contiguous 00088 my_lu=m; 00089 FortranInteger lda = m.num_rows(); 00090 FortranInteger M = m.num_rows(); 00091 FortranInteger N = m.num_rows(); 00092 00093 getrf_(&M,&N,&my_lu[0][0],&lda,&my_IPIV[0],&my_info); 00094 00095 if(my_info < 0){ 00096 std::cerr << "error in LU, INFO was " << my_info << std::endl; 00097 } 00098 } 00099 00100 /// Calculate result of multiplying the inverse of M by another matrix. For a matrix \f$A\f$, this 00101 /// calculates \f$M^{-1}A\f$ by back substitution (i.e. without explictly calculating the inverse). 00102 template <int Rows, int NRHS, class Base> 00103 Matrix<Size,NRHS,Precision> backsub(const Matrix<Rows,NRHS,Precision,Base>& rhs){ 00104 //Check the number of rows is OK. 00105 SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.num_rows()); 00106 00107 Matrix<Size, NRHS, Precision> result(rhs); 00108 00109 FortranInteger M=rhs.num_cols(); 00110 FortranInteger N=my_lu.num_rows(); 00111 double alpha=1; 00112 FortranInteger lda=my_lu.num_rows(); 00113 FortranInteger ldb=rhs.num_cols(); 00114 trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb); 00115 trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb); 00116 00117 // now do the row swapping (lapack dlaswp.f only shuffles fortran rows = Rowmajor cols) 00118 for(int i=N-1; i>=0; i--){ 00119 const int swaprow = my_IPIV[i]-1; // fortran arrays start at 1 00120 for(int j=0; j<NRHS; j++){ 00121 Precision temp = result[i][j]; 00122 result[i][j] = result[swaprow][j]; 00123 result[swaprow][j] = temp; 00124 } 00125 } 00126 return result; 00127 } 00128 00129 /// Calculate result of multiplying the inverse of M by a vector. For a vector \f$b\f$, this 00130 /// calculates \f$M^{-1}b\f$ by back substitution (i.e. without explictly calculating the inverse). 00131 template <int Rows, class Base> 00132 Vector<Size,Precision> backsub(const Vector<Rows,Precision,Base>& rhs){ 00133 //Check the number of rows is OK. 00134 SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.size()); 00135 00136 Vector<Size, Precision> result(rhs); 00137 00138 FortranInteger M=1; 00139 FortranInteger N=my_lu.num_rows(); 00140 double alpha=1; 00141 FortranInteger lda=my_lu.num_rows(); 00142 FortranInteger ldb=1; 00143 trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb); 00144 trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb); 00145 00146 // now do the row swapping (lapack dlaswp.f only shuffles fortran rows = Rowmajor cols) 00147 for(int i=N-1; i>=0; i--){ 00148 const int swaprow = my_IPIV[i]-1; // fortran arrays start at 1 00149 Precision temp = result[i]; 00150 result[i] = result[swaprow]; 00151 result[swaprow] = temp; 00152 } 00153 return result; 00154 } 00155 00156 /// Calculate inverse of the matrix. This is not usually needed: if you need the inverse just to 00157 /// multiply it by a matrix or a vector, use one of the backsub() functions, which will be faster. 00158 Matrix<Size,Size,Precision> get_inverse(){ 00159 Matrix<Size,Size,Precision> Inverse(my_lu); 00160 FortranInteger N = my_lu.num_rows(); 00161 FortranInteger lda=my_lu.num_rows(); 00162 FortranInteger lwork=-1; 00163 Precision size; 00164 getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], &size, &lwork, &my_info); 00165 lwork=FortranInteger(size); 00166 Precision* WORK = new Precision[lwork]; 00167 getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], WORK, &lwork, &my_info); 00168 delete [] WORK; 00169 return Inverse; 00170 } 00171 00172 /// Returns the L and U matrices. The permutation matrix is not returned. 00173 /// Since L is lower-triangular (with unit diagonal) 00174 /// and U is upper-triangular, these are returned conflated into one matrix, where the 00175 /// diagonal and above parts of the matrix are U and the below-diagonal part, plus a unit diagonal, 00176 /// are L. 00177 const Matrix<Size,Size,Precision>& get_lu()const {return my_lu;} 00178 00179 private: 00180 inline int get_sign() const { 00181 int result=1; 00182 for(int i=0; i<my_lu.num_rows()-1; i++){ 00183 if(my_IPIV[i] > i+1){ 00184 result=-result; 00185 } 00186 } 00187 return result; 00188 } 00189 public: 00190 00191 /// Calculate the determinant of the matrix 00192 inline Precision determinant() const { 00193 Precision result = get_sign(); 00194 for (int i=0; i<my_lu.num_rows(); i++){ 00195 result*=my_lu(i,i); 00196 } 00197 return result; 00198 } 00199 00200 /// Get the LAPACK info 00201 int get_info() const { return my_info; } 00202 00203 private: 00204 00205 Matrix<Size,Size,Precision> my_lu; 00206 FortranInteger my_info; 00207 Vector<Size, FortranInteger> my_IPIV; //Convenient static-or-dynamic array of ints :-) 00208 00209 }; 00210 } 00211 00212 00213 #endif