week9 tweak
[lambda.git] / code / caml-lambda / lambda.ml
1 (* *)
2
3 module Private =  struct
4     type var_t = int*string
5     let var v = (0, v)
6     let string_of_var (i, v) = v ^ String.make i '\''
7     let equal_var (i1, v1) (i2, v2) = i1 == i2 && (String.compare v1 v2 == 0)
8
9     type lambda_t = [ `Var of var_t | `Lam of var_t * lambda_t | `App of lambda_t * lambda_t ]
10
11 (* DeBruijn terms
12  * substitution and translation algorithms from Chris Hankin, An Introduction to Lambda Calculi for Comptuer Scientists
13  *)
14
15     type debruijn_t = [ `Db_free of var_t | `Db_index of int | `Db_lam of debruijn_t | `Db_app of debruijn_t*debruijn_t ]
16
17     let debruijn_subst (expr : debruijn_t) (m : int) (new_term : debruijn_t) =
18         let rec renumber m i = function
19         | `Db_free _ as term -> term
20         | `Db_index j as term when j < i -> term
21         | `Db_index j -> `Db_index (j + m - 1)
22         | `Db_app(left, right) -> `Db_app(renumber m i left, renumber m i right)
23         | `Db_lam body -> `Db_lam(renumber m (i+1) body)
24         in let rec loop m = function
25         | `Db_free _ as term -> term
26         | `Db_index j as term when j < m -> term
27         | `Db_index j when j > m -> `Db_index (j-1)
28         | `Db_index j -> renumber j 1 new_term
29         | `Db_app(left, right) -> `Db_app(loop m left, loop m right)
30         | `Db_lam body -> `Db_lam(loop (m+1) body)
31         in loop m expr
32
33     let debruijn (expr : lambda_t) : debruijn_t =
34         let pos seq (target : var_t) =
35             let rec loop (i : int) = function
36             | [] -> `Db_free target
37             | x::xs when equal_var x target -> `Db_index i
38             | _::xs -> loop (i+1) xs
39             in loop 1 seq
40         in let rec loop seq = function
41         | `Var v -> pos seq v
42         | `Lam (v, body) -> `Db_lam(loop (v::seq) body)
43         | `App (left, right) -> `Db_app(loop seq left, loop seq right)
44         in loop [] expr
45
46     let rec dbruijn_equal (t1 : debruijn_t) (t2 : debruijn_t) = match (t1, t2) with
47     | (`Db_free v1, `Db_free v2) -> equal_var v1 v2
48     | (`Db_index j1, `Db_index j2) -> j1 == j2
49     | (`Db_app(left1, right1), `Db_app(left2, right2)) -> dbruijn_equal left1 left2 && dbruijn_equal right1 right2
50     | (`Db_lam(body1), `Db_lam(body2)) -> dbruijn_equal body1 body2
51     | _ -> false
52
53     let rec debruijn_contains (t1 : debruijn_t) (t2 : debruijn_t) = match (t1, t2) with
54     | (`Db_free v1, `Db_free v2) -> equal_var v1 v2
55     | (`Db_index j1, `Db_index j2) -> j1 == j2
56     | (`Db_app(left1, right1), `Db_app(left2, right2)) when dbruijn_equal left1 left2 && dbruijn_equal right1 right2 -> true
57     | (`Db_app(left, right), term2) -> debruijn_contains left term2 || debruijn_contains right term2
58     | (`Db_lam(body1), `Db_lam(body2)) when dbruijn_equal body1 body2 -> true
59     | (`Db_lam(body1), term2) -> debruijn_contains body1 term2
60     | _ -> false
61
62
63     (* non-normalizing string_of_lambda *)
64     let string_of_lambda (expr : lambda_t) =
65         let rec top = function
66             | `Var v -> string_of_var v
67             | `Lam _ as term -> "fun " ^ dotted term
68             | `App ((`App _ as left), right) -> top left ^ " " ^ atom right
69             | `App (left, right) -> atom left ^ " " ^ atom right
70         and atom = function
71             | `Var v -> string_of_var v
72             | `Lam _ as term -> "(fun " ^ dotted term ^ ")"
73             | `App _ as term -> "(" ^ top term ^ ")"
74         and dotted = function
75             | `Lam (v, (`Lam _ as body)) -> (string_of_var v) ^ " " ^ dotted body
76             | `Lam (v, body) -> (string_of_var v) ^ " -> " ^ top body
77         in top expr
78
79 (*
80  * substitution and normal-order evaluator based on Haskell version by Oleg Kisleyov
81  * http://okmij.org/ftp/Computation/lambda-calc.html#lambda-calculator-haskell
82  *)
83
84 (* if v occurs free_in term, returns Some v' where v' is the highest-tagged
85  * variable with the same name as v occurring (free or bound) in term
86  *)
87     let free_in ((tag, name) as v) term =
88         let rec loop = function
89         | `Var((tag', name') as v') ->
90                 if name <> name' then false, v
91                 else if tag = tag' then true, v
92                 else false, v'
93         | `App(left, right) ->
94                 let left_bool, ((left_tag, _) as left_v) = loop left in
95                 let right_bool, ((right_tag, _) as right_v) = loop right in
96                 left_bool || right_bool, if left_tag > right_tag then left_v else right_v
97         | `Lam(v', _) when equal_var v v' -> (false, v)
98         | `Lam(_, body) -> loop body
99         in match loop term with
100         | false, _ -> None
101         | true, v -> Some v
102
103     let rec subst v new_term term = match new_term with
104         | `Var v' when equal_var v v' -> term
105         | _ -> (match term with
106             | `Var v' when equal_var v v' -> new_term
107             | `Var _ -> term
108             | `App(left, right) -> `App(subst v new_term left, subst v new_term right)
109             | `Lam(v', _) when equal_var v v' -> term
110             (* if x is free in the inserted term new_term, a capture is possible *)
111             | `Lam(v', body) ->
112                     (match free_in v' new_term with
113                     (* v' not free in new_term, can substitute new_term for v without any captures *)
114                     | None -> `Lam(v', subst v new_term body)
115                     (* v' free in new_term, need to alpha-convert *)
116                     | Some max_x ->  
117                         let bump_tag (tag, name) (tag', _) =
118                             (max tag tag') + 1, name in
119                         let bump_tag' ((_, name) as v1) ((_, name') as v2) =
120                             if (String.compare name name' == 0) then bump_tag v1 v2 else v1 in
121                         (* bump v' > max_x from new_term, then check whether
122                          * it also needs to be bumped > v
123                          *)
124                         let uniq_x = bump_tag' (bump_tag v' max_x) v in
125                         let uniq_x' = (match free_in uniq_x body with
126                             | None -> uniq_x
127                             (* bump uniq_x > max_x' from body *)
128                             | Some max_x' -> bump_tag uniq_x max_x'
129                         ) in
130                         (* alpha-convert body *)
131                         let body' = subst v' (`Var uniq_x') body in
132                         (* now substitute new_term for v *)
133                         `Lam(uniq_x', subst v new_term body')
134                     )
135         )
136
137     let check_eta = function
138         | `Lam(v, `App(body, `Var u)) when equal_var v u && free_in v body = None -> body
139         | (_ : lambda_t) as term -> term
140
141
142
143
144     exception Lambda_looping;;
145
146     let eval ?(eta=false) (expr : lambda_t) : lambda_t =
147         let rec looping (body : debruijn_t) = function
148           | [] -> false
149         | x::xs when dbruijn_equal body x -> true
150         | _::xs -> looping body xs
151         in let rec loop (stack : lambda_t list) (body : lambda_t) = 
152             match body with
153             | `Var v as term -> unwind term stack
154             | `App(left, right) -> loop (right::stack) left
155             | `Lam(v, body) -> (match stack with
156                 | [] ->
157                     let term = (`Lam(v, loop [] body)) in
158                         if eta then check_eta term else term
159                 | x::xs -> loop xs (subst v x body)
160             )
161         and unwind left = function
162         | [] -> left
163         | x::xs -> unwind (`App(left, loop [] x)) xs
164         in loop [] expr
165
166
167     let cbv ?(aggressive=true) (expr : lambda_t) : lambda_t =
168         let rec loop = function
169         | `Var v as term -> term
170         | `App(left, right) ->
171                 let right' = loop right in
172                 (match loop left with
173                 | `Lam(v, body) -> loop (subst v right' body)
174                 | _ as left' -> `App(left', right')
175                 )
176         | `Lam(v, body) as term ->
177                 if aggressive then `Lam(v, loop body)
178                 else term
179         in loop expr
180
181
182
183
184
185     (*
186     
187      (* (Oleg's version of) Ken's evaluator; doesn't seem to work -- requires laziness? *)
188     let eval' ?(eta=false) (expr : lambda_t) : lambda_t =
189         let rec loop = function
190         | `Var v as term -> term
191         | `Lam(v, body) ->
192                 let term = (`Lam(v, loop body)) in
193                     if eta then check_eta term else term
194         | `App(`App _ as left, right) ->
195             (match loop left with
196                 | `Lam _ as redux -> loop (`App(redux, right))
197                 | nonred_head -> `App(nonred_head, loop right)
198             )
199         | `App(left, right) -> `App(left, loop right)
200         in loop expr
201
202
203         module Sorted = struct
204             let rec cons y = function
205                 | x :: _ as xs when x = y -> xs
206                 | x :: xs when x < y -> x :: cons y xs
207                 | xs [* [] or x > y *] -> y :: xs
208
209             let rec mem y = function
210                 | x :: _ when x = y -> true
211                 | x :: xs when x < y -> mem y xs
212                 | _ [* [] or x > y *] -> false
213
214             let rec remove y = function
215                 | x :: xs when x = y -> xs
216                 | x :: xs when x < y -> x :: remove y xs
217                 | xs [* [] or x > y *] -> xs
218
219             let rec merge x' y' = match x', y' with
220                 | [], ys -> ys
221                 | xs, [] -> xs
222                 | x::xs, y::ys ->
223                     if x < y then x :: merge xs y'
224                     else if x = y then x :: merge xs ys
225                     else [* x > y *] y :: merge x' ys
226         end
227
228         let free_vars (expr : lambda_t) : string list =
229             let rec loop = function
230                 | `Var x -> [x]
231                 | `Lam(x, t) -> Sorted.remove x (loop t)
232                 | `App(t1, t2) -> Sorted.merge (loop t1) (loop t2)
233             in loop expr
234
235         let free_in v (expr : lambda_t) =
236             Sorted.mem v (free_vars t)
237
238         let new_var =
239             let counter = ref 0 in
240             fun () -> (let z = !counter in incr counter; "_v"^(string_of_int z))
241
242         ...
243         | `Lam(x, body) as term when not (free_in v body) -> term
244         | `Lam(y, body) when not (free_in y new_term) -> `Lam(y, subst v new_term body)
245         | `Lam(y, body) ->
246             let z = new_var () in
247             subst v new_term (`Lam(z, subst y (`Var z) body))
248     *)
249
250
251
252     (*
253
254     let bound_vars (expr : lambda_t) : string list =
255         let rec loop = function
256             | `Var x -> []
257             | `Lam(x, t) -> Sorted.cons x (loop t)
258             | `App(t1, t2) -> Sorted.merge (loop t1) (loop t2)
259         in loop expr
260
261     let reduce_cbv ?(aggressive=true) (expr : lambda_t) : lambda_t =
262         let rec loop = function
263         | `Var x as term -> term
264         | `App(t1, t2) ->
265                 let t2' = loop t2 in
266                 (match loop t1 with
267                 | `Lam(x, t) -> loop (subst x t2' t)
268                 | _ as term -> `App(term, t2')
269                 )
270         | `Lam(x, t) as term ->
271                 if aggressive then `Lam(x, loop t)
272                 else term
273         in loop expr
274
275     let reduce_cbn (expr : lambda_t) : lambda_t =
276         let rec loop = function
277         | `Var x as term -> term
278         | `Lam(v, body) ->
279                 check_eta (`Lam(v, loop body))
280         | `App(t1, t2) ->
281                 (match loop t1 with
282                 | `Lam(x, t) -> loop (subst x t2 t)
283                 | _ as term -> `App(term, loop t2)
284                 )
285         in loop expr
286
287     *)
288
289
290     (*
291
292     type env_t = (string * lambda_t) list
293
294     let subst body x value =
295         ((fun env ->
296             let new_env = (x, value) :: env in
297             body new_env) : env_t -> lambda_t)
298
299     type strategy_t = By_value | By_name
300
301     let eval (strategy : strategy_t) (expr : lambda_t) : lambda_t =
302         in let rec inner = function
303             | `Var x as t ->
304                 (fun env ->
305                     try List.assoc x env with
306                     | Not_found -> t)
307             | `App(t1, value) -> 
308                 (fun env ->
309                     let value' =
310                         if strategy = By_value then inner value env else value in
311                     (match inner t1 env with
312                     | `Lam(x, body) ->
313                         let body' = (subst (inner body) x value' env) in
314                         if strategy = By_value then body' else inner body' env
315                     | (t1' : lambda_t) -> `App(t1', inner value env)
316                     )
317                 )
318             | `Lam(x, body) ->
319                 (fun env ->
320                     let v = new_var () in
321                     `Lam(v, inner body ((x, `Var v) :: env)))
322         in inner expr ([] : env_t)
323
324     let pp_env env =
325         let rec loop acc = function
326             | [] -> acc
327             | (x, term)::es -> loop ((x ^ "=" ^ string_of_lambda term) :: acc) es
328         in "[" ^ (String.concat ", " (loop [] (List.rev env))) ^ "]"
329
330     let eval (strategy : strategy_t) (expr : lambda_t) : lambda_t =
331         let new_var =
332             let counter = ref 0 in
333             fun () -> (let z = !counter in incr counter; "_v"^(string_of_int z))
334         in let rec inner term =
335             begin
336             Printf.printf "starting [ %s ]\n" (string_of_lambda term);
337             let res = match term with
338             | `Var x as t ->
339                 (fun env ->
340                     try List.assoc x env with
341                     | Not_found -> t)
342             | `App(t1, value) -> 
343                 (fun env ->
344                     let value' =
345                         if strategy = By_value then inner value env else value in
346                     (match inner t1 env with
347                     | `Lam(x, body) ->
348                         let body' = (subst (inner body) x value' env) in
349                         if strategy = By_value then body' else inner body' env
350                     | (t1' : lambda_t) -> `App(t1', inner value env)
351                     )
352                 )
353             | `Lam(x, body) ->
354                 (fun env ->
355                     let v = new_var () in
356                     `Lam(v, inner body ((x, `Var v) :: env)))
357             in
358             (fun env -> 
359                 (Printf.printf "%s with %s => %s\n" (string_of_lambda term) (pp_env env) (string_of_lambda (res env)); res env))
360             end
361         in inner expr ([] : env_t)
362
363     *)
364
365     let normal ?(eta=false) expr = eval ~eta expr
366
367     let normal_string_of_lambda ?(eta=false) (expr : lambda_t) =
368         string_of_lambda (normal ~eta expr)
369
370     let rec to_int expr = match expr with
371         | `Lam(s, `Lam(z, `Var z')) when z' = z -> 0
372         | `Lam(s, `Var s') when equal_var s s' -> 1
373         | `Lam(s, `Lam(z, `App (`Var s', t))) when s' = s -> 1 + to_int (`Lam(s, `Lam(z, t)))
374         | _ -> failwith (normal_string_of_lambda expr ^ " is not a church numeral")
375
376     let int_of_lambda ?(eta=false) (expr : lambda_t) =
377         to_int (normal ~eta expr)
378
379 end
380
381 type lambda_t = Private.lambda_t
382 open Private
383 let var = var
384 let pp, pn, pi = string_of_lambda, normal_string_of_lambda, int_of_lambda
385 let pnv, piv= (fun expr -> string_of_lambda (cbv expr)), (fun expr -> to_int (cbv expr))
386 let debruijn, dbruijn_equal, debruijn_contains = debruijn, dbruijn_equal, debruijn_contains
387
388 let alpha_eq x y = dbruijn_equal (debruijn x) (debruijn y)
389