-----------------------------------------------------------------------------
-- |
-- Module : DSP.Matrix.LU
-- Copyright : (c) Matthew Donadio 2003
-- License : GPL
--
-- Maintainer : m.p.donadio@ieee.org
-- Stability : experimental
-- Portability : portable
--
-- Module implementing LU decomposition and related functions
--
-----------------------------------------------------------------------------
module Matrix.Complex.LU (psolve, lu, lu_solve, improve, inverse, lu_det, solve, det) where
import Data.Array
import Complex
import Matrix.Complex.Pivot
-- | LU decomposition via Crout's Algorithm
-- TODO: modify for partial pivoting / permutation matrix
-- TODO: add singularity check
-- I am sure these are in G&VL, but the two cases of function f below are
-- formulas (2.3.13) and (2.3.12) from NRIC with some variable renaming
lu :: Array (Int,Int) (Complex Double) -- ^ A
-> Array (Int,Int) (Complex Double) -- ^ LU(A)
lu a = a'
where a' = array bnds [ ((i,j), luij i j) | (i,j) <- range bnds ]
luij i j | i>j = (a!(i,j) - sum [ a'!(i,k) * a'!(k,j) | k <- [1 ..(j-1)] ]) / a'!(j,j)
| i<=j = a!(i,j) - sum [ a'!(i,k) * a'!(k,j) | k <- [1 ..(i-1)] ]
bnds = bounds a
-- | Solution to Ax=b via LU decomposition
-- forward is forumla (2.3.6) in NRIC, but remebering that a11=1
-- backward is forumla (2.3.7) in NRIC
lu_solve :: Array (Int,Int) (Complex Double) -- ^ LU(A)
-> Array Int (Complex Double) -- ^ b
-> Array Int (Complex Double) -- ^ x
lu_solve a b = x
where x = array (1,n) ([(n,xn)] ++ [ (i, backward i) | i <- (reverse [1..(n-1)]) ])
y = array (1,n) ([(1,y1)] ++ [ (i, forward i) | i <- [2..n] ])
y1 = b!1
forward i = (b!i - sum [ a!(i,j) * y!j | j <- [1..(i-1)] ])
xn = y!n / a!(n,n)
backward i = (y!i - sum [ a!(i,j) * x!j | j <- [(i+1)..n] ]) / a!(i,i)
((_,_),(n,_)) = bounds a
-- | Improve a solution to Ax=b via LU decomposition
-- formula (2.7.4) from NRIC
improve :: Array (Int,Int) (Complex Double) -- ^ A
-> Array (Int,Int) (Complex Double) -- ^ LU(A)
-> Array Int (Complex Double) -- ^ b
-> Array Int (Complex Double) -- ^ x
-> Array Int (Complex Double) -- ^ x'
improve a a_lu b x = array (1,n) [ (i, x!i - err!i) | i <- [1..n] ]
where err = lu_solve a_lu rhs
rhs = array (1,n) [ (i, sum [ a!(i,j) * x!j | j <- [1..n] ] - b!i) | i <- [1..n] ]
((_,_),(n,_)) = bounds a
-- | Matrix inversion via LU decomposition
-- Section (2.4) from NRIC
-- TODO: build in improve
inverse :: Array (Int,Int) (Complex Double) -- ^ A
-> Array (Int,Int) (Complex Double) -- ^ A^-1
inverse a = a'
where a' = array (bounds a) (arrange (makecols (lu a)) 1)
makecol i n = array (1,n) [ (j, (\i j->if i == j then 1.0 else 0.0) i j) | j <- [1..n] ]
makecols a = [ lu_solve a (makecol i n) | i <- [1..n] ]
((_,_),(n,_)) = bounds a
arrange [] _ = []
arrange (m:ms) j = (flatten m j) ++ (arrange ms (j+1))
flatten m j = map (\(i,x) -> ((i,j),x)) (assocs m)
-- | Determinant of a matrix via LU decomposition
-- Formula (2.5.1) from NRIC
lu_det :: Array (Int,Int) (Complex Double) -- ^ LU(A)
-> (Complex Double) -- ^ det(A)
lu_det a = product [ a!(i,i) | i <- [ 1 .. n] ]
where ((_,_),(n,_)) = bounds a
-- | LU solver using original matrix
solve :: Array (Int,Int) (Complex Double) -- ^ A
-> Array Int (Complex Double) -- ^ b
-> Array Int (Complex Double) -- ^ x
solve a b = (lu_solve . lu) a b
psolve :: Array (Int,Int) (Complex Double) -- ^ A
-> Array Int (Complex Double) -- ^ b
-> Array Int (Complex Double) -- ^ x
psolve a b = (lu_solve . lu) (mswap pv a) (vswap pv b)
where pv = permute a
-- | determinant using original matrix
det :: Array (Int,Int) (Complex Double) -- ^ A
-> (Complex Double) -- ^ det(A)
det a = (lu_det . lu) a
-------------------------------------------------------------------------------
-- tests
-------------------------------------------------------------------------------
{-
a = array ((1,1),(3,3)) [ ((1,1), 1.0), ((1,2), 2.0), ((1,3), 3.0),
((2,1), 2.0), ((2,2), 5.0), ((2,3), 3.0),
((3,1), 1.0), ((3,2), 0.0), ((3,3), 8.0) ]
a' = array ((1,1),(3,3)) [ ((1,1), -40.0), ((1,2), 16.0), ((1,3), 9.0),
((2,1), 13.0), ((2,2), -5.0), ((2,3), -3.0),
((3,1), 5.0), ((3,2), -2.0), ((3,3), -1.0) ]
a_lu = lu a
b = array (1,3) [ (1, 1.0), (2, 2.0), (3, 5.0) ]
x = lu_solve a_lu b
x' = improve a a_lu b x
x'' = improve a a_lu b x'
verify = a' == inverse a && -- tests lu, lu_solve, and inverse
det a == -1 && -- tests lu_det
x == x' && -- tests improve
x' == x''
-}
|