tweak monads-lib
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glasgow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources).
55  *)
56
57 exception Undefined
58
59 (* Some library functions used below. *)
60 module Util = struct
61   let fold_right = List.fold_right
62   let map = List.map
63   let append = List.append
64   let reverse = List.rev
65   let concat = List.concat
66   let concat_map f lst = List.concat (List.map f lst)
67   (* let zip = List.combine *)
68   let unzip = List.split
69   let zip_with = List.map2
70   let replicate len fill =
71     let rec loop n accu =
72       if n == 0 then accu else loop (pred n) (fill :: accu)
73     in loop len []
74   (* Dirty hack to be a default polymorphic zero.
75    * To implement this cleanly, monads without a natural zero
76    * should always wrap themselves in an option layer (see Leaf_monad). *)
77   let undef = Obj.magic (fun () -> raise Undefined)
78 end
79
80
81
82 (*
83  * This module contains factories that extend a base set of
84  * monadic definitions with a larger family of standard derived values.
85  *)
86
87 module Monad = struct
88   (*
89    * Signature extenders:
90    *   Make :: BASE -> S
91    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
92    *)
93
94
95   (* type of base definitions *)
96   module type BASE = sig
97     (* We make all monadic types doubly-parameterized so that they
98      * can layer nicely with Continuation, which needs the second
99      * type parameter. *)
100     type ('x,'a) m
101     type ('x,'a) result
102     type ('x,'a) result_exn
103     val unit : 'a -> ('x,'a) m
104     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
105     val run : ('x,'a) m -> ('x,'a) result
106     (* run_exn tries to provide a more ground-level result, but may fail *)
107     val run_exn : ('x,'a) m -> ('x,'a) result_exn
108     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
109      *     zero >>= f   ===  zero
110      *     plus zero u  ===  u
111      *     plus u zero  ===  u
112      * Additionally, they will obey one of the following laws:
113      *     (Catch)   plus (unit a) v  ===  unit a
114      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
115      * When no natural zero is available, use `let zero () = Util.undef`.
116      * The Make functor automatically detects for zero >>= ..., and 
117      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
118      *)
119     val zero : unit -> ('x,'a) m
120     (* zero has to be thunked to ensure results are always poly enough *)
121     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
122   end
123   module type S = sig
124     include BASE
125     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
126     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
127     val join : ('x,('x,'a) m) m -> ('x,'a) m
128     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
129     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
130     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
131     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
132     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
133     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
134     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
135     val sequence : ('x,'a) m list -> ('x,'a list) m
136     val sequence_ : ('x,'a) m list -> ('x,unit) m
137     val guard : bool -> ('x,unit) m
138     val sum : ('x,'a) m list -> ('x,'a) m
139   end
140
141   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
142     include B
143     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
144       if u == Util.undef then Util.undef
145       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
146     let plus u v =
147       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
148     let run u =
149       if u == Util.undef then raise Undefined else B.run u
150     let run_exn u =
151       if u == Util.undef then raise Undefined else B.run_exn u
152     let (>>=) = bind
153     (* expressions after >> will be evaluated before they're passed to
154      * bind, so you can't do `zero () >> assert false`
155      * this works though: `zero () >>= fun _ -> assert false`
156      *)
157     let (>>) u v = u >>= fun _ -> v
158     let lift f u = u >>= fun a -> unit (f a)
159     (* lift is called listM, fmap, and <$> in Haskell *)
160     let join uu = uu >>= fun u -> u
161     (* u >>= f === join (lift f u) *)
162     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
163     (* [f] <*> [x1,x2] = [f x1,f x2] *)
164     (* let apply u v = u >>= fun f -> lift f v *)
165     (* let apply = lift2 id *)
166     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
167     (* let lift f u === apply (unit f) u *)
168     (* let lift2 f u v = apply (lift f u) v *)
169     let (>=>) f g = fun a -> f a >>= g
170     let do_when test u = if test then u else unit ()
171     let do_unless test u = if test then unit () else u
172     (* A Haskell-like version works:
173          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
174      * but the recursive call is not in tail position so this can stack overflow. *)
175     let forever uthunk =
176         let z = zero () in
177         let id result = result in
178         let kcell = ref id in
179         let rec loop _ =
180             let result = uthunk (kcell := id) >>= chained
181             in !kcell result
182         and chained _ =
183             kcell := loop; z (* we use z only for its polymorphism *)
184         in loop z
185     (* Reimplementations of the preceding using a hand-rolled State or StateT
186 can also stack overflow. *)
187     let sequence ms =
188       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
189         Util.fold_right op ms (unit [])
190     let sequence_ ms =
191       Util.fold_right (>>) ms (unit ())
192
193     (* Haskell defines these other operations combining lists and monads.
194      * We don't, but notice that M.mapM == ListT(M).distribute
195      * There's also a parallel TreeT(M).distribute *)
196     (*
197     let mapM f alist = sequence (Util.map f alist)
198     let mapM_ f alist = sequence_ (Util.map f alist)
199     let rec filterM f lst = match lst with
200       | [] -> unit []
201       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
202     let forM alist f = mapM f alist
203     let forM_ alist f = mapM_ f alist
204     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
205     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
206     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
207     let rec foldM f z lst = match lst with
208       | [] -> unit z
209       | x::xs -> f z x >>= fun z' -> foldM f z' xs
210     let foldM_ f z xs = foldM f z xs >> unit ()
211     let replicateM n x = sequence (Util.replicate n x)
212     let replicateM_ n x = sequence_ (Util.replicate n x)
213     *)
214     let guard test = if test then B.unit () else zero ()
215     let sum ms = Util.fold_right plus ms (zero ())
216   end
217
218   (* Signatures for MonadT *)
219   module type BASET = sig
220     module Wrapped : S
221     type ('x,'a) m
222     type ('x,'a) result
223     type ('x,'a) result_exn
224     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
225     val run : ('x,'a) m -> ('x,'a) result
226     val run_exn : ('x,'a) m -> ('x,'a) result_exn
227     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
228     (* lift/elevate laws:
229      *     elevate (W.unit a) == unit a
230      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
231      *)
232     val zero : unit -> ('x,'a) m
233     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
234   end
235   module MakeT(T : BASET) = struct
236     include Make(struct
237         include T
238         let unit a = elevate (Wrapped.unit a)
239     end)
240     let elevate = T.elevate
241   end
242
243 end
244
245
246
247
248
249 module Identity_monad : sig
250   (* expose only the implementation of type `'a result` *)
251   type ('x,'a) result = 'a
252   type ('x,'a) result_exn = 'a
253   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
254 end = struct
255   module Base = struct
256     type ('x,'a) m = 'a
257     type ('x,'a) result = 'a
258     type ('x,'a) result_exn = 'a
259     let unit a = a
260     let bind a f = f a
261     let run a = a
262     let run_exn a = a
263     let zero () = Util.undef
264     let plus u v = u
265   end
266   include Monad.Make(Base)
267 end
268
269
270 module Maybe_monad : sig
271   (* expose only the implementation of type `'a result` *)
272   type ('x,'a) result = 'a option
273   type ('x,'a) result_exn = 'a
274   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
275   (* MaybeT transformer *)
276   module T : functor (Wrapped : Monad.S) -> sig
277     type ('x,'a) result = ('x,'a option) Wrapped.result
278     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
279     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
280     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
281   end
282 end = struct
283   module Base = struct
284     type ('x,'a) m = 'a option
285     type ('x,'a) result = 'a option
286     type ('x,'a) result_exn = 'a
287     let unit a = Some a
288     let bind u f = match u with Some a -> f a | None -> None
289     let run u = u
290     let run_exn u = match u with
291       | Some a -> a
292       | None -> failwith "no value"
293     let zero () = None
294     (* satisfies Catch *)
295     let plus u v = match u with None -> v | _ -> u
296   end
297   include Monad.Make(Base)
298   module T(Wrapped : Monad.S) = struct
299     module BaseT = struct
300       include Monad.MakeT(struct
301         module Wrapped = Wrapped
302         type ('x,'a) m = ('x,'a option) Wrapped.m
303         type ('x,'a) result = ('x,'a option) Wrapped.result
304         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
305         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
306         let bind u f = Wrapped.bind u (fun t -> match t with
307           | Some a -> f a
308           | None -> Wrapped.unit None)
309         let run u = Wrapped.run u
310         let run_exn u =
311           let w = Wrapped.bind u (fun t -> match t with
312             | Some a -> Wrapped.unit a
313             | None -> Wrapped.zero ()
314           ) in Wrapped.run_exn w
315         let zero () = Wrapped.unit None
316         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
317       end)
318     end
319     include BaseT
320   end
321 end
322
323
324 module List_monad : sig
325   (* declare additional operation, while still hiding implementation of type m *)
326   type ('x,'a) result = 'a list
327   type ('x,'a) result_exn = 'a
328   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
329   val permute : ('x,'a) m -> ('x,('x,'a) m) m
330   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
331   (* ListT transformer *)
332   module T : functor (Wrapped : Monad.S) -> sig
333     type ('x,'a) result = ('x,'a list) Wrapped.result
334     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
335     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
336     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
337     (* note that second argument is an 'a list, not the more abstract 'a m *)
338     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
339     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
340 (* TODO
341     val permute : 'a m -> 'a m m
342     val select : 'a m -> ('a * 'a m) m
343 *)
344   end
345 end = struct
346   module Base = struct
347    type ('x,'a) m = 'a list
348    type ('x,'a) result = 'a list
349    type ('x,'a) result_exn = 'a
350    let unit a = [a]
351    let bind u f = Util.concat_map f u
352    let run u = u
353    let run_exn u = match u with
354      | [] -> failwith "no values"
355      | [a] -> a
356      | many -> failwith "multiple values"
357    let zero () = []
358    (* satisfies Distrib *)
359    let plus = Util.append
360   end
361   include Monad.Make(Base)
362   (* let either u v = plus u v *)
363   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
364   let rec insert a u =
365     plus (unit (a :: u)) (match u with
366         | [] -> zero ()
367         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
368     )
369   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
370   let rec permute u = match u with
371       | [] -> unit []
372       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
373   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
374   let rec select u = match u with
375     | [] -> zero ()
376     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
377   module T(Wrapped : Monad.S) = struct
378     (* Wrapped.sequence ms  ===  
379          let plus1 u v =
380            Wrapped.bind u (fun x ->
381            Wrapped.bind v (fun xs ->
382            Wrapped.unit (x :: xs)))
383          in Util.fold_right plus1 ms (Wrapped.unit []) *)
384     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
385     let distribute f alist = Wrapped.sequence (Util.map f alist)
386
387     include Monad.MakeT(struct
388       module Wrapped = Wrapped
389       type ('x,'a) m = ('x,'a list) Wrapped.m
390       type ('x,'a) result = ('x,'a list) Wrapped.result
391       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
392       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
393       let bind u f =
394         Wrapped.bind u (fun ts ->
395         Wrapped.bind (distribute f ts) (fun tts ->
396         Wrapped.unit (Util.concat tts)))
397       let run u = Wrapped.run u
398       let run_exn u =
399         let w = Wrapped.bind u (fun ts -> match ts with
400           | [] -> Wrapped.zero ()
401           | [a] -> Wrapped.unit a
402           | many -> Wrapped.zero ()
403         ) in Wrapped.run_exn w
404       let zero () = Wrapped.unit []
405       let plus u v =
406         Wrapped.bind u (fun us ->
407         Wrapped.bind v (fun vs ->
408         Wrapped.unit (Base.plus us vs)))
409     end)
410 (*
411     let permute : 'a m -> 'a m m
412     let select : 'a m -> ('a * 'a m) m
413 *)
414   end
415 end
416
417 (*
418 # LL.(run(plus (unit 1) (unit 2) >>= fun i -> plus (unit i) (unit(10*i)) ));;
419 - : ('_a, int) LL.result = [[1; 10; 2; 20]]
420 # LL.(run(plus (unit 1) (unit 2) >>= fun i -> elevate L.(plus (unit i) (unit(10*i)) )));;
421 - : ('_a, int) LL.result = [[1; 2]; [1; 20]; [10; 2]; [10; 20]]
422 *)
423
424
425 (* must be parameterized on (struct type err = ... end) *)
426 module Error_monad(Err : sig
427   type err
428   exception Exc of err
429   (*
430   val zero : unit -> err
431   val plus : err -> err -> err
432   *)
433 end) : sig
434   (* declare additional operations, while still hiding implementation of type m *)
435   type err = Err.err
436   type 'a error = Error of err | Success of 'a
437   type ('x,'a) result = 'a error
438   type ('x,'a) result_exn = 'a
439   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
440   val throw : err -> ('x,'a) m
441   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
442   (* ErrorT transformer *)
443   module T : functor (Wrapped : Monad.S) -> sig
444     type ('x,'a) result = ('x,'a) Wrapped.result
445     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
446     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
447     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
448     val throw : err -> ('x,'a) m
449     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
450   end
451 end = struct
452   type err = Err.err
453   type 'a error = Error of err | Success of 'a
454   module Base = struct
455     type ('x,'a) m = 'a error
456     type ('x,'a) result = 'a error
457     type ('x,'a) result_exn = 'a
458     let unit a = Success a
459     let bind u f = match u with
460       | Success a -> f a
461       | Error e -> Error e (* input and output may be of different 'a types *)
462     let run u = u
463     let run_exn u = match u with
464       | Success a -> a
465       | Error e -> raise (Err.Exc e)
466     let zero () = Util.undef
467     (* satisfies Catch *)
468     let plus u v = match u with
469       | Success _ -> u
470       | Error _ -> if v == Util.undef then u else v
471   end
472   include Monad.Make(Base)
473   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
474   let throw e = Error e
475   let catch u handler = match u with
476     | Success _ -> u
477     | Error e -> handler e
478   module T(Wrapped : Monad.S) = struct
479     include Monad.MakeT(struct
480       module Wrapped = Wrapped
481       type ('x,'a) m = ('x,'a error) Wrapped.m
482       type ('x,'a) result = ('x,'a) Wrapped.result
483       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
484       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
485       let bind u f = Wrapped.bind u (fun t -> match t with
486         | Success a -> f a
487         | Error e -> Wrapped.unit (Error e))
488       let run u =
489         let w = Wrapped.bind u (fun t -> match t with
490           | Success a -> Wrapped.unit a
491           | Error e -> Wrapped.zero ()
492         ) in Wrapped.run w
493       let run_exn u =
494         let w = Wrapped.bind u (fun t -> match t with
495           | Success a -> Wrapped.unit a
496           | Error e -> raise (Err.Exc e))
497         in Wrapped.run_exn w
498       let plus u v = Wrapped.plus u v
499       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
500     end)
501     let throw e = Wrapped.unit (Error e)
502     let catch u handler = Wrapped.bind u (fun t -> match t with
503       | Success _ -> Wrapped.unit t
504       | Error e -> handler e)
505   end
506 end
507
508 (* pre-define common instance of Error_monad *)
509 module Failure = Error_monad(struct
510   type err = string
511   exception Exc = Failure
512   (*
513   let zero = ""
514   let plus s1 s2 = s1 ^ "\n" ^ s2
515   *)
516 end)
517
518 (*
519 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
520 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
521 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
522 - : int LE.result = Failure.Error "bye"
523 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
524 Exception: Failure "bye".
525 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
526 Exception: Failure "bye".
527
528 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
529 - : int Failure.error * S.store = (Failure.Error "bye", 1)
530 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
531 - : (int * S.store) Failure.result = Failure.Error "bye"
532 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
533 Exception: Failure "bye".
534 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
535 Exception: Failure "bye".
536  *)
537
538
539 (* must be parameterized on (struct type env = ... end) *)
540 module Reader_monad(Env : sig type env end) : sig
541   (* declare additional operations, while still hiding implementation of type m *)
542   type env = Env.env
543   type ('x,'a) result = env -> 'a
544   type ('x,'a) result_exn = env -> 'a
545   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
546   val ask : ('x,env) m
547   val asks : (env -> 'a) -> ('x,'a) m
548   (* lookup i == `fun e -> e i` would assume env is a functional type *)
549   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
550   (* ReaderT transformer *)
551   module T : functor (Wrapped : Monad.S) -> sig
552     type ('x,'a) result = env -> ('x,'a) Wrapped.result
553     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
554     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
555     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
556     val ask : ('x,env) m
557     val asks : (env -> 'a) -> ('x,'a) m
558     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
559   end
560 end = struct
561   type env = Env.env
562   module Base = struct
563     type ('x,'a) m = env -> 'a
564     type ('x,'a) result = env -> 'a
565     type ('x,'a) result_exn = env -> 'a
566     let unit a = fun e -> a
567     let bind u f = fun e -> let a = u e in let u' = f a in u' e
568     let run u = fun e -> u e
569     let run_exn = run
570     let zero () = Util.undef
571     let plus u v = u
572   end
573   include Monad.Make(Base)
574   let ask = fun e -> e
575   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
576   let local modifier u = fun e -> u (modifier e)
577   module T(Wrapped : Monad.S) = struct
578     module BaseT = struct
579       module Wrapped = Wrapped
580       type ('x,'a) m = env -> ('x,'a) Wrapped.m
581       type ('x,'a) result = env -> ('x,'a) Wrapped.result
582       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
583       let elevate w = fun e -> w
584       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
585       let run u = fun e -> Wrapped.run (u e)
586       let run_exn u = fun e -> Wrapped.run_exn (u e)
587       (* satisfies Distrib *)
588       let plus u v = fun e -> Wrapped.plus (u e) (v e)
589       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
590     end
591     include Monad.MakeT(BaseT)
592     let ask = Wrapped.unit
593     let local modifier u = fun e -> u (modifier e)
594     let asks selector = ask >>= (fun e ->
595       try unit (selector e)
596       with Not_found -> fun e -> Wrapped.zero ())
597   end
598 end
599
600
601 (* must be parameterized on (struct type store = ... end) *)
602 module State_monad(Store : sig type store end) : sig
603   (* declare additional operations, while still hiding implementation of type m *)
604   type store = Store.store
605   type ('x,'a) result =  store -> 'a * store
606   type ('x,'a) result_exn = store -> 'a
607   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
608   val get : ('x,store) m
609   val gets : (store -> 'a) -> ('x,'a) m
610   val put : store -> ('x,unit) m
611   val puts : (store -> store) -> ('x,unit) m
612   (* StateT transformer *)
613   module T : functor (Wrapped : Monad.S) -> sig
614     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
615     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
616     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
617     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
618     val get : ('x,store) m
619     val gets : (store -> 'a) -> ('x,'a) m
620     val put : store -> ('x,unit) m
621     val puts : (store -> store) -> ('x,unit) m
622   end
623 end = struct
624   type store = Store.store
625   module Base = struct
626     type ('x,'a) m =  store -> 'a * store
627     type ('x,'a) result =  store -> 'a * store
628     type ('x,'a) result_exn = store -> 'a
629     let unit a = fun s -> (a, s)
630     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
631     let run u = fun s -> (u s)
632     let run_exn u = fun s -> fst (u s)
633     let zero () = Util.undef
634     let plus u v = u
635   end
636   include Monad.Make(Base)
637   let get = fun s -> (s, s)
638   let gets viewer = fun s -> (viewer s, s) (* may fail *)
639   let put s = fun _ -> ((), s)
640   let puts modifier = fun s -> ((), modifier s)
641   module T(Wrapped : Monad.S) = struct
642     module BaseT = struct
643       module Wrapped = Wrapped
644       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
645       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
646       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
647       let elevate w = fun s ->
648         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
649       let bind u f = fun s ->
650         Wrapped.bind (u s) (fun (a, s') -> f a s')
651       let run u = fun s -> Wrapped.run (u s)
652       let run_exn u = fun s ->
653         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
654         in Wrapped.run_exn w
655       (* satisfies Distrib *)
656       let plus u v = fun s -> Wrapped.plus (u s) (v s)
657       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
658     end
659     include Monad.MakeT(BaseT)
660     let get = fun s -> Wrapped.unit (s, s)
661     let gets viewer = fun s ->
662       try Wrapped.unit (viewer s, s)
663       with Not_found -> Wrapped.zero ()
664     let put s = fun _ -> Wrapped.unit ((), s)
665     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
666   end
667 end
668
669 (* State monad with different interface (structured store) *)
670 module Ref_monad(V : sig
671   type value
672 end) : sig
673   type ref
674   type value = V.value
675   type ('x,'a) result = 'a
676   type ('x,'a) result_exn = 'a
677   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
678   val newref : value -> ('x,ref) m
679   val deref : ref -> ('x,value) m
680   val change : ref -> value -> ('x,unit) m
681   (* RefT transformer *)
682   module T : functor (Wrapped : Monad.S) -> sig
683     type ('x,'a) result = ('x,'a) Wrapped.result
684     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
685     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
686     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
687     val newref : value -> ('x,ref) m
688     val deref : ref -> ('x,value) m
689     val change : ref -> value -> ('x,unit) m
690   end
691 end = struct
692   type ref = int
693   type value = V.value
694   module D = Map.Make(struct type t = ref let compare = compare end)
695   type dict = { next: ref; tree : value D.t }
696   let empty = { next = 0; tree = D.empty }
697   let alloc (value : value) (d : dict) =
698     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
699   let read (key : ref) (d : dict) =
700     D.find key d.tree
701   let write (key : ref) (value : value) (d : dict) =
702     { next = d.next; tree = D.add key value d.tree }
703   module Base = struct
704     type ('x,'a) m = dict -> 'a * dict
705     type ('x,'a) result = 'a
706     type ('x,'a) result_exn = 'a
707     let unit a = fun s -> (a, s)
708     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
709     let run u = fst (u empty)
710     let run_exn = run
711     let zero () = Util.undef
712     let plus u v = u
713   end
714   include Monad.Make(Base)
715   let newref value = fun s -> alloc value s
716   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
717   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
718   module T(Wrapped : Monad.S) = struct
719     module BaseT = struct
720       module Wrapped = Wrapped
721       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
722       type ('x,'a) result = ('x,'a) Wrapped.result
723       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
724       let elevate w = fun s ->
725         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
726       let bind u f = fun s ->
727         Wrapped.bind (u s) (fun (a, s') -> f a s')
728       let run u =
729         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
730         in Wrapped.run w
731       let run_exn u =
732         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
733         in Wrapped.run_exn w
734       (* satisfies Distrib *)
735       let plus u v = fun s -> Wrapped.plus (u s) (v s)
736       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
737     end
738     include Monad.MakeT(BaseT)
739     let newref value = fun s -> Wrapped.unit (alloc value s)
740     let deref key = fun s -> Wrapped.unit (read key s, s)
741     let change key value = fun s -> Wrapped.unit ((), write key value s)
742   end
743 end
744
745
746 (* must be parameterized on (struct type log = ... end) *)
747 module Writer_monad(Log : sig
748   type log
749   val zero : log
750   val plus : log -> log -> log
751 end) : sig
752   (* declare additional operations, while still hiding implementation of type m *)
753   type log = Log.log
754   type ('x,'a) result = 'a * log
755   type ('x,'a) result_exn = 'a * log
756   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
757   val tell : log -> ('x,unit) m
758   val listen : ('x,'a) m -> ('x,'a * log) m
759   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
760   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
761   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
762   (* WriterT transformer *)
763   module T : functor (Wrapped : Monad.S) -> sig
764     type ('x,'a) result = ('x,'a * log) Wrapped.result
765     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
766     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
767     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
768     val tell : log -> ('x,unit) m
769     val listen : ('x,'a) m -> ('x,'a * log) m
770     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
771     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
772   end
773 end = struct
774   type log = Log.log
775   module Base = struct
776     type ('x,'a) m = 'a * log
777     type ('x,'a) result = 'a * log
778     type ('x,'a) result_exn = 'a * log
779     let unit a = (a, Log.zero)
780     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
781     let run u = u
782     let run_exn = run
783     let zero () = Util.undef
784     let plus u v = u
785   end
786   include Monad.Make(Base)
787   let tell entries = ((), entries) (* add entries to log *)
788   let listen (a, w) = ((a, w), w)
789   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
790   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
791   let censor f u = pass (u >>= fun a -> unit (a, f))
792   module T(Wrapped : Monad.S) = struct
793     module BaseT = struct
794       module Wrapped = Wrapped
795       type ('x,'a) m = ('x,'a * log) Wrapped.m
796       type ('x,'a) result = ('x,'a * log) Wrapped.result
797       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
798       let elevate w =
799         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
800       let bind u f =
801         Wrapped.bind u (fun (a, w) ->
802         Wrapped.bind (f a) (fun (b, w') ->
803         Wrapped.unit (b, Log.plus w w')))
804       let zero () = elevate (Wrapped.zero ())
805       let plus u v = Wrapped.plus u v
806       let run u = Wrapped.run u
807       let run_exn u = Wrapped.run_exn u
808     end
809     include Monad.MakeT(BaseT)
810     let tell entries = Wrapped.unit ((), entries)
811     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
812     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
813     (* rest are derived in same way as before *)
814     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
815     let censor f u = pass (u >>= fun a -> unit (a, f))
816   end
817 end
818
819 (* pre-define simple Writer *)
820 module Writer1 = Writer_monad(struct
821   type log = string
822   let zero = ""
823   let plus s1 s2 = s1 ^ "\n" ^ s2
824 end)
825
826 (* slightly more efficient Writer *)
827 module Writer2 = struct
828   include Writer_monad(struct
829     type log = string list
830     let zero = []
831     let plus w w' = Util.append w' w
832   end)
833   let tell_string s = tell [s]
834   let tell entries = tell (Util.reverse entries)
835   let run u = let (a, w) = run u in (a, Util.reverse w)
836   let run_exn = run
837 end
838
839
840 (* TODO needs a T *)
841 module IO_monad : sig
842   (* declare additional operation, while still hiding implementation of type m *)
843   type ('x,'a) result = 'a
844   type ('x,'a) result_exn = 'a
845   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
846   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
847   val print_string : string -> ('x,unit) m
848   val print_int : int -> ('x,unit) m
849   val print_hex : int -> ('x,unit) m
850   val print_bool : bool -> ('x,unit) m
851 end = struct
852   module Base = struct
853     type ('x,'a) m = { run : unit -> unit; value : 'a }
854     type ('x,'a) result = 'a
855     type ('x,'a) result_exn = 'a
856     let unit a = { run = (fun () -> ()); value = a }
857     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
858      let fres = f a.value in
859        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
860     let run a = let () = a.run () in a.value
861     let run_exn = run
862     let zero () = Util.undef
863     let plus u v = u
864   end
865   include Monad.Make(Base)
866   let printf fmt =
867     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
868   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
869   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
870   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
871   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
872 end
873
874
875 module Continuation_monad : sig
876   (* expose only the implementation of type `('r,'a) result` *)
877   type ('r,'a) m
878   type ('r,'a) result = ('r,'a) m
879   type ('r,'a) result_exn = ('a -> 'r) -> 'r
880   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
881   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
882   val reset : ('a,'a) m -> ('r,'a) m
883   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
884   (* val abort : ('a,'a) m -> ('a,'b) m *)
885   val abort : 'a -> ('a,'b) m
886   val run0 : ('a,'a) m -> 'a
887   (* ContinuationT transformer *)
888   module T : functor (Wrapped : Monad.S) -> sig
889     type ('r,'a) m
890     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
891     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
892     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
893     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
894     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
895     (* TODO: reset,shift,abort,run0 *)
896   end
897 end = struct
898   let id = fun i -> i
899   module Base = struct
900     (* 'r is result type of whole computation *)
901     type ('r,'a) m = ('a -> 'r) -> 'r
902     type ('r,'a) result = ('a -> 'r) -> 'r
903     type ('r,'a) result_exn = ('r,'a) result
904     let unit a = (fun k -> k a)
905     let bind u f = (fun k -> (u) (fun a -> (f a) k))
906     let run u k = (u) k
907     let run_exn = run
908     let zero () = Util.undef
909     let plus u v = u
910   end
911   include Monad.Make(Base)
912   let callcc f = (fun k ->
913     let usek a = (fun _ -> k a)
914     in (f usek) k)
915   (*
916   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
917   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
918   let callcc f = fun k -> f k k
919   let throw k a = fun _ -> k a
920   *)
921
922   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
923    *
924    *  reset :: (Monad m) => ContT a m a -> ContT r m a
925    *  reset e = ContT $ \k -> runContT e return >>= k
926    *
927    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
928    *  shift e = ContT $ \k ->
929    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
930   let reset u = unit ((u) id)
931   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
932   (* let abort a = shift (fun _ -> a) *)
933   let abort a = shift (fun _ -> unit a)
934   let run0 (u : ('a,'a) m) = (u) id
935   module T(Wrapped : Monad.S) = struct
936     module BaseT = struct
937       module Wrapped = Wrapped
938       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
939       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
940       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
941       let elevate w = fun k -> Wrapped.bind w k
942       let bind u f = fun k -> u (fun a -> f a k)
943       let run u k = Wrapped.run (u k)
944       let run_exn u k = Wrapped.run_exn (u k)
945       let zero () = Util.undef
946       let plus u v = u
947     end
948     include Monad.MakeT(BaseT)
949     let callcc f = (fun k ->
950       let usek a = (fun _ -> k a)
951       in (f usek) k)
952   end
953 end
954
955
956 (*
957  * Scheme:
958  * (define (example n)
959  *    (let ([u (let/cc k ; type int -> int pair
960  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
961  *                 (+ 1 (car v))))]) ; int
962  *      (cons u 0))) ; int pair
963  * ; (example 10) ~~> '(111 . 0)
964  * ; (example -10) ~~> '(0 . 0)
965  *
966  * OCaml monads:
967  * let example n : (int * int) =
968  *   Continuation_monad.(let u = callcc (fun k ->
969  *       (if n < 0 then k 0 else unit [n + 100])
970  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
971  *       >>= fun [x] -> unit (x + 1)
972  *   )
973  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
974  *   >>= fun x -> unit (x, 0)
975  *   in run u)
976  *
977  *
978  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
979  * let example1 () : int =
980  *   Continuation_monad.(let v = reset (
981  *       let u = shift (fun k -> unit (10 + 1))
982  *       in u >>= fun x -> unit (100 + x)
983  *     ) in let w = v >>= fun x -> unit (1000 + x)
984  *     in run w)
985  *
986  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
987  * let example2 () =
988  *   Continuation_monad.(let v = reset (
989  *       let u = shift (fun k -> k (10 :: [1]))
990  *       in u >>= fun x -> unit (100 :: x)
991  *     ) in let w = v >>= fun x -> unit (1000 :: x)
992  *     in run w)
993  *
994  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
995  * let example3 () =
996  *   Continuation_monad.(let v = reset (
997  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
998  *       in u >>= fun x -> unit (100 :: x)
999  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1000  *     in run w)
1001  *
1002  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1003  * (* not sure if this example can be typed without a sum-type *)
1004  *
1005  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1006  * let example5 () : int =
1007  *   Continuation_monad.(let v = reset (
1008  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1009  *       in u >>= fun x -> unit (10 + x)
1010  *     ) in let w = v >>= fun x -> unit (100 + x)
1011  *     in run w)
1012  *
1013  *)
1014
1015
1016 module Leaf_monad : sig
1017   (* We implement the type as `'a tree option` because it has a natural`plus`,
1018    * and the rest of the library expects that `plus` and `zero` will come together. *)
1019   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1020   type ('x,'a) result = 'a tree option
1021   type ('x,'a) result_exn = 'a tree
1022   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1023   (* LeafT transformer *)
1024   module T : functor (Wrapped : Monad.S) -> sig
1025     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1026     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1027     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1028     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1029     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1030     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1031     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1032   end
1033 end = struct
1034   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1035   (* uses supplied plus and zero to copy t to its image under f *)
1036   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1037       | None -> zero ()
1038       | Some ts -> let rec loop ts = (match ts with
1039                      | Leaf a -> f a
1040                      | Node (l, r) ->
1041                          (* recursive application of f may delete a branch *)
1042                          plus (loop l) (loop r)
1043                    ) in loop ts
1044   module Base = struct
1045     type ('x,'a) m = 'a tree option
1046     type ('x,'a) result = 'a tree option
1047     type ('x,'a) result_exn = 'a tree
1048     let unit a = Some (Leaf a)
1049     let zero () = None
1050     (* satisfies Distrib *)
1051     let plus u v = match (u, v) with
1052       | None, _ -> v
1053       | _, None -> u
1054       | Some us, Some vs -> Some (Node (us, vs))
1055     let bind u f = mapT f u zero plus
1056     let run u = u
1057     let run_exn u = match u with
1058       | None -> failwith "no values"
1059       (*
1060       | Some (Leaf a) -> a
1061       | many -> failwith "multiple values"
1062       *)
1063       | Some us -> us
1064   end
1065   include Monad.Make(Base)
1066   module T(Wrapped : Monad.S) = struct
1067     module BaseT = struct
1068       include Monad.MakeT(struct
1069         module Wrapped = Wrapped
1070         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1071         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1072         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1073         let zero () = Wrapped.unit None
1074         let plus u v =
1075           Wrapped.bind u (fun us ->
1076           Wrapped.bind v (fun vs ->
1077           Wrapped.unit (Base.plus us vs)))
1078         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1079         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1080         let run u = Wrapped.run u
1081         let run_exn u =
1082             let w = Wrapped.bind u (fun t -> match t with
1083               | None -> Wrapped.zero ()
1084               | Some ts -> Wrapped.unit ts
1085             ) in Wrapped.run_exn w
1086       end)
1087     end
1088     include BaseT
1089     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1090   end
1091 end
1092
1093
1094 module L = List_monad;;
1095 module R = Reader_monad(struct type env = int -> int end);;
1096 module S = State_monad(struct type store = int end);;
1097 module T = Leaf_monad;;
1098 module LR = L.T(R);;
1099 module LS = L.T(S);;
1100 module TL = T.T(L);;
1101 module TR = T.T(R);;
1102 module TS = T.T(S);;
1103 module C = Continuation_monad
1104 module TC = T.T(C);;
1105
1106
1107 print_endline "=== test Leaf(...).distribute ==================";;
1108
1109 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1110
1111 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1112 TS.run ts 0;;
1113 (*
1114 - : int T.tree option * S.store =
1115 (Some
1116   (T.Node
1117     (T.Node (T.Leaf 2, T.Leaf 3),
1118      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1119  5)
1120 *)
1121
1122 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1123 TS.run_exn ts2 0;;
1124 (*
1125 - : (int * S.store) T.tree option * S.store =
1126 (Some
1127   (T.Node
1128     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1129      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1130  5)
1131 *)
1132
1133 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1134 TR.run_exn tr (fun i -> i+i);;
1135 (*
1136 - : int T.tree option =
1137 Some
1138  (T.Node
1139    (T.Node (T.Leaf 4, T.Leaf 6),
1140     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1141 *)
1142
1143 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1144 TL.run_exn tl;;
1145 (*
1146 - : (int * int) TL.result =
1147 [Some
1148   (T.Node
1149     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1150      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1151 *)
1152
1153 let l2 = [1;2;3;4;5];;
1154 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1155
1156 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1157 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1158
1159 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1160 (*
1161 int T.tree option =
1162 Some
1163  (T.Node
1164    (T.Node (T.Leaf 10, T.Leaf 11),
1165     T.Node
1166      (T.Node
1167        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1168         T.Node (T.Leaf 40, T.Leaf 41)),
1169       T.Node (T.Leaf 50, T.Leaf 51))))
1170  *)
1171
1172 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1173 (*
1174 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1175 *)
1176
1177 print_endline "=== test Leaf(Continuation).distribute ==================";;
1178
1179 let id : 'z. 'z -> 'z = fun x -> x
1180
1181 let example n : (int * int) =
1182   Continuation_monad.(let u = callcc (fun k ->
1183       (if n < 0 then k 0 else unit [n + 100])
1184       (* all of the following is skipped by k 0; the end type int is k's input type *)
1185       >>= fun [x] -> unit (x + 1)
1186   )
1187   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1188   >>= fun x -> unit (x, 0)
1189   in run0 u)
1190
1191
1192 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1193 let example1 () : int =
1194   Continuation_monad.(let v = reset (
1195       let u = shift (fun k -> unit (10 + 1))
1196       in u >>= fun x -> unit (100 + x)
1197     ) in let w = v >>= fun x -> unit (1000 + x)
1198     in run0 w)
1199
1200 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1201 let example2 () =
1202   Continuation_monad.(let v = reset (
1203       let u = shift (fun k -> k (10 :: [1]))
1204       in u >>= fun x -> unit (100 :: x)
1205     ) in let w = v >>= fun x -> unit (1000 :: x)
1206     in run0 w)
1207
1208 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1209 let example3 () =
1210   Continuation_monad.(let v = reset (
1211       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1212       in u >>= fun x -> unit (100 :: x)
1213     ) in let w = v >>= fun x -> unit (1000 :: x)
1214     in run0 w)
1215
1216 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1217 (* not sure if this example can be typed without a sum-type *)
1218
1219 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1220 let example5 () : int =
1221   Continuation_monad.(let v = reset (
1222       let u = shift (fun k -> k 1 >>= k)
1223       in u >>= fun x -> unit (10 + x)
1224     ) in let w = v >>= fun x -> unit (100 + x)
1225     in run0 w)
1226
1227 ;;
1228
1229 print_endline "=== test bare Continuation ============";;
1230
1231 (1011, 1111, 1111, 121);;
1232 (example1(), example2(), example3(), example5());;
1233 ((111,0), (0,0));;
1234 (example ~+10, example ~-10);;
1235
1236 let testc df ic =
1237     C.run_exn TC.(run (distribute df t1)) ic;;
1238
1239
1240 (*
1241 (* do nothing *)
1242 let initial_continuation = fun t -> t in
1243 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1244 *)
1245 testc (C.unit) id;;
1246
1247 (*
1248 (* count leaves, using continuation *)
1249 let initial_continuation = fun t -> 0 in
1250 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1251 *)
1252
1253 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1254
1255 (*
1256 (* convert tree to list of leaves *)
1257 let initial_continuation = fun t -> [] in
1258 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1259 *)
1260
1261 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1262
1263 (*
1264 (* square each leaf using continuation *)
1265 let initial_continuation = fun t -> t in
1266 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1267 *)
1268
1269 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1270
1271
1272 (*
1273 (* replace leaves with list, using continuation *)
1274 let initial_continuation = fun t -> t in
1275 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1276 *)
1277
1278 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1279
1280 print_endline "=== pa_monad's Continuation Tests ============";;
1281
1282 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1283 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1284 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1285 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1286 (5, 27 = C.(run0 (
1287               let c = reset(abort 5 >>= fun y -> unit (y+6))
1288               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1289
1290 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1291
1292 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1293
1294 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1295
1296
1297 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1298