tweak lambda evaluator
[lambda.git] / code / lambda.ml
index 146c9eb..1eaca65 100644 (file)
 
 module Private =  struct
     type var_t = int*string
-    let var v = (0,v)
-    let string_of_var (i,v) = v ^ String.make i '\''
-    let equal_var (i1,v1) (i2,v2) = i1 == i2 && (String.compare v1 v2 == 0)
+    let var v = (0, v)
+    let string_of_var (i, v) = v ^ String.make i '\''
+    let equal_var (i1, v1) (i2, v2) = i1 == i2 && (String.compare v1 v2 == 0)
 
     type lambda_t = [ `Var of var_t | `Lam of var_t * lambda_t | `App of lambda_t * lambda_t ]
 
-    type debruijn_t = [ `Var of var_t | `DVar of int | `DLam of debruijn_t | `DApp of debruijn_t*debruijn_t ]
+(* DeBruijn terms
+ * substitution and translation algorithms from Chris Hankin, An Introduction to Lambda Calculi for Comptuer Scientists
+ *)
 
-    let db_subst (expr : debruijn_t) (m : int) (repl : debruijn_t) =
-        let rec rename m i = function
-        | `Var _ as term -> term
-        | `DVar j as term when j < i -> term
-        | `DVar j -> `DVar (j + m - 1)
-        | `DApp(n1,n2) -> `DApp(rename m i n1, rename m i n2)
-        | `DLam n -> `DLam(rename m (i+1) n)
+    type debruijn_t = [ `Db_free of var_t | `Db_index of int | `Db_lam of debruijn_t | `Db_app of debruijn_t*debruijn_t ]
+
+    let debruijn_subst (expr : debruijn_t) (m : int) (new_term : debruijn_t) =
+        let rec renumber m i = function
+        | `Db_free _ as term -> term
+        | `Db_index j as term when j < i -> term
+        | `Db_index j -> `Db_index (j + m - 1)
+        | `Db_app(left, right) -> `Db_app(renumber m i left, renumber m i right)
+        | `Db_lam body -> `Db_lam(renumber m (i+1) body)
         in let rec loop m = function
-        | `Var _ as term -> term
-        | `DVar n as term when n < m -> term
-        | `DVar n when n > m -> `DVar (n-1)
-        | `DVar n -> rename n 1 repl
-        | `DApp(m1,m2) -> `DApp(loop m m1, loop m m2)
-        | `DLam mterm -> `DLam(loop (m+1) mterm)
+        | `Db_free _ as term -> term
+        | `Db_index j as term when j < m -> term
+        | `Db_index j when j > m -> `Db_index (j-1)
+        | `Db_index j -> renumber j 1 new_term
+        | `Db_app(left, right) -> `Db_app(loop m left, loop m right)
+        | `Db_lam body -> `Db_lam(loop (m+1) body)
         in loop m expr
 
-    let db (expr : lambda_t) : debruijn_t =
-        let pos seq (target : var_t) handler default =
+    let debruijn (expr : lambda_t) : debruijn_t =
+        let pos seq (target : var_t) =
             let rec loop (i : int) = function
-            | [] -> default
-            | x::xs when equal_var x target -> handler i
+            | [] -> `Db_free target
+            | x::xs when equal_var x target -> `Db_index i
             | _::xs -> loop (i+1) xs
             in loop 1 seq
         in let rec loop seq = function
-        | `Var v as term -> pos seq v (fun i -> `DVar i) term
-        | `Lam (v,t) -> `DLam(loop (v::seq) t)
-        | `App (t1,t2) -> `DApp(loop seq t1, loop seq t2)
+        | `Var v -> pos seq v
+        | `Lam (v, body) -> `Db_lam(loop (v::seq) body)
+        | `App (left, right) -> `Db_app(loop seq left, loop seq right)
         in loop [] expr
 
-    let rec db_equal (t1 : debruijn_t) (t2 : debruijn_t) = match (t1,t2) with
-    | (`Var v1,`Var v2) -> equal_var v1 v2
-    | (`DVar i1, `DVar i2) -> i1 == i2
-    | (`DApp(m1,m2),`DApp(n1,n2)) -> db_equal m1 n1 && db_equal m2 n2
-    | (`DLam(t1),`DLam(t2)) -> db_equal t1 t2
+    let rec dbruijn_equal (t1 : debruijn_t) (t2 : debruijn_t) = match (t1, t2) with
+    | (`Db_free v1, `Db_free v2) -> equal_var v1 v2
+    | (`Db_index j1, `Db_index j2) -> j1 == j2
+    | (`Db_app(left1, right1), `Db_app(left2, right2)) -> dbruijn_equal left1 left2 && dbruijn_equal right1 right2
+    | (`Db_lam(body1), `Db_lam(body2)) -> dbruijn_equal body1 body2
     | _ -> false
 
-    let rec db_contains (t1 : debruijn_t) (t2 : debruijn_t) = match (t1,t2) with
-    | (`Var v1,`Var v2) -> equal_var v1 v2
-    | (`DVar i1, `DVar i2) -> i1 == i2
-    | (`DApp(m1,m2),`DApp(n1,n2)) when db_equal m1 n1 && db_equal m2 n2 -> true
-    | (`DApp(m1,m2), term) -> db_contains m1 term || db_contains m2 term
-    | (`DLam(t1),`DLam(t2)) when db_equal t1 t2 -> true
-    | (`DLam(t1), term) -> db_contains t1 term
+    let rec debruijn_contains (t1 : debruijn_t) (t2 : debruijn_t) = match (t1, t2) with
+    | (`Db_free v1, `Db_free v2) -> equal_var v1 v2
+    | (`Db_index j1, `Db_index j2) -> j1 == j2
+    | (`Db_app(left1, right1), `Db_app(left2, right2)) when dbruijn_equal left1 left2 && dbruijn_equal right1 right2 -> true
+    | (`Db_app(left, right), term2) -> debruijn_contains left term2 || debruijn_contains right term2
+    | (`Db_lam(body1), `Db_lam(body2)) when dbruijn_equal body1 body2 -> true
+    | (`Db_lam(body1), term2) -> debruijn_contains body1 term2
     | _ -> false
 
+
     (* non-normalizing string_of_lambda *)
     let string_of_lambda (expr : lambda_t) =
         let rec top = function
             | `Var v -> string_of_var v
-            | `Lam _ as t -> "fun " ^ funct t
-            | `App ((`App _ as t1),t2) -> top t1 ^ " " ^ atom t2
-            | `App (t1,t2) -> atom t1 ^ " " ^ atom t2
+            | `Lam _ as term -> "fun " ^ dotted term
+            | `App ((`App _ as left), right) -> top left ^ " " ^ atom right
+            | `App (left, right) -> atom left ^ " " ^ atom right
         and atom = function
             | `Var v -> string_of_var v
