module Matrix.Complex.Pivot (permute, permute_list, vswap, mswap) where
import Data.Array
import Data.List
import Complex
--
-- Trivial Pivoting Routines
--
mswap :: [Int] -> Array (Int,Int) (Complex Double) -- ^ A
-> Array (Int,Int) (Complex Double) -- ^ LU(A)
mswap n a = a'
where a' = array bnds [ ((i,j), luij i j) | (i,j) <- range bnds ]
luij i j = a!(ie i,j)
bnds = bounds a
ie i = n!!(i - 1)
vswap :: [Int] -> Array Int (Complex Double) -- ^ A
-> Array Int (Complex Double) -- ^ LU(A)
vswap n a = a'
where a' = array bnds [ (j, luij j) | j <- range bnds ]
luij j = a!(ie j)
bnds = bounds a
ie j = n!!(j - 1)
rows1 :: Int -> [(Complex Double)] -> [Int]
rows1 _ [] = []
rows1 n (x:xs) | abs (magnitude x) > 0 = (n + 1) : rows1 (n + 1) xs
| otherwise = rows1 (n + 1) xs
rowz l = rows1 0 l
split1 _ [] = []
split1 n l = (take n l) : (split1 n $ drop n l)
permute_list a = rowlist $ map rowz $ transpose $ split1 b $ elems a
where b = fst $ snd $ bounds a
permute :: Array (Int,Int) (Complex Double) -> [Int]
permute a = tt
where tt = l!!(head $ findIndices ( == (maximum n)) n)
n = map zz (dd a l)
zz a = foldr1 (+) (map (abs.magnitude) (elems (diag a)))
dd a d | d == [] = []
| otherwise = mswap (head d) a : dd a (tail d)
l = permute_list a
rowlist :: Eq a => [[a]] -> [[a]]
rowlist [] = [[]]
rowlist (set:sets) = [x:xs | xs <- rowlist sets, x <- set, not(x `elem` xs)]
diag :: Array (Int,Int) (Complex Double) -- ^ A
-> Array (Int,Int) (Complex Double) -- ^ LU(A)
diag a = a'
where a' = array bnds [ ((i,j), luij i j) | (i,j) <- range bnds ]
luij i j | i == j = a!(i,j)
| otherwise = 0
bnds = bounds a
|