tweak monads-lib
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73   (*
74    * Signature extenders:
75    *   Make :: BASE -> S
76    *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
77    *                                           which merges into S
78    *                                           (P is merged sig)
79    *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
80    *
81    *   Make2 :: BASE2 -> S2
82    *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
83    *   to wrap double-typed inner monads:
84    *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
85    *
86    *)
87
88
89
90   (* type of base definitions *)
91   module type BASE = sig
92     (* The only constraints we impose here on how the monadic type
93      * is implemented is that it have a single type parameter 'a. *)
94     type 'a m
95     type 'a result
96     type 'a result_exn
97     val unit : 'a -> 'a m
98     val bind : 'a m -> ('a -> 'b m) -> 'b m
99     val run : 'a m -> 'a result
100     (* run_exn tries to provide a more ground-level result, but may fail *)
101     val run_exn : 'a m -> 'a result_exn
102   end
103   module type S = sig
104     include BASE
105     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
106     val (>>) : 'a m -> 'b m -> 'b m
107     val join : ('a m) m -> 'a m
108     val apply : ('a -> 'b) m -> 'a m -> 'b m
109     val lift : ('a -> 'b) -> 'a m -> 'b m
110     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
111     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
112     val do_when :  bool -> unit m -> unit m
113     val do_unless :  bool -> unit m -> unit m
114     val forever : (unit -> 'a m) -> 'b m
115     val sequence : 'a m list -> 'a list m
116     val sequence_ : 'a m list -> unit m
117   end
118
119   (* Standard, single-type-parameter monads. *)
120   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
121     include B
122     let (>>=) = bind
123     let (>>) u v = u >>= fun _ -> v
124     let lift f u = u >>= fun a -> unit (f a)
125     (* lift is called listM, fmap, and <$> in Haskell *)
126     let join uu = uu >>= fun u -> u
127     (* u >>= f === join (lift f u) *)
128     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
129     (* [f] <*> [x1,x2] = [f x1,f x2] *)
130     (* let apply u v = u >>= fun f -> lift f v *)
131     (* let apply = lift2 id *)
132     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
133     (* let lift f u === apply (unit f) u *)
134     (* let lift2 f u v = apply (lift f u) v *)
135     let (>=>) f g = fun a -> f a >>= g
136     let do_when test u = if test then u else unit ()
137     let do_unless test u = if test then unit () else u
138     let forever uthunk =
139         let rec loop () = uthunk () >>= fun _ -> loop ()
140         in loop ()
141     let sequence ms =
142       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
143         Util.fold_right op ms (unit [])
144     let sequence_ ms =
145       Util.fold_right (>>) ms (unit ())
146
147     (* Haskell defines these other operations combining lists and monads.
148      * We don't, but notice that M.mapM == ListT(M).distribute
149      * There's also a parallel TreeT(M).distribute *)
150     (*
151     let mapM f alist = sequence (Util.map f alist)
152     let mapM_ f alist = sequence_ (Util.map f alist)
153     let rec filterM f lst = match lst with
154       | [] -> unit []
155       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
156     let forM alist f = mapM f alist
157     let forM_ alist f = mapM_ f alist
158     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
159     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
160     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
161     let rec foldM f z lst = match lst with
162       | [] -> unit z
163       | x::xs -> f z x >>= fun z' -> foldM f z' xs
164     let foldM_ f z xs = foldM f z xs >> unit ()
165     let replicateM n x = sequence (Util.replicate n x)
166     let replicateM_ n x = sequence_ (Util.replicate n x)
167     *)
168   end
169
170   (* Single-type-parameter monads that also define `plus` and `zero`
171    * operations. These obey the following laws:
172    *     zero >>= f   ===  zero
173    *     plus zero u  ===  u
174    *     plus u zero  ===  u
175    * Additionally, these monads will obey one of the following laws:
176    *     (Catch)   plus (unit a) v  ===  unit a
177    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
178    *)
179   module type PLUSBASE = sig
180     include BASE
181     val zero : unit -> 'a m
182     val plus : 'a m -> 'a m -> 'a m
183   end
184   module type PLUS = sig
185     type 'a m
186     val zero : unit -> 'a m
187     val plus : 'a m -> 'a m -> 'a m
188     val guard : bool -> unit m
189     val sum : 'a m list -> 'a m
190   end
191   (* MakeCatch and MakeDistrib have the same implementation; we just declare
192    * them twice to document which laws the client code is promising to honor. *)
193   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
194     type 'a m = 'a B.m
195     let zero = B.zero
196     let plus = B.plus
197     let guard test = if test then B.unit () else zero ()
198     let sum ms = Util.fold_right plus ms (zero ())
199   end
200   module MakeDistrib = MakeCatch
201
202   (* Signatures for MonadT *)
203   (* sig for Wrapped that include S and PLUS *)
204   module type P = sig
205     include S
206     include PLUS with type 'a m := 'a m
207   end
208   module type TRANS = sig
209     module Wrapped : S
210     type 'a m
211     type 'a result
212     type 'a result_exn
213     val bind : 'a m -> ('a -> 'b m) -> 'b m
214     val run : 'a m -> 'a result
215     val run_exn : 'a m -> 'a result_exn
216     val elevate : 'a Wrapped.m -> 'a m
217     (* lift/elevate laws:
218      *     elevate (W.unit a) == unit a
219      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
220      *)
221   end
222   module MakeT(T : TRANS) = struct
223     include Make(struct
224         include T
225         let unit a = elevate (Wrapped.unit a)
226     end)
227     let elevate = T.elevate
228   end
229
230
231   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
232   module type BASE2 = sig
233     type ('x,'a) m
234     type ('x,'a) result
235     type ('x,'a) result_exn
236     val unit : 'a -> ('x,'a) m
237     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
238     val run : ('x,'a) m -> ('x,'a) result
239     val run_exn : ('x,'a) m -> ('x,'a) result_exn
240   end
241   module type S2 = sig
242     include BASE2
243     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
244     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
245     val join : ('x,('x,'a) m) m -> ('x,'a) m
246     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
247     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
248     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
249     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
250     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
251     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
252     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
253     val sequence : ('x,'a) m list -> ('x,'a list) m
254     val sequence_ : ('x,'a) m list -> ('x,unit) m
255   end
256   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
257     (* code repetition, ugh *)
258     include B
259     let (>>=) = bind
260     let (>>) u v = u >>= fun _ -> v
261     let lift f u = u >>= fun a -> unit (f a)
262     let join uu = uu >>= fun u -> u
263     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
264     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
265     let (>=>) f g = fun a -> f a >>= g
266     let do_when test u = if test then u else unit ()
267     let do_unless test u = if test then unit () else u
268     let forever uthunk =
269         let rec loop () = uthunk () >>= fun _ -> loop ()
270         in loop ()
271     let sequence ms =
272       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
273         Util.fold_right op ms (unit [])
274     let sequence_ ms =
275       Util.fold_right (>>) ms (unit ())
276   end
277
278   module type PLUSBASE2 = sig
279     include BASE2
280     val zero : unit -> ('x,'a) m
281     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
282   end
283   module type PLUS2 = sig
284     type ('x,'a) m
285     val zero : unit -> ('x,'a) m
286     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
287     val guard : bool -> ('x,unit) m
288     val sum : ('x,'a) m list -> ('x,'a) m
289   end
290   module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
291     type ('x,'a) m = ('x,'a) B.m
292     (* code repetition, ugh *)
293     let zero = B.zero
294     let plus = B.plus
295     let guard test = if test then B.unit () else zero ()
296     let sum ms = Util.fold_right plus ms (zero ())
297   end
298   module MakeDistrib2 = MakeCatch2
299
300   (* Signatures for MonadT *)
301   (* sig for Wrapped that include S and PLUS *)
302   module type P2 = sig
303     include S2
304     include PLUS2 with type ('x,'a) m := ('x,'a) m
305   end
306   module type TRANS2 = sig
307     module Wrapped : S2
308     type ('x,'a) m
309     type ('x,'a) result
310     type ('x,'a) result_exn
311     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
312     val run : ('x,'a) m -> ('x,'a) result
313     val run_exn : ('x,'a) m -> ('x,'a) result_exn
314     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
315   end
316   module MakeT2(T : TRANS2) = struct
317     (* code repetition, ugh *)
318     include Make2(struct
319         include T
320         let unit a = elevate (Wrapped.unit a)
321     end)
322     let elevate = T.elevate
323   end
324
325 end
326
327
328
329
330
331 module Identity_monad : sig
332   (* expose only the implementation of type `'a result` *)
333   type 'a result = 'a
334   type 'a result_exn = 'a
335   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
336 end = struct
337   module Base = struct
338     type 'a m = 'a
339     type 'a result = 'a
340     type 'a result_exn = 'a
341     let unit a = a
342     let bind a f = f a
343     let run a = a
344     let run_exn a = a
345   end
346   include Monad.Make(Base)
347 end
348
349
350 module Maybe_monad : sig
351   (* expose only the implementation of type `'a result` *)
352   type 'a result = 'a option
353   type 'a result_exn = 'a
354   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
355   include Monad.PLUS with type 'a m := 'a m
356   (* MaybeT transformer *)
357   module T : functor (Wrapped : Monad.S) -> sig
358     type 'a result = 'a option Wrapped.result
359     type 'a result_exn = 'a Wrapped.result_exn
360     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
361     include Monad.PLUS with type 'a m := 'a m
362     val elevate : 'a Wrapped.m -> 'a m
363   end
364   module T2 : functor (Wrapped : Monad.S2) -> sig
365     type ('x,'a) result = ('x,'a option) Wrapped.result
366     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
367     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
368     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
369     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
370   end
371 end = struct
372   module Base = struct
373    type 'a m = 'a option
374    type 'a result = 'a option
375    type 'a result_exn = 'a
376    let unit a = Some a
377    let bind u f = match u with Some a -> f a | None -> None
378    let run u = u
379    let run_exn u = match u with
380      | Some a -> a
381      | None -> failwith "no value"
382    let zero () = None
383    let plus u v = match u with None -> v | _ -> u
384   end
385   include Monad.Make(Base)
386   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
387   module T(Wrapped : Monad.S) = struct
388     module Trans = struct
389       include Monad.MakeT(struct
390         module Wrapped = Wrapped
391         type 'a m = 'a option Wrapped.m
392         type 'a result = 'a option Wrapped.result
393         type 'a result_exn = 'a Wrapped.result_exn
394         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
395         let bind u f = Wrapped.bind u (fun t -> match t with
396           | Some a -> f a
397           | None -> Wrapped.unit None)
398         let run u = Wrapped.run u
399         let run_exn u =
400           let w = Wrapped.bind u (fun t -> match t with
401             | Some a -> Wrapped.unit a
402             | None -> failwith "no value")
403           in Wrapped.run_exn w
404       end)
405       let zero () = Wrapped.unit None
406       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
407     end
408     include Trans
409     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
410   end
411   module T2(Wrapped : Monad.S2) = struct
412     module Trans = struct
413       include Monad.MakeT2(struct
414         module Wrapped = Wrapped
415         type ('x,'a) m = ('x,'a option) Wrapped.m
416         type ('x,'a) result = ('x,'a option) Wrapped.result
417         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
418         (* code repetition, ugh *)
419         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
420         let bind u f = Wrapped.bind u (fun t -> match t with
421           | Some a -> f a
422           | None -> Wrapped.unit None)
423         let run u = Wrapped.run u
424         let run_exn u =
425           let w = Wrapped.bind u (fun t -> match t with
426             | Some a -> Wrapped.unit a
427             | None -> failwith "no value")
428           in Wrapped.run_exn w
429       end)
430       let zero () = Wrapped.unit None
431       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
432     end
433     include Trans
434     include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
435   end
436 end
437
438
439 module List_monad : sig
440   (* declare additional operation, while still hiding implementation of type m *)
441   type 'a result = 'a list
442   type 'a result_exn = 'a
443   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
444   include Monad.PLUS with type 'a m := 'a m
445   val permute : 'a m -> 'a m m
446   val select : 'a m -> ('a * 'a m) m
447   (* ListT transformer *)
448   module T : functor (Wrapped : Monad.S) -> sig
449     type 'a result = 'a list Wrapped.result
450     type 'a result_exn = 'a Wrapped.result_exn
451     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
452     include Monad.PLUS with type 'a m := 'a m
453     val elevate : 'a Wrapped.m -> 'a m
454     (* note that second argument is an 'a list, not the more abstract 'a m *)
455     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
456     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
457 (* TODO
458     val permute : 'a m -> 'a m m
459     val select : 'a m -> ('a * 'a m) m
460 *)
461   end
462   module T2 : functor (Wrapped : Monad.S2) -> sig
463     type ('x,'a) result = ('x,'a list) Wrapped.result
464     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
465     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
466     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
467     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
468     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
469   end
470 end = struct
471   module Base = struct
472    type 'a m = 'a list
473    type 'a result = 'a list
474    type 'a result_exn = 'a
475    let unit a = [a]
476    let bind u f = Util.concat_map f u
477    let run u = u
478    let run_exn u = match u with
479      | [] -> failwith "no values"
480      | [a] -> a
481      | many -> failwith "multiple values"
482    let zero () = []
483    let plus = Util.append
484   end
485   include Monad.Make(Base)
486   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
487   (* let either u v = plus u v *)
488   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
489   let rec insert a u =
490     plus (unit (a :: u)) (match u with
491         | [] -> zero ()
492         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
493     )
494   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
495   let rec permute u = match u with
496       | [] -> unit []
497       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
498   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
499   let rec select u = match u with
500     | [] -> zero ()
501     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
502   let base_plus = plus
503   module T(Wrapped : Monad.S) = struct
504     module Trans = struct
505       (* Wrapped.sequence ms  ===  
506            let plus1 u v =
507              Wrapped.bind u (fun x ->
508              Wrapped.bind v (fun xs ->
509              Wrapped.unit (x :: xs)))
510            in Util.fold_right plus1 ms (Wrapped.unit []) *)
511       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
512       let distribute f alist = Wrapped.sequence (Util.map f alist)
513       include Monad.MakeT(struct
514         module Wrapped = Wrapped
515         type 'a m = 'a list Wrapped.m
516         type 'a result = 'a list Wrapped.result
517         type 'a result_exn = 'a Wrapped.result_exn
518         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
519         let bind u f =
520           Wrapped.bind u (fun ts ->
521           Wrapped.bind (distribute f ts) (fun tts ->
522           Wrapped.unit (Util.concat tts)))
523         let run u = Wrapped.run u
524         let run_exn u =
525           let w = Wrapped.bind u (fun ts -> match ts with
526             | [] -> failwith "no values"
527             | [a] -> Wrapped.unit a
528             | many -> failwith "multiple values"
529           ) in Wrapped.run_exn w
530       end)
531       let zero () = Wrapped.unit []
532       let plus u v =
533         Wrapped.bind u (fun us ->
534         Wrapped.bind v (fun vs ->
535         Wrapped.unit (base_plus us vs)))
536     end
537     include Trans
538     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
539 (*
540     let permute : 'a m -> 'a m m
541     let select : 'a m -> ('a * 'a m) m
542 *)
543   end
544   module T2(Wrapped : Monad.S2) = struct
545     module Trans = struct
546       let distribute f alist = Wrapped.sequence (Util.map f alist)
547       include Monad.MakeT2(struct
548         module Wrapped = Wrapped
549         type ('x,'a) m = ('x,'a list) Wrapped.m
550         type ('x,'a) result = ('x,'a list) Wrapped.result
551         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
552         (* code repetition, ugh *)
553         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
554         let bind u f =
555           Wrapped.bind u (fun ts ->
556           Wrapped.bind (distribute f ts) (fun tts ->
557           Wrapped.unit (Util.concat tts)))
558         let run u = Wrapped.run u
559         let run_exn u =
560           let w = Wrapped.bind u (fun ts -> match ts with
561             | [] -> failwith "no values"
562             | [a] -> Wrapped.unit a
563             | many -> failwith "multiple values"
564           ) in Wrapped.run_exn w
565       end)
566       let zero () = Wrapped.unit []
567       let plus u v =
568         Wrapped.bind u (fun us ->
569         Wrapped.bind v (fun vs ->
570         Wrapped.unit (base_plus us vs)))
571     end
572     include Trans
573     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
574   end
575 end
576
577
578 (* must be parameterized on (struct type err = ... end) *)
579 module Error_monad(Err : sig
580   type err
581   exception Exc of err
582   (*
583   val zero : unit -> err
584   val plus : err -> err -> err
585   *)
586 end) : sig
587   (* declare additional operations, while still hiding implementation of type m *)
588   type err = Err.err
589   type 'a error = Error of err | Success of 'a
590   type 'a result = 'a
591   type 'a result_exn = 'a
592   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
593   (* include Monad.PLUS with type 'a m := 'a m *)
594   val throw : err -> 'a m
595   val catch : 'a m -> (err -> 'a m) -> 'a m
596   (* ErrorT transformer *)
597   module T : functor (Wrapped : Monad.S) -> sig
598     type 'a result = 'a Wrapped.result
599     type 'a result_exn = 'a Wrapped.result_exn
600     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
601     val elevate : 'a Wrapped.m -> 'a m
602     val throw : err -> 'a m
603     val catch : 'a m -> (err -> 'a m) -> 'a m
604   end
605   module T2 : functor (Wrapped : Monad.S2) -> sig
606     type ('x,'a) result = ('x,'a) Wrapped.result
607     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
608     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
609     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
610     val throw : err -> ('x,'a) m
611     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
612   end
613 end = struct
614   type err = Err.err
615   type 'a error = Error of err | Success of 'a
616   module Base = struct
617     type 'a m = 'a error
618     type 'a result = 'a
619     type 'a result_exn = 'a
620     let unit a = Success a
621     let bind u f = match u with
622       | Success a -> f a
623       | Error e -> Error e (* input and output may be of different 'a types *)
624     (* TODO: should run refrain from failing? *)
625     let run u = match u with
626       | Success a -> a
627       | Error e -> raise (Err.Exc e)
628     let run_exn = run
629     (*
630     let zero () = Error Err.zero
631     let plus u v = match (u, v) with
632       | Success _, _ -> u
633       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
634        * otherwise, plus (Error _) v = v *)
635       | Error _, _ when v = zero -> u
636       (* combine errors *)
637       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
638       | Error _, _ -> v
639     *)
640   end
641   include Monad.Make(Base)
642   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
643   let throw e = Error e
644   let catch u handler = match u with
645     | Success _ -> u
646     | Error e -> handler e
647   module T(Wrapped : Monad.S) = struct
648     module Trans = struct
649       module Wrapped = Wrapped
650       type 'a m = 'a error Wrapped.m
651       type 'a result = 'a Wrapped.result
652       type 'a result_exn = 'a Wrapped.result_exn
653       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
654       let bind u f = Wrapped.bind u (fun t -> match t with
655         | Success a -> f a
656         | Error e -> Wrapped.unit (Error e))
657       (* TODO: should run refrain from failing? *)
658       let run u =
659         let w = Wrapped.bind u (fun t -> match t with
660           | Success a -> Wrapped.unit a
661           (* | _ -> Wrapped.fail () *)
662           | Error e -> raise (Err.Exc e))
663         in Wrapped.run w
664       let run_exn u =
665         let w = Wrapped.bind u (fun t -> match t with
666           | Success a -> Wrapped.unit a
667           (* | _ -> Wrapped.fail () *)
668           | Error e -> raise (Err.Exc e))
669         in Wrapped.run_exn w
670     end
671     include Monad.MakeT(Trans)
672     let throw e = Wrapped.unit (Error e)
673     let catch u handler = Wrapped.bind u (fun t -> match t with
674       | Success _ -> Wrapped.unit t
675       | Error e -> handler e)
676   end
677   module T2(Wrapped : Monad.S2) = struct
678     module Trans = struct
679       module Wrapped = Wrapped
680       type ('x,'a) m = ('x,'a error) Wrapped.m
681       type ('x,'a) result = ('x,'a) Wrapped.result
682       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
683       (* code repetition, ugh *)
684       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
685       let bind u f = Wrapped.bind u (fun t -> match t with
686         | Success a -> f a
687         | Error e -> Wrapped.unit (Error e))
688       let run u =
689         let w = Wrapped.bind u (fun t -> match t with
690           | Success a -> Wrapped.unit a
691           | Error e -> raise (Err.Exc e))
692         in Wrapped.run w
693       let run_exn u =
694         let w = Wrapped.bind u (fun t -> match t with
695           | Success a -> Wrapped.unit a
696           | Error e -> raise (Err.Exc e))
697         in Wrapped.run_exn w
698     end
699     include Monad.MakeT2(Trans)
700     let throw e = Wrapped.unit (Error e)
701     let catch u handler = Wrapped.bind u (fun t -> match t with
702       | Success _ -> Wrapped.unit t
703       | Error e -> handler e)
704   end
705 end
706
707 (* pre-define common instance of Error_monad *)
708 module Failure = Error_monad(struct
709   type err = string
710   exception Exc = Failure
711   (*
712   let zero = ""
713   let plus s1 s2 = s1 ^ "\n" ^ s2
714   *)
715 end)
716
717 (* must be parameterized on (struct type env = ... end) *)
718 module Reader_monad(Env : sig type env end) : sig
719   (* declare additional operations, while still hiding implementation of type m *)
720   type env = Env.env
721   type 'a result = env -> 'a
722   type 'a result_exn = env -> 'a
723   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
724   val ask : env m
725   val asks : (env -> 'a) -> 'a m
726   val local : (env -> env) -> 'a m -> 'a m
727   (* ReaderT transformer *)
728   module T : functor (Wrapped : Monad.S) -> sig
729     type 'a result = env -> 'a Wrapped.result
730     type 'a result_exn = env -> 'a Wrapped.result_exn
731     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
732     val elevate : 'a Wrapped.m -> 'a m
733     val ask : env m
734     val asks : (env -> 'a) -> 'a m
735     val local : (env -> env) -> 'a m -> 'a m
736   end
737   (* ReaderT transformer when wrapped monad has plus, zero *)
738   module TP : functor (Wrapped : Monad.P) -> sig
739     include module type of T(Wrapped)
740     include Monad.PLUS with type 'a m := 'a m
741   end
742   module T2 : functor (Wrapped : Monad.S2) -> sig
743     type ('x,'a) result = env -> ('x,'a) Wrapped.result
744     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
745     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
746     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
747     val ask : ('x,env) m
748     val asks : (env -> 'a) -> ('x,'a) m
749     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
750   end
751   module TP2 : functor (Wrapped : Monad.P2) -> sig
752     include module type of T2(Wrapped)
753     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
754   end
755 end = struct
756   type env = Env.env
757   module Base = struct
758     type 'a m = env -> 'a
759     type 'a result = env -> 'a
760     type 'a result_exn = env -> 'a
761     let unit a = fun e -> a
762     let bind u f = fun e -> let a = u e in let u' = f a in u' e
763     let run u = fun e -> u e
764     let run_exn = run
765   end
766   include Monad.Make(Base)
767   let ask = fun e -> e
768   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
769   let local modifier u = fun e -> u (modifier e)
770   module T(Wrapped : Monad.S) = struct
771     module Trans = struct
772       module Wrapped = Wrapped
773       type 'a m = env -> 'a Wrapped.m
774       type 'a result = env -> 'a Wrapped.result
775       type 'a result_exn = env -> 'a Wrapped.result_exn
776       let elevate w = fun e -> w
777       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
778       let run u = fun e -> Wrapped.run (u e)
779       let run_exn u = fun e -> Wrapped.run_exn (u e)
780     end
781     include Monad.MakeT(Trans)
782     let ask = fun e -> Wrapped.unit e
783     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
784     let local modifier u = fun e -> u (modifier e)
785   end
786   module TP(Wrapped : Monad.P) = struct
787     module TransP = struct
788       include T(Wrapped)
789       let plus u v = fun s -> Wrapped.plus (u s) (v s)
790       let zero () = elevate (Wrapped.zero ())
791       let asks selector = ask >>= (fun e ->
792         try unit (selector e)
793         with Not_found -> fun e -> Wrapped.zero ())
794     end
795     include TransP
796     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
797   end
798   module T2(Wrapped : Monad.S2) = struct
799     module Trans = struct
800       module Wrapped = Wrapped
801       type ('x,'a) m = env -> ('x,'a) Wrapped.m
802       type ('x,'a) result = env -> ('x,'a) Wrapped.result
803       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
804       (* code repetition, ugh *)
805       let elevate w = fun e -> w
806       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
807       let run u = fun e -> Wrapped.run (u e)
808       let run_exn u = fun e -> Wrapped.run_exn (u e)
809     end
810     include Monad.MakeT2(Trans)
811     let ask = fun e -> Wrapped.unit e
812     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
813     let local modifier u = fun e -> u (modifier e)
814   end
815   module TP2(Wrapped : Monad.P2) = struct
816     module TransP = struct
817       (* code repetition, ugh *)
818       include T2(Wrapped)
819       let plus u v = fun s -> Wrapped.plus (u s) (v s)
820       let zero () = elevate (Wrapped.zero ())
821       let asks selector = ask >>= (fun e ->
822         try unit (selector e)
823         with Not_found -> fun e -> Wrapped.zero ())
824     end
825     include TransP
826     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
827   end
828 end
829
830
831 (* must be parameterized on (struct type store = ... end) *)
832 module State_monad(Store : sig type store end) : sig
833   (* declare additional operations, while still hiding implementation of type m *)
834   type store = Store.store
835   type 'a result = store -> 'a * store
836   type 'a result_exn = store -> 'a
837   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
838   val get : store m
839   val gets : (store -> 'a) -> 'a m
840   val put : store -> unit m
841   val puts : (store -> store) -> unit m
842   (* StateT transformer *)
843   module T : functor (Wrapped : Monad.S) -> sig
844     type 'a result = store -> ('a * store) Wrapped.result
845     type 'a result_exn = store -> 'a Wrapped.result_exn
846     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
847     val elevate : 'a Wrapped.m -> 'a m
848     val get : store m
849     val gets : (store -> 'a) -> 'a m
850     val put : store -> unit m
851     val puts : (store -> store) -> unit m
852   end
853   (* StateT transformer when wrapped monad has plus, zero *)
854   module TP : functor (Wrapped : Monad.P) -> sig
855     include module type of T(Wrapped)
856     include Monad.PLUS with type 'a m := 'a m
857   end
858   module T2 : functor (Wrapped : Monad.S2) -> sig
859     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
860     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
861     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
862     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
863     val get : ('x,store) m
864     val gets : (store -> 'a) -> ('x,'a) m
865     val put : store -> ('x,unit) m
866     val puts : (store -> store) -> ('x,unit) m
867   end
868   module TP2 : functor (Wrapped : Monad.P2) -> sig
869     include module type of T2(Wrapped)
870     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
871   end
872 end = struct
873   type store = Store.store
874   module Base = struct
875     type 'a m = store -> 'a * store
876     type 'a result = store -> 'a * store
877     type 'a result_exn = store -> 'a
878     let unit a = fun s -> (a, s)
879     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
880     let run u = fun s -> (u s)
881     let run_exn u = fun s -> fst (u s)
882   end
883   include Monad.Make(Base)
884   let get = fun s -> (s, s)
885   let gets viewer = fun s -> (viewer s, s) (* may fail *)
886   let put s = fun _ -> ((), s)
887   let puts modifier = fun s -> ((), modifier s)
888   module T(Wrapped : Monad.S) = struct
889     module Trans = struct
890       module Wrapped = Wrapped
891       type 'a m = store -> ('a * store) Wrapped.m
892       type 'a result = store -> ('a * store) Wrapped.result
893       type 'a result_exn = store -> 'a Wrapped.result_exn
894       let elevate w = fun s ->
895         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
896       let bind u f = fun s ->
897         Wrapped.bind (u s) (fun (a, s') -> f a s')
898       let run u = fun s -> Wrapped.run (u s)
899       let run_exn u = fun s ->
900         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
901         in Wrapped.run_exn w
902     end
903     include Monad.MakeT(Trans)
904     let get = fun s -> Wrapped.unit (s, s)
905     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
906     let put s = fun _ -> Wrapped.unit ((), s)
907     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
908   end
909   module TP(Wrapped : Monad.P) = struct
910     module TransP = struct
911       include T(Wrapped)
912       let plus u v = fun s -> Wrapped.plus (u s) (v s)
913       let zero () = elevate (Wrapped.zero ())
914     end
915     let gets viewer = fun s ->
916       try Wrapped.unit (viewer s, s)
917       with Not_found -> Wrapped.zero ()
918     include TransP
919     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
920   end
921   module T2(Wrapped : Monad.S2) = struct
922     module Trans = struct
923       module Wrapped = Wrapped
924       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
925       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
926       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
927       (* code repetition, ugh *)
928       let elevate w = fun s ->
929         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
930       let bind u f = fun s ->
931         Wrapped.bind (u s) (fun (a, s') -> f a s')
932       let run u = fun s -> Wrapped.run (u s)
933       let run_exn u = fun s ->
934         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
935         in Wrapped.run_exn w
936     end
937     include Monad.MakeT2(Trans)
938     let get = fun s -> Wrapped.unit (s, s)
939     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
940     let put s = fun _ -> Wrapped.unit ((), s)
941     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
942   end
943   module TP2(Wrapped : Monad.P2) = struct
944     module TransP = struct
945       include T2(Wrapped)
946       let plus u v = fun s -> Wrapped.plus (u s) (v s)
947       let zero () = elevate (Wrapped.zero ())
948     end
949     let gets viewer = fun s ->
950       try Wrapped.unit (viewer s, s)
951       with Not_found -> Wrapped.zero ()
952     include TransP
953     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
954   end
955 end
956
957 (* State monad with different interface (structured store) *)
958 module Ref_monad(V : sig
959   type value
960 end) : sig
961   type ref
962   type value = V.value
963   type 'a result = 'a
964   type 'a result_exn = 'a
965   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
966   val newref : value -> ref m
967   val deref : ref -> value m
968   val change : ref -> value -> unit m
969   (* RefT transformer *)
970   module T : functor (Wrapped : Monad.S) -> sig
971     type 'a result = 'a Wrapped.result
972     type 'a result_exn = 'a Wrapped.result_exn
973     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
974     val elevate : 'a Wrapped.m -> 'a m
975     val newref : value -> ref m
976     val deref : ref -> value m
977     val change : ref -> value -> unit m
978   end
979   (* RefT transformer when wrapped monad has plus, zero *)
980   module TP : functor (Wrapped : Monad.P) -> sig
981     include module type of T(Wrapped)
982     include Monad.PLUS with type 'a m := 'a m
983   end
984   module T2 : functor (Wrapped : Monad.S2) -> sig
985     type ('x,'a) result = ('x,'a) Wrapped.result
986     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
987     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
988     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
989     val newref : value -> ('x,ref) m
990     val deref : ref -> ('x,value) m
991     val change : ref -> value -> ('x,unit) m
992   end
993   module TP2 : functor (Wrapped : Monad.P2) -> sig
994     include module type of T2(Wrapped)
995     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
996   end
997 end = struct
998   type ref = int
999   type value = V.value
1000   module D = Map.Make(struct type t = ref let compare = compare end)
1001   type dict = { next: ref; tree : value D.t }
1002   let empty = { next = 0; tree = D.empty }
1003   let alloc (value : value) (d : dict) =
1004     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
1005   let read (key : ref) (d : dict) =
1006     D.find key d.tree
1007   let write (key : ref) (value : value) (d : dict) =
1008     { next = d.next; tree = D.add key value d.tree }
1009   module Base = struct
1010     type 'a m = dict -> 'a * dict
1011     type 'a result = 'a
1012     type 'a result_exn = 'a
1013     let unit a = fun s -> (a, s)
1014     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
1015     let run u = fst (u empty)
1016     let run_exn = run
1017   end
1018   include Monad.Make(Base)
1019   let newref value = fun s -> alloc value s
1020   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
1021   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
1022   module T(Wrapped : Monad.S) = struct
1023     module Trans = struct
1024       module Wrapped = Wrapped
1025       type 'a m = dict -> ('a * dict) Wrapped.m
1026       type 'a result = 'a Wrapped.result
1027       type 'a result_exn = 'a Wrapped.result_exn
1028       let elevate w = fun s ->
1029         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1030       let bind u f = fun s ->
1031         Wrapped.bind (u s) (fun (a, s') -> f a s')
1032       let run u =
1033         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1034         in Wrapped.run w
1035       let run_exn u =
1036         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1037         in Wrapped.run_exn w
1038     end
1039     include Monad.MakeT(Trans)
1040     let newref value = fun s -> Wrapped.unit (alloc value s)
1041     let deref key = fun s -> Wrapped.unit (read key s, s)
1042     let change key value = fun s -> Wrapped.unit ((), write key value s)
1043   end
1044   module TP(Wrapped : Monad.P) = struct
1045     module TransP = struct
1046       include T(Wrapped)
1047       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1048       let zero () = elevate (Wrapped.zero ())
1049     end
1050     include TransP
1051     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1052   end
1053   module T2(Wrapped : Monad.S2) = struct
1054     module Trans = struct
1055       module Wrapped = Wrapped
1056       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
1057       type ('x,'a) result = ('x,'a) Wrapped.result
1058       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1059       (* code repetition, ugh *)
1060       let elevate w = fun s ->
1061         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1062       let bind u f = fun s ->
1063         Wrapped.bind (u s) (fun (a, s') -> f a s')
1064       let run u =
1065         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1066         in Wrapped.run w
1067       let run_exn u =
1068         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1069         in Wrapped.run_exn w
1070     end
1071     include Monad.MakeT2(Trans)
1072     let newref value = fun s -> Wrapped.unit (alloc value s)
1073     let deref key = fun s -> Wrapped.unit (read key s, s)
1074     let change key value = fun s -> Wrapped.unit ((), write key value s)
1075   end
1076   module TP2(Wrapped : Monad.P2) = struct
1077     module TransP = struct
1078       include T2(Wrapped)
1079       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1080       let zero () = elevate (Wrapped.zero ())
1081     end
1082     include TransP
1083     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1084   end
1085 end
1086
1087
1088 (* must be parameterized on (struct type log = ... end) *)
1089 module Writer_monad(Log : sig
1090   type log
1091   val zero : log
1092   val plus : log -> log -> log
1093 end) : sig
1094   (* declare additional operations, while still hiding implementation of type m *)
1095   type log = Log.log
1096   type 'a result = 'a * log
1097   type 'a result_exn = 'a * log
1098   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1099   val tell : log -> unit m
1100   val listen : 'a m -> ('a * log) m
1101   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
1102   (* val pass : ('a * (log -> log)) m -> 'a m *)
1103   val censor : (log -> log) -> 'a m -> 'a m
1104 end = struct
1105   type log = Log.log
1106   module Base = struct
1107     type 'a m = 'a * log
1108     type 'a result = 'a * log
1109     type 'a result_exn = 'a * log
1110     let unit a = (a, Log.zero)
1111     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
1112     let run u = u
1113     let run_exn = run
1114   end
1115   include Monad.Make(Base)
1116   let tell entries = ((), entries) (* add entries to log *)
1117   let listen (a, w) = ((a, w), w)
1118   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
1119   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
1120   let censor f u = pass (u >>= fun a -> unit (a, f))
1121 end
1122
1123 (* pre-define simple Writer *)
1124 module Writer1 = Writer_monad(struct
1125   type log = string
1126   let zero = ""
1127   let plus s1 s2 = s1 ^ "\n" ^ s2
1128 end)
1129
1130 (* slightly more efficient Writer *)
1131 module Writer2 = struct
1132   include Writer_monad(struct
1133     type log = string list
1134     let zero = []
1135     let plus w w' = Util.append w' w
1136   end)
1137   let tell_string s = tell [s]
1138   let tell entries = tell (Util.reverse entries)
1139   let run u = let (a, w) = run u in (a, Util.reverse w)
1140   let run_exn = run
1141 end
1142
1143
1144 module IO_monad : sig
1145   (* declare additional operation, while still hiding implementation of type m *)
1146   type 'a result = 'a
1147   type 'a result_exn = 'a
1148   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1149   val printf : ('a, unit, string, unit m) format4 -> 'a
1150   val print_string : string -> unit m
1151   val print_int : int -> unit m
1152   val print_hex : int -> unit m
1153   val print_bool : bool -> unit m
1154 end = struct
1155   module Base = struct
1156     type 'a m = { run : unit -> unit; value : 'a }
1157     type 'a result = 'a
1158     type 'a result_exn = 'a
1159     let unit a = { run = (fun () -> ()); value = a }
1160     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
1161      let fres = f a.value in
1162        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
1163     let run a = let () = a.run () in a.value
1164     let run_exn = run
1165   end
1166   include Monad.Make(Base)
1167   let printf fmt =
1168     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
1169   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
1170   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
1171   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
1172   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
1173 end
1174
1175 (*
1176 module Continuation_monad : sig
1177   (* expose only the implementation of type `('r,'a) result` *)
1178   type 'a m
1179   type 'a result = 'a m
1180   type 'a result_exn = 'a m
1181   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
1182   (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
1183   (* misses that the answer types of all the cont's must be the same *)
1184   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
1185   (* val reset : ('a,'a) m -> ('r,'a) m *)
1186   val reset : 'a m -> 'a m
1187   (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
1188   (* misses that the answer types of second and third continuations must be b *)
1189   val shift : (('a -> 'b m) -> 'b m) -> 'a m
1190   (* overwrite the run declaration in S, because I can't declare 'a result =
1191    * this polymorphic type (complains that 'r is unbound *)
1192   val runk : 'a m -> ('a -> 'r) -> 'r
1193   val run0 : 'a m -> 'a
1194 end = struct
1195   let id = fun i -> i
1196   module Base = struct
1197     (* 'r is result type of whole computation *)
1198     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
1199     type 'a result = 'a m
1200     type 'a result_exn = 'a m
1201     let unit a =
1202       let cont : 'r. ('a -> 'r) -> 'r =
1203         fun k -> k a
1204       in { cont }
1205     let bind u f =
1206       let cont : 'r. ('a -> 'r) -> 'r =
1207         fun k -> u.cont (fun a -> (f a).cont k)
1208       in { cont }
1209     let run (u : 'a m) : 'a result = u
1210     let run_exn (u : 'a m) : 'a result_exn = u
1211     let callcc f =
1212       let cont : 'r. ('a -> 'r) -> 'r =
1213           (* Can't figure out how to make the type polymorphic enough
1214            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
1215            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
1216            * with Obj.magic... which tells OCaml's type checker to
1217            * relax, the supplied value has whatever type the context
1218            * needs it to have. *)
1219           fun k ->
1220           let usek a = { cont = Obj.magic (fun _ -> k a) }
1221           in (f usek).cont k
1222       in { cont }
1223     let reset u = unit (u.cont id)
1224     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
1225       let cont = fun k ->
1226         (f (fun a -> unit (k a))).cont id
1227       in { cont = Obj.magic cont }
1228     let runk u k = (u.cont : ('a -> 'r) -> 'r) k
1229     let run0 u = runk u id
1230   end
1231   include Monad.Make(Base)
1232   let callcc = Base.callcc
1233   let reset = Base.reset
1234   let shift = Base.shift
1235   let runk = Base.runk
1236   let run0 = Base.run0
1237 end
1238  *)
1239
1240 (* This two-type parameter version works without Obj.magic *)
1241 module Continuation_monad : sig
1242   (* expose only the implementation of type `('r,'a) result` *)
1243   type ('r,'a) m
1244   type ('r,'a) result = ('r,'a) m
1245   type ('r,'a) result_exn = ('a -> 'r) -> 'r
1246   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
1247   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
1248   val reset : ('a,'a) m -> ('r,'a) m
1249   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
1250   (* val abort : ('a,'a) m -> ('a,'b) m *)
1251   val abort : 'a -> ('a,'b) m
1252   val run0 : ('a,'a) m -> 'a
1253 end = struct
1254   let id = fun i -> i
1255   module Base = struct
1256     (* 'r is result type of whole computation *)
1257     type ('r,'a) m = ('a -> 'r) -> 'r
1258     type ('r,'a) result = ('a -> 'r) -> 'r
1259     type ('r,'a) result_exn = ('r,'a) result
1260     let unit a = (fun k -> k a)
1261     let bind u f = (fun k -> (u) (fun a -> (f a) k))
1262     let run u k = (u) k
1263     let run_exn = run
1264   end
1265   include Monad.Make2(Base)
1266   let callcc f = (fun k ->
1267     let usek a = (fun _ -> k a)
1268     in (f usek) k)
1269   (*
1270   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
1271   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
1272   let callcc f = fun k -> f k k
1273   let throw k a = fun _ -> k a
1274   *)
1275
1276   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
1277    *
1278    *  reset :: (Monad m) => ContT a m a -> ContT r m a
1279    *  reset e = ContT $ \k -> runContT e return >>= k
1280    *
1281    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
1282    *  shift e = ContT $ \k ->
1283    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
1284   let reset u = unit ((u) id)
1285   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
1286   (* let abort a = shift (fun _ -> a) *)
1287   let abort a = shift (fun _ -> unit a)
1288   let run0 (u : ('a,'a) m) = (u) id
1289 end
1290
1291
1292 (*
1293  * Scheme:
1294  * (define (example n)
1295  *    (let ([u (let/cc k ; type int -> int pair
1296  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
1297  *                 (+ 1 (car v))))]) ; int
1298  *      (cons u 0))) ; int pair
1299  * ; (example 10) ~~> '(111 . 0)
1300  * ; (example -10) ~~> '(0 . 0)
1301  *
1302  * OCaml monads:
1303  * let example n : (int * int) =
1304  *   Continuation_monad.(let u = callcc (fun k ->
1305  *       (if n < 0 then k 0 else unit [n + 100])
1306  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
1307  *       >>= fun [x] -> unit (x + 1)
1308  *   )
1309  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1310  *   >>= fun x -> unit (x, 0)
1311  *   in run u)
1312  *
1313  *
1314  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1315  * let example1 () : int =
1316  *   Continuation_monad.(let v = reset (
1317  *       let u = shift (fun k -> unit (10 + 1))
1318  *       in u >>= fun x -> unit (100 + x)
1319  *     ) in let w = v >>= fun x -> unit (1000 + x)
1320  *     in run w)
1321  *
1322  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1323  * let example2 () =
1324  *   Continuation_monad.(let v = reset (
1325  *       let u = shift (fun k -> k (10 :: [1]))
1326  *       in u >>= fun x -> unit (100 :: x)
1327  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1328  *     in run w)
1329  *
1330  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1331  * let example3 () =
1332  *   Continuation_monad.(let v = reset (
1333  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1334  *       in u >>= fun x -> unit (100 :: x)
1335  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1336  *     in run w)
1337  *
1338  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1339  * (* not sure if this example can be typed without a sum-type *)
1340  *
1341  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1342  * let example5 () : int =
1343  *   Continuation_monad.(let v = reset (
1344  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1345  *       in u >>= fun x -> unit (10 + x)
1346  *     ) in let w = v >>= fun x -> unit (100 + x)
1347  *     in run w)
1348  *
1349  *)
1350
1351
1352 module Leaf_monad : sig
1353   (* We implement the type as `'a tree option` because it has a natural`plus`,
1354    * and the rest of the library expects that `plus` and `zero` will come together. *)
1355   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1356   type 'a result = 'a tree option
1357   type 'a result_exn = 'a tree
1358   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1359   include Monad.PLUS with type 'a m := 'a m
1360   (* LeafT transformer *)
1361   module T : functor (Wrapped : Monad.S) -> sig
1362     type 'a result = 'a tree option Wrapped.result
1363     type 'a result_exn = 'a tree Wrapped.result_exn
1364     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1365     include Monad.PLUS with type 'a m := 'a m
1366     val elevate : 'a Wrapped.m -> 'a m
1367     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1368     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1369     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1370   end
1371   module T2 : functor (Wrapped : Monad.S2) -> sig
1372     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1373     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1374     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1375     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1376     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1377     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1378   end
1379 end = struct
1380   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1381   (* uses supplied plus and zero to copy t to its image under f *)
1382   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1383       | None -> zero ()
1384       | Some ts -> let rec loop ts = (match ts with
1385                      | Leaf a -> f a
1386                      | Node (l, r) ->
1387                          (* recursive application of f may delete a branch *)
1388                          plus (loop l) (loop r)
1389                    ) in loop ts
1390   module Base = struct
1391     type 'a m = 'a tree option
1392     type 'a result = 'a tree option
1393     type 'a result_exn = 'a tree
1394     let unit a = Some (Leaf a)
1395     let zero () = None
1396     let plus u v = match (u, v) with
1397       | None, _ -> v
1398       | _, None -> u
1399       | Some us, Some vs -> Some (Node (us, vs))
1400     let bind u f = mapT f u zero plus
1401     let run u = u
1402     let run_exn u = match u with
1403       | None -> failwith "no values"
1404       (*
1405       | Some (Leaf a) -> a
1406       | many -> failwith "multiple values"
1407       *)
1408       | Some us -> us
1409   end
1410   include Monad.Make(Base)
1411   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1412   let base_plus = plus
1413   let base_lift = lift
1414   module T(Wrapped : Monad.S) = struct
1415     module Trans = struct
1416       let zero () = Wrapped.unit None
1417       let plus u v =
1418         Wrapped.bind u (fun us ->
1419         Wrapped.bind v (fun vs ->
1420         Wrapped.unit (base_plus us vs)))
1421       include Monad.MakeT(struct
1422         module Wrapped = Wrapped
1423         type 'a m = 'a tree option Wrapped.m
1424         type 'a result = 'a tree option Wrapped.result
1425         type 'a result_exn = 'a tree Wrapped.result_exn
1426         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1427         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1428         let run u = Wrapped.run u
1429         let run_exn u =
1430             let w = Wrapped.bind u (fun t -> match t with
1431               | None -> failwith "no values"
1432               | Some ts -> Wrapped.unit ts)
1433             in Wrapped.run_exn w
1434       end)
1435     end
1436     include Trans
1437     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1438     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1439     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1440   end
1441   module T2(Wrapped : Monad.S2) = struct
1442     module Trans = struct
1443       let zero () = Wrapped.unit None
1444       let plus u v =
1445         Wrapped.bind u (fun us ->
1446         Wrapped.bind v (fun vs ->
1447         Wrapped.unit (base_plus us vs)))
1448       include Monad.MakeT2(struct
1449         module Wrapped = Wrapped
1450         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1451         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1452         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1453         (* code repetition, ugh *)
1454         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1455         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1456         let run u = Wrapped.run u
1457         let run_exn u =
1458             let w = Wrapped.bind u (fun t -> match t with
1459               | None -> failwith "no values"
1460               | Some ts -> Wrapped.unit ts)
1461             in Wrapped.run_exn w
1462       end)
1463     end
1464     include Trans
1465     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1466     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1467   end
1468 end
1469
1470
1471 module L = List_monad;;
1472 module R = Reader_monad(struct type env = int -> int end);;
1473 module S = State_monad(struct type store = int end);;
1474 module T = Leaf_monad;;
1475 module LR = L.T(R);;
1476 module LS = L.T(S);;
1477 module TL = T.T(L);;
1478 module TR = T.T(R);;
1479 module TS = T.T(S);;
1480 module C = Continuation_monad
1481 module TC = T.T2(C);;
1482
1483
1484 print_endline "=== test Leaf(...).distribute ==================";;
1485
1486 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1487
1488 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1489 TS.run ts 0;;
1490 (*
1491 - : int T.tree option * S.store =
1492 (Some
1493   (T.Node
1494     (T.Node (T.Leaf 2, T.Leaf 3),
1495      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1496  5)
1497 *)
1498
1499 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1500 TS.run_exn ts2 0;;
1501 (*
1502 - : (int * S.store) T.tree option * S.store =
1503 (Some
1504   (T.Node
1505     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1506      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1507  5)
1508 *)
1509
1510 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1511 TR.run_exn tr (fun i -> i+i);;
1512 (*
1513 - : int T.tree option =
1514 Some
1515  (T.Node
1516    (T.Node (T.Leaf 4, T.Leaf 6),
1517     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1518 *)
1519
1520 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1521 TL.run_exn tl;;
1522 (*
1523 - : (int * int) TL.result =
1524 [Some
1525   (T.Node
1526     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1527      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1528 *)
1529
1530 let l2 = [1;2;3;4;5];;
1531 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1532
1533 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1534 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1535
1536 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1537 (*
1538 int T.tree option =
1539 Some
1540  (T.Node
1541    (T.Node (T.Leaf 10, T.Leaf 11),
1542     T.Node
1543      (T.Node
1544        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1545         T.Node (T.Leaf 40, T.Leaf 41)),
1546       T.Node (T.Leaf 50, T.Leaf 51))))
1547  *)
1548
1549 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1550 (*
1551 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1552 *)
1553
1554 print_endline "=== test Leaf(Continuation).distribute ==================";;
1555
1556 let id : 'z. 'z -> 'z = fun x -> x
1557
1558 let example n : (int * int) =
1559   Continuation_monad.(let u = callcc (fun k ->
1560       (if n < 0 then k 0 else unit [n + 100])
1561       (* all of the following is skipped by k 0; the end type int is k's input type *)
1562       >>= fun [x] -> unit (x + 1)
1563   )
1564   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1565   >>= fun x -> unit (x, 0)
1566   in run0 u)
1567
1568
1569 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1570 let example1 () : int =
1571   Continuation_monad.(let v = reset (
1572       let u = shift (fun k -> unit (10 + 1))
1573       in u >>= fun x -> unit (100 + x)
1574     ) in let w = v >>= fun x -> unit (1000 + x)
1575     in run0 w)
1576
1577 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1578 let example2 () =
1579   Continuation_monad.(let v = reset (
1580       let u = shift (fun k -> k (10 :: [1]))
1581       in u >>= fun x -> unit (100 :: x)
1582     ) in let w = v >>= fun x -> unit (1000 :: x)
1583     in run0 w)
1584
1585 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1586 let example3 () =
1587   Continuation_monad.(let v = reset (
1588       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1589       in u >>= fun x -> unit (100 :: x)
1590     ) in let w = v >>= fun x -> unit (1000 :: x)
1591     in run0 w)
1592
1593 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1594 (* not sure if this example can be typed without a sum-type *)
1595
1596 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1597 let example5 () : int =
1598   Continuation_monad.(let v = reset (
1599       let u = shift (fun k -> k 1 >>= k)
1600       in u >>= fun x -> unit (10 + x)
1601     ) in let w = v >>= fun x -> unit (100 + x)
1602     in run0 w)
1603
1604 ;;
1605
1606 print_endline "=== test bare Continuation ============";;
1607
1608 (1011, 1111, 1111, 121);;
1609 (example1(), example2(), example3(), example5());;
1610 ((111,0), (0,0));;
1611 (example ~+10, example ~-10);;
1612
1613 let testc df ic =
1614     C.run_exn TC.(run (distribute df t1)) ic;;
1615
1616
1617 (*
1618 (* do nothing *)
1619 let initial_continuation = fun t -> t in
1620 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1621 *)
1622 testc (C.unit) id;;
1623
1624 (*
1625 (* count leaves, using continuation *)
1626 let initial_continuation = fun t -> 0 in
1627 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1628 *)
1629
1630 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1631
1632 (*
1633 (* convert tree to list of leaves *)
1634 let initial_continuation = fun t -> [] in
1635 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1636 *)
1637
1638 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1639
1640 (*
1641 (* square each leaf using continuation *)
1642 let initial_continuation = fun t -> t in
1643 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1644 *)
1645
1646 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1647
1648
1649 (*
1650 (* replace leaves with list, using continuation *)
1651 let initial_continuation = fun t -> t in
1652 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1653 *)
1654
1655 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1656
1657 print_endline "=== pa_monad's Continuation Tests ============";;
1658
1659 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1660 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1661 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1662 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1663 (5, 27 = C.(run0 (
1664               let c = reset(abort 5 >>= fun y -> unit (y+6))
1665               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1666
1667 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1668
1669 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1670
1671 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1672
1673
1674 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1675