Leaf_monad -> Tree_monad
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glasgow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources this is partly
55  * derived from)
56  *)
57
58 exception Undefined
59
60 (* Some library functions used below. *)
61 module Util = struct
62   let fold_right = List.fold_right
63   let map = List.map
64   let append = List.append
65   let reverse = List.rev
66   let concat = List.concat
67   let concat_map f lst = List.concat (List.map f lst)
68   (* let zip = List.combine *)
69   let unzip = List.split
70   let zip_with = List.map2
71   let replicate len fill =
72     let rec loop n accu =
73       if n == 0 then accu else loop (pred n) (fill :: accu)
74     in loop len []
75   (* Dirty hack to be a default polymorphic zero.
76    * To implement this cleanly, monads without a natural zero
77    * should always wrap themselves in an option layer (see Tree_monad). *)
78   let undef = Obj.magic (fun () -> raise Undefined)
79 end
80
81
82
83 (*
84  * This module contains factories that extend a base set of
85  * monadic definitions with a larger family of standard derived values.
86  *)
87
88 module Monad = struct
89   (*
90    * Signature extenders:
91    *   Make :: BASE -> S
92    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
93    *)
94
95
96   (* type of base definitions *)
97   module type BASE = sig
98     (* We make all monadic types doubly-parameterized so that they
99      * can layer nicely with Continuation, which needs the second
100      * type parameter. *)
101     type ('x,'a) m
102     type ('x,'a) result
103     type ('x,'a) result_exn
104     val unit : 'a -> ('x,'a) m
105     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
106     val run : ('x,'a) m -> ('x,'a) result
107     (* run_exn tries to provide a more ground-level result, but may fail *)
108     val run_exn : ('x,'a) m -> ('x,'a) result_exn
109     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
110      *     zero >>= f   ===  zero
111      *     plus zero u  ===  u
112      *     plus u zero  ===  u
113      * Additionally, they will obey one of the following laws:
114      *     (Catch)   plus (unit a) v  ===  unit a
115      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
116      * When no natural zero is available, use `let zero () = Util.undef`.
117      * The Make functor automatically detects for zero >>= ..., and 
118      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
119      *)
120     val zero : unit -> ('x,'a) m
121     (* zero has to be thunked to ensure results are always poly enough *)
122     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
123   end
124   module type S = sig
125     include BASE
126     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
127     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
128     val join : ('x,('x,'a) m) m -> ('x,'a) m
129     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
130     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
131     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
132     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
133     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
134     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
135     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
136     val sequence : ('x,'a) m list -> ('x,'a list) m
137     val sequence_ : ('x,'a) m list -> ('x,unit) m
138     val guard : bool -> ('x,unit) m
139     val sum : ('x,'a) m list -> ('x,'a) m
140   end
141
142   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
143     include B
144     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
145       if u == Util.undef then Util.undef
146       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
147     let plus u v =
148       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
149     let run u =
150       if u == Util.undef then raise Undefined else B.run u
151     let run_exn u =
152       if u == Util.undef then raise Undefined else B.run_exn u
153     let (>>=) = bind
154     (* expressions after >> will be evaluated before they're passed to
155      * bind, so you can't do `zero () >> assert false`
156      * this works though: `zero () >>= fun _ -> assert false`
157      *)
158     let (>>) u v = u >>= fun _ -> v
159     let lift f u = u >>= fun a -> unit (f a)
160     (* lift is called listM, fmap, and <$> in Haskell *)
161     let join uu = uu >>= fun u -> u
162     (* u >>= f === join (lift f u) *)
163     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
164     (* [f] <*> [x1,x2] = [f x1,f x2] *)
165     (* let apply u v = u >>= fun f -> lift f v *)
166     (* let apply = lift2 id *)
167     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
168     (* let lift f u === apply (unit f) u *)
169     (* let lift2 f u v = apply (lift f u) v *)
170     let (>=>) f g = fun a -> f a >>= g
171     let do_when test u = if test then u else unit ()
172     let do_unless test u = if test then unit () else u
173     (* A Haskell-like version works:
174          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
175      * but the recursive call is not in tail position so this can stack overflow. *)
176     let forever uthunk =
177         let z = zero () in
178         let id result = result in
179         let kcell = ref id in
180         let rec loop _ =
181             let result = uthunk (kcell := id) >>= chained
182             in !kcell result
183         and chained _ =
184             kcell := loop; z (* we use z only for its polymorphism *)
185         in loop z
186     (* Reimplementations of the preceding using a hand-rolled State or StateT
187 can also stack overflow. *)
188     let sequence ms =
189       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
190         Util.fold_right op ms (unit [])
191     let sequence_ ms =
192       Util.fold_right (>>) ms (unit ())
193
194     (* Haskell defines these other operations combining lists and monads.
195      * We don't, but notice that M.mapM == ListT(M).distribute
196      * There's also a parallel TreeT(M).distribute *)
197     (*
198     let mapM f alist = sequence (Util.map f alist)
199     let mapM_ f alist = sequence_ (Util.map f alist)
200     let rec filterM f lst = match lst with
201       | [] -> unit []
202       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
203     let forM alist f = mapM f alist
204     let forM_ alist f = mapM_ f alist
205     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
206     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
207     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
208     let rec foldM f z lst = match lst with
209       | [] -> unit z
210       | x::xs -> f z x >>= fun z' -> foldM f z' xs
211     let foldM_ f z xs = foldM f z xs >> unit ()
212     let replicateM n x = sequence (Util.replicate n x)
213     let replicateM_ n x = sequence_ (Util.replicate n x)
214     *)
215     let guard test = if test then B.unit () else zero ()
216     let sum ms = Util.fold_right plus ms (zero ())
217   end
218
219   (* Signatures for MonadT *)
220   module type BASET = sig
221     module Wrapped : S
222     type ('x,'a) m
223     type ('x,'a) result
224     type ('x,'a) result_exn
225     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
226     val run : ('x,'a) m -> ('x,'a) result
227     val run_exn : ('x,'a) m -> ('x,'a) result_exn
228     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
229     (* lift/elevate laws:
230      *     elevate (W.unit a) == unit a
231      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
232      *)
233     val zero : unit -> ('x,'a) m
234     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
235   end
236   module MakeT(T : BASET) = struct
237     include Make(struct
238         include T
239         let unit a = elevate (Wrapped.unit a)
240     end)
241     let elevate = T.elevate
242   end
243
244 end
245
246
247
248
249
250 module Identity_monad : sig
251   (* expose only the implementation of type `'a result` *)
252   type ('x,'a) result = 'a
253   type ('x,'a) result_exn = 'a
254   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
255 end = struct
256   module Base = struct
257     type ('x,'a) m = 'a
258     type ('x,'a) result = 'a
259     type ('x,'a) result_exn = 'a
260     let unit a = a
261     let bind a f = f a
262     let run a = a
263     let run_exn a = a
264     let zero () = Util.undef
265     let plus u v = u
266   end
267   include Monad.Make(Base)
268 end
269
270
271 module Maybe_monad : sig
272   (* expose only the implementation of type `'a result` *)
273   type ('x,'a) result = 'a option
274   type ('x,'a) result_exn = 'a
275   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
276   (* MaybeT transformer *)
277   module T : functor (Wrapped : Monad.S) -> sig
278     type ('x,'a) result = ('x,'a option) Wrapped.result
279     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
280     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
281     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
282   end
283 end = struct
284   module Base = struct
285     type ('x,'a) m = 'a option
286     type ('x,'a) result = 'a option
287     type ('x,'a) result_exn = 'a
288     let unit a = Some a
289     let bind u f = match u with Some a -> f a | None -> None
290     let run u = u
291     let run_exn u = match u with
292       | Some a -> a
293       | None -> failwith "no value"
294     let zero () = None
295     (* satisfies Catch *)
296     let plus u v = match u with None -> v | _ -> u
297   end
298   include Monad.Make(Base)
299   module T(Wrapped : Monad.S) = struct
300     module BaseT = struct
301       include Monad.MakeT(struct
302         module Wrapped = Wrapped
303         type ('x,'a) m = ('x,'a option) Wrapped.m
304         type ('x,'a) result = ('x,'a option) Wrapped.result
305         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
306         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
307         let bind u f = Wrapped.bind u (fun t -> match t with
308           | Some a -> f a
309           | None -> Wrapped.unit None)
310         let run u = Wrapped.run u
311         let run_exn u =
312           let w = Wrapped.bind u (fun t -> match t with
313             | Some a -> Wrapped.unit a
314             | None -> Wrapped.zero ()
315           ) in Wrapped.run_exn w
316         let zero () = Wrapped.unit None
317         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
318       end)
319     end
320     include BaseT
321   end
322 end
323
324
325 module List_monad : sig
326   (* declare additional operation, while still hiding implementation of type m *)
327   type ('x,'a) result = 'a list
328   type ('x,'a) result_exn = 'a
329   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
330   val permute : ('x,'a) m -> ('x,('x,'a) m) m
331   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
332   (* ListT transformer *)
333   module T : functor (Wrapped : Monad.S) -> sig
334     type ('x,'a) result = ('x,'a list) Wrapped.result
335     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
336     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
337     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
338     (* note that second argument is an 'a list, not the more abstract 'a m *)
339     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
340     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
341 (* TODO
342     val permute : 'a m -> 'a m m
343     val select : 'a m -> ('a * 'a m) m
344 *)
345   end
346 end = struct
347   module Base = struct
348    type ('x,'a) m = 'a list
349    type ('x,'a) result = 'a list
350    type ('x,'a) result_exn = 'a
351    let unit a = [a]
352    let bind u f = Util.concat_map f u
353    let run u = u
354    let run_exn u = match u with
355      | [] -> failwith "no values"
356      | [a] -> a
357      | many -> failwith "multiple values"
358    let zero () = []
359    (* satisfies Distrib *)
360    let plus = Util.append
361   end
362   include Monad.Make(Base)
363   (* let either u v = plus u v *)
364   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
365   let rec insert a u =
366     plus (unit (a :: u)) (match u with
367         | [] -> zero ()
368         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
369     )
370   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
371   let rec permute u = match u with
372       | [] -> unit []
373       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
374   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
375   let rec select u = match u with
376     | [] -> zero ()
377     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
378   module T(Wrapped : Monad.S) = struct
379     (* Wrapped.sequence ms  ===  
380          let plus1 u v =
381            Wrapped.bind u (fun x ->
382            Wrapped.bind v (fun xs ->
383            Wrapped.unit (x :: xs)))
384          in Util.fold_right plus1 ms (Wrapped.unit []) *)
385     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
386     let distribute f alist = Wrapped.sequence (Util.map f alist)
387
388     include Monad.MakeT(struct
389       module Wrapped = Wrapped
390       type ('x,'a) m = ('x,'a list) Wrapped.m
391       type ('x,'a) result = ('x,'a list) Wrapped.result
392       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
393       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
394       let bind u f =
395         Wrapped.bind u (fun ts ->
396         Wrapped.bind (distribute f ts) (fun tts ->
397         Wrapped.unit (Util.concat tts)))
398       let run u = Wrapped.run u
399       let run_exn u =
400         let w = Wrapped.bind u (fun ts -> match ts with
401           | [] -> Wrapped.zero ()
402           | [a] -> Wrapped.unit a
403           | many -> Wrapped.zero ()
404         ) in Wrapped.run_exn w
405       let zero () = Wrapped.unit []
406       let plus u v =
407         Wrapped.bind u (fun us ->
408         Wrapped.bind v (fun vs ->
409         Wrapped.unit (Base.plus us vs)))
410     end)
411 (*
412     let permute : 'a m -> 'a m m
413     let select : 'a m -> ('a * 'a m) m
414 *)
415   end
416 end
417
418 (*
419 # LL.(run(plus (unit 1) (unit 2) >>= fun i -> plus (unit i) (unit(10*i)) ));;
420 - : ('_a, int) LL.result = [[1; 10; 2; 20]]
421 # LL.(run(plus (unit 1) (unit 2) >>= fun i -> elevate L.(plus (unit i) (unit(10*i)) )));;
422 - : ('_a, int) LL.result = [[1; 2]; [1; 20]; [10; 2]; [10; 20]]
423 *)
424
425
426 (* must be parameterized on (struct type err = ... end) *)
427 module Error_monad(Err : sig
428   type err
429   exception Exc of err
430   (*
431   val zero : unit -> err
432   val plus : err -> err -> err
433   *)
434 end) : sig
435   (* declare additional operations, while still hiding implementation of type m *)
436   type err = Err.err
437   type 'a error = Error of err | Success of 'a
438   type ('x,'a) result = 'a error
439   type ('x,'a) result_exn = 'a
440   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
441   val throw : err -> ('x,'a) m
442   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
443   (* ErrorT transformer *)
444   module T : functor (Wrapped : Monad.S) -> sig
445     type ('x,'a) result = ('x,'a) Wrapped.result
446     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
447     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
448     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
449     val throw : err -> ('x,'a) m
450     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
451   end
452 end = struct
453   type err = Err.err
454   type 'a error = Error of err | Success of 'a
455   module Base = struct
456     type ('x,'a) m = 'a error
457     type ('x,'a) result = 'a error
458     type ('x,'a) result_exn = 'a
459     let unit a = Success a
460     let bind u f = match u with
461       | Success a -> f a
462       | Error e -> Error e (* input and output may be of different 'a types *)
463     let run u = u
464     let run_exn u = match u with
465       | Success a -> a
466       | Error e -> raise (Err.Exc e)
467     let zero () = Util.undef
468     (* satisfies Catch *)
469     let plus u v = match u with
470       | Success _ -> u
471       | Error _ -> if v == Util.undef then u else v
472   end
473   include Monad.Make(Base)
474   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
475   let throw e = Error e
476   let catch u handler = match u with
477     | Success _ -> u
478     | Error e -> handler e
479   module T(Wrapped : Monad.S) = struct
480     include Monad.MakeT(struct
481       module Wrapped = Wrapped
482       type ('x,'a) m = ('x,'a error) Wrapped.m
483       type ('x,'a) result = ('x,'a) Wrapped.result
484       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
485       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
486       let bind u f = Wrapped.bind u (fun t -> match t with
487         | Success a -> f a
488         | Error e -> Wrapped.unit (Error e))
489       let run u =
490         let w = Wrapped.bind u (fun t -> match t with
491           | Success a -> Wrapped.unit a
492           | Error e -> Wrapped.zero ()
493         ) in Wrapped.run w
494       let run_exn u =
495         let w = Wrapped.bind u (fun t -> match t with
496           | Success a -> Wrapped.unit a
497           | Error e -> raise (Err.Exc e))
498         in Wrapped.run_exn w
499       let plus u v = Wrapped.plus u v
500       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
501     end)
502     let throw e = Wrapped.unit (Error e)
503     let catch u handler = Wrapped.bind u (fun t -> match t with
504       | Success _ -> Wrapped.unit t
505       | Error e -> handler e)
506   end
507 end
508
509 (* pre-define common instance of Error_monad *)
510 module Failure = Error_monad(struct
511   type err = string
512   exception Exc = Failure
513   (*
514   let zero = ""
515   let plus s1 s2 = s1 ^ "\n" ^ s2
516   *)
517 end)
518
519 (*
520 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
521 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
522 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
523 - : int LE.result = Failure.Error "bye"
524 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
525 Exception: Failure "bye".
526 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
527 Exception: Failure "bye".
528
529 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
530 - : int Failure.error * S.store = (Failure.Error "bye", 1)
531 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
532 - : (int * S.store) Failure.result = Failure.Error "bye"
533 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
534 Exception: Failure "bye".
535 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
536 Exception: Failure "bye".
537  *)
538
539
540 (* must be parameterized on (struct type env = ... end) *)
541 module Reader_monad(Env : sig type env end) : sig
542   (* declare additional operations, while still hiding implementation of type m *)
543   type env = Env.env
544   type ('x,'a) result = env -> 'a
545   type ('x,'a) result_exn = env -> 'a
546   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
547   val ask : ('x,env) m
548   val asks : (env -> 'a) -> ('x,'a) m
549   (* lookup i == `fun e -> e i` would assume env is a functional type *)
550   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
551   (* ReaderT transformer *)
552   module T : functor (Wrapped : Monad.S) -> sig
553     type ('x,'a) result = env -> ('x,'a) Wrapped.result
554     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
555     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
556     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
557     val ask : ('x,env) m
558     val asks : (env -> 'a) -> ('x,'a) m
559     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
560   end
561 end = struct
562   type env = Env.env
563   module Base = struct
564     type ('x,'a) m = env -> 'a
565     type ('x,'a) result = env -> 'a
566     type ('x,'a) result_exn = env -> 'a
567     let unit a = fun e -> a
568     let bind u f = fun e -> let a = u e in let u' = f a in u' e
569     let run u = fun e -> u e
570     let run_exn = run
571     let zero () = Util.undef
572     let plus u v = u
573   end
574   include Monad.Make(Base)
575   let ask = fun e -> e
576   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
577   let local modifier u = fun e -> u (modifier e)
578   module T(Wrapped : Monad.S) = struct
579     module BaseT = struct
580       module Wrapped = Wrapped
581       type ('x,'a) m = env -> ('x,'a) Wrapped.m
582       type ('x,'a) result = env -> ('x,'a) Wrapped.result
583       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
584       let elevate w = fun e -> w
585       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
586       let run u = fun e -> Wrapped.run (u e)
587       let run_exn u = fun e -> Wrapped.run_exn (u e)
588       (* satisfies Distrib *)
589       let plus u v = fun e -> Wrapped.plus (u e) (v e)
590       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
591     end
592     include Monad.MakeT(BaseT)
593     let ask = Wrapped.unit
594     let local modifier u = fun e -> u (modifier e)
595     let asks selector = ask >>= (fun e ->
596       try unit (selector e)
597       with Not_found -> fun e -> Wrapped.zero ())
598   end
599 end
600
601
602 (* must be parameterized on (struct type store = ... end) *)
603 module State_monad(Store : sig type store end) : sig
604   (* declare additional operations, while still hiding implementation of type m *)
605   type store = Store.store
606   type ('x,'a) result =  store -> 'a * store
607   type ('x,'a) result_exn = store -> 'a
608   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
609   val get : ('x,store) m
610   val gets : (store -> 'a) -> ('x,'a) m
611   val put : store -> ('x,unit) m
612   val puts : (store -> store) -> ('x,unit) m
613   (* StateT transformer *)
614   module T : functor (Wrapped : Monad.S) -> sig
615     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
616     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
617     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
618     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
619     val get : ('x,store) m
620     val gets : (store -> 'a) -> ('x,'a) m
621     val put : store -> ('x,unit) m
622     val puts : (store -> store) -> ('x,unit) m
623   end
624 end = struct
625   type store = Store.store
626   module Base = struct
627     type ('x,'a) m =  store -> 'a * store
628     type ('x,'a) result =  store -> 'a * store
629     type ('x,'a) result_exn = store -> 'a
630     let unit a = fun s -> (a, s)
631     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
632     let run u = fun s -> (u s)
633     let run_exn u = fun s -> fst (u s)
634     let zero () = Util.undef
635     let plus u v = u
636   end
637   include Monad.Make(Base)
638   let get = fun s -> (s, s)
639   let gets viewer = fun s -> (viewer s, s) (* may fail *)
640   let put s = fun _ -> ((), s)
641   let puts modifier = fun s -> ((), modifier s)
642   module T(Wrapped : Monad.S) = struct
643     module BaseT = struct
644       module Wrapped = Wrapped
645       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
646       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
647       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
648       let elevate w = fun s ->
649         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
650       let bind u f = fun s ->
651         Wrapped.bind (u s) (fun (a, s') -> f a s')
652       let run u = fun s -> Wrapped.run (u s)
653       let run_exn u = fun s ->
654         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
655         in Wrapped.run_exn w
656       (* satisfies Distrib *)
657       let plus u v = fun s -> Wrapped.plus (u s) (v s)
658       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
659     end
660     include Monad.MakeT(BaseT)
661     let get = fun s -> Wrapped.unit (s, s)
662     let gets viewer = fun s ->
663       try Wrapped.unit (viewer s, s)
664       with Not_found -> Wrapped.zero ()
665     let put s = fun _ -> Wrapped.unit ((), s)
666     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
667   end
668 end
669
670 (* State monad with different interface (structured store) *)
671 module Ref_monad(V : sig
672   type value
673 end) : sig
674   type ref
675   type value = V.value
676   type ('x,'a) result = 'a
677   type ('x,'a) result_exn = 'a
678   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
679   val newref : value -> ('x,ref) m
680   val deref : ref -> ('x,value) m
681   val change : ref -> value -> ('x,unit) m
682   (* RefT transformer *)
683   module T : functor (Wrapped : Monad.S) -> sig
684     type ('x,'a) result = ('x,'a) Wrapped.result
685     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
686     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
687     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
688     val newref : value -> ('x,ref) m
689     val deref : ref -> ('x,value) m
690     val change : ref -> value -> ('x,unit) m
691   end
692 end = struct
693   type ref = int
694   type value = V.value
695   module D = Map.Make(struct type t = ref let compare = compare end)
696   type dict = { next: ref; tree : value D.t }
697   let empty = { next = 0; tree = D.empty }
698   let alloc (value : value) (d : dict) =
699     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
700   let read (key : ref) (d : dict) =
701     D.find key d.tree
702   let write (key : ref) (value : value) (d : dict) =
703     { next = d.next; tree = D.add key value d.tree }
704   module Base = struct
705     type ('x,'a) m = dict -> 'a * dict
706     type ('x,'a) result = 'a
707     type ('x,'a) result_exn = 'a
708     let unit a = fun s -> (a, s)
709     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
710     let run u = fst (u empty)
711     let run_exn = run
712     let zero () = Util.undef
713     let plus u v = u
714   end
715   include Monad.Make(Base)
716   let newref value = fun s -> alloc value s
717   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
718   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
719   module T(Wrapped : Monad.S) = struct
720     module BaseT = struct
721       module Wrapped = Wrapped
722       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
723       type ('x,'a) result = ('x,'a) Wrapped.result
724       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
725       let elevate w = fun s ->
726         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
727       let bind u f = fun s ->
728         Wrapped.bind (u s) (fun (a, s') -> f a s')
729       let run u =
730         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
731         in Wrapped.run w
732       let run_exn u =
733         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
734         in Wrapped.run_exn w
735       (* satisfies Distrib *)
736       let plus u v = fun s -> Wrapped.plus (u s) (v s)
737       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
738     end
739     include Monad.MakeT(BaseT)
740     let newref value = fun s -> Wrapped.unit (alloc value s)
741     let deref key = fun s -> Wrapped.unit (read key s, s)
742     let change key value = fun s -> Wrapped.unit ((), write key value s)
743   end
744 end
745
746
747 (* must be parameterized on (struct type log = ... end) *)
748 module Writer_monad(Log : sig
749   type log
750   val zero : log
751   val plus : log -> log -> log
752 end) : sig
753   (* declare additional operations, while still hiding implementation of type m *)
754   type log = Log.log
755   type ('x,'a) result = 'a * log
756   type ('x,'a) result_exn = 'a * log
757   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
758   val tell : log -> ('x,unit) m
759   val listen : ('x,'a) m -> ('x,'a * log) m
760   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
761   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
762   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
763   (* WriterT transformer *)
764   module T : functor (Wrapped : Monad.S) -> sig
765     type ('x,'a) result = ('x,'a * log) Wrapped.result
766     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
767     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
768     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
769     val tell : log -> ('x,unit) m
770     val listen : ('x,'a) m -> ('x,'a * log) m
771     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
772     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
773   end
774 end = struct
775   type log = Log.log
776   module Base = struct
777     type ('x,'a) m = 'a * log
778     type ('x,'a) result = 'a * log
779     type ('x,'a) result_exn = 'a * log
780     let unit a = (a, Log.zero)
781     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
782     let run u = u
783     let run_exn = run
784     let zero () = Util.undef
785     let plus u v = u
786   end
787   include Monad.Make(Base)
788   let tell entries = ((), entries) (* add entries to log *)
789   let listen (a, w) = ((a, w), w)
790   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
791   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
792   let censor f u = pass (u >>= fun a -> unit (a, f))
793   module T(Wrapped : Monad.S) = struct
794     module BaseT = struct
795       module Wrapped = Wrapped
796       type ('x,'a) m = ('x,'a * log) Wrapped.m
797       type ('x,'a) result = ('x,'a * log) Wrapped.result
798       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
799       let elevate w =
800         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
801       let bind u f =
802         Wrapped.bind u (fun (a, w) ->
803         Wrapped.bind (f a) (fun (b, w') ->
804         Wrapped.unit (b, Log.plus w w')))
805       let zero () = elevate (Wrapped.zero ())
806       let plus u v = Wrapped.plus u v
807       let run u = Wrapped.run u
808       let run_exn u = Wrapped.run_exn u
809     end
810     include Monad.MakeT(BaseT)
811     let tell entries = Wrapped.unit ((), entries)
812     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
813     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
814     (* rest are derived in same way as before *)
815     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
816     let censor f u = pass (u >>= fun a -> unit (a, f))
817   end
818 end
819
820 (* pre-define simple Writer *)
821 module Writer1 = Writer_monad(struct
822   type log = string
823   let zero = ""
824   let plus s1 s2 = s1 ^ "\n" ^ s2
825 end)
826
827 (* slightly more efficient Writer *)
828 module Writer2 = struct
829   include Writer_monad(struct
830     type log = string list
831     let zero = []
832     let plus w w' = Util.append w' w
833   end)
834   let tell_string s = tell [s]
835   let tell entries = tell (Util.reverse entries)
836   let run u = let (a, w) = run u in (a, Util.reverse w)
837   let run_exn = run
838 end
839
840
841 (* TODO needs a T *)
842 module IO_monad : sig
843   (* declare additional operation, while still hiding implementation of type m *)
844   type ('x,'a) result = 'a
845   type ('x,'a) result_exn = 'a
846   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
847   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
848   val print_string : string -> ('x,unit) m
849   val print_int : int -> ('x,unit) m
850   val print_hex : int -> ('x,unit) m
851   val print_bool : bool -> ('x,unit) m
852 end = struct
853   module Base = struct
854     type ('x,'a) m = { run : unit -> unit; value : 'a }
855     type ('x,'a) result = 'a
856     type ('x,'a) result_exn = 'a
857     let unit a = { run = (fun () -> ()); value = a }
858     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
859      let fres = f a.value in
860        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
861     let run a = let () = a.run () in a.value
862     let run_exn = run
863     let zero () = Util.undef
864     let plus u v = u
865   end
866   include Monad.Make(Base)
867   let printf fmt =
868     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
869   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
870   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
871   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
872   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
873 end
874
875
876 module Continuation_monad : sig
877   (* expose only the implementation of type `('r,'a) result` *)
878   type ('r,'a) m
879   type ('r,'a) result = ('r,'a) m
880   type ('r,'a) result_exn = ('a -> 'r) -> 'r
881   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
882   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
883   val reset : ('a,'a) m -> ('r,'a) m
884   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
885   (* val abort : ('a,'a) m -> ('a,'b) m *)
886   val abort : 'a -> ('a,'b) m
887   val run0 : ('a,'a) m -> 'a
888   (* ContinuationT transformer *)
889   module T : functor (Wrapped : Monad.S) -> sig
890     type ('r,'a) m
891     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
892     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
893     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
894     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
895     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
896     (* TODO: reset,shift,abort,run0 *)
897   end
898 end = struct
899   let id = fun i -> i
900   module Base = struct
901     (* 'r is result type of whole computation *)
902     type ('r,'a) m = ('a -> 'r) -> 'r
903     type ('r,'a) result = ('a -> 'r) -> 'r
904     type ('r,'a) result_exn = ('r,'a) result
905     let unit a = (fun k -> k a)
906     let bind u f = (fun k -> (u) (fun a -> (f a) k))
907     let run u k = (u) k
908     let run_exn = run
909     let zero () = Util.undef
910     let plus u v = u
911   end
912   include Monad.Make(Base)
913   let callcc f = (fun k ->
914     let usek a = (fun _ -> k a)
915     in (f usek) k)
916   (*
917   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
918   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
919   let callcc f = fun k -> f k k
920   let throw k a = fun _ -> k a
921   *)
922
923   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
924    *
925    *  reset :: (Monad m) => ContT a m a -> ContT r m a
926    *  reset e = ContT $ \k -> runContT e return >>= k
927    *
928    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
929    *  shift e = ContT $ \k ->
930    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
931   let reset u = unit ((u) id)
932   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
933   (* let abort a = shift (fun _ -> a) *)
934   let abort a = shift (fun _ -> unit a)
935   let run0 (u : ('a,'a) m) = (u) id
936   module T(Wrapped : Monad.S) = struct
937     module BaseT = struct
938       module Wrapped = Wrapped
939       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
940       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
941       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
942       let elevate w = fun k -> Wrapped.bind w k
943       let bind u f = fun k -> u (fun a -> f a k)
944       let run u k = Wrapped.run (u k)
945       let run_exn u k = Wrapped.run_exn (u k)
946       let zero () = Util.undef
947       let plus u v = u
948     end
949     include Monad.MakeT(BaseT)
950     let callcc f = (fun k ->
951       let usek a = (fun _ -> k a)
952       in (f usek) k)
953   end
954 end
955
956
957 (*
958  * Scheme:
959  * (define (example n)
960  *    (let ([u (let/cc k ; type int -> int pair
961  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
962  *                 (+ 1 (car v))))]) ; int
963  *      (cons u 0))) ; int pair
964  * ; (example 10) ~~> '(111 . 0)
965  * ; (example -10) ~~> '(0 . 0)
966  *
967  * OCaml monads:
968  * let example n : (int * int) =
969  *   Continuation_monad.(let u = callcc (fun k ->
970  *       (if n < 0 then k 0 else unit [n + 100])
971  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
972  *       >>= fun [x] -> unit (x + 1)
973  *   )
974  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
975  *   >>= fun x -> unit (x, 0)
976  *   in run u)
977  *
978  *
979  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
980  * let example1 () : int =
981  *   Continuation_monad.(let v = reset (
982  *       let u = shift (fun k -> unit (10 + 1))
983  *       in u >>= fun x -> unit (100 + x)
984  *     ) in let w = v >>= fun x -> unit (1000 + x)
985  *     in run w)
986  *
987  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
988  * let example2 () =
989  *   Continuation_monad.(let v = reset (
990  *       let u = shift (fun k -> k (10 :: [1]))
991  *       in u >>= fun x -> unit (100 :: x)
992  *     ) in let w = v >>= fun x -> unit (1000 :: x)
993  *     in run w)
994  *
995  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
996  * let example3 () =
997  *   Continuation_monad.(let v = reset (
998  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
999  *       in u >>= fun x -> unit (100 :: x)
1000  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1001  *     in run w)
1002  *
1003  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1004  * (* not sure if this example can be typed without a sum-type *)
1005  *
1006  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1007  * let example5 () : int =
1008  *   Continuation_monad.(let v = reset (
1009  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1010  *       in u >>= fun x -> unit (10 + x)
1011  *     ) in let w = v >>= fun x -> unit (100 + x)
1012  *     in run w)
1013  *
1014  *)
1015
1016
1017 module Tree_monad : sig
1018   (* We implement the type as `'a tree option` because it has a natural`plus`,
1019    * and the rest of the library expects that `plus` and `zero` will come together. *)
1020   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1021   type ('x,'a) result = 'a tree option
1022   type ('x,'a) result_exn = 'a tree
1023   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1024   (* TreeT transformer *)
1025   module T : functor (Wrapped : Monad.S) -> sig
1026     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1027     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1028     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1029     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1030     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1031     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1032     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1033   end
1034 end = struct
1035   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1036   (* uses supplied plus and zero to copy t to its image under f *)
1037   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1038       | None -> zero ()
1039       | Some ts -> let rec loop ts = (match ts with
1040                      | Leaf a -> f a
1041                      | Node (l, r) ->
1042                          (* recursive application of f may delete a branch *)
1043                          plus (loop l) (loop r)
1044                    ) in loop ts
1045   module Base = struct
1046     type ('x,'a) m = 'a tree option
1047     type ('x,'a) result = 'a tree option
1048     type ('x,'a) result_exn = 'a tree
1049     let unit a = Some (Leaf a)
1050     let zero () = None
1051     (* satisfies Distrib *)
1052     let plus u v = match (u, v) with
1053       | None, _ -> v
1054       | _, None -> u
1055       | Some us, Some vs -> Some (Node (us, vs))
1056     let bind u f = mapT f u zero plus
1057     let run u = u
1058     let run_exn u = match u with
1059       | None -> failwith "no values"
1060       (*
1061       | Some (Leaf a) -> a
1062       | many -> failwith "multiple values"
1063       *)
1064       | Some us -> us
1065   end
1066   include Monad.Make(Base)
1067   module T(Wrapped : Monad.S) = struct
1068     module BaseT = struct
1069       include Monad.MakeT(struct
1070         module Wrapped = Wrapped
1071         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1072         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1073         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1074         let zero () = Wrapped.unit None
1075         let plus u v =
1076           Wrapped.bind u (fun us ->
1077           Wrapped.bind v (fun vs ->
1078           Wrapped.unit (Base.plus us vs)))
1079         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1080         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1081         let run u = Wrapped.run u
1082         let run_exn u =
1083             let w = Wrapped.bind u (fun t -> match t with
1084               | None -> Wrapped.zero ()
1085               | Some ts -> Wrapped.unit ts
1086             ) in Wrapped.run_exn w
1087       end)
1088     end
1089     include BaseT
1090     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1091   end
1092 end
1093
1094
1095 module L = List_monad;;
1096 module R = Reader_monad(struct type env = int -> int end);;
1097 module S = State_monad(struct type store = int end);;
1098 module T = Tree_monad;;
1099 module LR = L.T(R);;
1100 module LS = L.T(S);;
1101 module TL = T.T(L);;
1102 module TR = T.T(R);;
1103 module TS = T.T(S);;
1104 module C = Continuation_monad
1105 module TC = T.T(C);;
1106
1107
1108 print_endline "=== test TreeT(...).distribute ==================";;
1109
1110 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1111
1112 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1113 TS.run ts 0;;
1114 (*
1115 - : int T.tree option * S.store =
1116 (Some
1117   (T.Node
1118     (T.Node (T.Leaf 2, T.Leaf 3),
1119      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1120  5)
1121 *)
1122
1123 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1124 TS.run_exn ts2 0;;
1125 (*
1126 - : (int * S.store) T.tree option * S.store =
1127 (Some
1128   (T.Node
1129     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1130      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1131  5)
1132 *)
1133
1134 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1135 TR.run_exn tr (fun i -> i+i);;
1136 (*
1137 - : int T.tree option =
1138 Some
1139  (T.Node
1140    (T.Node (T.Leaf 4, T.Leaf 6),
1141     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1142 *)
1143
1144 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1145 TL.run_exn tl;;
1146 (*
1147 - : (int * int) TL.result =
1148 [Some
1149   (T.Node
1150     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1151      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1152 *)
1153
1154 let l2 = [1;2;3;4;5];;
1155 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1156
1157 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1158 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1159
1160 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1161 (*
1162 int T.tree option =
1163 Some
1164  (T.Node
1165    (T.Node (T.Leaf 10, T.Leaf 11),
1166     T.Node
1167      (T.Node
1168        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1169         T.Node (T.Leaf 40, T.Leaf 41)),
1170       T.Node (T.Leaf 50, T.Leaf 51))))
1171  *)
1172
1173 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1174 (*
1175 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1176 *)
1177
1178 print_endline "=== test TreeT(Continuation).distribute ==================";;
1179
1180 let id : 'z. 'z -> 'z = fun x -> x
1181
1182 let example n : (int * int) =
1183   Continuation_monad.(let u = callcc (fun k ->
1184       (if n < 0 then k 0 else unit [n + 100])
1185       (* all of the following is skipped by k 0; the end type int is k's input type *)
1186       >>= fun [x] -> unit (x + 1)
1187   )
1188   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1189   >>= fun x -> unit (x, 0)
1190   in run0 u)
1191
1192
1193 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1194 let example1 () : int =
1195   Continuation_monad.(let v = reset (
1196       let u = shift (fun k -> unit (10 + 1))
1197       in u >>= fun x -> unit (100 + x)
1198     ) in let w = v >>= fun x -> unit (1000 + x)
1199     in run0 w)
1200
1201 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1202 let example2 () =
1203   Continuation_monad.(let v = reset (
1204       let u = shift (fun k -> k (10 :: [1]))
1205       in u >>= fun x -> unit (100 :: x)
1206     ) in let w = v >>= fun x -> unit (1000 :: x)
1207     in run0 w)
1208
1209 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1210 let example3 () =
1211   Continuation_monad.(let v = reset (
1212       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1213       in u >>= fun x -> unit (100 :: x)
1214     ) in let w = v >>= fun x -> unit (1000 :: x)
1215     in run0 w)
1216
1217 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1218 (* not sure if this example can be typed without a sum-type *)
1219
1220 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1221 let example5 () : int =
1222   Continuation_monad.(let v = reset (
1223       let u = shift (fun k -> k 1 >>= k)
1224       in u >>= fun x -> unit (10 + x)
1225     ) in let w = v >>= fun x -> unit (100 + x)
1226     in run0 w)
1227
1228 ;;
1229
1230 print_endline "=== test bare Continuation ============";;
1231
1232 (1011, 1111, 1111, 121);;
1233 (example1(), example2(), example3(), example5());;
1234 ((111,0), (0,0));;
1235 (example ~+10, example ~-10);;
1236
1237 let testc df ic =
1238     C.run_exn TC.(run (distribute df t1)) ic;;
1239
1240
1241 (*
1242 (* do nothing *)
1243 let initial_continuation = fun t -> t in
1244 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1245 *)
1246 testc (C.unit) id;;
1247
1248 (*
1249 (* count leaves, using continuation *)
1250 let initial_continuation = fun t -> 0 in
1251 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1252 *)
1253
1254 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1255
1256 (*
1257 (* convert tree to list of leaves *)
1258 let initial_continuation = fun t -> [] in
1259 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1260 *)
1261
1262 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1263
1264 (*
1265 (* square each leaf using continuation *)
1266 let initial_continuation = fun t -> t in
1267 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1268 *)
1269
1270 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1271
1272
1273 (*
1274 (* replace leaves with list, using continuation *)
1275 let initial_continuation = fun t -> t in
1276 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1277 *)
1278
1279 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1280
1281 print_endline "=== pa_monad's Continuation Tests ============";;
1282
1283 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1284 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1285 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1286 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1287 (5, 27 = C.(run0 (
1288               let c = reset(abort 5 >>= fun y -> unit (y+6))
1289               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1290
1291 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1292
1293 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1294
1295 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1296
1297
1298 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1299