As part of the Ruby Quiz in Haskell solutions appearing on the wiki
recently, I just a solution to Ruby Quiz #100 - create a bytecode
interpreter for a simple expression language.

Like I said, the code below parses simple integer arithmetic
expressions and generates byte codes for a hypothetical stack-based
intepreter that would evaluate the expressions. To run it, save it as
a literate haskell file and run "interpret_tests". That just shows
correctness, though. Other output can be obtained by running
"compile_tests" (shows bytes for all tests), "generate_tests"
(symbolic bytecodes for all tests), and "eval_tests" (evaluate ASTs
for all tests).

To see the AST generated for a example expression, try something like
'parse "2-2-2"'.

I'm just learning Haskell (about a month in) and if anyone has time
and desire to critique the code below, I'd love to hear it. I come
from an OOP (primarily C# & Ruby) background, so I'm really interested
in getting a handle on the functional/Haskell "way" of coding. Thanks
for any feedback!

Justin

p.s. This code is also available on the wiki:
http://www.haskell.org/haskellwiki/Haskell_Quiz/Bytecode_Compiler/Solution_Justin_Bailey
p.p.s. The original ruby quiz is available at:
http://www.rubyquiz.com/quiz100.html

\begin{code}
import Text.ParserCombinators.Parsec hiding (parse)
import qualified Text.ParserCombinators.Parsec as P (parse)
import Text.ParserCombinators.Parsec.Expr
import Data.Bits
import Data.Int

-- Represents various operations that can be applied
-- to expressions.
data Op = Plus | Minus | Mult | Div | Pow | Mod | Neg
 deriving (Show, Eq)

-- Represents expression we can build - either numbers or expressions
-- connected by operators. This structure is the basis of the AST built
-- when parsing
data Expression = Statement Op Expression Expression
          | Val Integer
          | Empty
 deriving (Show)

-- Define the byte codes that can be generated.
data Bytecode = NOOP | CONST Integer | LCONST Integer
           | ADD
           | SUB
           | MUL
           | POW
           | DIV
           | MOD
           | SWAP
 deriving (Show)


-- Using imported Parsec.Expr library, build a parser for expressions.
expr :: Parser Expression
expr =
 buildExpressionParser table factor
 <?> "expression"
 where
 -- Recognizes a factor in an expression
 factor  =
   do{ char '('
         ; x <- expr
         ; char ')'
         ; return x
         }
     <|> number
     <?> "simple expression"
 -- Recognizes a number
 number  :: Parser Expression
 number  = do{ ds <- many1 digit
             ; return (Val (read ds))
             }
         <?> "number"
 -- Specifies operator, associativity, precendence, and constructor to execute
 -- and built AST with.
 table =
   [[prefix "-" (Statement Mult (Val (-1)))],
     [binary "^" (Statement Pow) AssocRight],
     [binary "*" (Statement Mult) AssocLeft, binary "/" (Statement
Div) AssocLeft, binary "%" (Statement Mod) AssocLeft],
     [binary "+" (Statement Plus) AssocLeft, binary "-" (Statement
Minus) AssocLeft]
      ]
   where
     binary s f assoc
        = Infix (do{ string s; return f}) assoc
     prefix s f
        = Prefix (do{ string s; return f})

-- Parses a string into an AST, using the parser defined above
parse s = case P.parse expr "" s of
 Right ast -> ast
 Left e -> error $ show e

-- Take AST and evaluate (mostly for testing)
eval (Val n) = n
eval (Statement op left right)
       | op == Mult = eval left * eval right
       | op == Minus = eval left - eval right
       | op == Plus = eval left + eval right
       | op == Div = eval left `div` eval right
       | op == Pow = eval left ^ eval right
       | op == Mod = eval left `mod` eval right

-- Takes an AST and turns it into a byte code list
generate stmt = generate' stmt []
      where
              generate' (Statement op left right) instr =
                      let
                              li = generate' left instr
                              ri = generate' right instr
                              lri = li ++ ri
                      in case op of
                              Plus -> lri ++ [ADD]
                              Minus -> lri ++ [SUB]
                              Mult -> lri ++ [MUL]
                              Div -> lri ++ [DIV]
                              Mod -> lri ++ [MOD]
                              Pow -> lri ++ [POW]
              generate' (Val n) instr =
               if abs(n) > 32768
               then LCONST n : instr
               else CONST n : instr

-- Takes a statement and converts it into a list of actual bytes to
-- be interpreted
compile s = toBytes (generate $ parse s)

-- Convert a list of byte codes to a list of integer codes. If LCONST or CONST
-- instruction are seen, correct byte representantion is produced
toBytes ((NOOP):xs) = 0 : toBytes xs
toBytes ((CONST n):xs) = 1 : (toConstBytes (fromInteger n)) ++ toBytes xs
toBytes ((LCONST n):xs) = 2 : (toLConstBytes (fromInteger n)) ++ toBytes xs
toBytes ((ADD):xs) = 0x0a : toBytes xs
toBytes ((SUB):xs) = 0x0b : toBytes xs
toBytes ((MUL):xs) = 0x0c : toBytes xs
toBytes ((POW):xs) = 0x0d : toBytes xs
toBytes ((DIV):xs) = 0x0e : toBytes xs
toBytes ((MOD):xs) = 0x0f : toBytes xs
toBytes ((SWAP):xs) = 0x0a : toBytes xs
toBytes [] = []

-- Convert number to CONST representation (2 element list)
toConstBytes n = toByteList 2 n
toLConstBytes n = toByteList 4 n

-- Convert a number into a list of 8-bit bytes (big-endian/network byte order).
-- Make sure final list is size elements long
toByteList ::  Bits Int => Int -> Int -> [Int]
toByteList size n = reverse $ take size (toByteList' n)
   where
     toByteList' a = (a .&. 255) : toByteList' (a `shiftR` 8)

-- All tests defined by the quiz, with the associated values they
should evaluate to.
test1 = [(2+2, "2+2"), (2-2, "2-2"), (2*2, "2*2"), (2^2, "2^2"), (2
`div` 2, "2/2"),
 (2 `mod` 2, "2%2"), (3 `mod` 2, "3%2")]

test2 = [(2+2+2, "2+2+2"), (2-2-2, "2-2-2"), (2*2*2, "2*2*2"), (2^2^2,
"2^2^2"), (4 `div` 2 `div` 2, "4/2/2"),
 (7`mod`2`mod`1, "7%2%1")]

test3 = [(2+2-2, "2+2-2"), (2-2+2, "2-2+2"), (2*2+2, "2*2+2"), (2^2+2, "2^2+2"),
 (4 `div` 2+2, "4/2+2"), (7`mod`2+1, "7%2+1")]

test4 = [(2+(2-2), "2+(2-2)"), (2-(2+2), "2-(2+2)"), (2+(2*2),
"2+(2*2)"), (2*(2+2), "2*(2+2)"),
 (2^(2+2), "2^(2+2)"), (4 `div` (2+2), "4/(2+2)"), (7`mod`(2+1), "7%(2+1)")]

test5 = [(-2+(2-2), "-2+(2-2)"), (2-(-2+2), "2-(-2+2)"), (2+(2 * -2),
"2+(2*-2)")]

test6 = [((3 `div` 3)+(8-2), "(3/3)+(8-2)"), ((1+3) `div` (2 `div`
2)*(10-8), "(1+3)/(2/2)*(10-8)"),
   ((1*3)*4*(5*6), "(1*3)*4*(5*6)"), ((10`mod`3)*(2+2),
"(10%3)*(2+2)"), (2^(2+(3 `div` 2)^2), "2^(2+(3/2)^2)"),
   ((10 `div` (2+3)*4), "(10/(2+3)*4)"), (5+((5*4)`mod`(2+1)),
"5+((5*4)%(2+1))")]

-- Evaluates the tests and makes sure the expressions match the expected values
eval_tests = concat $ map eval_tests [test1, test2, test3, test4, test5, test6]
 where
   eval_tests ((val, stmt):ts) =
     let eval_val = eval $ parse stmt
     in
       if val == eval_val
       then ("Passed: " ++ stmt) : eval_tests ts
       else ("Failed: " ++ stmt ++ "(" ++ show eval_val ++ ")") : eval_tests ts
   eval_tests [] = []

-- Takes all the tests and displays symbolic bytes codes for each
generate_tests = concat $ map generate_all [test1,test2,test3,test4,test5,test6]
 where generate_all ((val, stmt):ts) = (stmt, generate (parse stmt))
: generate_all ts
       generate_all [] = []

-- Takes all tests and generates a list of bytes representing them
compile_tests = concat $ map compile_all [test1,test2,test3,test4,test5,test6]
 where compile_all ((val, stmt):ts) = (stmt, compile stmt) : compile_all ts
       compile_all [] = []

interpret_tests = concat $ map f' [test1, test2, test3, test4, test5, test6]
 where
   f' tests = map f'' tests
   f'' (expected, stmt) =
     let value = fromIntegral $ interpret [] $ compile stmt
     in
       if value == expected
       then "Passed: " ++ stmt
       else "Failed: " ++ stmt ++ "(" ++ (show value) ++ ")"

fromBytes n xs =
 let int16 = (fromIntegral ((fromIntegral int32) :: Int16)) :: Int
     int32 = byte xs
     byte xs = foldl (\accum byte -> (accum `shiftL` 8) .|. (byte))
(head xs) (take (n - 1) (tail xs))
 in
   if n == 2
   then int16
   else int32

interpret [] [] = error "no result produced"
interpret (s1:s) [] = s1
interpret s (o:xs) | o < 10 = interpret ((fromBytes (o*2) xs):s) (drop (o*2) xs)
interpret (s1:s2:s) (o:xs)
 | o == 16 = interpret (s2:s1:s) xs
 | otherwise = interpret (((case o of 10 -> (+); 11 -> (-); 12 ->
(*); 13 -> (^); 14 -> div; 15 -> mod) s2 s1):s) xs

\end{code}
_______________________________________________
Haskell-Cafe mailing list
[email protected]
http://www.haskell.org/mailman/listinfo/haskell-cafe

Reply via email to