tweak monads-lib, start T2
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73   (*
74    * Signature extenders:
75    *   Make :: BASE -> S
76    *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
77    *                                           which merges into S
78    *                                           (P is merged sig)
79    *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
80    *
81    *   Make2 :: BASE2 -> S2
82    *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
83    *   to wrap double-typed inner monads:
84    *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
85    *
86    *)
87
88
89
90   (* type of base definitions *)
91   module type BASE = sig
92     (* The only constraints we impose here on how the monadic type
93      * is implemented is that it have a single type parameter 'a. *)
94     type 'a m
95     type 'a result
96     type 'a result_exn
97     val unit : 'a -> 'a m
98     val bind : 'a m -> ('a -> 'b m) -> 'b m
99     val run : 'a m -> 'a result
100     (* run_exn tries to provide a more ground-level result, but may fail *)
101     val run_exn : 'a m -> 'a result_exn
102   end
103   module type S = sig
104     include BASE
105     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
106     val (>>) : 'a m -> 'b m -> 'b m
107     val join : ('a m) m -> 'a m
108     val apply : ('a -> 'b) m -> 'a m -> 'b m
109     val lift : ('a -> 'b) -> 'a m -> 'b m
110     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
111     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
112     val do_when :  bool -> unit m -> unit m
113     val do_unless :  bool -> unit m -> unit m
114     val forever : 'a m -> 'b m
115     val sequence : 'a m list -> 'a list m
116     val sequence_ : 'a m list -> unit m
117   end
118
119   (* Standard, single-type-parameter monads. *)
120   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
121     include B
122     let (>>=) = bind
123     let (>>) u v = u >>= fun _ -> v
124     let lift f u = u >>= fun a -> unit (f a)
125     (* lift is called listM, fmap, and <$> in Haskell *)
126     let join uu = uu >>= fun u -> u
127     (* u >>= f === join (lift f u) *)
128     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
129     (* [f] <*> [x1,x2] = [f x1,f x2] *)
130     (* let apply u v = u >>= fun f -> lift f v *)
131     (* let apply = lift2 id *)
132     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
133     (* let lift f u === apply (unit f) u *)
134     (* let lift2 f u v = apply (lift f u) v *)
135     let (>=>) f g = fun a -> f a >>= g
136     let do_when test u = if test then u else unit ()
137     let do_unless test u = if test then unit () else u
138     let rec forever u = u >> forever u
139     let sequence ms =
140       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
141         Util.fold_right op ms (unit [])
142     let sequence_ ms =
143       Util.fold_right (>>) ms (unit ())
144
145     (* Haskell defines these other operations combining lists and monads.
146      * We don't, but notice that M.mapM == ListT(M).distribute
147      * There's also a parallel TreeT(M).distribute *)
148     (*
149     let mapM f alist = sequence (Util.map f alist)
150     let mapM_ f alist = sequence_ (Util.map f alist)
151     let rec filterM f lst = match lst with
152       | [] -> unit []
153       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
154     let forM alist f = mapM f alist
155     let forM_ alist f = mapM_ f alist
156     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
157     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
158     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
159     let rec foldM f z lst = match lst with
160       | [] -> unit z
161       | x::xs -> f z x >>= fun z' -> foldM f z' xs
162     let foldM_ f z xs = foldM f z xs >> unit ()
163     let replicateM n x = sequence (Util.replicate n x)
164     let replicateM_ n x = sequence_ (Util.replicate n x)
165     *)
166   end
167
168   (* Single-type-parameter monads that also define `plus` and `zero`
169    * operations. These obey the following laws:
170    *     zero >>= f   ===  zero
171    *     plus zero u  ===  u
172    *     plus u zero  ===  u
173    * Additionally, these monads will obey one of the following laws:
174    *     (Catch)   plus (unit a) v  ===  unit a
175    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
176    *)
177   module type PLUSBASE = sig
178     include BASE
179     val zero : unit -> 'a m
180     val plus : 'a m -> 'a m -> 'a m
181   end
182   module type PLUS = sig
183     type 'a m
184     val zero : unit -> 'a m
185     val plus : 'a m -> 'a m -> 'a m
186     val guard : bool -> unit m
187     val sum : 'a m list -> 'a m
188   end
189   (* MakeCatch and MakeDistrib have the same implementation; we just declare
190    * them twice to document which laws the client code is promising to honor. *)
191   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
192     type 'a m = 'a B.m
193     let zero = B.zero
194     let plus = B.plus
195     let guard test = if test then B.unit () else zero ()
196     let sum ms = Util.fold_right plus ms (zero ())
197   end
198   module MakeDistrib = MakeCatch
199
200   (* Signatures for MonadT *)
201   (* sig for Wrapped that include S and PLUS *)
202   module type P = sig
203     include S
204     include PLUS with type 'a m := 'a m
205   end
206   module type TRANS = sig
207     module Wrapped : S
208     type 'a m
209     type 'a result
210     type 'a result_exn
211     val bind : 'a m -> ('a -> 'b m) -> 'b m
212     val run : 'a m -> 'a result
213     val run_exn : 'a m -> 'a result_exn
214     val elevate : 'a Wrapped.m -> 'a m
215     (* lift/elevate laws:
216      *     elevate (W.unit a) == unit a
217      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
218      *)
219   end
220   module MakeT(T : TRANS) = struct
221     include Make(struct
222         include T
223         let unit a = elevate (Wrapped.unit a)
224     end)
225     let elevate = T.elevate
226   end
227
228
229   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
230   module type BASE2 = sig
231     type ('x,'a) m
232     type ('x,'a) result
233     type ('x,'a) result_exn
234     val unit : 'a -> ('x,'a) m
235     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
236     val run : ('x,'a) m -> ('x,'a) result
237     val run_exn : ('x,'a) m -> ('x,'a) result_exn
238   end
239   module type S2 = sig
240     include BASE2
241     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
242     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
243     val join : ('x,('x,'a) m) m -> ('x,'a) m
244     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
245     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
246     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
247     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
248     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
249     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
250     val forever : ('x,'a) m -> ('x,'b) m
251     val sequence : ('x,'a) m list -> ('x,'a list) m
252     val sequence_ : ('x,'a) m list -> ('x,unit) m
253   end
254   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
255     (* code repetition, ugh *)
256     include B
257     let (>>=) = bind
258     let (>>) u v = u >>= fun _ -> v
259     let lift f u = u >>= fun a -> unit (f a)
260     let join uu = uu >>= fun u -> u
261     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
262     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
263     let (>=>) f g = fun a -> f a >>= g
264     let do_when test u = if test then u else unit ()
265     let do_unless test u = if test then unit () else u
266     let rec forever u = u >> forever u
267     let sequence ms =
268       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
269         Util.fold_right op ms (unit [])
270     let sequence_ ms =
271       Util.fold_right (>>) ms (unit ())
272   end
273
274   module type PLUSBASE2 = sig
275     include BASE2
276     val zero : unit -> ('x,'a) m
277     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
278   end
279   module type PLUS2 = sig
280     type ('x,'a) m
281     val zero : unit -> ('x,'a) m
282     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
283     val guard : bool -> ('x,unit) m
284     val sum : ('x,'a) m list -> ('x,'a) m
285   end
286   module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
287     type ('x,'a) m = ('x,'a) B.m
288     (* code repetition, ugh *)
289     let zero = B.zero
290     let plus = B.plus
291     let guard test = if test then B.unit () else zero ()
292     let sum ms = Util.fold_right plus ms (zero ())
293   end
294   module MakeDistrib2 = MakeCatch2
295
296   (* Signatures for MonadT *)
297   (* sig for Wrapped that include S and PLUS *)
298   module type P2 = sig
299     include S2
300     include PLUS2 with type ('x,'a) m := ('x,'a) m
301   end
302   module type TRANS2 = sig
303     module Wrapped : S2
304     type ('x,'a) m
305     type ('x,'a) result
306     type ('x,'a) result_exn
307     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
308     val run : ('x,'a) m -> ('x,'a) result
309     val run_exn : ('x,'a) m -> ('x,'a) result_exn
310     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
311   end
312   module MakeT2(T : TRANS2) = struct
313     (* code repetition, ugh *)
314     include Make2(struct
315         include T
316         let unit a = elevate (Wrapped.unit a)
317     end)
318     let elevate = T.elevate
319   end
320
321 end
322
323
324
325
326
327 module Identity_monad : sig
328   (* expose only the implementation of type `'a result` *)
329   type 'a result = 'a
330   type 'a result_exn = 'a
331   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
332 end = struct
333   module Base = struct
334     type 'a m = 'a
335     let unit a = a
336     let bind a f = f a
337     type 'a result = 'a
338     let run a = a
339     type 'a result_exn = 'a
340     let run_exn a = a
341   end
342   include Monad.Make(Base)
343 end
344
345
346 module Maybe_monad : sig
347   (* expose only the implementation of type `'a result` *)
348   type 'a result = 'a option
349   type 'a result_exn = 'a
350   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
351   include Monad.PLUS with type 'a m := 'a m
352   (* MaybeT transformer *)
353   module T : functor (Wrapped : Monad.S) -> sig
354     type 'a result = 'a option Wrapped.result
355     type 'a result_exn = 'a Wrapped.result_exn
356     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
357     include Monad.PLUS with type 'a m := 'a m
358     val elevate : 'a Wrapped.m -> 'a m
359   end
360   module T2 : functor (Wrapped : Monad.S2) -> sig
361     type ('x,'a) result = ('x,'a option) Wrapped.result
362     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
363     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
364     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
365     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
366   end
367 end = struct
368   module Base = struct
369    type 'a m = 'a option
370    let unit a = Some a
371    let bind u f = match u with Some a -> f a | None -> None
372    type 'a result = 'a option
373    let run u = u
374    type 'a result_exn = 'a
375    let run_exn u = match u with
376      | Some a -> a
377      | None -> failwith "no value"
378    let zero () = None
379    let plus u v = match u with None -> v | _ -> u
380   end
381   include Monad.Make(Base)
382   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
383   module T(Wrapped : Monad.S) = struct
384     module Trans = struct
385       include Monad.MakeT(struct
386         module Wrapped = Wrapped
387         type 'a m = 'a option Wrapped.m
388         type 'a result = 'a option Wrapped.result
389         type 'a result_exn = 'a Wrapped.result_exn
390         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
391         let bind u f = Wrapped.bind u (fun t -> match t with
392           | Some a -> f a
393           | None -> Wrapped.unit None)
394         let run u = Wrapped.run u
395         let run_exn u =
396           let w = Wrapped.bind u (fun t -> match t with
397             | Some a -> Wrapped.unit a
398             | None -> failwith "no value")
399           in Wrapped.run_exn w
400       end)
401       let zero () = Wrapped.unit None
402       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
403     end
404     include Trans
405     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
406   end
407   module T2(Wrapped : Monad.S2) = struct
408     module Trans = struct
409       include Monad.MakeT2(struct
410         module Wrapped = Wrapped
411         type ('x,'a) m = ('x,'a option) Wrapped.m
412         type ('x,'a) result = ('x,'a option) Wrapped.result
413         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
414         (* code repetition, ugh *)
415         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
416         let bind u f = Wrapped.bind u (fun t -> match t with
417           | Some a -> f a
418           | None -> Wrapped.unit None)
419         let run u = Wrapped.run u
420         let run_exn u =
421           let w = Wrapped.bind u (fun t -> match t with
422             | Some a -> Wrapped.unit a
423             | None -> failwith "no value")
424           in Wrapped.run_exn w
425       end)
426       let zero () = Wrapped.unit None
427       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
428     end
429     include Trans
430     include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
431   end
432 end
433
434
435 module List_monad : sig
436   (* declare additional operation, while still hiding implementation of type m *)
437   type 'a result = 'a list
438   type 'a result_exn = 'a
439   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
440   include Monad.PLUS with type 'a m := 'a m
441   val permute : 'a m -> 'a m m
442   val select : 'a m -> ('a * 'a m) m
443   (* ListT transformer *)
444   module T : functor (Wrapped : Monad.S) -> sig
445     type 'a result = 'a list Wrapped.result
446     type 'a result_exn = 'a Wrapped.result_exn
447     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
448     include Monad.PLUS with type 'a m := 'a m
449     val elevate : 'a Wrapped.m -> 'a m
450     (* note that second argument is an 'a list, not the more abstract 'a m *)
451     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
452     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
453 (* TODO
454     val permute : 'a m -> 'a m m
455     val select : 'a m -> ('a * 'a m) m
456 *)
457   end
458 (*
459   module T2 : functor (Wrapped : Monad.S2) -> sig
460     type ('x,'a) result = ('x,'a list) Wrapped.result
461     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
462     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
463     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
464     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
465     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
466   end
467  *)
468 end = struct
469   module Base = struct
470    type 'a m = 'a list
471    let unit a = [a]
472    let bind u f = Util.concat_map f u
473    type 'a result = 'a list
474    let run u = u
475    type 'a result_exn = 'a
476    let run_exn u = match u with
477      | [] -> failwith "no values"
478      | [a] -> a
479      | many -> failwith "multiple values"
480    let zero () = []
481    let plus = Util.append
482   end
483   include Monad.Make(Base)
484   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
485   (* let either u v = plus u v *)
486   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
487   let rec insert a u =
488     plus (unit (a :: u)) (match u with
489         | [] -> zero ()
490         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
491     )
492   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
493   let rec permute u = match u with
494       | [] -> unit []
495       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
496   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
497   let rec select u = match u with
498     | [] -> zero ()
499     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
500   let base_plus = plus
501   module T(Wrapped : Monad.S) = struct
502     module Trans = struct
503       let zero () = Wrapped.unit []
504       let plus u v =
505         Wrapped.bind u (fun us ->
506         Wrapped.bind v (fun vs ->
507         Wrapped.unit (base_plus us vs)))
508       (* Wrapped.sequence ms  ===  
509            let plus1 u v =
510              Wrapped.bind u (fun x ->
511              Wrapped.bind v (fun xs ->
512              Wrapped.unit (x :: xs)))
513            in Util.fold_right plus1 ms (Wrapped.unit []) *)
514       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
515       let distribute f alist = Wrapped.sequence (Util.map f alist)
516       include Monad.MakeT(struct
517         module Wrapped = Wrapped
518         type 'a m = 'a list Wrapped.m
519         type 'a result = 'a list Wrapped.result
520         type 'a result_exn = 'a Wrapped.result_exn
521         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
522         let bind u f =
523           Wrapped.bind u (fun ts ->
524           Wrapped.bind (distribute f ts) (fun tts ->
525           Wrapped.unit (Util.concat tts)))
526         let run u = Wrapped.run u
527         let run_exn u =
528           let w = Wrapped.bind u (fun ts -> match ts with
529             | [] -> failwith "no values"
530             | [a] -> Wrapped.unit a
531             | many -> failwith "multiple values"
532           ) in Wrapped.run_exn w
533       end)
534     end
535     include Trans
536     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
537 (*
538     let permute : 'a m -> 'a m m
539     let select : 'a m -> ('a * 'a m) m
540 *)
541   end
542 end
543
544
545 (* must be parameterized on (struct type err = ... end) *)
546 module Error_monad(Err : sig
547   type err
548   exception Exc of err
549   (*
550   val zero : unit -> err
551   val plus : err -> err -> err
552   *)
553 end) : sig
554   (* declare additional operations, while still hiding implementation of type m *)
555   type err = Err.err
556   type 'a error = Error of err | Success of 'a
557   type 'a result = 'a
558   type 'a result_exn = 'a
559   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
560   (* include Monad.PLUS with type 'a m := 'a m *)
561   val throw : err -> 'a m
562   val catch : 'a m -> (err -> 'a m) -> 'a m
563   (* ErrorT transformer *)
564   module T : functor (Wrapped : Monad.S) -> sig
565     type 'a result = 'a Wrapped.result
566     type 'a result_exn = 'a Wrapped.result_exn
567     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
568     val elevate : 'a Wrapped.m -> 'a m
569     val throw : err -> 'a m
570     val catch : 'a m -> (err -> 'a m) -> 'a m
571   end
572 end = struct
573   type err = Err.err
574   type 'a error = Error of err | Success of 'a
575   module Base = struct
576     type 'a m = 'a error
577     let unit a = Success a
578     let bind u f = match u with
579       | Success a -> f a
580       | Error e -> Error e (* input and output may be of different 'a types *)
581     type 'a result = 'a
582     (* TODO: should run refrain from failing? *)
583     let run u = match u with
584       | Success a -> a
585       | Error e -> raise (Err.Exc e)
586     type 'a result_exn = 'a
587     let run_exn = run
588     (*
589     let zero () = Error Err.zero
590     let plus u v = match (u, v) with
591       | Success _, _ -> u
592       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
593        * otherwise, plus (Error _) v = v *)
594       | Error _, _ when v = zero -> u
595       (* combine errors *)
596       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
597       | Error _, _ -> v
598     *)
599   end
600   include Monad.Make(Base)
601   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
602   let throw e = Error e
603   let catch u handler = match u with
604     | Success _ -> u
605     | Error e -> handler e
606   module T(Wrapped : Monad.S) = struct
607     module Trans = struct
608       module Wrapped = Wrapped
609       type 'a m = 'a Base.m Wrapped.m
610       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
611       let bind u f = Wrapped.bind u (fun t -> match t with
612         | Success a -> f a
613         | Error e -> Wrapped.unit (Error e))
614       type 'a result = 'a Wrapped.result
615       (* TODO: should run refrain from failing? *)
616       let run u =
617         let w = Wrapped.bind u (fun t -> match t with
618           | Success a -> Wrapped.unit a
619           (* | _ -> Wrapped.fail () *)
620           | Error e -> raise (Err.Exc e))
621         in Wrapped.run w
622       type 'a result_exn = 'a Wrapped.result_exn
623       let run_exn u =
624         let w = Wrapped.bind u (fun t -> match t with
625           | Success a -> Wrapped.unit a
626           (* | _ -> Wrapped.fail () *)
627           | Error e -> raise (Err.Exc e))
628         in Wrapped.run_exn w
629     end
630     include Monad.MakeT(Trans)
631     let throw e = Wrapped.unit (Error e)
632     let catch u handler = Wrapped.bind u (fun t -> match t with
633       | Success _ -> Wrapped.unit t
634       | Error e -> handler e)
635   end
636 end
637
638 (* pre-define common instance of Error_monad *)
639 module Failure = Error_monad(struct
640   type err = string
641   exception Exc = Failure
642   (*
643   let zero = ""
644   let plus s1 s2 = s1 ^ "\n" ^ s2
645   *)
646 end)
647
648 (* must be parameterized on (struct type env = ... end) *)
649 module Reader_monad(Env : sig type env end) : sig
650   (* declare additional operations, while still hiding implementation of type m *)
651   type env = Env.env
652   type 'a result = env -> 'a
653   type 'a result_exn = env -> 'a
654   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
655   val ask : env m
656   val asks : (env -> 'a) -> 'a m
657   val local : (env -> env) -> 'a m -> 'a m
658   (* ReaderT transformer *)
659   module T : functor (Wrapped : Monad.S) -> sig
660     type 'a result = env -> 'a Wrapped.result
661     type 'a result_exn = env -> 'a Wrapped.result_exn
662     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
663     val elevate : 'a Wrapped.m -> 'a m
664     val ask : env m
665     val asks : (env -> 'a) -> 'a m
666     val local : (env -> env) -> 'a m -> 'a m
667   end
668   (* ReaderT transformer when wrapped monad has plus, zero *)
669   module TP : functor (Wrapped : Monad.P) -> sig
670     include module type of T(Wrapped)
671     include Monad.PLUS with type 'a m := 'a m
672   end
673 end = struct
674   type env = Env.env
675   module Base = struct
676     type 'a m = env -> 'a
677     let unit a = fun e -> a
678     let bind u f = fun e -> let a = u e in let u' = f a in u' e
679     type 'a result = env -> 'a
680     let run u = fun e -> u e
681     type 'a result_exn = env -> 'a
682     let run_exn = run
683   end
684   include Monad.Make(Base)
685   let ask = fun e -> e
686   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
687   let local modifier u = fun e -> u (modifier e)
688   module T(Wrapped : Monad.S) = struct
689     module Trans = struct
690       module Wrapped = Wrapped
691       type 'a m = env -> 'a Wrapped.m
692       let elevate w = fun e -> w
693       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
694       type 'a result = env -> 'a Wrapped.result
695       let run u = fun e -> Wrapped.run (u e)
696       type 'a result_exn = env -> 'a Wrapped.result_exn
697       let run_exn u = fun e -> Wrapped.run_exn (u e)
698     end
699     include Monad.MakeT(Trans)
700     let ask = fun e -> Wrapped.unit e
701     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
702     let local modifier u = fun e -> u (modifier e)
703   end
704   module TP(Wrapped : Monad.P) = struct
705     module TransP = struct
706       include T(Wrapped)
707       let plus u v = fun s -> Wrapped.plus (u s) (v s)
708       let zero () = elevate (Wrapped.zero ())
709       let asks selector = ask >>= (fun e ->
710         try unit (selector e)
711         with Not_found -> fun e -> Wrapped.zero ())
712     end
713     include TransP
714     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
715   end
716 end
717
718
719 (* must be parameterized on (struct type store = ... end) *)
720 module State_monad(Store : sig type store end) : sig
721   (* declare additional operations, while still hiding implementation of type m *)
722   type store = Store.store
723   type 'a result = store -> 'a * store
724   type 'a result_exn = store -> 'a
725   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
726   val get : store m
727   val gets : (store -> 'a) -> 'a m
728   val put : store -> unit m
729   val puts : (store -> store) -> unit m
730   (* StateT transformer *)
731   module T : functor (Wrapped : Monad.S) -> sig
732     type 'a result = store -> ('a * store) Wrapped.result
733     type 'a result_exn = store -> 'a Wrapped.result_exn
734     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
735     val elevate : 'a Wrapped.m -> 'a m
736     val get : store m
737     val gets : (store -> 'a) -> 'a m
738     val put : store -> unit m
739     val puts : (store -> store) -> unit m
740   end
741   (* StateT transformer when wrapped monad has plus, zero *)
742   module TP : functor (Wrapped : Monad.P) -> sig
743     include module type of T(Wrapped)
744     include Monad.PLUS with type 'a m := 'a m
745   end
746 end = struct
747   type store = Store.store
748   module Base = struct
749     type 'a m = store -> 'a * store
750     let unit a = fun s -> (a, s)
751     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
752     type 'a result = store -> 'a * store
753     let run u = fun s -> (u s)
754     type 'a result_exn = store -> 'a
755     let run_exn u = fun s -> fst (u s)
756   end
757   include Monad.Make(Base)
758   let get = fun s -> (s, s)
759   let gets viewer = fun s -> (viewer s, s) (* may fail *)
760   let put s = fun _ -> ((), s)
761   let puts modifier = fun s -> ((), modifier s)
762   module T(Wrapped : Monad.S) = struct
763     module Trans = struct
764       module Wrapped = Wrapped
765       type 'a m = store -> ('a * store) Wrapped.m
766       let elevate w = fun s ->
767         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
768       let bind u f = fun s ->
769         Wrapped.bind (u s) (fun (a, s') -> f a s')
770       type 'a result = store -> ('a * store) Wrapped.result
771       let run u = fun s -> Wrapped.run (u s)
772       type 'a result_exn = store -> 'a Wrapped.result_exn
773       let run_exn u = fun s ->
774         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
775         in Wrapped.run_exn w
776     end
777     include Monad.MakeT(Trans)
778     let get = fun s -> Wrapped.unit (s, s)
779     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
780     let put s = fun _ -> Wrapped.unit ((), s)
781     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
782   end
783   module TP(Wrapped : Monad.P) = struct
784     module TransP = struct
785       include T(Wrapped)
786       let plus u v = fun s -> Wrapped.plus (u s) (v s)
787       let zero () = elevate (Wrapped.zero ())
788     end
789     let gets viewer = fun s ->
790       try Wrapped.unit (viewer s, s)
791       with Not_found -> Wrapped.zero ()
792     include TransP
793     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
794   end
795 end
796
797 (* State monad with different interface (structured store) *)
798 module Ref_monad(V : sig
799   type value
800 end) : sig
801   type ref
802   type value = V.value
803   type 'a result = 'a
804   type 'a result_exn = 'a
805   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
806   val newref : value -> ref m
807   val deref : ref -> value m
808   val change : ref -> value -> unit m
809   (* RefT transformer *)
810   module T : functor (Wrapped : Monad.S) -> sig
811     type 'a result = 'a Wrapped.result
812     type 'a result_exn = 'a Wrapped.result_exn
813     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
814     val elevate : 'a Wrapped.m -> 'a m
815     val newref : value -> ref m
816     val deref : ref -> value m
817     val change : ref -> value -> unit m
818   end
819   (* RefT transformer when wrapped monad has plus, zero *)
820   module TP : functor (Wrapped : Monad.P) -> sig
821     include module type of T(Wrapped)
822     include Monad.PLUS with type 'a m := 'a m
823   end
824 end = struct
825   type ref = int
826   type value = V.value
827   module D = Map.Make(struct type t = ref let compare = compare end)
828   type dict = { next: ref; tree : value D.t }
829   let empty = { next = 0; tree = D.empty }
830   let alloc (value : value) (d : dict) =
831     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
832   let read (key : ref) (d : dict) =
833     D.find key d.tree
834   let write (key : ref) (value : value) (d : dict) =
835     { next = d.next; tree = D.add key value d.tree }
836   module Base = struct
837     type 'a m = dict -> 'a * dict
838     let unit a = fun s -> (a, s)
839     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
840     type 'a result = 'a
841     let run u = fst (u empty)
842     type 'a result_exn = 'a
843     let run_exn = run
844   end
845   include Monad.Make(Base)
846   let newref value = fun s -> alloc value s
847   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
848   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
849   module T(Wrapped : Monad.S) = struct
850     module Trans = struct
851       module Wrapped = Wrapped
852       type 'a m = dict -> ('a * dict) Wrapped.m
853       let elevate w = fun s ->
854         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
855       let bind u f = fun s ->
856         Wrapped.bind (u s) (fun (a, s') -> f a s')
857       type 'a result = 'a Wrapped.result
858       let run u =
859         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
860         in Wrapped.run w
861       type 'a result_exn = 'a Wrapped.result_exn
862       let run_exn u =
863         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
864         in Wrapped.run_exn w
865     end
866     include Monad.MakeT(Trans)
867     let newref value = fun s -> Wrapped.unit (alloc value s)
868     let deref key = fun s -> Wrapped.unit (read key s, s)
869     let change key value = fun s -> Wrapped.unit ((), write key value s)
870   end
871   module TP(Wrapped : Monad.P) = struct
872     module TransP = struct
873       include T(Wrapped)
874       let plus u v = fun s -> Wrapped.plus (u s) (v s)
875       let zero () = elevate (Wrapped.zero ())
876     end
877     include TransP
878     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
879   end
880 end
881
882
883 (* must be parameterized on (struct type log = ... end) *)
884 module Writer_monad(Log : sig
885   type log
886   val zero : log
887   val plus : log -> log -> log
888 end) : sig
889   (* declare additional operations, while still hiding implementation of type m *)
890   type log = Log.log
891   type 'a result = 'a * log
892   type 'a result_exn = 'a * log
893   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
894   val tell : log -> unit m
895   val listen : 'a m -> ('a * log) m
896   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
897   (* val pass : ('a * (log -> log)) m -> 'a m *)
898   val censor : (log -> log) -> 'a m -> 'a m
899 end = struct
900   type log = Log.log
901   module Base = struct
902     type 'a m = 'a * log
903     let unit a = (a, Log.zero)
904     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
905     type 'a result = 'a * log
906     let run u = u
907     type 'a result_exn = 'a * log
908     let run_exn = run
909   end
910   include Monad.Make(Base)
911   let tell entries = ((), entries) (* add entries to log *)
912   let listen (a, w) = ((a, w), w)
913   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
914   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
915   let censor f u = pass (u >>= fun a -> unit (a, f))
916 end
917
918 (* pre-define simple Writer *)
919 module Writer1 = Writer_monad(struct
920   type log = string
921   let zero = ""
922   let plus s1 s2 = s1 ^ "\n" ^ s2
923 end)
924
925 (* slightly more efficient Writer *)
926 module Writer2 = struct
927   include Writer_monad(struct
928     type log = string list
929     let zero = []
930     let plus w w' = Util.append w' w
931   end)
932   let tell_string s = tell [s]
933   let tell entries = tell (Util.reverse entries)
934   let run u = let (a, w) = run u in (a, Util.reverse w)
935   let run_exn = run
936 end
937
938
939 module IO_monad : sig
940   (* declare additional operation, while still hiding implementation of type m *)
941   type 'a result = 'a
942   type 'a result_exn = 'a
943   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
944   val printf : ('a, unit, string, unit m) format4 -> 'a
945   val print_string : string -> unit m
946   val print_int : int -> unit m
947   val print_hex : int -> unit m
948   val print_bool : bool -> unit m
949 end = struct
950   module Base = struct
951     type 'a m = { run : unit -> unit; value : 'a }
952     let unit a = { run = (fun () -> ()); value = a }
953     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
954      let fres = f a.value in
955        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
956     type 'a result = 'a
957     let run a = let () = a.run () in a.value
958     type 'a result_exn = 'a
959     let run_exn = run
960   end
961   include Monad.Make(Base)
962   let printf fmt =
963     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
964   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
965   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
966   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
967   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
968 end
969
970 module Continuation_monad : sig
971   (* expose only the implementation of type `('r,'a) result` *)
972   type 'a m
973   type 'a result = 'a m
974   type 'a result_exn = 'a m
975   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
976   (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
977   (* misses that the answer types of all the cont's must be the same *)
978   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
979   (* val reset : ('a,'a) m -> ('r,'a) m *)
980   val reset : 'a m -> 'a m
981   (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
982   (* misses that the answer types of second and third continuations must be b *)
983   val shift : (('a -> 'b m) -> 'b m) -> 'a m
984   (* overwrite the run declaration in S, because I can't declare 'a result =
985    * this polymorphic type (complains that 'r is unbound *)
986   val runk : 'a m -> ('a -> 'r) -> 'r
987   val run0 : 'a m -> 'a
988 end = struct
989   let id = fun i -> i
990   module Base = struct
991     (* 'r is result type of whole computation *)
992     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
993     let unit a =
994       let cont : 'r. ('a -> 'r) -> 'r =
995         fun k -> k a
996       in { cont }
997     let bind u f =
998       let cont : 'r. ('a -> 'r) -> 'r =
999         fun k -> u.cont (fun a -> (f a).cont k)
1000       in { cont }
1001     type 'a result = 'a m
1002     let run (u : 'a m) : 'a result = u
1003     type 'a result_exn = 'a m
1004     let run_exn (u : 'a m) : 'a result_exn = u
1005     let callcc f =
1006       let cont : 'r. ('a -> 'r) -> 'r =
1007           (* Can't figure out how to make the type polymorphic enough
1008            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
1009            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
1010            * with Obj.magic... which tells OCaml's type checker to
1011            * relax, the supplied value has whatever type the context
1012            * needs it to have. *)
1013           fun k ->
1014           let usek a = { cont = Obj.magic (fun _ -> k a) }
1015           in (f usek).cont k
1016       in { cont }
1017     let reset u = unit (u.cont id)
1018     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
1019       let cont = fun k ->
1020         (f (fun a -> unit (k a))).cont id
1021       in { cont = Obj.magic cont }
1022     let runk u k = (u.cont : ('a -> 'r) -> 'r) k
1023     let run0 u = runk u id
1024   end
1025   include Monad.Make(Base)
1026   let callcc = Base.callcc
1027   let reset = Base.reset
1028   let shift = Base.shift
1029   let runk = Base.runk
1030   let run0 = Base.run0
1031 end
1032
1033 (*
1034 (* This two-type parameter version works without Obj.magic *)
1035
1036 module Continuation_monad2 : sig
1037   (* expose only the implementation of type `('r,'a) result` *)
1038   type ('r,'a) result = ('a -> 'r) -> 'r
1039   type ('r,'a) result_exn = ('a -> 'r) -> 'r
1040   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn
1041   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
1042   val reset : ('a,'a) m -> ('r,'a) m
1043   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
1044
1045 end = struct
1046   let id = fun i -> i
1047   module Base = struct
1048     (* 'r is result type of whole computation *)
1049     type ('r,'a) m = ('a -> 'r) -> 'r
1050     let unit a = fun k -> k a
1051     let bind u f = fun k -> u (fun a -> (f a) k)
1052     type ('r,'a) result = ('a -> 'r) -> 'r
1053     let run u = u
1054     type ('r,'a) result_exn = ('a -> 'r) -> 'r
1055     let run_exn = run
1056   end
1057   include Monad.Make2(Base)
1058   let callcc f = fun k ->
1059     let usek a = fun _ -> k a
1060     in f usek k
1061   (*
1062   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
1063   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
1064   let callcc f = fun k -> f k k
1065   let throw k a = fun _ -> k a
1066   *)
1067   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *)
1068   let reset u = unit (u id)
1069   let shift u = fun k -> u (fun a -> unit (k a)) id
1070 end
1071  *)
1072
1073
1074 (*
1075  * Scheme:
1076  * (define (example n)
1077  *    (let ([u (let/cc k ; type int -> int pair
1078  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
1079  *                 (+ 1 (car v))))]) ; int
1080  *      (cons u 0))) ; int pair
1081  * ; (example 10) ~~> '(111 . 0)
1082  * ; (example -10) ~~> '(0 . 0)
1083  *
1084  * OCaml monads:
1085  * let example n : (int * int) =
1086  *   Continuation_monad.(let u = callcc (fun k ->
1087  *       (if n < 0 then k 0 else unit [n + 100])
1088  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
1089  *       >>= fun [x] -> unit (x + 1)
1090  *   )
1091  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1092  *   >>= fun x -> unit (x, 0)
1093  *   in run u)
1094  *
1095  *
1096  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1097  * let example1 () : int =
1098  *   Continuation_monad.(let v = reset (
1099  *       let u = shift (fun k -> unit (10 + 1))
1100  *       in u >>= fun x -> unit (100 + x)
1101  *     ) in let w = v >>= fun x -> unit (1000 + x)
1102  *     in run w)
1103  *
1104  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1105  * let example2 () =
1106  *   Continuation_monad.(let v = reset (
1107  *       let u = shift (fun k -> k (10 :: [1]))
1108  *       in u >>= fun x -> unit (100 :: x)
1109  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1110  *     in run w)
1111  *
1112  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1113  * let example3 () =
1114  *   Continuation_monad.(let v = reset (
1115  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1116  *       in u >>= fun x -> unit (100 :: x)
1117  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1118  *     in run w)
1119  *
1120  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1121  * (* not sure if this example can be typed without a sum-type *)
1122  *
1123  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1124  * let example5 () : int =
1125  *   Continuation_monad.(let v = reset (
1126  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1127  *       in u >>= fun x -> unit (10 + x)
1128  *     ) in let w = v >>= fun x -> unit (100 + x)
1129  *     in run w)
1130  *
1131  *)
1132
1133
1134 module Leaf_monad : sig
1135   (* We implement the type as `'a tree option` because it has a natural`plus`,
1136    * and the rest of the library expects that `plus` and `zero` will come together. *)
1137   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1138   type 'a result = 'a tree option
1139   type 'a result_exn = 'a tree
1140   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1141   include Monad.PLUS with type 'a m := 'a m
1142   (* LeafT transformer *)
1143   module T : functor (Wrapped : Monad.S) -> sig
1144     type 'a result = 'a tree option Wrapped.result
1145     type 'a result_exn = 'a tree Wrapped.result_exn
1146     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1147     include Monad.PLUS with type 'a m := 'a m
1148     val elevate : 'a Wrapped.m -> 'a m
1149     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1150     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1151     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1152   end
1153 end = struct
1154   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1155   (* uses supplied plus and zero to copy t to its image under f *)
1156   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1157       | None -> zero ()
1158       | Some ts -> let rec loop ts = (match ts with
1159                      | Leaf a -> f a
1160                      | Node (l, r) ->
1161                          (* recursive application of f may delete a branch *)
1162                          plus (loop l) (loop r)
1163                    ) in loop ts
1164   module Base = struct
1165     type 'a m = 'a tree option
1166     let unit a = Some (Leaf a)
1167     let zero () = None
1168     let plus u v = match (u, v) with
1169       | None, _ -> v
1170       | _, None -> u
1171       | Some us, Some vs -> Some (Node (us, vs))
1172     let bind u f = mapT f u zero plus
1173     type 'a result = 'a tree option
1174     let run u = u
1175     type 'a result_exn = 'a tree
1176     let run_exn u = match u with
1177       | None -> failwith "no values"
1178       (*
1179       | Some (Leaf a) -> a
1180       | many -> failwith "multiple values"
1181       *)
1182       | Some us -> us
1183   end
1184   include Monad.Make(Base)
1185   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1186   let base_plus = plus
1187   let base_lift = lift
1188   module T(Wrapped : Monad.S) = struct
1189     module Trans = struct
1190       let zero () = Wrapped.unit None
1191       let plus u v =
1192         Wrapped.bind u (fun us ->
1193         Wrapped.bind v (fun vs ->
1194         Wrapped.unit (base_plus us vs)))
1195       include Monad.MakeT(struct
1196         module Wrapped = Wrapped
1197         type 'a m = 'a Base.m Wrapped.m
1198         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1199         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1200         type 'a result = 'a tree option Wrapped.result
1201         let run u = Wrapped.run u
1202         type 'a result_exn = 'a tree Wrapped.result_exn
1203         let run_exn u =
1204             let w = Wrapped.bind u (fun t -> match t with
1205               | None -> failwith "no values"
1206               | Some ts -> Wrapped.unit ts)
1207             in Wrapped.run_exn w
1208       end)
1209     end
1210     include Trans
1211     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1212     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1213     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1214   end
1215 end
1216
1217
1218 module L = List_monad;;
1219 module R = Reader_monad(struct type env = int -> int end);;
1220 module S = State_monad(struct type store = int end);;
1221 module T = Leaf_monad;;
1222 module LR = L.T(R);;
1223 module LS = L.T(S);;
1224 module TL = T.T(L);;
1225 module TR = T.T(R);;
1226 module TS = T.T(S);;
1227
1228 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1229
1230 (*
1231 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1232 TS.run ts 0;;
1233 (*
1234 - : int T.tree option * S.store =
1235 (Some
1236   (T.Node
1237     (T.Node (T.Leaf 2, T.Leaf 3),
1238      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1239  5)
1240 *)
1241
1242 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1243 TS.run_exn ts2 0;;
1244 (*
1245 - : (int * S.store) T.tree option * S.store =
1246 (Some
1247   (T.Node
1248     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1249      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1250  5)
1251 *)
1252
1253 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1254 TR.run_exn tr (fun i -> i+i);;
1255 (*
1256 - : int T.tree option =
1257 Some
1258  (T.Node
1259    (T.Node (T.Leaf 4, T.Leaf 6),
1260     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1261 *)
1262
1263 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1264 TL.run_exn tl;;
1265 (*
1266 - : (int * int) TL.result =
1267 [Some
1268   (T.Node
1269     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1270      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1271 *)
1272
1273 let l2 = [1;2;3;4;5];;
1274 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1275
1276 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1277 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1278
1279 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1280 (*
1281 int T.tree option =
1282 Some
1283  (T.Node
1284    (T.Node (T.Leaf 10, T.Leaf 11),
1285     T.Node
1286      (T.Node
1287        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1288         T.Node (T.Leaf 40, T.Leaf 41)),
1289       T.Node (T.Leaf 50, T.Leaf 51))))
1290  *)
1291
1292 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1293 (*
1294 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1295 *)
1296
1297 *)
1298
1299 let id : 'z. 'z -> 'z = fun x -> x
1300
1301 let example n : (int * int) =
1302   Continuation_monad.(let u = callcc (fun k ->
1303       (if n < 0 then k 0 else unit [n + 100])
1304       (* all of the following is skipped by k 0; the end type int is k's input type *)
1305       >>= fun [x] -> unit (x + 1)
1306   )
1307   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1308   >>= fun x -> unit (x, 0)
1309   in run0 u)
1310
1311
1312 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1313 let example1 () : int =
1314   Continuation_monad.(let v = reset (
1315       let u = shift (fun k -> unit (10 + 1))
1316       in u >>= fun x -> unit (100 + x)
1317     ) in let w = v >>= fun x -> unit (1000 + x)
1318     in run0 w)
1319
1320 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1321 let example2 () =
1322   Continuation_monad.(let v = reset (
1323       let u = shift (fun k -> k (10 :: [1]))
1324       in u >>= fun x -> unit (100 :: x)
1325     ) in let w = v >>= fun x -> unit (1000 :: x)
1326     in run0 w)
1327
1328 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1329 let example3 () =
1330   Continuation_monad.(let v = reset (
1331       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1332       in u >>= fun x -> unit (100 :: x)
1333     ) in let w = v >>= fun x -> unit (1000 :: x)
1334     in run0 w)
1335
1336 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1337 (* not sure if this example can be typed without a sum-type *)
1338
1339 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1340 let example5 () : int =
1341   Continuation_monad.(let v = reset (
1342       let u = shift (fun k -> k 1 >>= fun x -> k x)
1343       in u >>= fun x -> unit (10 + x)
1344     ) in let w = v >>= fun x -> unit (100 + x)
1345     in run0 w)
1346
1347
1348 ;;
1349
1350 (1011, 1111, 1111, 121);;
1351 (example1(), example2(), example3(), example5());;
1352 ((111,0), (0,0));;
1353 (example ~+10, example ~-10);;
1354
1355 module C = Continuation_monad
1356 module TC = T.T(C)
1357
1358 let testc df ic =
1359     C.runk TC.(run_exn (distribute df t1)) ic;;
1360
1361
1362 (*
1363 (* do nothing *)
1364 let initial_continuation = fun t -> t in
1365 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1366 *)
1367 testc (C.unit) id;;
1368
1369 (*
1370 (* count leaves, using continuation *)
1371 let initial_continuation = fun t -> 0 in
1372 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1373 *)
1374
1375 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1376
1377 (*
1378 (* convert tree to list of leaves *)
1379 let initial_continuation = fun t -> [] in
1380 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1381 *)
1382
1383 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1384
1385 (*
1386 (* square each leaf using continuation *)
1387 let initial_continuation = fun t -> t in
1388 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1389 *)
1390
1391 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1392
1393
1394 (*
1395 (* replace leaves with list, using continuation *)
1396 let initial_continuation = fun t -> t in
1397 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1398 *)
1399
1400 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1401