monads lib: made all monad types doubly param'd; required plus,zero from all monads
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63   let undefined = Obj.magic ""
64 end
65
66
67
68 (*
69  * This module contains factories that extend a base set of
70  * monadic definitions with a larger family of standard derived values.
71  *)
72
73 module Monad = struct
74   (*
75    * Signature extenders:
76    *   Make :: BASE -> S
77    *   MakeT :: TRANS (with Wrapped : S) -> custom sig
78    *)
79
80
81   (* type of base definitions *)
82   module type BASE = sig
83     (* The only constraints we impose here on how the monadic type
84      * is implemented is that it have a single type parameter 'a. *)
85     type ('x,'a) m
86     type ('x,'a) result
87     type ('x,'a) result_exn
88     val unit : 'a -> ('x,'a) m
89     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
90     val run : ('x,'a) m -> ('x,'a) result
91     (* run_exn tries to provide a more ground-level result, but may fail *)
92     val run_exn : ('x,'a) m -> ('x,'a) result_exn
93     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
94      *     zero >>= f   ===  zero
95      *     plus zero u  ===  u
96      *     plus u zero  ===  u
97      * Additionally, they will obey one of the following laws:
98      *     (Catch)   plus (unit a) v  ===  unit a
99      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
100      * When no natural zero is available, use `let zero () = Util.undefined
101      * The Make process automatically detects for zero >>= ..., and 
102      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
103      *)
104     val zero : unit -> ('x,'a) m
105     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
106   end
107   module type S = sig
108     include BASE
109     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
110     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
111     val join : ('x,('x,'a) m) m -> ('x,'a) m
112     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
113     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
114     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
115     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
116     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
117     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
118     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
119     val sequence : ('x,'a) m list -> ('x,'a list) m
120     val sequence_ : ('x,'a) m list -> ('x,unit) m
121     val guard : bool -> ('x,unit) m
122     val sum : ('x,'a) m list -> ('x,'a) m
123   end
124
125   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
126     include B
127     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
128       if u == Util.undefined then Util.undefined
129       else bind u (fun a -> try f a with Match_failure _ -> zero ())
130     let plus u v =
131       if u == Util.undefined then v else if v == Util.undefined then u else plus u v
132     let run u =
133       if u == Util.undefined then failwith "no zero" else run u
134     let run_exn u =
135       if u == Util.undefined then failwith "no zero" else run_exn u
136     let (>>=) = bind
137     let (>>) u v = u >>= fun _ -> v
138     let lift f u = u >>= fun a -> unit (f a)
139     (* lift is called listM, fmap, and <$> in Haskell *)
140     let join uu = uu >>= fun u -> u
141     (* u >>= f === join (lift f u) *)
142     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
143     (* [f] <*> [x1,x2] = [f x1,f x2] *)
144     (* let apply u v = u >>= fun f -> lift f v *)
145     (* let apply = lift2 id *)
146     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
147     (* let lift f u === apply (unit f) u *)
148     (* let lift2 f u v = apply (lift f u) v *)
149     let (>=>) f g = fun a -> f a >>= g
150     let do_when test u = if test then u else unit ()
151     let do_unless test u = if test then unit () else u
152     let forever uthunk =
153         let rec loop () = uthunk () >>= fun _ -> loop ()
154         in loop ()
155     let sequence ms =
156       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
157         Util.fold_right op ms (unit [])
158     let sequence_ ms =
159       Util.fold_right (>>) ms (unit ())
160
161     (* Haskell defines these other operations combining lists and monads.
162      * We don't, but notice that M.mapM == ListT(M).distribute
163      * There's also a parallel TreeT(M).distribute *)
164     (*
165     let mapM f alist = sequence (Util.map f alist)
166     let mapM_ f alist = sequence_ (Util.map f alist)
167     let rec filterM f lst = match lst with
168       | [] -> unit []
169       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
170     let forM alist f = mapM f alist
171     let forM_ alist f = mapM_ f alist
172     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
173     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
174     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
175     let rec foldM f z lst = match lst with
176       | [] -> unit z
177       | x::xs -> f z x >>= fun z' -> foldM f z' xs
178     let foldM_ f z xs = foldM f z xs >> unit ()
179     let replicateM n x = sequence (Util.replicate n x)
180     let replicateM_ n x = sequence_ (Util.replicate n x)
181     *)
182     let guard test = if test then B.unit () else zero ()
183     let sum ms = Util.fold_right plus ms (zero ())
184   end
185
186   (* Signatures for MonadT *)
187   module type TRANS = sig
188     module Wrapped : S
189     type ('x,'a) m
190     type ('x,'a) result
191     type ('x,'a) result_exn
192     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
193     val run : ('x,'a) m -> ('x,'a) result
194     val run_exn : ('x,'a) m -> ('x,'a) result_exn
195     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
196     (* lift/elevate laws:
197      *     elevate (W.unit a) == unit a
198      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
199      *)
200     val zero : unit -> ('x,'a) m
201     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
202   end
203   module MakeT(T : TRANS) = struct
204     include Make(struct
205         include T
206         let unit a = elevate (Wrapped.unit a)
207     end)
208     let elevate = T.elevate
209   end
210
211 end
212
213
214
215
216
217 module Identity_monad : sig
218   (* expose only the implementation of type `'a result` *)
219   type ('x,'a) result = 'a
220   type ('x,'a) result_exn = 'a
221   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
222 end = struct
223   module Base = struct
224     type ('x,'a) m = 'a
225     type ('x,'a) result = 'a
226     type ('x,'a) result_exn = 'a
227     let unit a = a
228     let bind a f = f a
229     let run a = a
230     let run_exn a = a
231     let zero () = Util.undefined
232     let plus u v = u
233   end
234   include Monad.Make(Base)
235 end
236
237
238 module Maybe_monad : sig
239   (* expose only the implementation of type `'a result` *)
240   type ('x,'a) result = 'a option
241   type ('x,'a) result_exn = 'a
242   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
243   (* MaybeT transformer *)
244   module T : functor (Wrapped : Monad.S) -> sig
245     type ('x,'a) result = ('x,'a option) Wrapped.result
246     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
247     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
248     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
249   end
250 end = struct
251   module Base = struct
252     type ('x,'a) m = 'a option
253     type ('x,'a) result = 'a option
254     type ('x,'a) result_exn = 'a
255     let unit a = Some a
256     let bind u f = match u with Some a -> f a | None -> None
257     let run u = u
258     let run_exn u = match u with
259       | Some a -> a
260       | None -> failwith "no value"
261     let zero () = None
262     (* satisfies Catch *)
263     let plus u v = match u with None -> v | _ -> u
264   end
265   include Monad.Make(Base)
266   module T(Wrapped : Monad.S) = struct
267     module Trans = struct
268       include Monad.MakeT(struct
269         module Wrapped = Wrapped
270         type ('x,'a) m = ('x,'a option) Wrapped.m
271         type ('x,'a) result = ('x,'a option) Wrapped.result
272         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
273         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
274         let bind u f = Wrapped.bind u (fun t -> match t with
275           | Some a -> f a
276           | None -> Wrapped.unit None)
277         let run u = Wrapped.run u
278         let run_exn u =
279           let w = Wrapped.bind u (fun t -> match t with
280             | Some a -> Wrapped.unit a
281             | None -> failwith "no value")
282           in Wrapped.run_exn w
283         let zero () = Wrapped.unit None
284         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
285       end)
286     end
287     include Trans
288   end
289 end
290
291
292 module List_monad : sig
293   (* declare additional operation, while still hiding implementation of type m *)
294   type ('x,'a) result = 'a list
295   type ('x,'a) result_exn = 'a
296   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
297   val permute : ('x,'a) m -> ('x,('x,'a) m) m
298   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
299   (* ListT transformer *)
300   module T : functor (Wrapped : Monad.S) -> sig
301     type ('x,'a) result = ('x,'a list) Wrapped.result
302     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
303     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
304     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
305     (* note that second argument is an 'a list, not the more abstract 'a m *)
306     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
307     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
308 (* TODO
309     val permute : 'a m -> 'a m m
310     val select : 'a m -> ('a * 'a m) m
311 *)
312   end
313 end = struct
314   module Base = struct
315    type ('x,'a) m = 'a list
316    type ('x,'a) result = 'a list
317    type ('x,'a) result_exn = 'a
318    let unit a = [a]
319    let bind u f = Util.concat_map f u
320    let run u = u
321    let run_exn u = match u with
322      | [] -> failwith "no values"
323      | [a] -> a
324      | many -> failwith "multiple values"
325    let zero () = []
326    let plus = Util.append
327   end
328   include Monad.Make(Base)
329   (* let either u v = plus u v *)
330   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
331   let rec insert a u =
332     plus (unit (a :: u)) (match u with
333         | [] -> zero ()
334         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
335     )
336   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
337   let rec permute u = match u with
338       | [] -> unit []
339       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
340   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
341   let rec select u = match u with
342     | [] -> zero ()
343     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
344   let base_plus = plus
345   module T(Wrapped : Monad.S) = struct
346     (* Wrapped.sequence ms  ===  
347          let plus1 u v =
348            Wrapped.bind u (fun x ->
349            Wrapped.bind v (fun xs ->
350            Wrapped.unit (x :: xs)))
351          in Util.fold_right plus1 ms (Wrapped.unit []) *)
352     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
353     let distribute f alist = Wrapped.sequence (Util.map f alist)
354
355     include Monad.MakeT(struct
356       module Wrapped = Wrapped
357       type ('x,'a) m = ('x,'a list) Wrapped.m
358       type ('x,'a) result = ('x,'a list) Wrapped.result
359       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
360       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
361       let bind u f =
362         Wrapped.bind u (fun ts ->
363         Wrapped.bind (distribute f ts) (fun tts ->
364         Wrapped.unit (Util.concat tts)))
365       let run u = Wrapped.run u
366       let run_exn u =
367         let w = Wrapped.bind u (fun ts -> match ts with
368           | [] -> failwith "no values"
369           | [a] -> Wrapped.unit a
370           | many -> failwith "multiple values"
371         ) in Wrapped.run_exn w
372       let zero () = Wrapped.unit []
373       let plus u v =
374         Wrapped.bind u (fun us ->
375         Wrapped.bind v (fun vs ->
376         Wrapped.unit (base_plus us vs)))
377     end)
378 (*
379     let permute : 'a m -> 'a m m
380     let select : 'a m -> ('a * 'a m) m
381 *)
382   end
383 end
384
385
386 (* must be parameterized on (struct type err = ... end) *)
387 module Error_monad(Err : sig
388   type err
389   exception Exc of err
390   (*
391   val zero : unit -> err
392   val plus : err -> err -> err
393   *)
394 end) : sig
395   (* declare additional operations, while still hiding implementation of type m *)
396   type err = Err.err
397   type 'a error = Error of err | Success of 'a
398   type ('x,'a) result = 'a error
399   type ('x,'a) result_exn = 'a
400   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
401   val throw : err -> ('x,'a) m
402   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
403   (* ErrorT transformer *)
404   module T : functor (Wrapped : Monad.S) -> sig
405     type ('x,'a) result = ('x,'a) Wrapped.result
406     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
407     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
408     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
409     val throw : err -> ('x,'a) m
410     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
411   end
412 end = struct
413   type err = Err.err
414   type 'a error = Error of err | Success of 'a
415   module Base = struct
416     type ('x,'a) m = 'a error
417     type ('x,'a) result = 'a error
418     type ('x,'a) result_exn = 'a
419     let unit a = Success a
420     let bind u f = match u with
421       | Success a -> f a
422       | Error e -> Error e (* input and output may be of different 'a types *)
423     let run u = u
424     let run_exn u = match u with
425       | Success a -> a
426       | Error e -> raise (Err.Exc e)
427     let zero () = Util.undefined
428     let plus u v = u
429     (*
430     let zero () = Error Err.zero
431     let plus u v = match (u, v) with
432       | Success _, _ -> u
433       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
434        * otherwise, plus (Error _) v = v *)
435       | Error _, _ when v = zero -> u
436       (* combine errors *)
437       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
438       | Error _, _ -> v
439     *)
440   end
441   include Monad.Make(Base)
442   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
443   let throw e = Error e
444   let catch u handler = match u with
445     | Success _ -> u
446     | Error e -> handler e
447   module T(Wrapped : Monad.S) = struct
448     include Monad.MakeT(struct
449       module Wrapped = Wrapped
450       type ('x,'a) m = ('x,'a error) Wrapped.m
451       type ('x,'a) result = ('x,'a) Wrapped.result
452       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
453       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
454       let bind u f = Wrapped.bind u (fun t -> match t with
455         | Success a -> f a
456         | Error e -> Wrapped.unit (Error e))
457       let run u =
458         let w = Wrapped.bind u (fun t -> match t with
459           | Success a -> Wrapped.unit a
460           | Error e -> Wrapped.zero ())
461         in Wrapped.run w
462       let run_exn u =
463         let w = Wrapped.bind u (fun t -> match t with
464           | Success a -> Wrapped.unit a
465           | Error e -> raise (Err.Exc e))
466         in Wrapped.run_exn w
467       let plus u v = Wrapped.plus u v
468       let zero () = elevate (Wrapped.zero ())
469     end)
470     let throw e = Wrapped.unit (Error e)
471     let catch u handler = Wrapped.bind u (fun t -> match t with
472       | Success _ -> Wrapped.unit t
473       | Error e -> handler e)
474   end
475 end
476
477 (* pre-define common instance of Error_monad *)
478 module Failure = Error_monad(struct
479   type err = string
480   exception Exc = Failure
481   (*
482   let zero = ""
483   let plus s1 s2 = s1 ^ "\n" ^ s2
484   *)
485 end)
486
487 (*
488 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
489 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
490 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
491 - : int LE.result = Failure.Error "bye"
492 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
493 Exception: Failure "bye".
494 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
495 Exception: Failure "bye".
496
497 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
498 - : int Failure.error * S.store = (Failure.Error "bye", 1)
499 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
500 - : (int * S.store) Failure.result = Failure.Error "bye"
501 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
502 Exception: Failure "bye".
503 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
504 Exception: Failure "bye".
505  *)
506
507
508 (* must be parameterized on (struct type env = ... end) *)
509 module Reader_monad(Env : sig type env end) : sig
510   (* declare additional operations, while still hiding implementation of type m *)
511   type env = Env.env
512   type ('x,'a) result = env -> 'a
513   type ('x,'a) result_exn = env -> 'a
514   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
515   val ask : ('x,env) m
516   val asks : (env -> 'a) -> ('x,'a) m
517   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
518   (* ReaderT transformer *)
519   module T : functor (Wrapped : Monad.S) -> sig
520     type ('x,'a) result = env -> ('x,'a) Wrapped.result
521     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
522     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
523     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
524     val ask : ('x,env) m
525     val asks : (env -> 'a) -> ('x,'a) m
526     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
527   end
528 end = struct
529   type env = Env.env
530   module Base = struct
531     type ('x,'a) m = env -> 'a
532     type ('x,'a) result = env -> 'a
533     type ('x,'a) result_exn = env -> 'a
534     let unit a = fun e -> a
535     let bind u f = fun e -> let a = u e in let u' = f a in u' e
536     let run u = fun e -> u e
537     let run_exn = run
538     let zero () = Util.undefined
539     let plus u v = u
540   end
541   include Monad.Make(Base)
542   let ask = fun e -> e
543   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
544   let local modifier u = fun e -> u (modifier e)
545   module T(Wrapped : Monad.S) = struct
546     module Trans = struct
547       module Wrapped = Wrapped
548       type ('x,'a) m = env -> ('x,'a) Wrapped.m
549       type ('x,'a) result = env -> ('x,'a) Wrapped.result
550       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
551       let elevate w = fun e -> w
552       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
553       let run u = fun e -> Wrapped.run (u e)
554       let run_exn u = fun e -> Wrapped.run_exn (u e)
555       let plus u v = fun s -> Wrapped.plus (u s) (v s)
556       let zero () = elevate (Wrapped.zero ())
557     end
558     include Monad.MakeT(Trans)
559     let ask = fun e -> Wrapped.unit e
560     let local modifier u = fun e -> u (modifier e)
561     let asks selector = ask >>= (fun e ->
562       try unit (selector e)
563       with Not_found -> fun e -> Wrapped.zero ())
564   end
565 end
566
567
568 (* must be parameterized on (struct type store = ... end) *)
569 module State_monad(Store : sig type store end) : sig
570   (* declare additional operations, while still hiding implementation of type m *)
571   type store = Store.store
572   type ('x,'a) result =  store -> 'a * store
573   type ('x,'a) result_exn = store -> 'a
574   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
575   val get : ('x,store) m
576   val gets : (store -> 'a) -> ('x,'a) m
577   val put : store -> ('x,unit) m
578   val puts : (store -> store) -> ('x,unit) m
579   (* StateT transformer *)
580   module T : functor (Wrapped : Monad.S) -> sig
581     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
582     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
583     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
584     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
585     val get : ('x,store) m
586     val gets : (store -> 'a) -> ('x,'a) m
587     val put : store -> ('x,unit) m
588     val puts : (store -> store) -> ('x,unit) m
589   end
590 end = struct
591   type store = Store.store
592   module Base = struct
593     type ('x,'a) m =  store -> 'a * store
594     type ('x,'a) result =  store -> 'a * store
595     type ('x,'a) result_exn = store -> 'a
596     let unit a = fun s -> (a, s)
597     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
598     let run u = fun s -> (u s)
599     let run_exn u = fun s -> fst (u s)
600     let zero () = Util.undefined
601     let plus u v = u
602   end
603   include Monad.Make(Base)
604   let get = fun s -> (s, s)
605   let gets viewer = fun s -> (viewer s, s) (* may fail *)
606   let put s = fun _ -> ((), s)
607   let puts modifier = fun s -> ((), modifier s)
608   module T(Wrapped : Monad.S) = struct
609     module Trans = struct
610       module Wrapped = Wrapped
611       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
612       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
613       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
614       let elevate w = fun s ->
615         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
616       let bind u f = fun s ->
617         Wrapped.bind (u s) (fun (a, s') -> f a s')
618       let run u = fun s -> Wrapped.run (u s)
619       let run_exn u = fun s ->
620         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
621         in Wrapped.run_exn w
622       let plus u v = fun s -> Wrapped.plus (u s) (v s)
623       let zero () = elevate (Wrapped.zero ())
624     end
625     include Monad.MakeT(Trans)
626     let get = fun s -> Wrapped.unit (s, s)
627     let gets viewer = fun s ->
628       try Wrapped.unit (viewer s, s)
629       with Not_found -> Wrapped.zero ()
630     let put s = fun _ -> Wrapped.unit ((), s)
631     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
632   end
633 end
634
635 (* State monad with different interface (structured store) *)
636 module Ref_monad(V : sig
637   type value
638 end) : sig
639   type ref
640   type value = V.value
641   type ('x,'a) result = 'a
642   type ('x,'a) result_exn = 'a
643   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
644   val newref : value -> ('x,ref) m
645   val deref : ref -> ('x,value) m
646   val change : ref -> value -> ('x,unit) m
647   (* RefT transformer *)
648   module T : functor (Wrapped : Monad.S) -> sig
649     type ('x,'a) result = ('x,'a) Wrapped.result
650     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
651     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
652     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
653     val newref : value -> ('x,ref) m
654     val deref : ref -> ('x,value) m
655     val change : ref -> value -> ('x,unit) m
656   end
657 end = struct
658   type ref = int
659   type value = V.value
660   module D = Map.Make(struct type t = ref let compare = compare end)
661   type dict = { next: ref; tree : value D.t }
662   let empty = { next = 0; tree = D.empty }
663   let alloc (value : value) (d : dict) =
664     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
665   let read (key : ref) (d : dict) =
666     D.find key d.tree
667   let write (key : ref) (value : value) (d : dict) =
668     { next = d.next; tree = D.add key value d.tree }
669   module Base = struct
670     type ('x,'a) m = dict -> 'a * dict
671     type ('x,'a) result = 'a
672     type ('x,'a) result_exn = 'a
673     let unit a = fun s -> (a, s)
674     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
675     let run u = fst (u empty)
676     let run_exn = run
677     let zero () = Util.undefined
678     let plus u v = u
679   end
680   include Monad.Make(Base)
681   let newref value = fun s -> alloc value s
682   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
683   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
684   module T(Wrapped : Monad.S) = struct
685     module Trans = struct
686       module Wrapped = Wrapped
687       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
688       type ('x,'a) result = ('x,'a) Wrapped.result
689       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
690       let elevate w = fun s ->
691         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
692       let bind u f = fun s ->
693         Wrapped.bind (u s) (fun (a, s') -> f a s')
694       let run u =
695         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
696         in Wrapped.run w
697       let run_exn u =
698         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
699         in Wrapped.run_exn w
700       let plus u v = fun s -> Wrapped.plus (u s) (v s)
701       let zero () = elevate (Wrapped.zero ())
702     end
703     include Monad.MakeT(Trans)
704     let newref value = fun s -> Wrapped.unit (alloc value s)
705     let deref key = fun s -> Wrapped.unit (read key s, s)
706     let change key value = fun s -> Wrapped.unit ((), write key value s)
707   end
708 end
709
710
711 (* must be parameterized on (struct type log = ... end) *)
712 module Writer_monad(Log : sig
713   type log
714   val zero : log
715   val plus : log -> log -> log
716 end) : sig
717   (* declare additional operations, while still hiding implementation of type m *)
718   type log = Log.log
719   type ('x,'a) result = 'a * log
720   type ('x,'a) result_exn = 'a * log
721   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
722   val tell : log -> ('x,unit) m
723   val listen : ('x,'a) m -> ('x,'a * log) m
724   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
725   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
726   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
727 end = struct
728   type log = Log.log
729   module Base = struct
730     type ('x,'a) m = 'a * log
731     type ('x,'a) result = 'a * log
732     type ('x,'a) result_exn = 'a * log
733     let unit a = (a, Log.zero)
734     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
735     let run u = u
736     let run_exn = run
737     let zero () = Util.undefined
738     let plus u v = u
739   end
740   include Monad.Make(Base)
741   let tell entries = ((), entries) (* add entries to log *)
742   let listen (a, w) = ((a, w), w)
743   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
744   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
745   let censor f u = pass (u >>= fun a -> unit (a, f))
746 end
747
748 (* pre-define simple Writer *)
749 module Writer1 = Writer_monad(struct
750   type log = string
751   let zero = ""
752   let plus s1 s2 = s1 ^ "\n" ^ s2
753 end)
754
755 (* slightly more efficient Writer *)
756 module Writer2 = struct
757   include Writer_monad(struct
758     type log = string list
759     let zero = []
760     let plus w w' = Util.append w' w
761   end)
762   let tell_string s = tell [s]
763   let tell entries = tell (Util.reverse entries)
764   let run u = let (a, w) = run u in (a, Util.reverse w)
765   let run_exn = run
766 end
767
768
769 module IO_monad : sig
770   (* declare additional operation, while still hiding implementation of type m *)
771   type ('x,'a) result = 'a
772   type ('x,'a) result_exn = 'a
773   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
774   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
775   val print_string : string -> ('x,unit) m
776   val print_int : int -> ('x,unit) m
777   val print_hex : int -> ('x,unit) m
778   val print_bool : bool -> ('x,unit) m
779 end = struct
780   module Base = struct
781     type ('x,'a) m = { run : unit -> unit; value : 'a }
782     type ('x,'a) result = 'a
783     type ('x,'a) result_exn = 'a
784     let unit a = { run = (fun () -> ()); value = a }
785     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
786      let fres = f a.value in
787        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
788     let run a = let () = a.run () in a.value
789     let run_exn = run
790     let zero () = Util.undefined
791     let plus u v = u
792   end
793   include Monad.Make(Base)
794   let printf fmt =
795     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
796   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
797   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
798   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
799   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
800 end
801
802
803 module Continuation_monad : sig
804   (* expose only the implementation of type `('r,'a) result` *)
805   type ('r,'a) m
806   type ('r,'a) result = ('r,'a) m
807   type ('r,'a) result_exn = ('a -> 'r) -> 'r
808   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
809   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
810   val reset : ('a,'a) m -> ('r,'a) m
811   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
812   (* val abort : ('a,'a) m -> ('a,'b) m *)
813   val abort : 'a -> ('a,'b) m
814   val run0 : ('a,'a) m -> 'a
815 end = struct
816   let id = fun i -> i
817   module Base = struct
818     (* 'r is result type of whole computation *)
819     type ('r,'a) m = ('a -> 'r) -> 'r
820     type ('r,'a) result = ('a -> 'r) -> 'r
821     type ('r,'a) result_exn = ('r,'a) result
822     let unit a = (fun k -> k a)
823     let bind u f = (fun k -> (u) (fun a -> (f a) k))
824     let run u k = (u) k
825     let run_exn = run
826     let zero () = Util.undefined
827     let plus u v = u
828   end
829   include Monad.Make(Base)
830   let callcc f = (fun k ->
831     let usek a = (fun _ -> k a)
832     in (f usek) k)
833   (*
834   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
835   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
836   let callcc f = fun k -> f k k
837   let throw k a = fun _ -> k a
838   *)
839
840   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
841    *
842    *  reset :: (Monad m) => ContT a m a -> ContT r m a
843    *  reset e = ContT $ \k -> runContT e return >>= k
844    *
845    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
846    *  shift e = ContT $ \k ->
847    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
848   let reset u = unit ((u) id)
849   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
850   (* let abort a = shift (fun _ -> a) *)
851   let abort a = shift (fun _ -> unit a)
852   let run0 (u : ('a,'a) m) = (u) id
853 end
854
855
856 (*
857  * Scheme:
858  * (define (example n)
859  *    (let ([u (let/cc k ; type int -> int pair
860  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
861  *                 (+ 1 (car v))))]) ; int
862  *      (cons u 0))) ; int pair
863  * ; (example 10) ~~> '(111 . 0)
864  * ; (example -10) ~~> '(0 . 0)
865  *
866  * OCaml monads:
867  * let example n : (int * int) =
868  *   Continuation_monad.(let u = callcc (fun k ->
869  *       (if n < 0 then k 0 else unit [n + 100])
870  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
871  *       >>= fun [x] -> unit (x + 1)
872  *   )
873  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
874  *   >>= fun x -> unit (x, 0)
875  *   in run u)
876  *
877  *
878  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
879  * let example1 () : int =
880  *   Continuation_monad.(let v = reset (
881  *       let u = shift (fun k -> unit (10 + 1))
882  *       in u >>= fun x -> unit (100 + x)
883  *     ) in let w = v >>= fun x -> unit (1000 + x)
884  *     in run w)
885  *
886  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
887  * let example2 () =
888  *   Continuation_monad.(let v = reset (
889  *       let u = shift (fun k -> k (10 :: [1]))
890  *       in u >>= fun x -> unit (100 :: x)
891  *     ) in let w = v >>= fun x -> unit (1000 :: x)
892  *     in run w)
893  *
894  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
895  * let example3 () =
896  *   Continuation_monad.(let v = reset (
897  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
898  *       in u >>= fun x -> unit (100 :: x)
899  *     ) in let w = v >>= fun x -> unit (1000 :: x)
900  *     in run w)
901  *
902  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
903  * (* not sure if this example can be typed without a sum-type *)
904  *
905  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
906  * let example5 () : int =
907  *   Continuation_monad.(let v = reset (
908  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
909  *       in u >>= fun x -> unit (10 + x)
910  *     ) in let w = v >>= fun x -> unit (100 + x)
911  *     in run w)
912  *
913  *)
914
915
916 module Leaf_monad : sig
917   (* We implement the type as `'a tree option` because it has a natural`plus`,
918    * and the rest of the library expects that `plus` and `zero` will come together. *)
919   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
920   type ('x,'a) result = 'a tree option
921   type ('x,'a) result_exn = 'a tree
922   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
923   (* LeafT transformer *)
924   module T : functor (Wrapped : Monad.S) -> sig
925     type ('x,'a) result = ('x,'a tree option) Wrapped.result
926     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
927     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
928     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
929     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
930     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
931     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
932   end
933 end = struct
934   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
935   (* uses supplied plus and zero to copy t to its image under f *)
936   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
937       | None -> zero ()
938       | Some ts -> let rec loop ts = (match ts with
939                      | Leaf a -> f a
940                      | Node (l, r) ->
941                          (* recursive application of f may delete a branch *)
942                          plus (loop l) (loop r)
943                    ) in loop ts
944   module Base = struct
945     type ('x,'a) m = 'a tree option
946     type ('x,'a) result = 'a tree option
947     type ('x,'a) result_exn = 'a tree
948     let unit a = Some (Leaf a)
949     let zero () = None
950     let plus u v = match (u, v) with
951       | None, _ -> v
952       | _, None -> u
953       | Some us, Some vs -> Some (Node (us, vs))
954     let bind u f = mapT f u zero plus
955     let run u = u
956     let run_exn u = match u with
957       | None -> failwith "no values"
958       (*
959       | Some (Leaf a) -> a
960       | many -> failwith "multiple values"
961       *)
962       | Some us -> us
963   end
964   include Monad.Make(Base)
965   let base_plus = plus
966   let base_lift = lift
967   module T(Wrapped : Monad.S) = struct
968     module Trans = struct
969       include Monad.MakeT(struct
970         module Wrapped = Wrapped
971         type ('x,'a) m = ('x,'a tree option) Wrapped.m
972         type ('x,'a) result = ('x,'a tree option) Wrapped.result
973         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
974         let zero () = Wrapped.unit None
975         let plus u v =
976           Wrapped.bind u (fun us ->
977           Wrapped.bind v (fun vs ->
978           Wrapped.unit (base_plus us vs)))
979         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
980         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
981         let run u = Wrapped.run u
982         let run_exn u =
983             let w = Wrapped.bind u (fun t -> match t with
984               | None -> failwith "no values"
985               | Some ts -> Wrapped.unit ts)
986             in Wrapped.run_exn w
987       end)
988     end
989     include Trans
990     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
991     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
992   end
993 end
994
995
996 module L = List_monad;;
997 module R = Reader_monad(struct type env = int -> int end);;
998 module S = State_monad(struct type store = int end);;
999 module T = Leaf_monad;;
1000 module LR = L.T(R);;
1001 module LS = L.T(S);;
1002 module TL = T.T(L);;
1003 module TR = T.T(R);;
1004 module TS = T.T(S);;
1005 module C = Continuation_monad
1006 module TC = T.T(C);;
1007
1008
1009 print_endline "=== test Leaf(...).distribute ==================";;
1010
1011 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1012
1013 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1014 TS.run ts 0;;
1015 (*
1016 - : int T.tree option * S.store =
1017 (Some
1018   (T.Node
1019     (T.Node (T.Leaf 2, T.Leaf 3),
1020      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1021  5)
1022 *)
1023
1024 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1025 TS.run_exn ts2 0;;
1026 (*
1027 - : (int * S.store) T.tree option * S.store =
1028 (Some
1029   (T.Node
1030     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1031      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1032  5)
1033 *)
1034
1035 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1036 TR.run_exn tr (fun i -> i+i);;
1037 (*
1038 - : int T.tree option =
1039 Some
1040  (T.Node
1041    (T.Node (T.Leaf 4, T.Leaf 6),
1042     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1043 *)
1044
1045 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1046 TL.run_exn tl;;
1047 (*
1048 - : (int * int) TL.result =
1049 [Some
1050   (T.Node
1051     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1052      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1053 *)
1054
1055 let l2 = [1;2;3;4;5];;
1056 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1057
1058 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1059 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1060
1061 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1062 (*
1063 int T.tree option =
1064 Some
1065  (T.Node
1066    (T.Node (T.Leaf 10, T.Leaf 11),
1067     T.Node
1068      (T.Node
1069        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1070         T.Node (T.Leaf 40, T.Leaf 41)),
1071       T.Node (T.Leaf 50, T.Leaf 51))))
1072  *)
1073
1074 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1075 (*
1076 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1077 *)
1078
1079 print_endline "=== test Leaf(Continuation).distribute ==================";;
1080
1081 let id : 'z. 'z -> 'z = fun x -> x
1082
1083 let example n : (int * int) =
1084   Continuation_monad.(let u = callcc (fun k ->
1085       (if n < 0 then k 0 else unit [n + 100])
1086       (* all of the following is skipped by k 0; the end type int is k's input type *)
1087       >>= fun [x] -> unit (x + 1)
1088   )
1089   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1090   >>= fun x -> unit (x, 0)
1091   in run0 u)
1092
1093
1094 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1095 let example1 () : int =
1096   Continuation_monad.(let v = reset (
1097       let u = shift (fun k -> unit (10 + 1))
1098       in u >>= fun x -> unit (100 + x)
1099     ) in let w = v >>= fun x -> unit (1000 + x)
1100     in run0 w)
1101
1102 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1103 let example2 () =
1104   Continuation_monad.(let v = reset (
1105       let u = shift (fun k -> k (10 :: [1]))
1106       in u >>= fun x -> unit (100 :: x)
1107     ) in let w = v >>= fun x -> unit (1000 :: x)
1108     in run0 w)
1109
1110 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1111 let example3 () =
1112   Continuation_monad.(let v = reset (
1113       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1114       in u >>= fun x -> unit (100 :: x)
1115     ) in let w = v >>= fun x -> unit (1000 :: x)
1116     in run0 w)
1117
1118 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1119 (* not sure if this example can be typed without a sum-type *)
1120
1121 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1122 let example5 () : int =
1123   Continuation_monad.(let v = reset (
1124       let u = shift (fun k -> k 1 >>= k)
1125       in u >>= fun x -> unit (10 + x)
1126     ) in let w = v >>= fun x -> unit (100 + x)
1127     in run0 w)
1128
1129 ;;
1130
1131 print_endline "=== test bare Continuation ============";;
1132
1133 (1011, 1111, 1111, 121);;
1134 (example1(), example2(), example3(), example5());;
1135 ((111,0), (0,0));;
1136 (example ~+10, example ~-10);;
1137
1138 let testc df ic =
1139     C.run_exn TC.(run (distribute df t1)) ic;;
1140
1141
1142 (*
1143 (* do nothing *)
1144 let initial_continuation = fun t -> t in
1145 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1146 *)
1147 testc (C.unit) id;;
1148
1149 (*
1150 (* count leaves, using continuation *)
1151 let initial_continuation = fun t -> 0 in
1152 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1153 *)
1154
1155 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1156
1157 (*
1158 (* convert tree to list of leaves *)
1159 let initial_continuation = fun t -> [] in
1160 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1161 *)
1162
1163 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1164
1165 (*
1166 (* square each leaf using continuation *)
1167 let initial_continuation = fun t -> t in
1168 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1169 *)
1170
1171 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1172
1173
1174 (*
1175 (* replace leaves with list, using continuation *)
1176 let initial_continuation = fun t -> t in
1177 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1178 *)
1179
1180 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1181
1182 print_endline "=== pa_monad's Continuation Tests ============";;
1183
1184 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1185 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1186 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1187 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1188 (5, 27 = C.(run0 (
1189               let c = reset(abort 5 >>= fun y -> unit (y+6))
1190               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1191
1192 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1193
1194 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1195
1196 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1197
1198
1199 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1200