-            | `Lam _ as t -> "(fun " ^ funct t ^ ")"
-            | `App _ as t -> "(" ^ top t ^ ")"
-        and funct = function
-            | `Lam (v,(`Lam _ as t)) -> (string_of_var v) ^ " " ^ funct t
-            | `Lam (v,t) -> (string_of_var v) ^ " -> " ^ top t
+            | `Lam _ as term -> "(fun " ^ dotted term ^ ")"
+            | `App _ as term -> "(" ^ top term ^ ")"
+        and dotted = function
+            | `Lam (v, (`Lam _ as body)) -> (string_of_var v) ^ " " ^ dotted body
+            | `Lam (v, body) -> (string_of_var v) ^ " -> " ^ top body
         in top expr
 
+(*
+ * substitution and normal-order evaluator based on Haskell version by Oleg Kisleyov
+ * http://okmij.org/ftp/Computation/lambda-calc.html#lambda-calculator-haskell
+ *)
 
-    (* evaluator based on http://okmij.org/ftp/Haskell/Lambda_calc.lhs *)
-
-    (* if v occurs free_in term, returns Some v' where v' is the highest-tagged
-     * variable with the same name as v occurring (free or bound) in term *)
-
+(* if v occurs free_in term, returns Some v' where v' is the highest-tagged
+ * variable with the same name as v occurring (free or bound) in term
+ *)
     let free_in ((tag, name) as v) term =
         let rec loop = function
         | `Var((tag', name') as v') ->
                 if name <> name' then false, v
                 else if tag = tag' then true, v
                 else false, v'
-        | `App(t1, t2) ->
-                let b1, ((tag1, _) as v1) = loop t1 in
-                let b2, ((tag2, _) as v2) = loop t2 in
-                b1 || b2, if tag1 > tag2 then v1 else v2
-        | `Lam(x, _) when x = v -> (false, v)
+        | `App(left, right) ->
+                let left_bool, ((left_tag, _) as left_v) = loop left in
+                let right_bool, ((right_tag, _) as right_v) = loop right in
+                left_bool || right_bool, if left_tag > right_tag then left_v else right_v
+        | `Lam(v', _) when equal_var v v' -> (false, v)
         | `Lam(_, body) -> loop body
         in match loop term with
         | false, _ -> None
         | true, v -> Some v
 
-    let rec subst v st = function
-        | term when st = `Var v -> term
-        | `Var x when x = v -> st
-        | `Var _ as term -> term
-        | `App(t1,t2) -> `App(subst v st t1, subst v st t2)
-        | `Lam(x, _) as term when x = v -> term
-        (* if x is free in the inserted term st, a capture is possible
-         * we handle by ...
-         *)
-        | `Lam(x, body) ->
-                (match free_in x st with
-                (* x not free in st, can substitute st for v without any captures *)
-                | None -> `Lam(x, subst v st body)
-                (* x free in st, need to alpha-convert `Lam(x, body) *)
-                | Some max_x ->  
-                    let bump_tag (tag, name) (tag', _) =
-                        (max tag tag') + 1, name in
-                    let bump_tag' ((_, name) as v1) ((_, name') as v2) =
-                        if name = name' then bump_tag v1 v2 else v1 in
-                    (* bump x > max_x from st, then check whether
-                     * it also needs to be bumped > v
-                     *)
-                    let uniq_x = bump_tag' (bump_tag x max_x) v in
-                    let uniq_x' = (match free_in uniq_x body with
-                        | None -> uniq_x
-                        (* bump uniq_x > max_x' from body *)
-                        | Some max_x' -> bump_tag uniq_x max_x'
-                    ) in
-                    (* alpha-convert body *)
-                    let body' = subst x (`Var uniq_x') body in
-                    (* now substitute st for v *)
-                    `Lam(uniq_x', subst v st body')
-                )
+    let rec subst v new_term term = match new_term with
+        | `Var v' when equal_var v v' -> term
+        | _ -> (match term with
+            | `Var v' when equal_var v v' -> new_term
+            | `Var _ -> term
+            | `App(left, right) -> `App(subst v new_term left, subst v new_term right)
+            | `Lam(v', _) when equal_var v v' -> term
+            (* if x is free in the inserted term new_term, a capture is possible *)
+            | `Lam(v', body) ->
+                    (match free_in v' new_term with
+                    (* v' not free in new_term, can substitute new_term for v without any captures *)
+                    | None -> `Lam(v', subst v new_term body)
+                    (* v' free in new_term, need to alpha-convert *)
+                    | Some max_x ->  
+                        let bump_tag (tag, name) (tag', _) =
+                            (max tag tag') + 1, name in
+                        let bump_tag' ((_, name) as v1) ((_, name') as v2) =
+                            if (String.compare name name' == 0) then bump_tag v1 v2 else v1 in
+                        (* bump v' > max_x from new_term, then check whether
+                         * it also needs to be bumped > v
+                         *)
+                        let uniq_x = bump_tag' (bump_tag v' max_x) v in
+                        let uniq_x' = (match free_in uniq_x body with
+                            | None -> uniq_x
+                            (* bump uniq_x > max_x' from body *)
+                            | Some max_x' -> bump_tag uniq_x max_x'
+                        ) in
+                        (* alpha-convert body *)
+                        let body' = subst v' (`Var uniq_x') body in
+                        (* now substitute new_term for v *)
+                        `Lam(uniq_x', subst v new_term body')
+                    )
+        )
 
     let check_eta = function
-        | `Lam(v, `App(t, `Var u)) when v = u && free_in v t = None -> t
+        | `Lam(v, `App(body, `Var u)) when equal_var v u && free_in v body = None -> body
         | (_ : lambda_t) as term -> term
 
+
+
+
     exception Lambda_looping;;
 
     let eval ?(eta=false) (expr : lambda_t) : lambda_t =
         let rec looping (body : debruijn_t) = function
-        | [] -> false
-        | x::xs when db_equal body x -> true
+          | [] -> false
+        | x::xs when dbruijn_equal body x -> true
         | _::xs -> looping body xs
         in let rec loop (stack : lambda_t list) (body : lambda_t) = 
             match body with
             | `Var v as term -> unwind term stack
-            | `App(t1, t2) as term -> loop (t2::stack) t1
+            | `App(left, right) -> loop (right::stack) left
             | `Lam(v, body) -> (match stack with
                 | [] ->
                     let term = (`Lam(v, loop [] body)) in
                         if eta then check_eta term else term
-                | t::rest -> loop rest (subst v t body)
+                | x::xs -> loop xs (subst v x body)
             )
-        and unwind t1 = function
-        | [] -> t1
-        | t2::ts -> unwind (`App(t1, loop [] t2)) ts
+        and unwind left = function
+        | [] -> left
+        | x::xs -> unwind (`App(left, loop [] x)) xs
         in loop [] expr
 
 
-    (* (Oleg's version of) Ken's evaluator; doesn't seem to work -- requires laziness? *)
+    let cbv ?(aggressive=true) (expr : lambda_t) : lambda_t =
+        let rec loop = function
+        | `Var v as term -> term
+        | `App(left, right) ->
+                let right' = loop right in
+                (match loop left with
+                | `Lam(v, body) -> loop (subst v right' body)
+                | _ as left' -> `App(left', right')
+                )
+        | `Lam(v, body) as term ->
+                if aggressive then `Lam(v, loop body)
+                else term
+        in loop expr
+
+
+
 
+
+    (*
+    
+     (* (Oleg's version of) Ken's evaluator; doesn't seem to work -- requires laziness? *)
     let eval' ?(eta=false) (expr : lambda_t) : lambda_t =
         let rec loop = function
         | `Var v as term -> term
         | `Lam(v, body) ->
                 let term = (`Lam(v, loop body)) in
                     if eta then check_eta term else term
-        | `App(`App _ as t1, t2) ->
-            (match loop t1 with
-                | `Lam _ as redux -> loop (`App(redux, t2))
-                | nonred_head -> `App(nonred_head, loop t2)
+        | `App(`App _ as left, right) ->
+            (match loop left with
+                | `Lam _ as redux -> loop (`App(redux, right))
+                | nonred_head -> `App(nonred_head, loop right)
             )
-        | `App(t1, t2) -> `App(t1, loop t2)
-        in loop expr
-
-    let cbv ?(aggressive=true) (expr : lambda_t) : lambda_t =
-        let rec loop = function
-        | `Var x as term -> term
-        | `App(t1,t2) ->
-                let t2' = loop t2 in
-                (match loop t1 with
-                | `Lam(x, t) -> loop (subst x t2' t)
-                | _ as term -> `App(term, t2')
-                )
-        | `Lam(x, t) as term ->
-                if aggressive then `Lam(x, loop t)
-                else term
+        | `App(left, right) -> `App(left, loop right)
         in loop expr
 
 
-
-    (*
         module Sorted = struct
             let rec cons y = function
                 | x :: _ as xs when x = y -> xs
@@ -215,8 +228,8 @@ module Private =  struct
         let free_vars (expr : lambda_t) : string list =
             let rec loop = function
                 | `Var x -> [x]
-                | `Lam(x,t) -> Sorted.remove x (loop t)
-                | `App(t1,t2) -> Sorted.merge (loop t1) (loop t2)
+                | `Lam(x, t) -> Sorted.remove x (loop t)
+                | `App(t1, t2) -> Sorted.merge (loop t1) (loop t2)
             in loop expr
 
         let free_in v (expr : lambda_t) =
@@ -228,10 +241,10 @@ module Private =  struct
 
         ...
         | `Lam(x, body) as term when not (free_in v body) -> term
-        | `Lam(y, body) when not (free_in y st) -> `Lam(y, subst v st body)
+        | `Lam(y, body) when not (free_in y new_term) -> `Lam(y, subst v new_term body)
         | `Lam(y, body) ->
             let z = new_var () in
-            subst v st (`Lam(z, subst y (`Var z) body))
+            subst v new_term (`Lam(z, subst y (`Var z) body))
     *)
 
 
@@ -241,14 +254,14 @@ module Private =  struct
     let bound_vars (expr : lambda_t) : string list =
         let rec loop = function
             | `Var x -> []
-            | `Lam(x,t) -> Sorted.cons x (loop t)
-            | `App(t1,t2) -> Sorted.merge (loop t1) (loop t2)
+            | `Lam(x, t) -> Sorted.cons x (loop t)
+            | `App(t1, t2) -> Sorted.merge (loop t1) (loop t2)
         in loop expr
 
     let reduce_cbv ?(aggressive=true) (expr : lambda_t) : lambda_t =
         let rec loop = function
         | `Var x as term -> term
-        | `App(t1,t2) ->
+        | `App(t1, t2) ->
                 let t2' = loop t2 in
                 (match loop t1 with
                 | `Lam(x, t) -> loop (subst x t2' t)
@@ -264,7 +277,7 @@ module Private =  struct
         | `Var x as term -> term
         | `Lam(v, body) ->
                 check_eta (`Lam(v, loop body))
-        | `App(t1,t2) ->
+        | `App(t1, t2) ->
                 (match loop t1 with
                 | `Lam(x, t) -> loop (subst x t2 t)
                 | _ as term -> `App(term, loop t2)
@@ -305,13 +318,13 @@ module Private =  struct
             | `Lam(x, body) ->
                 (fun env ->
                     let v = new_var () in
-                    `Lam(v, inner body ((x,`Var v) :: env)))
+                    `Lam(v, inner body ((x, `Var v) :: env)))
         in inner expr ([] : env_t)
 
     let pp_env env =
         let rec loop acc = function
             | [] -> acc
-            | (x,term)::es -> loop ((x ^ "=" ^ string_of_lambda term) :: acc) es
+            | (x, term)::es -> loop ((x ^ "=" ^ string_of_lambda term) :: acc) es
         in "[" ^ (String.concat ", " (loop [] (List.rev env))) ^ "]"
 
     let eval (strategy : strategy_t) (expr : lambda_t) : lambda_t =
@@ -340,7 +353,7 @@ module Private =  struct
             | `Lam(x, body) ->
                 (fun env ->
                     let v = new_var () in
-                    `Lam(v, inner body ((x,`Var v) :: env)))
+                    `Lam(v, inner body ((x, `Var v) :: env)))
             in
             (fun env -> 
                 (Printf.printf "%s with %s => %s\n" (string_of_lambda term) (pp_env env) (string_of_lambda (res env)); res env))
@@ -356,7 +369,7 @@ module Private =  struct
 
     let rec to_int expr = match expr with
         | `Lam(s, `Lam(z, `Var z')) when z' = z -> 0
-        | `Lam(s, `Var s') when s = s' -> 1
+        | `Lam(s, `Var s') when equal_var s s' -> 1
         | `Lam(s, `Lam(z, `App (`Var s', t))) when s' = s -> 1 + to_int (`Lam(s, `Lam(z, t)))
         | _ -> failwith (normal_string_of_lambda expr ^ " is not a church numeral")
 
@@ -369,8 +382,8 @@ type lambda_t = Private.lambda_t
 open Private
 let var = var
 let pp, pn, pi = string_of_lambda, normal_string_of_lambda, int_of_lambda
-let pnv,piv= (fun expr -> string_of_lambda (cbv expr)), (fun expr -> to_int (cbv expr))
-let db, db_equal, db_contains = db, db_equal, db_contains
+let pnv, piv= (fun expr -> string_of_lambda (cbv expr)), (fun expr -> to_int (cbv expr))
+let debruijn, dbruijn_equal, debruijn_contains = debruijn, dbruijn_equal, debruijn_contains
 
-let alpha_eq x f = db_equal (db x) (db y)
+let alpha_eq x y = dbruijn_equal (debruijn x) (debruijn y)