module Functions where

-- This module implements functions as expressions in Reverse Polish Notation
-- The implementation includes evaluation for specified points and partial 
-- derivatives.
-- 
-- Written 2006 by Gerhard Navratil
-- Institute for Geoinformation and Cartography, Vienna University of Technology

import List

class MathFunction a b where
-- This class implements mathematical functions and the main transformations
  eval        :: a b -> a b
  -- simplifies the function by merging values if possible
  evaluate    :: a b -> a b
  -- simplifies the function as much as possible
  evaluateAt  :: VarName -> b -> a b -> a b
  -- replaces variable 'VarName' by number 'b'
  evaluateAt' :: [(VarName,b)] -> a b -> a b
  -- replaces all variables in the list by the respective numbers
  containsVar :: VarName -> a b -> Bool
  -- tests if a function is dependent on the variable
  derivative  :: VarName -> a b -> a b
  -- computes the first derivative with respect to the variable
  listVariables :: a b -> [VarName]
  -- lists all variablen in the function
  partDerivatives :: [VarName] -> a b -> [a b]
  -- derives the pertial derivatives for all variables in the list
  
  makeVal :: b -> a b
  getVal  :: a b -> b

-- Types and Instances
type VarName = String
data Fkt a = Val a |
             Var VarName |
             Add (Fkt a) (Fkt a) |
             Sub (Fkt a) (Fkt a) |
             Mul (Fkt a) (Fkt a) |
             Div (Fkt a) (Fkt a) |
             Abs (Fkt a) |
             Exp (Fkt a) (Fkt a) |
             Sinus (Fkt a) |
             Cosinus (Fkt a) |
             Tangens (Fkt a) |
             ASinus (Fkt a) |
             ACosinus (Fkt a) |
             ATangens (Fkt a)
          deriving Eq

instance (Show a) => Show (Fkt a) where
  show (Var x) = x
  show (Val x) = show x
  show (Add a b)     = "(" ++ show a ++ ")+(" ++ show b ++ ")"
  show (Sub a b)    = "(" ++ show a ++ ")-(" ++ show b ++ ")"
  show (Mul a b) = "(" ++ show a ++ ")*(" ++ show b ++ ")"
  show (Div a b)   = "(" ++ show a ++ ")/(" ++ show b ++ ")"
  show (Abs x)        = "|" ++ show x ++ "|"
  show (Exp a b)      = "(" ++ show a ++ ")^(" ++ show b ++ ")"
  show (Sinus a)      = "sin (" ++ show a ++ ")"
  show (Cosinus a)    = "cos (" ++ show a ++ ")"
  show (Tangens a)    = "tan (" ++ show a ++ ")"
  show (ASinus a)     = "asin (" ++ show a ++ ")"
  show (ACosinus a)   = "acos (" ++ show a ++ ")"
  show (ATangens a)   = "atan (" ++ show a ++ ")"

instance (Num a,Ord a) => Num (Fkt a) where
  (+) a b = Add a b
  (-) a b = Sub a b
  (*) a b = Mul a b
  abs a = a
  fromInteger x = Val (fromInteger x)
  signum (Val x) = if x < 0 then -1 else if x == 0 then 0 else 1
  signum _ = 0

instance (Fractional a,Ord a) => Fractional (Fkt a) where
  (/) a b = Div a b
  fromRational x = Val (fromRational x)

