-------------------------------------------------------------------------------
-- $Id: Combinators.hs,v 1.21 2000/01/15 01:02:31 satnam Exp $
-------------------------------------------------------------------------------

module Combinators
where
import LavaClasses
import IOExts
import Tile

infixl 5 >->
infixl 5 /\
infixl 5 >|>

-------------------------------------------------------------------------------

-- This is the identity circuit which just returns its input.

iD :: Monad t => a -> t a
iD x = return x    

-------------------------------------------------------------------------------

fork2 :: Monad t => a -> t (a, a)
fork2 a = return (a, a)

-------------------------------------------------------------------------------
  
-- Left to right serial composition.

(>->) :: Combinators t => (a -> t b) -> (b -> t c) -> a -> t c
(>->) f g x
  = do (y, left_tile)  <- tile f x
       (z, right_tile) <- tile g y
       putBeside left_tile right_tile
       return z

-------------------------------------------------------------------------------

-- Bottom up serial composition.

(/\) :: Combinators t => (a -> t b) -> (b -> t c) -> a -> t c
(/\) f g x
  = do (y, left_tile)  <- tile f x
       (z, right_tile) <- tile g y
       putBelow left_tile right_tile
       return z

-------------------------------------------------------------------------------

-- Overlayed serial composition.

(>|>) :: Combinators t => (a -> t b) -> (b -> t c) -> a -> t c
(>|>) f g x
  = do (y, left_tile)  <- tile f x
       (z, right_tile) <- tile g y
       putOver left_tile right_tile
       return z
-------------------------------------------------------------------------------

fsT :: Monad t => (a -> t b) -> (a, c) -> t (b, c)
fsT f (a, b)
   = do c <- f a
        return (c, b)

-------------------------------------------------------------------------------

snD :: Monad t => (a -> t b) -> (c, a) -> t (c, b)
snD f (a, b)
   = do c <- f b
        return (a, c)
                     
-------------------------------------------------------------------------------

proj1 :: Monad t => (a, b) -> t a
proj1 (a, b) = return a

-------------------------------------------------------------------------------

liftCircuit :: Monad m => (a -> b) -> a -> m b
liftCircuit f i = return (f i)

-------------------------------------------------------------------------------


mapCircuit :: (Show (nsi [a]), Monad nsi) =>
       Int -> (nsi a -> nsi b) -> nsi [a] -> nsi [b]
mapCircuit n f as
  = do avals <- as
       (sequence [f (return (avals!!i)) | i <- [0..n-1]])    

-------------------------------------------------------------------------------


maP :: Monad nsi => (a -> nsi b) -> [a] -> nsi [b]
maP f as
  = sequence [f (as!!i) | i <- [0..length as-1]]    

-------------------------------------------------------------------------------

delayBit0 :: Circuit nsi bit =>
             bit -> nsi bit 
delayBit0 = delayBit False

-------------------------------------------------------------------------------

injectPair v x 
  = do r <- v
       return (x, r)
       
-------------------------------------------------------------------------------

project1 ab
  = do (a, b) <- ab
       return a

-------------------------------------------------------------------------------

project2 ab
  = do (a, b) <- ab
       return b

-------------------------------------------------------------------------------


par2 circuit1 circuit2 (input1, input2)
  = do (o1, lower_tile)  <- tile circuit1 input1
       (o2, upper_tile) <- tile circuit2 input2
       putBelow lower_tile upper_tile
       return (o1, o2)
              
-------------------------------------------------------------------------------

par circuits inputs | length circuits /= length inputs
  = error ("par: #circuits = " ++ show (length circuits) ++ " but #inputs = "
           ++ show (length inputs) ++ " inputs = " ++ show inputs ++ "\n")
par [] [] = return []
par (c:circuits) (i:inputs)
  = do (o, lower_tile)  <- tile c i
       (or, upper_tile) <- tile (par circuits) inputs
       putBelow lower_tile upper_tile
       return (o:or)
   

-------------------------------------------------------------------------------

hpar circuits inputs | length circuits /= length inputs
  = error ("hpar: #circuits = " ++ show (length circuits) ++ " but #inputs = "
           ++ show (length inputs) ++ " inputs = " ++ show inputs ++ "\n")

hpar [] [] = return []
hpar (c:circuits) (i:inputs)
  = do (o, left_tile)  <- tile c i
       (or, right_tile) <- tile (hpar circuits) inputs
       putBeside left_tile right_tile
       return (o:or)

-------------------------------------------------------------------------------

middle lhs mid rhs (li, ri)
  = do (lv, left)  <- tile lhs li
       (rv, right) <- tile rhs ri
       (o, middle) <- tile mid (lv, rv)
       putBeside middle right
       putBeside left (mergeTiles [middle, right])
       return o  

-------------------------------------------------------------------------------

halveList xs
  = (take half_n xs, drop half_n xs)
     where
     n = length xs
     half_n = n `div` 2

