learning ocaml links
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glaskow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources).
55  *)
56
57 exception Undefined
58
59 (* Some library functions used below. *)
60 module Util = struct
61   let fold_right = List.fold_right
62   let map = List.map
63   let append = List.append
64   let reverse = List.rev
65   let concat = List.concat
66   let concat_map f lst = List.concat (List.map f lst)
67   (* let zip = List.combine *)
68   let unzip = List.split
69   let zip_with = List.map2
70   let replicate len fill =
71     let rec loop n accu =
72       if n == 0 then accu else loop (pred n) (fill :: accu)
73     in loop len []
74   (* Dirty hack to be a default polymorphic zero.
75    * To implement this cleanly, monads without a natural zero
76    * should always wrap themselves in an option layer (see Leaf_monad). *)
77   let undef = Obj.magic (fun () -> raise Undefined)
78 end
79
80
81
82 (*
83  * This module contains factories that extend a base set of
84  * monadic definitions with a larger family of standard derived values.
85  *)
86
87 module Monad = struct
88   (*
89    * Signature extenders:
90    *   Make :: BASE -> S
91    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
92    *)
93
94
95   (* type of base definitions *)
96   module type BASE = sig
97     (* We make all monadic types doubly-parameterized so that they
98      * can layer nicely with Continuation, which needs the second
99      * type parameter. *)
100     type ('x,'a) m
101     type ('x,'a) result
102     type ('x,'a) result_exn
103     val unit : 'a -> ('x,'a) m
104     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
105     val run : ('x,'a) m -> ('x,'a) result
106     (* run_exn tries to provide a more ground-level result, but may fail *)
107     val run_exn : ('x,'a) m -> ('x,'a) result_exn
108     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
109      *     zero >>= f   ===  zero
110      *     plus zero u  ===  u
111      *     plus u zero  ===  u
112      * Additionally, they will obey one of the following laws:
113      *     (Catch)   plus (unit a) v  ===  unit a
114      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
115      * When no natural zero is available, use `let zero () = Util.undef`.
116      * The Make functor automatically detects for zero >>= ..., and 
117      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
118      *)
119     val zero : unit -> ('x,'a) m
120     (* zero has to be thunked to ensure results are always poly enough *)
121     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
122   end
123   module type S = sig
124     include BASE
125     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
126     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
127     val join : ('x,('x,'a) m) m -> ('x,'a) m
128     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
129     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
130     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
131     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
132     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
133     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
134     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
135     val sequence : ('x,'a) m list -> ('x,'a list) m
136     val sequence_ : ('x,'a) m list -> ('x,unit) m
137     val guard : bool -> ('x,unit) m
138     val sum : ('x,'a) m list -> ('x,'a) m
139   end
140
141   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
142     include B
143     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
144       if u == Util.undef then Util.undef
145       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
146     let plus u v =
147       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
148     let run u =
149       if u == Util.undef then raise Undefined else B.run u
150     let run_exn u =
151       if u == Util.undef then raise Undefined else B.run_exn u
152     let (>>=) = bind
153     (* expressions after >> will be evaluated before they're passed to
154      * bind, so you can't do `zero () >> assert false`
155      * this works though: `zero () >>= fun _ -> assert false`
156      *)
157     let (>>) u v = u >>= fun _ -> v
158     let lift f u = u >>= fun a -> unit (f a)
159     (* lift is called listM, fmap, and <$> in Haskell *)
160     let join uu = uu >>= fun u -> u
161     (* u >>= f === join (lift f u) *)
162     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
163     (* [f] <*> [x1,x2] = [f x1,f x2] *)
164     (* let apply u v = u >>= fun f -> lift f v *)
165     (* let apply = lift2 id *)
166     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
167     (* let lift f u === apply (unit f) u *)
168     (* let lift2 f u v = apply (lift f u) v *)
169     let (>=>) f g = fun a -> f a >>= g
170     let do_when test u = if test then u else unit ()
171     let do_unless test u = if test then unit () else u
172     (* A Haskell-like version:
173          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
174      * is not in tail position and will stack overflow. *)
175     let forever uthunk =
176         let z = zero () in
177         let id result = result in
178         let newk = ref id in
179         let rec loop () =
180             let result = uthunk (newk := id) >>= chained
181             in !newk result
182         and chained =
183             fun _ -> newk := (fun _ -> loop ()); z (* we use z only for its polymorphism *)
184         in loop ()
185     (* reimplementations of the preceding using a hand-rolled State or StateT also stack overflowed *)
186     let sequence ms =
187       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
188         Util.fold_right op ms (unit [])
189     let sequence_ ms =
190       Util.fold_right (>>) ms (unit ())
191
192     (* Haskell defines these other operations combining lists and monads.
193      * We don't, but notice that M.mapM == ListT(M).distribute
194      * There's also a parallel TreeT(M).distribute *)
195     (*
196     let mapM f alist = sequence (Util.map f alist)
197     let mapM_ f alist = sequence_ (Util.map f alist)
198     let rec filterM f lst = match lst with
199       | [] -> unit []
200       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
201     let forM alist f = mapM f alist
202     let forM_ alist f = mapM_ f alist
203     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
204     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
205     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
206     let rec foldM f z lst = match lst with
207       | [] -> unit z
208       | x::xs -> f z x >>= fun z' -> foldM f z' xs
209     let foldM_ f z xs = foldM f z xs >> unit ()
210     let replicateM n x = sequence (Util.replicate n x)
211     let replicateM_ n x = sequence_ (Util.replicate n x)
212     *)
213     let guard test = if test then B.unit () else zero ()
214     let sum ms = Util.fold_right plus ms (zero ())
215   end
216
217   (* Signatures for MonadT *)
218   module type BASET = sig
219     module Wrapped : S
220     type ('x,'a) m
221     type ('x,'a) result
222     type ('x,'a) result_exn
223     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
224     val run : ('x,'a) m -> ('x,'a) result
225     val run_exn : ('x,'a) m -> ('x,'a) result_exn
226     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
227     (* lift/elevate laws:
228      *     elevate (W.unit a) == unit a
229      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
230      *)
231     val zero : unit -> ('x,'a) m
232     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
233   end
234   module MakeT(T : BASET) = struct
235     include Make(struct
236         include T
237         let unit a = elevate (Wrapped.unit a)
238     end)
239     let elevate = T.elevate
240   end
241
242 end
243
244
245
246
247
248 module Identity_monad : sig
249   (* expose only the implementation of type `'a result` *)
250   type ('x,'a) result = 'a
251   type ('x,'a) result_exn = 'a
252   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
253 end = struct
254   module Base = struct
255     type ('x,'a) m = 'a
256     type ('x,'a) result = 'a
257     type ('x,'a) result_exn = 'a
258     let unit a = a
259     let bind a f = f a
260     let run a = a
261     let run_exn a = a
262     let zero () = Util.undef
263     let plus u v = u
264   end
265   include Monad.Make(Base)
266 end
267
268
269 module Maybe_monad : sig
270   (* expose only the implementation of type `'a result` *)
271   type ('x,'a) result = 'a option
272   type ('x,'a) result_exn = 'a
273   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
274   (* MaybeT transformer *)
275   module T : functor (Wrapped : Monad.S) -> sig
276     type ('x,'a) result = ('x,'a option) Wrapped.result
277     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
278     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
279     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
280   end
281 end = struct
282   module Base = struct
283     type ('x,'a) m = 'a option
284     type ('x,'a) result = 'a option
285     type ('x,'a) result_exn = 'a
286     let unit a = Some a
287     let bind u f = match u with Some a -> f a | None -> None
288     let run u = u
289     let run_exn u = match u with
290       | Some a -> a
291       | None -> failwith "no value"
292     let zero () = None
293     (* satisfies Catch *)
294     let plus u v = match u with None -> v | _ -> u
295   end
296   include Monad.Make(Base)
297   module T(Wrapped : Monad.S) = struct
298     module BaseT = struct
299       include Monad.MakeT(struct
300         module Wrapped = Wrapped
301         type ('x,'a) m = ('x,'a option) Wrapped.m
302         type ('x,'a) result = ('x,'a option) Wrapped.result
303         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
304         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
305         let bind u f = Wrapped.bind u (fun t -> match t with
306           | Some a -> f a
307           | None -> Wrapped.unit None)
308         let run u = Wrapped.run u
309         let run_exn u =
310           let w = Wrapped.bind u (fun t -> match t with
311             | Some a -> Wrapped.unit a
312             | None -> Wrapped.zero ()
313           ) in Wrapped.run_exn w
314         let zero () = Wrapped.unit None
315         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
316       end)
317     end
318     include BaseT
319   end
320 end
321
322
323 module List_monad : sig
324   (* declare additional operation, while still hiding implementation of type m *)
325   type ('x,'a) result = 'a list
326   type ('x,'a) result_exn = 'a
327   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
328   val permute : ('x,'a) m -> ('x,('x,'a) m) m
329   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
330   (* ListT transformer *)
331   module T : functor (Wrapped : Monad.S) -> sig
332     type ('x,'a) result = ('x,'a list) Wrapped.result
333     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
334     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
335     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
336     (* note that second argument is an 'a list, not the more abstract 'a m *)
337     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
338     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
339 (* TODO
340     val permute : 'a m -> 'a m m
341     val select : 'a m -> ('a * 'a m) m
342 *)
343   end
344 end = struct
345   module Base = struct
346    type ('x,'a) m = 'a list
347    type ('x,'a) result = 'a list
348    type ('x,'a) result_exn = 'a
349    let unit a = [a]
350    let bind u f = Util.concat_map f u
351    let run u = u
352    let run_exn u = match u with
353      | [] -> failwith "no values"
354      | [a] -> a
355      | many -> failwith "multiple values"
356    let zero () = []
357    (* satisfies Distrib *)
358    let plus = Util.append
359   end
360   include Monad.Make(Base)
361   (* let either u v = plus u v *)
362   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
363   let rec insert a u =
364     plus (unit (a :: u)) (match u with
365         | [] -> zero ()
366         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
367     )
368   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
369   let rec permute u = match u with
370       | [] -> unit []
371       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
372   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
373   let rec select u = match u with
374     | [] -> zero ()
375     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
376   module T(Wrapped : Monad.S) = struct
377     (* Wrapped.sequence ms  ===  
378          let plus1 u v =
379            Wrapped.bind u (fun x ->
380            Wrapped.bind v (fun xs ->
381            Wrapped.unit (x :: xs)))
382          in Util.fold_right plus1 ms (Wrapped.unit []) *)
383     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
384     let distribute f alist = Wrapped.sequence (Util.map f alist)
385
386     include Monad.MakeT(struct
387       module Wrapped = Wrapped
388       type ('x,'a) m = ('x,'a list) Wrapped.m
389       type ('x,'a) result = ('x,'a list) Wrapped.result
390       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
391       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
392       let bind u f =
393         Wrapped.bind u (fun ts ->
394         Wrapped.bind (distribute f ts) (fun tts ->
395         Wrapped.unit (Util.concat tts)))
396       let run u = Wrapped.run u
397       let run_exn u =
398         let w = Wrapped.bind u (fun ts -> match ts with
399           | [] -> Wrapped.zero ()
400           | [a] -> Wrapped.unit a
401           | many -> Wrapped.zero ()
402         ) in Wrapped.run_exn w
403       let zero () = Wrapped.unit []
404       let plus u v =
405         Wrapped.bind u (fun us ->
406         Wrapped.bind v (fun vs ->
407         Wrapped.unit (Base.plus us vs)))
408     end)
409 (*
410     let permute : 'a m -> 'a m m
411     let select : 'a m -> ('a * 'a m) m
412 *)
413   end
414 end
415
416 (*
417 # LL.(run(plus (unit 1) (unit 2) >>= fun i -> plus (unit i) (unit(10*i)) ));;
418 - : ('_a, int) LL.result = [[1; 10; 2; 20]]
419 # LL.(run(plus (unit 1) (unit 2) >>= fun i -> elevate L.(plus (unit i) (unit(10*i)) )));;
420 - : ('_a, int) LL.result = [[1; 2]; [1; 20]; [10; 2]; [10; 20]]
421 *)
422
423
424 (* must be parameterized on (struct type err = ... end) *)
425 module Error_monad(Err : sig
426   type err
427   exception Exc of err
428   (*
429   val zero : unit -> err
430   val plus : err -> err -> err
431   *)
432 end) : sig
433   (* declare additional operations, while still hiding implementation of type m *)
434   type err = Err.err
435   type 'a error = Error of err | Success of 'a
436   type ('x,'a) result = 'a error
437   type ('x,'a) result_exn = 'a
438   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
439   val throw : err -> ('x,'a) m
440   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
441   (* ErrorT transformer *)
442   module T : functor (Wrapped : Monad.S) -> sig
443     type ('x,'a) result = ('x,'a) Wrapped.result
444     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
445     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
446     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
447     val throw : err -> ('x,'a) m
448     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
449   end
450 end = struct
451   type err = Err.err
452   type 'a error = Error of err | Success of 'a
453   module Base = struct
454     type ('x,'a) m = 'a error
455     type ('x,'a) result = 'a error
456     type ('x,'a) result_exn = 'a
457     let unit a = Success a
458     let bind u f = match u with
459       | Success a -> f a
460       | Error e -> Error e (* input and output may be of different 'a types *)
461     let run u = u
462     let run_exn u = match u with
463       | Success a -> a
464       | Error e -> raise (Err.Exc e)
465     let zero () = Util.undef
466     (* satisfies Catch *)
467     let plus u v = match u with
468       | Success _ -> u
469       | Error _ -> if v == Util.undef then u else v
470   end
471   include Monad.Make(Base)
472   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
473   let throw e = Error e
474   let catch u handler = match u with
475     | Success _ -> u
476     | Error e -> handler e
477   module T(Wrapped : Monad.S) = struct
478     include Monad.MakeT(struct
479       module Wrapped = Wrapped
480       type ('x,'a) m = ('x,'a error) Wrapped.m
481       type ('x,'a) result = ('x,'a) Wrapped.result
482       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
483       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
484       let bind u f = Wrapped.bind u (fun t -> match t with
485         | Success a -> f a
486         | Error e -> Wrapped.unit (Error e))
487       let run u =
488         let w = Wrapped.bind u (fun t -> match t with
489           | Success a -> Wrapped.unit a
490           | Error e -> Wrapped.zero ()
491         ) in Wrapped.run w
492       let run_exn u =
493         let w = Wrapped.bind u (fun t -> match t with
494           | Success a -> Wrapped.unit a
495           | Error e -> raise (Err.Exc e))
496         in Wrapped.run_exn w
497       let plus u v = Wrapped.plus u v
498       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
499     end)
500     let throw e = Wrapped.unit (Error e)
501     let catch u handler = Wrapped.bind u (fun t -> match t with
502       | Success _ -> Wrapped.unit t
503       | Error e -> handler e)
504   end
505 end
506
507 (* pre-define common instance of Error_monad *)
508 module Failure = Error_monad(struct
509   type err = string
510   exception Exc = Failure
511   (*
512   let zero = ""
513   let plus s1 s2 = s1 ^ "\n" ^ s2
514   *)
515 end)
516
517 (*
518 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
519 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
520 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
521 - : int LE.result = Failure.Error "bye"
522 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
523 Exception: Failure "bye".
524 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
525 Exception: Failure "bye".
526
527 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
528 - : int Failure.error * S.store = (Failure.Error "bye", 1)
529 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
530 - : (int * S.store) Failure.result = Failure.Error "bye"
531 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
532 Exception: Failure "bye".
533 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
534 Exception: Failure "bye".
535  *)
536
537
538 (* must be parameterized on (struct type env = ... end) *)
539 module Reader_monad(Env : sig type env end) : sig
540   (* declare additional operations, while still hiding implementation of type m *)
541   type env = Env.env
542   type ('x,'a) result = env -> 'a
543   type ('x,'a) result_exn = env -> 'a
544   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
545   val ask : ('x,env) m
546   val asks : (env -> 'a) -> ('x,'a) m
547   (* lookup i == `fun e -> e i` would assume env is a functional type *)
548   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
549   (* ReaderT transformer *)
550   module T : functor (Wrapped : Monad.S) -> sig
551     type ('x,'a) result = env -> ('x,'a) Wrapped.result
552     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
553     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
554     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
555     val ask : ('x,env) m
556     val asks : (env -> 'a) -> ('x,'a) m
557     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
558   end
559 end = struct
560   type env = Env.env
561   module Base = struct
562     type ('x,'a) m = env -> 'a
563     type ('x,'a) result = env -> 'a
564     type ('x,'a) result_exn = env -> 'a
565     let unit a = fun e -> a
566     let bind u f = fun e -> let a = u e in let u' = f a in u' e
567     let run u = fun e -> u e
568     let run_exn = run
569     let zero () = Util.undef
570     let plus u v = u
571   end
572   include Monad.Make(Base)
573   let ask = fun e -> e
574   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
575   let local modifier u = fun e -> u (modifier e)
576   module T(Wrapped : Monad.S) = struct
577     module BaseT = struct
578       module Wrapped = Wrapped
579       type ('x,'a) m = env -> ('x,'a) Wrapped.m
580       type ('x,'a) result = env -> ('x,'a) Wrapped.result
581       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
582       let elevate w = fun e -> w
583       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
584       let run u = fun e -> Wrapped.run (u e)
585       let run_exn u = fun e -> Wrapped.run_exn (u e)
586       (* satisfies Distrib *)
587       let plus u v = fun e -> Wrapped.plus (u e) (v e)
588       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
589     end
590     include Monad.MakeT(BaseT)
591     let ask = Wrapped.unit
592     let local modifier u = fun e -> u (modifier e)
593     let asks selector = ask >>= (fun e ->
594       try unit (selector e)
595       with Not_found -> fun e -> Wrapped.zero ())
596   end
597 end
598
599
600 (* must be parameterized on (struct type store = ... end) *)
601 module State_monad(Store : sig type store end) : sig
602   (* declare additional operations, while still hiding implementation of type m *)
603   type store = Store.store
604   type ('x,'a) result =  store -> 'a * store
605   type ('x,'a) result_exn = store -> 'a
606   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
607   val get : ('x,store) m
608   val gets : (store -> 'a) -> ('x,'a) m
609   val put : store -> ('x,unit) m
610   val puts : (store -> store) -> ('x,unit) m
611   (* StateT transformer *)
612   module T : functor (Wrapped : Monad.S) -> sig
613     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
614     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
615     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
616     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
617     val get : ('x,store) m
618     val gets : (store -> 'a) -> ('x,'a) m
619     val put : store -> ('x,unit) m
620     val puts : (store -> store) -> ('x,unit) m
621   end
622 end = struct
623   type store = Store.store
624   module Base = struct
625     type ('x,'a) m =  store -> 'a * store
626     type ('x,'a) result =  store -> 'a * store
627     type ('x,'a) result_exn = store -> 'a
628     let unit a = fun s -> (a, s)
629     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
630     let run u = fun s -> (u s)
631     let run_exn u = fun s -> fst (u s)
632     let zero () = Util.undef
633     let plus u v = u
634   end
635   include Monad.Make(Base)
636   let get = fun s -> (s, s)
637   let gets viewer = fun s -> (viewer s, s) (* may fail *)
638   let put s = fun _ -> ((), s)
639   let puts modifier = fun s -> ((), modifier s)
640   module T(Wrapped : Monad.S) = struct
641     module BaseT = struct
642       module Wrapped = Wrapped
643       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
644       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
645       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
646       let elevate w = fun s ->
647         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
648       let bind u f = fun s ->
649         Wrapped.bind (u s) (fun (a, s') -> f a s')
650       let run u = fun s -> Wrapped.run (u s)
651       let run_exn u = fun s ->
652         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
653         in Wrapped.run_exn w
654       (* satisfies Distrib *)
655       let plus u v = fun s -> Wrapped.plus (u s) (v s)
656       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
657     end
658     include Monad.MakeT(BaseT)
659     let get = fun s -> Wrapped.unit (s, s)
660     let gets viewer = fun s ->
661       try Wrapped.unit (viewer s, s)
662       with Not_found -> Wrapped.zero ()
663     let put s = fun _ -> Wrapped.unit ((), s)
664     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
665   end
666 end
667
668 (* State monad with different interface (structured store) *)
669 module Ref_monad(V : sig
670   type value
671 end) : sig
672   type ref
673   type value = V.value
674   type ('x,'a) result = 'a
675   type ('x,'a) result_exn = 'a
676   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
677   val newref : value -> ('x,ref) m
678   val deref : ref -> ('x,value) m
679   val change : ref -> value -> ('x,unit) m
680   (* RefT transformer *)
681   module T : functor (Wrapped : Monad.S) -> sig
682     type ('x,'a) result = ('x,'a) Wrapped.result
683     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
684     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
685     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
686     val newref : value -> ('x,ref) m
687     val deref : ref -> ('x,value) m
688     val change : ref -> value -> ('x,unit) m
689   end
690 end = struct
691   type ref = int
692   type value = V.value
693   module D = Map.Make(struct type t = ref let compare = compare end)
694   type dict = { next: ref; tree : value D.t }
695   let empty = { next = 0; tree = D.empty }
696   let alloc (value : value) (d : dict) =
697     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
698   let read (key : ref) (d : dict) =
699     D.find key d.tree
700   let write (key : ref) (value : value) (d : dict) =
701     { next = d.next; tree = D.add key value d.tree }
702   module Base = struct
703     type ('x,'a) m = dict -> 'a * dict
704     type ('x,'a) result = 'a
705     type ('x,'a) result_exn = 'a
706     let unit a = fun s -> (a, s)
707     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
708     let run u = fst (u empty)
709     let run_exn = run
710     let zero () = Util.undef
711     let plus u v = u
712   end
713   include Monad.Make(Base)
714   let newref value = fun s -> alloc value s
715   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
716   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
717   module T(Wrapped : Monad.S) = struct
718     module BaseT = struct
719       module Wrapped = Wrapped
720       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
721       type ('x,'a) result = ('x,'a) Wrapped.result
722       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
723       let elevate w = fun s ->
724         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
725       let bind u f = fun s ->
726         Wrapped.bind (u s) (fun (a, s') -> f a s')
727       let run u =
728         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
729         in Wrapped.run w
730       let run_exn u =
731         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
732         in Wrapped.run_exn w
733       (* satisfies Distrib *)
734       let plus u v = fun s -> Wrapped.plus (u s) (v s)
735       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
736     end
737     include Monad.MakeT(BaseT)
738     let newref value = fun s -> Wrapped.unit (alloc value s)
739     let deref key = fun s -> Wrapped.unit (read key s, s)
740     let change key value = fun s -> Wrapped.unit ((), write key value s)
741   end
742 end
743
744
745 (* must be parameterized on (struct type log = ... end) *)
746 module Writer_monad(Log : sig
747   type log
748   val zero : log
749   val plus : log -> log -> log
750 end) : sig
751   (* declare additional operations, while still hiding implementation of type m *)
752   type log = Log.log
753   type ('x,'a) result = 'a * log
754   type ('x,'a) result_exn = 'a * log
755   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
756   val tell : log -> ('x,unit) m
757   val listen : ('x,'a) m -> ('x,'a * log) m
758   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
759   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
760   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
761   (* WriterT transformer *)
762   module T : functor (Wrapped : Monad.S) -> sig
763     type ('x,'a) result = ('x,'a * log) Wrapped.result
764     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
765     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
766     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
767     val tell : log -> ('x,unit) m
768     val listen : ('x,'a) m -> ('x,'a * log) m
769     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
770     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
771   end
772 end = struct
773   type log = Log.log
774   module Base = struct
775     type ('x,'a) m = 'a * log
776     type ('x,'a) result = 'a * log
777     type ('x,'a) result_exn = 'a * log
778     let unit a = (a, Log.zero)
779     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
780     let run u = u
781     let run_exn = run
782     let zero () = Util.undef
783     let plus u v = u
784   end
785   include Monad.Make(Base)
786   let tell entries = ((), entries) (* add entries to log *)
787   let listen (a, w) = ((a, w), w)
788   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
789   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
790   let censor f u = pass (u >>= fun a -> unit (a, f))
791   module T(Wrapped : Monad.S) = struct
792     module BaseT = struct
793       module Wrapped = Wrapped
794       type ('x,'a) m = ('x,'a * log) Wrapped.m
795       type ('x,'a) result = ('x,'a * log) Wrapped.result
796       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
797       let elevate w =
798         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
799       let bind u f =
800         Wrapped.bind u (fun (a, w) ->
801         Wrapped.bind (f a) (fun (b, w') ->
802         Wrapped.unit (b, Log.plus w w')))
803       let zero () = elevate (Wrapped.zero ())
804       let plus u v = Wrapped.plus u v
805       let run u = Wrapped.run u
806       let run_exn u = Wrapped.run_exn u
807     end
808     include Monad.MakeT(BaseT)
809     let tell entries = Wrapped.unit ((), entries)
810     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
811     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
812     (* rest are derived in same way as before *)
813     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
814     let censor f u = pass (u >>= fun a -> unit (a, f))
815   end
816 end
817
818 (* pre-define simple Writer *)
819 module Writer1 = Writer_monad(struct
820   type log = string
821   let zero = ""
822   let plus s1 s2 = s1 ^ "\n" ^ s2
823 end)
824
825 (* slightly more efficient Writer *)
826 module Writer2 = struct
827   include Writer_monad(struct
828     type log = string list
829     let zero = []
830     let plus w w' = Util.append w' w
831   end)
832   let tell_string s = tell [s]
833   let tell entries = tell (Util.reverse entries)
834   let run u = let (a, w) = run u in (a, Util.reverse w)
835   let run_exn = run
836 end
837
838
839 (* TODO needs a T *)
840 module IO_monad : sig
841   (* declare additional operation, while still hiding implementation of type m *)
842   type ('x,'a) result = 'a
843   type ('x,'a) result_exn = 'a
844   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
845   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
846   val print_string : string -> ('x,unit) m
847   val print_int : int -> ('x,unit) m
848   val print_hex : int -> ('x,unit) m
849   val print_bool : bool -> ('x,unit) m
850 end = struct
851   module Base = struct
852     type ('x,'a) m = { run : unit -> unit; value : 'a }
853     type ('x,'a) result = 'a
854     type ('x,'a) result_exn = 'a
855     let unit a = { run = (fun () -> ()); value = a }
856     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
857      let fres = f a.value in
858        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
859     let run a = let () = a.run () in a.value
860     let run_exn = run
861     let zero () = Util.undef
862     let plus u v = u
863   end
864   include Monad.Make(Base)
865   let printf fmt =
866     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
867   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
868   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
869   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
870   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
871 end
872
873
874 module Continuation_monad : sig
875   (* expose only the implementation of type `('r,'a) result` *)
876   type ('r,'a) m
877   type ('r,'a) result = ('r,'a) m
878   type ('r,'a) result_exn = ('a -> 'r) -> 'r
879   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
880   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
881   val reset : ('a,'a) m -> ('r,'a) m
882   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
883   (* val abort : ('a,'a) m -> ('a,'b) m *)
884   val abort : 'a -> ('a,'b) m
885   val run0 : ('a,'a) m -> 'a
886   (* ContinuationT transformer *)
887   module T : functor (Wrapped : Monad.S) -> sig
888     type ('r,'a) m
889     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
890     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
891     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
892     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
893     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
894     (* TODO: reset,shift,abort,run0 *)
895   end
896 end = struct
897   let id = fun i -> i
898   module Base = struct
899     (* 'r is result type of whole computation *)
900     type ('r,'a) m = ('a -> 'r) -> 'r
901     type ('r,'a) result = ('a -> 'r) -> 'r
902     type ('r,'a) result_exn = ('r,'a) result
903     let unit a = (fun k -> k a)
904     let bind u f = (fun k -> (u) (fun a -> (f a) k))
905     let run u k = (u) k
906     let run_exn = run
907     let zero () = Util.undef
908     let plus u v = u
909   end
910   include Monad.Make(Base)
911   let callcc f = (fun k ->
912     let usek a = (fun _ -> k a)
913     in (f usek) k)
914   (*
915   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
916   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
917   let callcc f = fun k -> f k k
918   let throw k a = fun _ -> k a
919   *)
920
921   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
922    *
923    *  reset :: (Monad m) => ContT a m a -> ContT r m a
924    *  reset e = ContT $ \k -> runContT e return >>= k
925    *
926    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
927    *  shift e = ContT $ \k ->
928    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
929   let reset u = unit ((u) id)
930   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
931   (* let abort a = shift (fun _ -> a) *)
932   let abort a = shift (fun _ -> unit a)
933   let run0 (u : ('a,'a) m) = (u) id
934   module T(Wrapped : Monad.S) = struct
935     module BaseT = struct
936       module Wrapped = Wrapped
937       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
938       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
939       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
940       let elevate w = fun k -> Wrapped.bind w k
941       let bind u f = fun k -> u (fun a -> f a k)
942       let run u k = Wrapped.run (u k)
943       let run_exn u k = Wrapped.run_exn (u k)
944       let zero () = Util.undef
945       let plus u v = u
946     end
947     include Monad.MakeT(BaseT)
948     let callcc f = (fun k ->
949       let usek a = (fun _ -> k a)
950       in (f usek) k)
951   end
952 end
953
954
955 (*
956  * Scheme:
957  * (define (example n)
958  *    (let ([u (let/cc k ; type int -> int pair
959  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
960  *                 (+ 1 (car v))))]) ; int
961  *      (cons u 0))) ; int pair
962  * ; (example 10) ~~> '(111 . 0)
963  * ; (example -10) ~~> '(0 . 0)
964  *
965  * OCaml monads:
966  * let example n : (int * int) =
967  *   Continuation_monad.(let u = callcc (fun k ->
968  *       (if n < 0 then k 0 else unit [n + 100])
969  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
970  *       >>= fun [x] -> unit (x + 1)
971  *   )
972  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
973  *   >>= fun x -> unit (x, 0)
974  *   in run u)
975  *
976  *
977  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
978  * let example1 () : int =
979  *   Continuation_monad.(let v = reset (
980  *       let u = shift (fun k -> unit (10 + 1))
981  *       in u >>= fun x -> unit (100 + x)
982  *     ) in let w = v >>= fun x -> unit (1000 + x)
983  *     in run w)
984  *
985  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
986  * let example2 () =
987  *   Continuation_monad.(let v = reset (
988  *       let u = shift (fun k -> k (10 :: [1]))
989  *       in u >>= fun x -> unit (100 :: x)
990  *     ) in let w = v >>= fun x -> unit (1000 :: x)
991  *     in run w)
992  *
993  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
994  * let example3 () =
995  *   Continuation_monad.(let v = reset (
996  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
997  *       in u >>= fun x -> unit (100 :: x)
998  *     ) in let w = v >>= fun x -> unit (1000 :: x)
999  *     in run w)
1000  *
1001  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1002  * (* not sure if this example can be typed without a sum-type *)
1003  *
1004  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1005  * let example5 () : int =
1006  *   Continuation_monad.(let v = reset (
1007  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1008  *       in u >>= fun x -> unit (10 + x)
1009  *     ) in let w = v >>= fun x -> unit (100 + x)
1010  *     in run w)
1011  *
1012  *)
1013
1014
1015 module Leaf_monad : sig
1016   (* We implement the type as `'a tree option` because it has a natural`plus`,
1017    * and the rest of the library expects that `plus` and `zero` will come together. *)
1018   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1019   type ('x,'a) result = 'a tree option
1020   type ('x,'a) result_exn = 'a tree
1021   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1022   (* LeafT transformer *)
1023   module T : functor (Wrapped : Monad.S) -> sig
1024     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1025     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1026     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1027     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1028     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1029     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1030     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1031   end
1032 end = struct
1033   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1034   (* uses supplied plus and zero to copy t to its image under f *)
1035   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1036       | None -> zero ()
1037       | Some ts -> let rec loop ts = (match ts with
1038                      | Leaf a -> f a
1039                      | Node (l, r) ->
1040                          (* recursive application of f may delete a branch *)
1041                          plus (loop l) (loop r)
1042                    ) in loop ts
1043   module Base = struct
1044     type ('x,'a) m = 'a tree option
1045     type ('x,'a) result = 'a tree option
1046     type ('x,'a) result_exn = 'a tree
1047     let unit a = Some (Leaf a)
1048     let zero () = None
1049     (* satisfies Distrib *)
1050     let plus u v = match (u, v) with
1051       | None, _ -> v
1052       | _, None -> u
1053       | Some us, Some vs -> Some (Node (us, vs))
1054     let bind u f = mapT f u zero plus
1055     let run u = u
1056     let run_exn u = match u with
1057       | None -> failwith "no values"
1058       (*
1059       | Some (Leaf a) -> a
1060       | many -> failwith "multiple values"
1061       *)
1062       | Some us -> us
1063   end
1064   include Monad.Make(Base)
1065   module T(Wrapped : Monad.S) = struct
1066     module BaseT = struct
1067       include Monad.MakeT(struct
1068         module Wrapped = Wrapped
1069         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1070         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1071         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1072         let zero () = Wrapped.unit None
1073         let plus u v =
1074           Wrapped.bind u (fun us ->
1075           Wrapped.bind v (fun vs ->
1076           Wrapped.unit (Base.plus us vs)))
1077         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1078         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1079         let run u = Wrapped.run u
1080         let run_exn u =
1081             let w = Wrapped.bind u (fun t -> match t with
1082               | None -> Wrapped.zero ()
1083               | Some ts -> Wrapped.unit ts
1084             ) in Wrapped.run_exn w
1085       end)
1086     end
1087     include BaseT
1088     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1089   end
1090 end
1091
1092
1093 module L = List_monad;;
1094 module R = Reader_monad(struct type env = int -> int end);;
1095 module S = State_monad(struct type store = int end);;
1096 module T = Leaf_monad;;
1097 module LR = L.T(R);;
1098 module LS = L.T(S);;
1099 module TL = T.T(L);;
1100 module TR = T.T(R);;
1101 module TS = T.T(S);;
1102 module C = Continuation_monad
1103 module TC = T.T(C);;
1104
1105
1106 print_endline "=== test Leaf(...).distribute ==================";;
1107
1108 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1109
1110 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1111 TS.run ts 0;;
1112 (*
1113 - : int T.tree option * S.store =
1114 (Some
1115   (T.Node
1116     (T.Node (T.Leaf 2, T.Leaf 3),
1117      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1118  5)
1119 *)
1120
1121 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1122 TS.run_exn ts2 0;;
1123 (*
1124 - : (int * S.store) T.tree option * S.store =
1125 (Some
1126   (T.Node
1127     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1128      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1129  5)
1130 *)
1131
1132 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1133 TR.run_exn tr (fun i -> i+i);;
1134 (*
1135 - : int T.tree option =
1136 Some
1137  (T.Node
1138    (T.Node (T.Leaf 4, T.Leaf 6),
1139     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1140 *)
1141
1142 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1143 TL.run_exn tl;;
1144 (*
1145 - : (int * int) TL.result =
1146 [Some
1147   (T.Node
1148     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1149      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1150 *)
1151
1152 let l2 = [1;2;3;4;5];;
1153 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1154
1155 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1156 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1157
1158 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1159 (*
1160 int T.tree option =
1161 Some
1162  (T.Node
1163    (T.Node (T.Leaf 10, T.Leaf 11),
1164     T.Node
1165      (T.Node
1166        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1167         T.Node (T.Leaf 40, T.Leaf 41)),
1168       T.Node (T.Leaf 50, T.Leaf 51))))
1169  *)
1170
1171 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1172 (*
1173 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1174 *)
1175
1176 print_endline "=== test Leaf(Continuation).distribute ==================";;
1177
1178 let id : 'z. 'z -> 'z = fun x -> x
1179
1180 let example n : (int * int) =
1181   Continuation_monad.(let u = callcc (fun k ->
1182       (if n < 0 then k 0 else unit [n + 100])
1183       (* all of the following is skipped by k 0; the end type int is k's input type *)
1184       >>= fun [x] -> unit (x + 1)
1185   )
1186   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1187   >>= fun x -> unit (x, 0)
1188   in run0 u)
1189
1190
1191 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1192 let example1 () : int =
1193   Continuation_monad.(let v = reset (
1194       let u = shift (fun k -> unit (10 + 1))
1195       in u >>= fun x -> unit (100 + x)
1196     ) in let w = v >>= fun x -> unit (1000 + x)
1197     in run0 w)
1198
1199 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1200 let example2 () =
1201   Continuation_monad.(let v = reset (
1202       let u = shift (fun k -> k (10 :: [1]))
1203       in u >>= fun x -> unit (100 :: x)
1204     ) in let w = v >>= fun x -> unit (1000 :: x)
1205     in run0 w)
1206
1207 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1208 let example3 () =
1209   Continuation_monad.(let v = reset (
1210       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1211       in u >>= fun x -> unit (100 :: x)
1212     ) in let w = v >>= fun x -> unit (1000 :: x)
1213     in run0 w)
1214
1215 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1216 (* not sure if this example can be typed without a sum-type *)
1217
1218 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1219 let example5 () : int =
1220   Continuation_monad.(let v = reset (
1221       let u = shift (fun k -> k 1 >>= k)
1222       in u >>= fun x -> unit (10 + x)
1223     ) in let w = v >>= fun x -> unit (100 + x)
1224     in run0 w)
1225
1226 ;;
1227
1228 print_endline "=== test bare Continuation ============";;
1229
1230 (1011, 1111, 1111, 121);;
1231 (example1(), example2(), example3(), example5());;
1232 ((111,0), (0,0));;
1233 (example ~+10, example ~-10);;
1234
1235 let testc df ic =
1236     C.run_exn TC.(run (distribute df t1)) ic;;
1237
1238
1239 (*
1240 (* do nothing *)
1241 let initial_continuation = fun t -> t in
1242 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1243 *)
1244 testc (C.unit) id;;
1245
1246 (*
1247 (* count leaves, using continuation *)
1248 let initial_continuation = fun t -> 0 in
1249 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1250 *)
1251
1252 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1253
1254 (*
1255 (* convert tree to list of leaves *)
1256 let initial_continuation = fun t -> [] in
1257 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1258 *)
1259
1260 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1261
1262 (*
1263 (* square each leaf using continuation *)
1264 let initial_continuation = fun t -> t in
1265 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1266 *)
1267
1268 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1269
1270
1271 (*
1272 (* replace leaves with list, using continuation *)
1273 let initial_continuation = fun t -> t in
1274 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1275 *)
1276
1277 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1278
1279 print_endline "=== pa_monad's Continuation Tests ============";;
1280
1281 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1282 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1283 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1284 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1285 (5, 27 = C.(run0 (
1286               let c = reset(abort 5 >>= fun y -> unit (y+6))
1287               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1288
1289 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1290
1291 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1292
1293 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1294
1295
1296 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1297