instance (Floating a) => MathFunction Fkt a where
  eval (Val a) = Val a
  eval (Var n) = Var n
  eval (Add (Val a) (Val b)) = Val (a+b)
  eval (Add (Val 0) b) = b
  eval (Add a (Val 0)) = a
  eval (Add a b) = Add (eval a) (eval b)
  eval (Sub (Val a) (Val b)) = Val (a-b)
  eval (Sub (Val 0) b) = Mul (Val (-1)) b
  eval (Sub a (Val 0)) = a
  eval (Sub a b) = Sub (eval a) (eval b)
  eval (Mul (Val a) (Val b)) = Val (a*b)
  eval (Mul (Val 0) b) = Val 0
  eval (Mul a (Val 0)) = Val 0
  eval (Mul (Val 1) b) = b
  eval (Mul a (Val 1)) = a
  eval (Mul a b) = Mul (eval a) (eval b)
  eval (Div (Val a) (Val b)) = Val (a / b)
  eval (Div (Val 0) b) = Val 0
  eval (Div a (Val 1)) = a
  eval (Div a b) = Div (eval a) (eval b)
  eval (Abs (Val a)) = Val (abs a)
  eval (Abs a) = Abs (eval a)
  eval (Exp (Val a) (Val b)) = Val (a ** b)
  eval (Exp (Val 0) b) = Val 0
  eval (Exp a (Val 0)) = Val 1
  eval (Exp (Val 1) b) = Val 1
  eval (Exp a (Val 1)) = a
  eval (Exp a b) = Exp (eval a) (eval b)
  eval (Sinus (Val a)) = Val (sin a)
  eval (Sinus a) = Sinus (eval a)
  eval (Cosinus (Val a)) = Val (cos a)
  eval (Cosinus a) = Cosinus (eval a)
  eval (Tangens (Val a)) = Val (tan a)
  eval (Tangens a) = Tangens (eval a)
  eval (ASinus (Val a)) = Val (asin a)
  eval (ASinus a) = ASinus (eval a)
  eval (ACosinus (Val a)) = Val (acos a)
  eval (ACosinus a) = ACosinus (eval a)
  eval (ATangens (Val a)) = Val (atan a)
  eval (ATangens a) = ATangens (eval a)

  evaluate f = if f == (eval f) then f else evaluate (eval f)

  evaluateAt name val (Var n) = if n == name then (Val val) else (Var n)
  evaluateAt name val (Val a) = Val a
  evaluateAt name val (Add a b) = Add (evaluateAt name val a) (evaluateAt name val b)
  evaluateAt name val (Sub a b) = Sub (evaluateAt name val a) (evaluateAt name val b)
  evaluateAt name val (Mul a b) = Mul (evaluateAt name val a) (evaluateAt name val b)
  evaluateAt name val (Div a b) = Div (evaluateAt name val a) (evaluateAt name val b)
  evaluateAt name val (Abs a) = Abs (evaluateAt name val a)
  evaluateAt name val (Exp a b) = Exp (evaluateAt name val a) (evaluateAt name val b)
  evaluateAt name val (Sinus a) = Sinus (evaluateAt name val a)
  evaluateAt name val (Cosinus a) = Cosinus (evaluateAt name val a)
  evaluateAt name val (Tangens a) = Tangens (evaluateAt name val a)
  evaluateAt name val (ASinus a) = ASinus (evaluateAt name val a)
  evaluateAt name val (ACosinus a) = ACosinus (evaluateAt name val a)
  evaluateAt name val (ATangens a) = ATangens (evaluateAt name val a)

  evaluateAt' [] fkt = fkt
  evaluateAt' (v:vs) fkt = evaluateAt' vs (evaluateAt (fst v) (snd v) fkt)

  containsVar name fkt = not (evaluateAt name 0.0 fkt == fkt)

  derivative name (Var n) = if n == name then Val 1.0 else Val 0.0
  derivative name (Add a b)
    | (containsVar name a) && (containsVar name b) = Add (derivative name a) (derivative name b)
    | (containsVar name a) && (not(containsVar name b)) = (derivative name a)
    | (not(containsVar name a)) && (containsVar name b) = (derivative name b)
    | otherwise = Val 0.0
  derivative name (Sub a b)
    | (containsVar name a) && (containsVar name b) = Sub (derivative name a) (derivative name b)
    | (containsVar name a) && (not(containsVar name b)) = (derivative name a)
    | (not(containsVar name a)) && (containsVar name b) = (derivative name b)
    | otherwise = Val 0.0
  derivative name (Mul a b)
    | (containsVar name a) && (containsVar name b) = 
         Add (Mul (derivative name a) b) (Mul a (derivative name b))
    | (containsVar name a) && (not(containsVar name b)) =
         Mul (derivative name a) b
    | (not(containsVar name a)) && (containsVar name b) =
         Mul a (derivative name b)
    | otherwise = Val 0.0
  derivative name (Div a b)
    | (containsVar name a) && (containsVar name b) = 
         Div (Sub (Mul (derivative name a) b) (Mul a (derivative name b)))
                (Mul b b)
    | (containsVar name a) && (not(containsVar name b)) =
         Div (derivative name a) b
    | (not(containsVar name a)) && (containsVar name b) =
         Div (Mul a (derivative name b)) (Mul b b)
    | otherwise = Val 0.0
  derivative name (Abs a) = if containsVar name a 
      then Mul (Mul a (Exp (Abs a) (Val (-2)))) (derivative name a)
      else Val 0.0
  derivative name (Exp a b)
    | (containsVar name a) && (containsVar name b) = error "Not yet implemented f(x)^g(x)"
    | (containsVar name a) && (not(containsVar name b)) =
         Mul (Mul b (Exp a (Sub b (Val (1.0))))) (derivative name a)
    | (not(containsVar name a)) && (containsVar name b) = error "Not yet implemented f^g(x)"
    | otherwise = Val 0.0
  derivative name (Sinus a) =
    if containsVar name a then Mul (Cosinus a) (derivative name a) else Val 0.0
  derivative name (Cosinus a) =
    if containsVar name a 
      then Mul (Mul (Val (-1)) (Sinus a)) (derivative name a)
      else Val 0.0
  derivative name (Tangens a) =
    if containsVar name a
      then Mul (Div (Val 1.0) (Mul (Cosinus a) (Cosinus a))) (derivative name a)
      else Val 0.0
  derivative name (ASinus a) =
    if containsVar name a 
      then Mul (Exp (Sub (Val 1.0) (Exp a (Val 2.0))) (Val (-0.5))) (derivative name a)
      else Val 0.0
  derivative name (ACosinus a) =
    if containsVar name a 
      then Mul (Exp (Sub (Val (-1.0)) (Exp a (Val 2.0))) (Val (-0.5))) (derivative name a)
      else Val 0.0
  derivative name (ATangens a) =
    if containsVar name a
      then Mul (Exp (Add (Val 1.0) (Exp a (Val 2.0))) (Val (-1))) (derivative name a)
      else Val 0.0
  derivative name fkt = fkt -- how to do that?
  
  listVariables fkt = (nub.lv) fkt where
     lv (Var n) = [n]
     lv (Val _) = []
     lv (Add a b)     = lv a ++ lv b
     lv (Sub a b)    = lv a ++ lv b
     lv (Mul a b) = lv a ++ lv b
     lv (Div a b)   = lv a ++ lv b
     lv (Exp a b)      = lv a ++ lv b
     lv (Sinus a)      = lv a
     lv (Cosinus a)    = lv a
     lv (Tangens a)    = lv a
     lv (ASinus a)     = lv a
     lv (ACosinus a)   = lv a
     lv (ATangens a)   = lv a

  partDerivatives names fkt = map pd names where
    pd n = derivative n fkt
  
  makeVal x = Val x
  
  getVal (Val x) = x
  getVal _ = error "Result is still a function"