-------------------------------------------------------------------------------

tree :: Combinators m => ((a,a) -> m a) -> [a] -> m a
tree circuit [x] = return x
tree circuit [x, y] = circuit (x, y)
tree circuit input
  = (middle (tree circuit) circuit (tree circuit)) (halveList input)

-------------------------------------------------------------------------------

-- The top level definition ensures that no pipeline delays
-- are added for the case when there is only one input to the tree.

pipeTree :: (Show a, Combinators m) => 
            (a -> m a) -> ((a,a) -> m a) -> [a] -> m a
pipeTree f circuit [x] = return x
pipeTree f circuit input = pipeTree' f circuit input            

pipeTree' :: (Show a, Combinators m) => 
            (a -> m a) -> ((a,a) -> m a) -> [a] -> m a
pipeTree' f circuit [x] = f x
pipeTree' f circuit [x, y] = (circuit >|> f) (x, y)
pipeTree' f circuit input
  = (middle (pipeTree' f circuit) (circuit >|> f)
            (pipeTree' f circuit)) (halveList input)

-------------------------------------------------------------------------------

pipeTree2 :: (Show a, Combinators m) =>
             (a -> m a) -> ((a,a) -> m a) -> [(a -> m a, a)] 
             -> m a
pipeTree2 f circuit [(g, x)] = g x
pipeTree2 f circuit [(g1, x), (g2, y)] 
  = (pair2list >-> (hpar [g1, g2]) >-> list2pair >-> circuit >-> f) (x, y)
pipeTree2 f circuit input
  = (middle (pipeTree2 f circuit) (circuit >-> f)
            (pipeTree2 f circuit)) (halveList input)

-------------------------------------------------------------------------------

list2pair :: Monad m => [a] -> m (a, a)           
list2pair [a, b] = return (a, b)  

pair2list (a, b) = return [a, b]         
           
-------------------------------------------------------------------------------
-- 4-Sided Tile Combinators
-------------------------------------------------------------------------------

-- BESIDE

--            d             g
--            |             |
--            ^             ^
--          -----         -----
--         |     |       |     |
--     b ->|  r  |-> c ->|  s  |-> f
--         |     |       |     |
--          -----         -----
--            |             |
--            ^             ^
--            a             e


beside :: Combinators m =>
           (((a, b) -> m (c, d)) ->          -- type of r
           ((e, c)  -> m (f, g)) ->          -- type of s
           (((a,e), b)) -> m (f, (d,g)))     -- type of result

beside r s ((a,e), b)
  = do ((c,d), left)  <- tile r (a, b)
       ((f,g), right) <- tile s (e, c)
       putBeside left right
       return (f, (d,g)) 
       
-------------------------------------------------------------------------------

-- BELOW

--            g
--            ^
--            |
--          -----
--         |     |
--     e ->|  s  |-> f
--         |     |
--          -----
--            ^
--            |
--            d
--            ^
--            |
--          -----
--         |     |
--     b ->|  r  |-> c
--         |     |
--          -----
--            ^ 
--            |
--            a

below :: Combinators m =>
         (((a, b) -> m (c, d)) ->             -- type of r
          ((d, e) -> m (f, g)) ->             -- type of s
          ((a, (b,e)) -> m ((c,f), g)))       -- type of result

below r s (a, (b,e))
  = do ((c,d), lower) <- tile r (a, b)
       ((f,g), upper) <- tile s (d, e)
       putBelow lower upper
       return ((c,f), g)
                             
-------------------------------------------------------------------------------
       
-- ROW

row :: Combinators m =>
       Int ->                        -- Number of tiles in row
       ((a, b) -> m (b, d)) ->       -- Circuit to replicate
       ([a], b) ->                   -- Type of input for row
       m (b, [d])                    -- Type of output for row

row 0 r input = error ("row 0")
row 1 r ([a],b)
  = do (c,d) <- r (a, b)
       return (c, [d])
row n r (a:as, b)
  = do (c, (d,ds)) <- (r `beside` (row (n-1) r)) ((a,as), b)
       return (c, d:ds)

-------------------------------------------------------------------------------

-- COL

col :: (Show a, Show b, Combinators m) =>
       Int ->                          -- Number of tiles in column
       ((a, b) -> m (c, a)) ->         -- Circuit to replicate
       (a, [b]) ->                     -- Type of input for col
       m ([c], a)                      -- Type of output for col

col 0 r input = error ("col 0 but input = " ++ show input ++ "\n")
col 1 r (a, [b])
  = do (c,d) <- r (a, b)
       return ([c],d)
col n r (a, b:bs)
  = do ((c,cs), d) <- (r `below` (col (n-1) r)) (a, (b, bs))
       return (c:cs, d)
col n r other = error ("col shape error: col " ++ show n ++ " " ++
                       show other ++ "\n")

-------------------------------------------------------------------------------

                                                  