update tree_monadize.ml
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glasgow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources this is partly
55  * derived from)
56  *)
57
58 exception Undefined
59
60 (* Some library functions used below. *)
61 module Util = struct
62   let fold_right = List.fold_right
63   let map = List.map
64   let append = List.append
65   let reverse = List.rev
66   let concat = List.concat
67   let concat_map f lst = List.concat (List.map f lst)
68   (* let zip = List.combine *)
69   let unzip = List.split
70   let zip_with = List.map2
71   let replicate len fill =
72     let rec loop n accu =
73       if n == 0 then accu else loop (pred n) (fill :: accu)
74     in loop len []
75   (* Dirty hack to be a default polymorphic zero.
76    * To implement this cleanly, monads without a natural zero
77    * should always wrap themselves in an option layer (see Tree_monad). *)
78   let undef = Obj.magic (fun () -> raise Undefined)
79 end
80
81
82
83 (*
84  * This module contains factories that extend a base set of
85  * monadic definitions with a larger family of standard derived values.
86  *)
87
88 module Monad = struct
89   (*
90    * Signature extenders:
91    *   Make :: BASE -> S
92    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
93    *)
94
95
96   (* type of base definitions *)
97   module type BASE = sig
98     (* We make all monadic types doubly-parameterized so that they
99      * can layer nicely with Continuation, which needs the second
100      * type parameter. *)
101     type ('x,'a) m
102     type ('x,'a) result
103     type ('x,'a) result_exn
104     val unit : 'a -> ('x,'a) m
105     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
106     val run : ('x,'a) m -> ('x,'a) result
107     (* run_exn tries to provide a more ground-level result, but may fail *)
108     val run_exn : ('x,'a) m -> ('x,'a) result_exn
109     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
110      *     zero >>= f   ===  zero
111      *     plus zero u  ===  u
112      *     plus u zero  ===  u
113      * Additionally, they will obey one of the following laws:
114      *     (Catch)   plus (unit a) v  ===  unit a
115      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
116      * When no natural zero is available, use `let zero () = Util.undef`.
117      * The Make functor automatically detects for zero >>= ..., and 
118      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
119      *)
120     val zero : unit -> ('x,'a) m
121     (* zero has to be thunked to ensure results are always poly enough *)
122     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
123   end
124   module type S = sig
125     include BASE
126     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
127     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
128     val join : ('x,('x,'a) m) m -> ('x,'a) m
129     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
130     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
131     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
132     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
133     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
134     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
135     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
136     val sequence : ('x,'a) m list -> ('x,'a list) m
137     val sequence_ : ('x,'a) m list -> ('x,unit) m
138     val guard : bool -> ('x,unit) m
139     val sum : ('x,'a) m list -> ('x,'a) m
140   end
141
142   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
143     include B
144     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
145       if u == Util.undef then Util.undef
146       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
147     let plus u v =
148       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
149     let run u =
150       if u == Util.undef then raise Undefined else B.run u
151     let run_exn u =
152       if u == Util.undef then raise Undefined else B.run_exn u
153     let (>>=) = bind
154     (* expressions after >> will be evaluated before they're passed to
155      * bind, so you can't do `zero () >> assert false`
156      * this works though: `zero () >>= fun _ -> assert false`
157      *)
158     let (>>) u v = u >>= fun _ -> v
159     let lift f u = u >>= fun a -> unit (f a)
160     (* lift is called listM, fmap, and <$> in Haskell *)
161     let join uu = uu >>= fun u -> u
162     (* u >>= f === join (lift f u) *)
163     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
164     (* [f] <*> [x1,x2] = [f x1,f x2] *)
165     (* let apply u v = u >>= fun f -> lift f v *)
166     (* let apply = lift2 id *)
167     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
168     (* let lift f u === apply (unit f) u *)
169     (* let lift2 f u v = apply (lift f u) v *)
170     let (>=>) f g = fun a -> f a >>= g
171     let do_when test u = if test then u else unit ()
172     let do_unless test u = if test then unit () else u
173     (* A Haskell-like version works:
174          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
175      * but the recursive call is not in tail position so this can stack overflow. *)
176     let forever uthunk =
177         let z = zero () in
178         let id result = result in
179         let kcell = ref id in
180         let rec loop _ =
181             let result = uthunk (kcell := id) >>= chained
182             in !kcell result
183         and chained _ =
184             kcell := loop; z (* we use z only for its polymorphism *)
185         in loop z
186     (* Reimplementations of the preceding using a hand-rolled State or StateT
187 can also stack overflow. *)
188     let sequence ms =
189       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
190         Util.fold_right op ms (unit [])
191     let sequence_ ms =
192       Util.fold_right (>>) ms (unit ())
193
194     (* Haskell defines these other operations combining lists and monads.
195      * We don't, but notice that M.mapM == ListT(M).distribute
196      * There's also a parallel TreeT(M).distribute *)
197     (*
198     let mapM f alist = sequence (Util.map f alist)
199     let mapM_ f alist = sequence_ (Util.map f alist)
200     let rec filterM f lst = match lst with
201       | [] -> unit []
202       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
203     let forM alist f = mapM f alist
204     let forM_ alist f = mapM_ f alist
205     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
206     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
207     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
208     let rec foldM f z lst = match lst with
209       | [] -> unit z
210       | x::xs -> f z x >>= fun z' -> foldM f z' xs
211     let foldM_ f z xs = foldM f z xs >> unit ()
212     let replicateM n x = sequence (Util.replicate n x)
213     let replicateM_ n x = sequence_ (Util.replicate n x)
214     *)
215     let guard test = if test then B.unit () else zero ()
216     let sum ms = Util.fold_right plus ms (zero ())
217   end
218
219   (* Signatures for MonadT *)
220   module type BASET = sig
221     module Wrapped : S
222     type ('x,'a) m
223     type ('x,'a) result
224     type ('x,'a) result_exn
225     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
226     val run : ('x,'a) m -> ('x,'a) result
227     val run_exn : ('x,'a) m -> ('x,'a) result_exn
228     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
229     (* lift/elevate laws:
230      *     elevate (W.unit a) == unit a
231      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
232      *)
233     val zero : unit -> ('x,'a) m
234     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
235   end
236   module MakeT(T : BASET) = struct
237     include Make(struct
238         include T
239         let unit a = elevate (Wrapped.unit a)
240     end)
241     let elevate = T.elevate
242   end
243
244 end
245
246
247
248
249
250 module Identity_monad : sig
251   (* expose only the implementation of type `'a result` *)
252   type ('x,'a) result = 'a
253   type ('x,'a) result_exn = 'a
254   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
255 end = struct
256   module Base = struct
257     type ('x,'a) m = 'a
258     type ('x,'a) result = 'a
259     type ('x,'a) result_exn = 'a
260     let unit a = a
261     let bind a f = f a
262     let run a = a
263     let run_exn a = a
264     let zero () = Util.undef
265     let plus u v = u
266   end
267   include Monad.Make(Base)
268 end
269
270
271 module Maybe_monad : sig
272   (* expose only the implementation of type `'a result` *)
273   type ('x,'a) result = 'a option
274   type ('x,'a) result_exn = 'a
275   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
276   (* MaybeT transformer *)
277   module T : functor (Wrapped : Monad.S) -> sig
278     type ('x,'a) result = ('x,'a option) Wrapped.result
279     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
280     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
281     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
282   end
283 end = struct
284   module Base = struct
285     type ('x,'a) m = 'a option
286     type ('x,'a) result = 'a option
287     type ('x,'a) result_exn = 'a
288     let unit a = Some a
289     let bind u f = match u with Some a -> f a | None -> None
290     let run u = u
291     let run_exn u = match u with
292       | Some a -> a
293       | None -> failwith "no value"
294     let zero () = None
295     (* satisfies Catch *)
296     let plus u v = match u with None -> v | _ -> u
297   end
298   include Monad.Make(Base)
299   module T(Wrapped : Monad.S) = struct
300     module BaseT = struct
301       include Monad.MakeT(struct
302         module Wrapped = Wrapped
303         type ('x,'a) m = ('x,'a option) Wrapped.m
304         type ('x,'a) result = ('x,'a option) Wrapped.result
305         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
306         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
307         let bind u f = Wrapped.bind u (fun t -> match t with
308           | Some a -> f a
309           | None -> Wrapped.unit None)
310         let run u = Wrapped.run u
311         let run_exn u =
312           let w = Wrapped.bind u (fun t -> match t with
313             | Some a -> Wrapped.unit a
314             | None -> Wrapped.zero ()
315           ) in Wrapped.run_exn w
316         let zero () = Wrapped.unit None
317         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
318       end)
319     end
320     include BaseT
321   end
322 end
323
324
325 module List_monad : sig
326   (* declare additional operation, while still hiding implementation of type m *)
327   type ('x,'a) result = 'a list
328   type ('x,'a) result_exn = 'a
329   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
330   val permute : ('x,'a) m -> ('x,('x,'a) m) m
331   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
332   (* ListT transformer *)
333   module T : functor (Wrapped : Monad.S) -> sig
334     type ('x,'a) result = ('x,'a list) Wrapped.result
335     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
336     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
337     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
338     (* note that second argument is an 'a list, not the more abstract 'a m *)
339     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
340     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
341 (* TODO
342     val permute : 'a m -> 'a m m
343     val select : 'a m -> ('a * 'a m) m
344 *)
345   end
346 end = struct
347   module Base = struct
348    type ('x,'a) m = 'a list
349    type ('x,'a) result = 'a list
350    type ('x,'a) result_exn = 'a
351    let unit a = [a]
352    let bind u f = Util.concat_map f u
353    let run u = u
354    let run_exn u = match u with
355      | [] -> failwith "no values"
356      | [a] -> a
357      | many -> failwith "multiple values"
358    let zero () = []
359    (* satisfies Distrib *)
360    let plus = Util.append
361   end
362   include Monad.Make(Base)
363   (* let either u v = plus u v *)
364   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
365   let rec insert a u =
366     plus (unit (a :: u)) (match u with
367         | [] -> zero ()
368         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
369     )
370   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
371   let rec permute u = match u with
372       | [] -> unit []
373       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
374   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
375   let rec select u = match u with
376     | [] -> zero ()
377     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
378   module T(Wrapped : Monad.S) = struct
379     (* Wrapped.sequence ms  ===  
380          let plus1 u v =
381            Wrapped.bind u (fun x ->
382            Wrapped.bind v (fun xs ->
383            Wrapped.unit (x :: xs)))
384          in Util.fold_right plus1 ms (Wrapped.unit []) *)
385     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
386     let distribute f alist = Wrapped.sequence (Util.map f alist)
387
388     include Monad.MakeT(struct
389       module Wrapped = Wrapped
390       type ('x,'a) m = ('x,'a list) Wrapped.m
391       type ('x,'a) result = ('x,'a list) Wrapped.result
392       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
393       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
394       let bind u f =
395         Wrapped.bind u (fun ts ->
396         Wrapped.bind (distribute f ts) (fun tts ->
397         Wrapped.unit (Util.concat tts)))
398       let run u = Wrapped.run u
399       let run_exn u =
400         let w = Wrapped.bind u (fun ts -> match ts with
401           | [] -> Wrapped.zero ()
402           | [a] -> Wrapped.unit a
403           | many -> Wrapped.zero ()
404         ) in Wrapped.run_exn w
405       let zero () = Wrapped.unit []
406       let plus u v =
407         Wrapped.bind u (fun us ->
408         Wrapped.bind v (fun vs ->
409         Wrapped.unit (Base.plus us vs)))
410     end)
411 (*
412     let permute : 'a m -> 'a m m
413     let select : 'a m -> ('a * 'a m) m
414 *)
415   end
416 end
417
418
419 (* must be parameterized on (struct type err = ... end) *)
420 module Error_monad(Err : sig
421   type err
422   exception Exc of err
423   (*
424   val zero : unit -> err
425   val plus : err -> err -> err
426   *)
427 end) : sig
428   (* declare additional operations, while still hiding implementation of type m *)
429   type err = Err.err
430   type 'a error = Error of err | Success of 'a
431   type ('x,'a) result = 'a error
432   type ('x,'a) result_exn = 'a
433   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
434   val throw : err -> ('x,'a) m
435   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
436   (* ErrorT transformer *)
437   module T : functor (Wrapped : Monad.S) -> sig
438     type ('x,'a) result = ('x,'a) Wrapped.result
439     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
440     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
441     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
442     val throw : err -> ('x,'a) m
443     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
444   end
445 end = struct
446   type err = Err.err
447   type 'a error = Error of err | Success of 'a
448   module Base = struct
449     type ('x,'a) m = 'a error
450     type ('x,'a) result = 'a error
451     type ('x,'a) result_exn = 'a
452     let unit a = Success a
453     let bind u f = match u with
454       | Success a -> f a
455       | Error e -> Error e (* input and output may be of different 'a types *)
456     let run u = u
457     let run_exn u = match u with
458       | Success a -> a
459       | Error e -> raise (Err.Exc e)
460     let zero () = Util.undef
461     (* satisfies Catch *)
462     let plus u v = match u with
463       | Success _ -> u
464       | Error _ -> if v == Util.undef then u else v
465   end
466   include Monad.Make(Base)
467   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
468   let throw e = Error e
469   let catch u handler = match u with
470     | Success _ -> u
471     | Error e -> handler e
472   module T(Wrapped : Monad.S) = struct
473     include Monad.MakeT(struct
474       module Wrapped = Wrapped
475       type ('x,'a) m = ('x,'a error) Wrapped.m
476       type ('x,'a) result = ('x,'a) Wrapped.result
477       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
478       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
479       let bind u f = Wrapped.bind u (fun t -> match t with
480         | Success a -> f a
481         | Error e -> Wrapped.unit (Error e))
482       let run u =
483         let w = Wrapped.bind u (fun t -> match t with
484           | Success a -> Wrapped.unit a
485           | Error e -> Wrapped.zero ()
486         ) in Wrapped.run w
487       let run_exn u =
488         let w = Wrapped.bind u (fun t -> match t with
489           | Success a -> Wrapped.unit a
490           | Error e -> raise (Err.Exc e))
491         in Wrapped.run_exn w
492       let plus u v = Wrapped.plus u v
493       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
494     end)
495     let throw e = Wrapped.unit (Error e)
496     let catch u handler = Wrapped.bind u (fun t -> match t with
497       | Success _ -> Wrapped.unit t
498       | Error e -> handler e)
499   end
500 end
501
502 (* pre-define common instance of Error_monad *)
503 module Failure = Error_monad(struct
504   type err = string
505   exception Exc = Failure
506   (*
507   let zero = ""
508   let plus s1 s2 = s1 ^ "\n" ^ s2
509   *)
510 end)
511
512
513 (* must be parameterized on (struct type env = ... end) *)
514 module Reader_monad(Env : sig type env end) : sig
515   (* declare additional operations, while still hiding implementation of type m *)
516   type env = Env.env
517   type ('x,'a) result = env -> 'a
518   type ('x,'a) result_exn = env -> 'a
519   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
520   val ask : ('x,env) m
521   val asks : (env -> 'a) -> ('x,'a) m
522   (* lookup i == `fun e -> e i` would assume env is a functional type *)
523   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
524   (* ReaderT transformer *)
525   module T : functor (Wrapped : Monad.S) -> sig
526     type ('x,'a) result = env -> ('x,'a) Wrapped.result
527     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
528     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
529     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
530     val ask : ('x,env) m
531     val asks : (env -> 'a) -> ('x,'a) m
532     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
533   end
534 end = struct
535   type env = Env.env
536   module Base = struct
537     type ('x,'a) m = env -> 'a
538     type ('x,'a) result = env -> 'a
539     type ('x,'a) result_exn = env -> 'a
540     let unit a = fun e -> a
541     let bind u f = fun e -> let a = u e in let u' = f a in u' e
542     let run u = fun e -> u e
543     let run_exn = run
544     let zero () = Util.undef
545     let plus u v = u
546   end
547   include Monad.Make(Base)
548   let ask = fun e -> e
549   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
550   let local modifier u = fun e -> u (modifier e)
551   module T(Wrapped : Monad.S) = struct
552     module BaseT = struct
553       module Wrapped = Wrapped
554       type ('x,'a) m = env -> ('x,'a) Wrapped.m
555       type ('x,'a) result = env -> ('x,'a) Wrapped.result
556       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
557       let elevate w = fun e -> w
558       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
559       let run u = fun e -> Wrapped.run (u e)
560       let run_exn u = fun e -> Wrapped.run_exn (u e)
561       (* satisfies Distrib *)
562       let plus u v = fun e -> Wrapped.plus (u e) (v e)
563       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
564     end
565     include Monad.MakeT(BaseT)
566     let ask = Wrapped.unit
567     let local modifier u = fun e -> u (modifier e)
568     let asks selector = ask >>= (fun e ->
569       try unit (selector e)
570       with Not_found -> fun e -> Wrapped.zero ())
571   end
572 end
573
574
575 (* must be parameterized on (struct type store = ... end) *)
576 module State_monad(Store : sig type store end) : sig
577   (* declare additional operations, while still hiding implementation of type m *)
578   type store = Store.store
579   type ('x,'a) result =  store -> 'a * store
580   type ('x,'a) result_exn = store -> 'a
581   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
582   val get : ('x,store) m
583   val gets : (store -> 'a) -> ('x,'a) m
584   val put : store -> ('x,unit) m
585   val puts : (store -> store) -> ('x,unit) m
586   (* StateT transformer *)
587   module T : functor (Wrapped : Monad.S) -> sig
588     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
589     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
590     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
591     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
592     val get : ('x,store) m
593     val gets : (store -> 'a) -> ('x,'a) m
594     val put : store -> ('x,unit) m
595     val puts : (store -> store) -> ('x,unit) m
596   end
597 end = struct
598   type store = Store.store
599   module Base = struct
600     type ('x,'a) m =  store -> 'a * store
601     type ('x,'a) result =  store -> 'a * store
602     type ('x,'a) result_exn = store -> 'a
603     let unit a = fun s -> (a, s)
604     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
605     let run u = fun s -> (u s)
606     let run_exn u = fun s -> fst (u s)
607     let zero () = Util.undef
608     let plus u v = u
609   end
610   include Monad.Make(Base)
611   let get = fun s -> (s, s)
612   let gets viewer = fun s -> (viewer s, s) (* may fail *)
613   let put s = fun _ -> ((), s)
614   let puts modifier = fun s -> ((), modifier s)
615   module T(Wrapped : Monad.S) = struct
616     module BaseT = struct
617       module Wrapped = Wrapped
618       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
619       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
620       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
621       let elevate w = fun s ->
622         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
623       let bind u f = fun s ->
624         Wrapped.bind (u s) (fun (a, s') -> f a s')
625       let run u = fun s -> Wrapped.run (u s)
626       let run_exn u = fun s ->
627         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
628         in Wrapped.run_exn w
629       (* satisfies Distrib *)
630       let plus u v = fun s -> Wrapped.plus (u s) (v s)
631       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
632     end
633     include Monad.MakeT(BaseT)
634     let get = fun s -> Wrapped.unit (s, s)
635     let gets viewer = fun s ->
636       try Wrapped.unit (viewer s, s)
637       with Not_found -> Wrapped.zero ()
638     let put s = fun _ -> Wrapped.unit ((), s)
639     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
640   end
641 end
642
643
644 (* State monad with different interface (structured store) *)
645 module Ref_monad(V : sig
646   type value
647 end) : sig
648   type ref
649   type value = V.value
650   type ('x,'a) result = 'a
651   type ('x,'a) result_exn = 'a
652   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
653   val newref : value -> ('x,ref) m
654   val deref : ref -> ('x,value) m
655   val change : ref -> value -> ('x,unit) m
656   (* RefT transformer *)
657   module T : functor (Wrapped : Monad.S) -> sig
658     type ('x,'a) result = ('x,'a) Wrapped.result
659     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
660     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
661     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
662     val newref : value -> ('x,ref) m
663     val deref : ref -> ('x,value) m
664     val change : ref -> value -> ('x,unit) m
665   end
666 end = struct
667   type ref = int
668   type value = V.value
669   module D = Map.Make(struct type t = ref let compare = compare end)
670   type dict = { next: ref; tree : value D.t }
671   let empty = { next = 0; tree = D.empty }
672   let alloc (value : value) (d : dict) =
673     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
674   let read (key : ref) (d : dict) =
675     D.find key d.tree
676   let write (key : ref) (value : value) (d : dict) =
677     { next = d.next; tree = D.add key value d.tree }
678   module Base = struct
679     type ('x,'a) m = dict -> 'a * dict
680     type ('x,'a) result = 'a
681     type ('x,'a) result_exn = 'a
682     let unit a = fun s -> (a, s)
683     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
684     let run u = fst (u empty)
685     let run_exn = run
686     let zero () = Util.undef
687     let plus u v = u
688   end
689   include Monad.Make(Base)
690   let newref value = fun s -> alloc value s
691   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
692   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
693   module T(Wrapped : Monad.S) = struct
694     module BaseT = struct
695       module Wrapped = Wrapped
696       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
697       type ('x,'a) result = ('x,'a) Wrapped.result
698       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
699       let elevate w = fun s ->
700         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
701       let bind u f = fun s ->
702         Wrapped.bind (u s) (fun (a, s') -> f a s')
703       let run u =
704         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
705         in Wrapped.run w
706       let run_exn u =
707         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
708         in Wrapped.run_exn w
709       (* satisfies Distrib *)
710       let plus u v = fun s -> Wrapped.plus (u s) (v s)
711       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
712     end
713     include Monad.MakeT(BaseT)
714     let newref value = fun s -> Wrapped.unit (alloc value s)
715     let deref key = fun s -> Wrapped.unit (read key s, s)
716     let change key value = fun s -> Wrapped.unit ((), write key value s)
717   end
718 end
719
720
721 (* must be parameterized on (struct type log = ... end) *)
722 module Writer_monad(Log : sig
723   type log
724   val zero : log
725   val plus : log -> log -> log
726 end) : sig
727   (* declare additional operations, while still hiding implementation of type m *)
728   type log = Log.log
729   type ('x,'a) result = 'a * log
730   type ('x,'a) result_exn = 'a * log
731   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
732   val tell : log -> ('x,unit) m
733   val listen : ('x,'a) m -> ('x,'a * log) m
734   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
735   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
736   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
737   (* WriterT transformer *)
738   module T : functor (Wrapped : Monad.S) -> sig
739     type ('x,'a) result = ('x,'a * log) Wrapped.result
740     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
741     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
742     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
743     val tell : log -> ('x,unit) m
744     val listen : ('x,'a) m -> ('x,'a * log) m
745     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
746     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
747   end
748 end = struct
749   type log = Log.log
750   module Base = struct
751     type ('x,'a) m = 'a * log
752     type ('x,'a) result = 'a * log
753     type ('x,'a) result_exn = 'a * log
754     let unit a = (a, Log.zero)
755     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
756     let run u = u
757     let run_exn = run
758     let zero () = Util.undef
759     let plus u v = u
760   end
761   include Monad.Make(Base)
762   let tell entries = ((), entries) (* add entries to log *)
763   let listen (a, w) = ((a, w), w)
764   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
765   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
766   let censor f u = pass (u >>= fun a -> unit (a, f))
767   module T(Wrapped : Monad.S) = struct
768     module BaseT = struct
769       module Wrapped = Wrapped
770       type ('x,'a) m = ('x,'a * log) Wrapped.m
771       type ('x,'a) result = ('x,'a * log) Wrapped.result
772       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
773       let elevate w =
774         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
775       let bind u f =
776         Wrapped.bind u (fun (a, w) ->
777         Wrapped.bind (f a) (fun (b, w') ->
778         Wrapped.unit (b, Log.plus w w')))
779       let zero () = elevate (Wrapped.zero ())
780       let plus u v = Wrapped.plus u v
781       let run u = Wrapped.run u
782       let run_exn u = Wrapped.run_exn u
783     end
784     include Monad.MakeT(BaseT)
785     let tell entries = Wrapped.unit ((), entries)
786     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
787     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
788     (* rest are derived in same way as before *)
789     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
790     let censor f u = pass (u >>= fun a -> unit (a, f))
791   end
792 end
793
794 (* pre-define simple Writer *)
795 module Writer1 = Writer_monad(struct
796   type log = string
797   let zero = ""
798   let plus s1 s2 = s1 ^ "\n" ^ s2
799 end)
800
801 (* slightly more efficient Writer *)
802 module Writer2 = struct
803   include Writer_monad(struct
804     type log = string list
805     let zero = []
806     let plus w w' = Util.append w' w
807   end)
808   let tell_string s = tell [s]
809   let tell entries = tell (Util.reverse entries)
810   let run u = let (a, w) = run u in (a, Util.reverse w)
811   let run_exn = run
812 end
813
814
815 (* TODO needs a T *)
816 module IO_monad : sig
817   (* declare additional operation, while still hiding implementation of type m *)
818   type ('x,'a) result = 'a
819   type ('x,'a) result_exn = 'a
820   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
821   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
822   val print_string : string -> ('x,unit) m
823   val print_int : int -> ('x,unit) m
824   val print_hex : int -> ('x,unit) m
825   val print_bool : bool -> ('x,unit) m
826 end = struct
827   module Base = struct
828     type ('x,'a) m = { run : unit -> unit; value : 'a }
829     type ('x,'a) result = 'a
830     type ('x,'a) result_exn = 'a
831     let unit a = { run = (fun () -> ()); value = a }
832     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
833      let fres = f a.value in
834        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
835     let run a = let () = a.run () in a.value
836     let run_exn = run
837     let zero () = Util.undef
838     let plus u v = u
839   end
840   include Monad.Make(Base)
841   let printf fmt =
842     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
843   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
844   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
845   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
846   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
847 end
848
849
850 module Continuation_monad : sig
851   (* expose only the implementation of type `('r,'a) result` *)
852   type ('r,'a) m
853   type ('r,'a) result = ('r,'a) m
854   type ('r,'a) result_exn = ('a -> 'r) -> 'r
855   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
856   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
857   val reset : ('a,'a) m -> ('r,'a) m
858   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
859   (* val abort : ('a,'a) m -> ('a,'b) m *)
860   val abort : 'a -> ('a,'b) m
861   val run0 : ('a,'a) m -> 'a
862   (* ContinuationT transformer *)
863   module T : functor (Wrapped : Monad.S) -> sig
864     type ('r,'a) m
865     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
866     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
867     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
868     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
869     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
870     (* TODO: reset,shift,abort,run0 *)
871   end
872 end = struct
873   let id = fun i -> i
874   module Base = struct
875     (* 'r is result type of whole computation *)
876     type ('r,'a) m = ('a -> 'r) -> 'r
877     type ('r,'a) result = ('a -> 'r) -> 'r
878     type ('r,'a) result_exn = ('r,'a) result
879     let unit a = (fun k -> k a)
880     let bind u f = (fun k -> (u) (fun a -> (f a) k))
881     let run u k = (u) k
882     let run_exn = run
883     let zero () = Util.undef
884     let plus u v = u
885   end
886   include Monad.Make(Base)
887   let callcc f = (fun k ->
888     let usek a = (fun _ -> k a)
889     in (f usek) k)
890   (*
891   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
892   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
893   let callcc f = fun k -> f k k
894   let throw k a = fun _ -> k a
895   *)
896
897   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
898    *
899    *  reset :: (Monad m) => ContT a m a -> ContT r m a
900    *  reset e = ContT $ \k -> runContT e return >>= k
901    *
902    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
903    *  shift e = ContT $ \k ->
904    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
905   let reset u = unit ((u) id)
906   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
907   (* let abort a = shift (fun _ -> a) *)
908   let abort a = shift (fun _ -> unit a)
909   let run0 (u : ('a,'a) m) = (u) id
910   module T(Wrapped : Monad.S) = struct
911     module BaseT = struct
912       module Wrapped = Wrapped
913       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
914       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
915       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
916       let elevate w = fun k -> Wrapped.bind w k
917       let bind u f = fun k -> u (fun a -> f a k)
918       let run u k = Wrapped.run (u k)
919       let run_exn u k = Wrapped.run_exn (u k)
920       let zero () = Util.undef
921       let plus u v = u
922     end
923     include Monad.MakeT(BaseT)
924     let callcc f = (fun k ->
925       let usek a = (fun _ -> k a)
926       in (f usek) k)
927   end
928 end
929
930
931 (*
932  * Scheme:
933  * (define (example n)
934  *    (let ([u (let/cc k ; type int -> int pair
935  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
936  *                 (+ 1 (car v))))]) ; int
937  *      (cons u 0))) ; int pair
938  * ; (example 10) ~~> '(111 . 0)
939  * ; (example -10) ~~> '(0 . 0)
940  *
941  * OCaml monads:
942  * let example n : (int * int) =
943  *   Continuation_monad.(let u = callcc (fun k ->
944  *       (if n < 0 then k 0 else unit [n + 100])
945  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
946  *       >>= fun [x] -> unit (x + 1)
947  *   )
948  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
949  *   >>= fun x -> unit (x, 0)
950  *   in run u)
951  *
952  *
953  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
954  * let example1 () : int =
955  *   Continuation_monad.(let v = reset (
956  *       let u = shift (fun k -> unit (10 + 1))
957  *       in u >>= fun x -> unit (100 + x)
958  *     ) in let w = v >>= fun x -> unit (1000 + x)
959  *     in run w)
960  *
961  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
962  * let example2 () =
963  *   Continuation_monad.(let v = reset (
964  *       let u = shift (fun k -> k (10 :: [1]))
965  *       in u >>= fun x -> unit (100 :: x)
966  *     ) in let w = v >>= fun x -> unit (1000 :: x)
967  *     in run w)
968  *
969  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
970  * let example3 () =
971  *   Continuation_monad.(let v = reset (
972  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
973  *       in u >>= fun x -> unit (100 :: x)
974  *     ) in let w = v >>= fun x -> unit (1000 :: x)
975  *     in run w)
976  *
977  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
978  * (* not sure if this example can be typed without a sum-type *)
979  *
980  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
981  * let example5 () : int =
982  *   Continuation_monad.(let v = reset (
983  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
984  *       in u >>= fun x -> unit (10 + x)
985  *     ) in let w = v >>= fun x -> unit (100 + x)
986  *     in run w)
987  *
988  *)
989
990
991 module Tree_monad : sig
992   (* We implement the type as `'a tree option` because it has a natural`plus`,
993    * and the rest of the library expects that `plus` and `zero` will come together. *)
994   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
995   type ('x,'a) result = 'a tree option
996   type ('x,'a) result_exn = 'a tree
997   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
998   (* TreeT transformer *)
999   module T : functor (Wrapped : Monad.S) -> sig
1000     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1001     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1002     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1003     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1004     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1005     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1006     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1007   end
1008 end = struct
1009   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1010   (* uses supplied plus and zero to copy t to its image under f *)
1011   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1012       | None -> zero ()
1013       | Some ts -> let rec loop ts = (match ts with
1014                      | Leaf a -> f a
1015                      | Node (l, r) ->
1016                          (* recursive application of f may delete a branch *)
1017                          plus (loop l) (loop r)
1018                    ) in loop ts
1019   module Base = struct
1020     type ('x,'a) m = 'a tree option
1021     type ('x,'a) result = 'a tree option
1022     type ('x,'a) result_exn = 'a tree
1023     let unit a = Some (Leaf a)
1024     let zero () = None
1025     (* satisfies Distrib *)
1026     let plus u v = match (u, v) with
1027       | None, _ -> v
1028       | _, None -> u
1029       | Some us, Some vs -> Some (Node (us, vs))
1030     let bind u f = mapT f u zero plus
1031     let run u = u
1032     let run_exn u = match u with
1033       | None -> failwith "no values"
1034       (*
1035       | Some (Leaf a) -> a
1036       | many -> failwith "multiple values"
1037       *)
1038       | Some us -> us
1039   end
1040   include Monad.Make(Base)
1041   module T(Wrapped : Monad.S) = struct
1042     module BaseT = struct
1043       include Monad.MakeT(struct
1044         module Wrapped = Wrapped
1045         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1046         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1047         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1048         let zero () = Wrapped.unit None
1049         let plus u v =
1050           Wrapped.bind u (fun us ->
1051           Wrapped.bind v (fun vs ->
1052           Wrapped.unit (Base.plus us vs)))
1053         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1054         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1055         let run u = Wrapped.run u
1056         let run_exn u =
1057             let w = Wrapped.bind u (fun t -> match t with
1058               | None -> Wrapped.zero ()
1059               | Some ts -> Wrapped.unit ts
1060             ) in Wrapped.run_exn w
1061       end)
1062     end
1063     include BaseT
1064     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1065   end
1066 end
1067
1068
1069 module L = List_monad;;
1070 module R = Reader_monad(struct type env = int -> int end);;
1071 module S = State_monad(struct type store = int end);;
1072 module T = Tree_monad;;
1073 module LR = L.T(R);;
1074 module LS = L.T(S);;
1075 module TL = T.T(L);;
1076 module TR = T.T(R);;
1077 module TS = T.T(S);;
1078 module C = Continuation_monad
1079 module TC = T.T(C);;
1080
1081
1082 print_endline "=== test TreeT(...).distribute ==================";;
1083
1084 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1085
1086 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1087 TS.run ts 0;;
1088 (*
1089 - : int T.tree option * S.store =
1090 (Some
1091   (T.Node
1092     (T.Node (T.Leaf 2, T.Leaf 3),
1093      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1094  5)
1095 *)
1096
1097 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1098 TS.run_exn ts2 0;;
1099 (*
1100 - : (int * S.store) T.tree option * S.store =
1101 (Some
1102   (T.Node
1103     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1104      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1105  5)
1106 *)
1107
1108 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1109 TR.run_exn tr (fun i -> i+i);;
1110 (*
1111 - : int T.tree option =
1112 Some
1113  (T.Node
1114    (T.Node (T.Leaf 4, T.Leaf 6),
1115     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1116 *)
1117
1118 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1119 TL.run_exn tl;;
1120 (*
1121 - : (int * int) TL.result =
1122 [Some
1123   (T.Node
1124     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1125      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1126 *)
1127
1128 let l2 = [1;2;3;4;5];;
1129 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1130
1131 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1132 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1133
1134 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1135 (*
1136 int T.tree option =
1137 Some
1138  (T.Node
1139    (T.Node (T.Leaf 10, T.Leaf 11),
1140     T.Node
1141      (T.Node
1142        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1143         T.Node (T.Leaf 40, T.Leaf 41)),
1144       T.Node (T.Leaf 50, T.Leaf 51))))
1145  *)
1146
1147 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1148 (*
1149 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1150 *)
1151
1152 print_endline "=== test TreeT(Continuation).distribute ==================";;
1153
1154 let id : 'z. 'z -> 'z = fun x -> x
1155
1156 let example n : (int * int) =
1157   Continuation_monad.(let u = callcc (fun k ->
1158       (if n < 0 then k 0 else unit [n + 100])
1159       (* all of the following is skipped by k 0; the end type int is k's input type *)
1160       >>= fun [x] -> unit (x + 1)
1161   )
1162   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1163   >>= fun x -> unit (x, 0)
1164   in run0 u)
1165
1166
1167 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1168 let example1 () : int =
1169   Continuation_monad.(let v = reset (
1170       let u = shift (fun k -> unit (10 + 1))
1171       in u >>= fun x -> unit (100 + x)
1172     ) in let w = v >>= fun x -> unit (1000 + x)
1173     in run0 w)
1174
1175 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1176 let example2 () =
1177   Continuation_monad.(let v = reset (
1178       let u = shift (fun k -> k (10 :: [1]))
1179       in u >>= fun x -> unit (100 :: x)
1180     ) in let w = v >>= fun x -> unit (1000 :: x)
1181     in run0 w)
1182
1183 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1184 let example3 () =
1185   Continuation_monad.(let v = reset (
1186       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1187       in u >>= fun x -> unit (100 :: x)
1188     ) in let w = v >>= fun x -> unit (1000 :: x)
1189     in run0 w)
1190
1191 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1192 (* not sure if this example can be typed without a sum-type *)
1193
1194 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1195 let example5 () : int =
1196   Continuation_monad.(let v = reset (
1197       let u = shift (fun k -> k 1 >>= k)
1198       in u >>= fun x -> unit (10 + x)
1199     ) in let w = v >>= fun x -> unit (100 + x)
1200     in run0 w)
1201
1202 ;;
1203
1204 print_endline "=== test bare Continuation ============";;
1205
1206 (1011, 1111, 1111, 121);;
1207 (example1(), example2(), example3(), example5());;
1208 ((111,0), (0,0));;
1209 (example ~+10, example ~-10);;
1210
1211 let testc df ic =
1212     C.run_exn TC.(run (distribute df t1)) ic;;
1213
1214
1215 (*
1216 (* do nothing *)
1217 let initial_continuation = fun t -> t in
1218 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1219 *)
1220 testc (C.unit) id;;
1221
1222 (*
1223 (* count leaves, using continuation *)
1224 let initial_continuation = fun t -> 0 in
1225 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1226 *)
1227
1228 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1229
1230 (*
1231 (* convert tree to list of leaves *)
1232 let initial_continuation = fun t -> [] in
1233 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1234 *)
1235
1236 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1237
1238 (*
1239 (* square each leaf using continuation *)
1240 let initial_continuation = fun t -> t in
1241 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1242 *)
1243
1244 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1245
1246
1247 (*
1248 (* replace leaves with list, using continuation *)
1249 let initial_continuation = fun t -> t in
1250 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1251 *)
1252
1253 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1254
1255 print_endline "=== pa_monad's Continuation Tests ============";;
1256
1257 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1258 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1259 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1260 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1261 (5, 27 = C.(run0 (
1262               let c = reset(abort 5 >>= fun y -> unit (y+6))
1263               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1264
1265 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1266
1267 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1268
1269 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1270
1271
1272 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1273