monads.ml: add TP to Error
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73   (*
74    * Signature extenders:
75    *   Make :: BASE -> S
76    *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
77    *                                           which merges into S
78    *                                           (P is merged sig)
79    *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
80    *
81    *   Make2 :: BASE2 -> S2
82    *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
83    *   to wrap double-typed inner monads:
84    *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
85    *
86    *)
87
88
89
90   (* type of base definitions *)
91   module type BASE = sig
92     (* The only constraints we impose here on how the monadic type
93      * is implemented is that it have a single type parameter 'a. *)
94     type 'a m
95     type 'a result
96     type 'a result_exn
97     val unit : 'a -> 'a m
98     val bind : 'a m -> ('a -> 'b m) -> 'b m
99     val run : 'a m -> 'a result
100     (* run_exn tries to provide a more ground-level result, but may fail *)
101     val run_exn : 'a m -> 'a result_exn
102   end
103   module type S = sig
104     include BASE
105     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
106     val (>>) : 'a m -> 'b m -> 'b m
107     val join : ('a m) m -> 'a m
108     val apply : ('a -> 'b) m -> 'a m -> 'b m
109     val lift : ('a -> 'b) -> 'a m -> 'b m
110     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
111     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
112     val do_when :  bool -> unit m -> unit m
113     val do_unless :  bool -> unit m -> unit m
114     val forever : (unit -> 'a m) -> 'b m
115     val sequence : 'a m list -> 'a list m
116     val sequence_ : 'a m list -> unit m
117   end
118
119   (* Standard, single-type-parameter monads. *)
120   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
121     include B
122     let (>>=) = bind
123     let (>>) u v = u >>= fun _ -> v
124     let lift f u = u >>= fun a -> unit (f a)
125     (* lift is called listM, fmap, and <$> in Haskell *)
126     let join uu = uu >>= fun u -> u
127     (* u >>= f === join (lift f u) *)
128     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
129     (* [f] <*> [x1,x2] = [f x1,f x2] *)
130     (* let apply u v = u >>= fun f -> lift f v *)
131     (* let apply = lift2 id *)
132     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
133     (* let lift f u === apply (unit f) u *)
134     (* let lift2 f u v = apply (lift f u) v *)
135     let (>=>) f g = fun a -> f a >>= g
136     let do_when test u = if test then u else unit ()
137     let do_unless test u = if test then unit () else u
138     let forever uthunk =
139         let rec loop () = uthunk () >>= fun _ -> loop ()
140         in loop ()
141     let sequence ms =
142       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
143         Util.fold_right op ms (unit [])
144     let sequence_ ms =
145       Util.fold_right (>>) ms (unit ())
146
147     (* Haskell defines these other operations combining lists and monads.
148      * We don't, but notice that M.mapM == ListT(M).distribute
149      * There's also a parallel TreeT(M).distribute *)
150     (*
151     let mapM f alist = sequence (Util.map f alist)
152     let mapM_ f alist = sequence_ (Util.map f alist)
153     let rec filterM f lst = match lst with
154       | [] -> unit []
155       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
156     let forM alist f = mapM f alist
157     let forM_ alist f = mapM_ f alist
158     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
159     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
160     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
161     let rec foldM f z lst = match lst with
162       | [] -> unit z
163       | x::xs -> f z x >>= fun z' -> foldM f z' xs
164     let foldM_ f z xs = foldM f z xs >> unit ()
165     let replicateM n x = sequence (Util.replicate n x)
166     let replicateM_ n x = sequence_ (Util.replicate n x)
167     *)
168   end
169
170   (* Single-type-parameter monads that also define `plus` and `zero`
171    * operations. These obey the following laws:
172    *     zero >>= f   ===  zero
173    *     plus zero u  ===  u
174    *     plus u zero  ===  u
175    * Additionally, these monads will obey one of the following laws:
176    *     (Catch)   plus (unit a) v  ===  unit a
177    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
178    *)
179   module type PLUSBASE = sig
180     include BASE
181     val zero : unit -> 'a m
182     val plus : 'a m -> 'a m -> 'a m
183   end
184   module type PLUS = sig
185     type 'a m
186     val zero : unit -> 'a m
187     val plus : 'a m -> 'a m -> 'a m
188     val guard : bool -> unit m
189     val sum : 'a m list -> 'a m
190   end
191   (* MakeCatch and MakeDistrib have the same implementation; we just declare
192    * them twice to document which laws the client code is promising to honor. *)
193   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
194     type 'a m = 'a B.m
195     let zero = B.zero
196     let plus = B.plus
197     let guard test = if test then B.unit () else zero ()
198     let sum ms = Util.fold_right plus ms (zero ())
199   end
200   module MakeDistrib = MakeCatch
201
202   (* Signatures for MonadT *)
203   (* sig for Wrapped that include S and PLUS *)
204   module type P = sig
205     include S
206     include PLUS with type 'a m := 'a m
207   end
208   module type TRANS = sig
209     module Wrapped : S
210     type 'a m
211     type 'a result
212     type 'a result_exn
213     val bind : 'a m -> ('a -> 'b m) -> 'b m
214     val run : 'a m -> 'a result
215     val run_exn : 'a m -> 'a result_exn
216     val elevate : 'a Wrapped.m -> 'a m
217     (* lift/elevate laws:
218      *     elevate (W.unit a) == unit a
219      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
220      *)
221   end
222   module MakeT(T : TRANS) = struct
223     include Make(struct
224         include T
225         let unit a = elevate (Wrapped.unit a)
226     end)
227     let elevate = T.elevate
228   end
229
230
231   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
232   module type BASE2 = sig
233     type ('x,'a) m
234     type ('x,'a) result
235     type ('x,'a) result_exn
236     val unit : 'a -> ('x,'a) m
237     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
238     val run : ('x,'a) m -> ('x,'a) result
239     val run_exn : ('x,'a) m -> ('x,'a) result_exn
240   end
241   module type S2 = sig
242     include BASE2
243     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
244     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
245     val join : ('x,('x,'a) m) m -> ('x,'a) m
246     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
247     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
248     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
249     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
250     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
251     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
252     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
253     val sequence : ('x,'a) m list -> ('x,'a list) m
254     val sequence_ : ('x,'a) m list -> ('x,unit) m
255   end
256   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
257     (* code repetition, ugh *)
258     include B
259     let (>>=) = bind
260     let (>>) u v = u >>= fun _ -> v
261     let lift f u = u >>= fun a -> unit (f a)
262     let join uu = uu >>= fun u -> u
263     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
264     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
265     let (>=>) f g = fun a -> f a >>= g
266     let do_when test u = if test then u else unit ()
267     let do_unless test u = if test then unit () else u
268     let forever uthunk =
269         let rec loop () = uthunk () >>= fun _ -> loop ()
270         in loop ()
271     let sequence ms =
272       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
273         Util.fold_right op ms (unit [])
274     let sequence_ ms =
275       Util.fold_right (>>) ms (unit ())
276   end
277
278   module type PLUSBASE2 = sig
279     include BASE2
280     val zero : unit -> ('x,'a) m
281     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
282   end
283   module type PLUS2 = sig
284     type ('x,'a) m
285     val zero : unit -> ('x,'a) m
286     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
287     val guard : bool -> ('x,unit) m
288     val sum : ('x,'a) m list -> ('x,'a) m
289   end
290   module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
291     type ('x,'a) m = ('x,'a) B.m
292     (* code repetition, ugh *)
293     let zero = B.zero
294     let plus = B.plus
295     let guard test = if test then B.unit () else zero ()
296     let sum ms = Util.fold_right plus ms (zero ())
297   end
298   module MakeDistrib2 = MakeCatch2
299
300   (* Signatures for MonadT *)
301   (* sig for Wrapped that include S and PLUS *)
302   module type P2 = sig
303     include S2
304     include PLUS2 with type ('x,'a) m := ('x,'a) m
305   end
306   module type TRANS2 = sig
307     module Wrapped : S2
308     type ('x,'a) m
309     type ('x,'a) result
310     type ('x,'a) result_exn
311     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
312     val run : ('x,'a) m -> ('x,'a) result
313     val run_exn : ('x,'a) m -> ('x,'a) result_exn
314     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
315   end
316   module MakeT2(T : TRANS2) = struct
317     (* code repetition, ugh *)
318     include Make2(struct
319         include T
320         let unit a = elevate (Wrapped.unit a)
321     end)
322     let elevate = T.elevate
323   end
324
325 end
326
327
328
329
330
331 module Identity_monad : sig
332   (* expose only the implementation of type `'a result` *)
333   type 'a result = 'a
334   type 'a result_exn = 'a
335   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
336 end = struct
337   module Base = struct
338     type 'a m = 'a
339     type 'a result = 'a
340     type 'a result_exn = 'a
341     let unit a = a
342     let bind a f = f a
343     let run a = a
344     let run_exn a = a
345   end
346   include Monad.Make(Base)
347 end
348
349
350 module Maybe_monad : sig
351   (* expose only the implementation of type `'a result` *)
352   type 'a result = 'a option
353   type 'a result_exn = 'a
354   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
355   include Monad.PLUS with type 'a m := 'a m
356   (* MaybeT transformer *)
357   module T : functor (Wrapped : Monad.S) -> sig
358     type 'a result = 'a option Wrapped.result
359     type 'a result_exn = 'a Wrapped.result_exn
360     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
361     include Monad.PLUS with type 'a m := 'a m
362     val elevate : 'a Wrapped.m -> 'a m
363   end
364   module T2 : functor (Wrapped : Monad.S2) -> sig
365     type ('x,'a) result = ('x,'a option) Wrapped.result
366     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
367     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
368     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
369     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
370   end
371 end = struct
372   module Base = struct
373    type 'a m = 'a option
374    type 'a result = 'a option
375    type 'a result_exn = 'a
376    let unit a = Some a
377    let bind u f = match u with Some a -> f a | None -> None
378    let run u = u
379    let run_exn u = match u with
380      | Some a -> a
381      | None -> failwith "no value"
382    let zero () = None
383    let plus u v = match u with None -> v | _ -> u
384   end
385   include Monad.Make(Base)
386   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
387   module T(Wrapped : Monad.S) = struct
388     module Trans = struct
389       include Monad.MakeT(struct
390         module Wrapped = Wrapped
391         type 'a m = 'a option Wrapped.m
392         type 'a result = 'a option Wrapped.result
393         type 'a result_exn = 'a Wrapped.result_exn
394         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
395         let bind u f = Wrapped.bind u (fun t -> match t with
396           | Some a -> f a
397           | None -> Wrapped.unit None)
398         let run u = Wrapped.run u
399         let run_exn u =
400           let w = Wrapped.bind u (fun t -> match t with
401             | Some a -> Wrapped.unit a
402             | None -> failwith "no value")
403           in Wrapped.run_exn w
404       end)
405       let zero () = Wrapped.unit None
406       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
407     end
408     include Trans
409     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
410   end
411   module T2(Wrapped : Monad.S2) = struct
412     module Trans = struct
413       include Monad.MakeT2(struct
414         module Wrapped = Wrapped
415         type ('x,'a) m = ('x,'a option) Wrapped.m
416         type ('x,'a) result = ('x,'a option) Wrapped.result
417         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
418         (* code repetition, ugh *)
419         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
420         let bind u f = Wrapped.bind u (fun t -> match t with
421           | Some a -> f a
422           | None -> Wrapped.unit None)
423         let run u = Wrapped.run u
424         let run_exn u =
425           let w = Wrapped.bind u (fun t -> match t with
426             | Some a -> Wrapped.unit a
427             | None -> failwith "no value")
428           in Wrapped.run_exn w
429       end)
430       let zero () = Wrapped.unit None
431       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
432     end
433     include Trans
434     include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
435   end
436 end
437
438
439 module List_monad : sig
440   (* declare additional operation, while still hiding implementation of type m *)
441   type 'a result = 'a list
442   type 'a result_exn = 'a
443   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
444   include Monad.PLUS with type 'a m := 'a m
445   val permute : 'a m -> 'a m m
446   val select : 'a m -> ('a * 'a m) m
447   (* ListT transformer *)
448   module T : functor (Wrapped : Monad.S) -> sig
449     type 'a result = 'a list Wrapped.result
450     type 'a result_exn = 'a Wrapped.result_exn
451     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
452     include Monad.PLUS with type 'a m := 'a m
453     val elevate : 'a Wrapped.m -> 'a m
454     (* note that second argument is an 'a list, not the more abstract 'a m *)
455     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
456     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
457 (* TODO
458     val permute : 'a m -> 'a m m
459     val select : 'a m -> ('a * 'a m) m
460 *)
461   end
462   module T2 : functor (Wrapped : Monad.S2) -> sig
463     type ('x,'a) result = ('x,'a list) Wrapped.result
464     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
465     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
466     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
467     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
468     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
469   end
470 end = struct
471   module Base = struct
472    type 'a m = 'a list
473    type 'a result = 'a list
474    type 'a result_exn = 'a
475    let unit a = [a]
476    let bind u f = Util.concat_map f u
477    let run u = u
478    let run_exn u = match u with
479      | [] -> failwith "no values"
480      | [a] -> a
481      | many -> failwith "multiple values"
482    let zero () = []
483    let plus = Util.append
484   end
485   include Monad.Make(Base)
486   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
487   (* let either u v = plus u v *)
488   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
489   let rec insert a u =
490     plus (unit (a :: u)) (match u with
491         | [] -> zero ()
492         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
493     )
494   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
495   let rec permute u = match u with
496       | [] -> unit []
497       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
498   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
499   let rec select u = match u with
500     | [] -> zero ()
501     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
502   let base_plus = plus
503   module T(Wrapped : Monad.S) = struct
504     module Trans = struct
505       (* Wrapped.sequence ms  ===  
506            let plus1 u v =
507              Wrapped.bind u (fun x ->
508              Wrapped.bind v (fun xs ->
509              Wrapped.unit (x :: xs)))
510            in Util.fold_right plus1 ms (Wrapped.unit []) *)
511       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
512       let distribute f alist = Wrapped.sequence (Util.map f alist)
513       include Monad.MakeT(struct
514         module Wrapped = Wrapped
515         type 'a m = 'a list Wrapped.m
516         type 'a result = 'a list Wrapped.result
517         type 'a result_exn = 'a Wrapped.result_exn
518         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
519         let bind u f =
520           Wrapped.bind u (fun ts ->
521           Wrapped.bind (distribute f ts) (fun tts ->
522           Wrapped.unit (Util.concat tts)))
523         let run u = Wrapped.run u
524         let run_exn u =
525           let w = Wrapped.bind u (fun ts -> match ts with
526             | [] -> failwith "no values"
527             | [a] -> Wrapped.unit a
528             | many -> failwith "multiple values"
529           ) in Wrapped.run_exn w
530       end)
531       let zero () = Wrapped.unit []
532       let plus u v =
533         Wrapped.bind u (fun us ->
534         Wrapped.bind v (fun vs ->
535         Wrapped.unit (base_plus us vs)))
536     end
537     include Trans
538     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
539 (*
540     let permute : 'a m -> 'a m m
541     let select : 'a m -> ('a * 'a m) m
542 *)
543   end
544   module T2(Wrapped : Monad.S2) = struct
545     module Trans = struct
546       let distribute f alist = Wrapped.sequence (Util.map f alist)
547       include Monad.MakeT2(struct
548         module Wrapped = Wrapped
549         type ('x,'a) m = ('x,'a list) Wrapped.m
550         type ('x,'a) result = ('x,'a list) Wrapped.result
551         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
552         (* code repetition, ugh *)
553         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
554         let bind u f =
555           Wrapped.bind u (fun ts ->
556           Wrapped.bind (distribute f ts) (fun tts ->
557           Wrapped.unit (Util.concat tts)))
558         let run u = Wrapped.run u
559         let run_exn u =
560           let w = Wrapped.bind u (fun ts -> match ts with
561             | [] -> failwith "no values"
562             | [a] -> Wrapped.unit a
563             | many -> failwith "multiple values"
564           ) in Wrapped.run_exn w
565       end)
566       let zero () = Wrapped.unit []
567       let plus u v =
568         Wrapped.bind u (fun us ->
569         Wrapped.bind v (fun vs ->
570         Wrapped.unit (base_plus us vs)))
571     end
572     include Trans
573     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
574   end
575 end
576
577
578 (* must be parameterized on (struct type err = ... end) *)
579 module Error_monad(Err : sig
580   type err
581   exception Exc of err
582   (*
583   val zero : unit -> err
584   val plus : err -> err -> err
585   *)
586 end) : sig
587   (* declare additional operations, while still hiding implementation of type m *)
588   type err = Err.err
589   type 'a error = Error of err | Success of 'a
590   type 'a result = 'a
591   type 'a result_exn = 'a
592   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
593   (* include Monad.PLUS with type 'a m := 'a m *)
594   val throw : err -> 'a m
595   val catch : 'a m -> (err -> 'a m) -> 'a m
596   (* ErrorT transformer *)
597   module T : functor (Wrapped : Monad.S) -> sig
598     type 'a result = 'a Wrapped.result
599     type 'a result_exn = 'a Wrapped.result_exn
600     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
601     val elevate : 'a Wrapped.m -> 'a m
602     val throw : err -> 'a m
603     val catch : 'a m -> (err -> 'a m) -> 'a m
604   end
605   (* ErrorT transformer when wrapped monad has plus, zero *)
606   module TP : functor (Wrapped : Monad.P) -> sig
607     include module type of T(Wrapped)
608     include Monad.PLUS with type 'a m := 'a m
609   end
610   module T2 : functor (Wrapped : Monad.S2) -> sig
611     type ('x,'a) result = ('x,'a) Wrapped.result
612     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
613     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
614     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
615     val throw : err -> ('x,'a) m
616     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
617   end
618 end = struct
619   type err = Err.err
620   type 'a error = Error of err | Success of 'a
621   module Base = struct
622     type 'a m = 'a error
623     type 'a result = 'a
624     type 'a result_exn = 'a
625     let unit a = Success a
626     let bind u f = match u with
627       | Success a -> f a
628       | Error e -> Error e (* input and output may be of different 'a types *)
629     (* TODO: should run refrain from failing? *)
630     let run u = match u with
631       | Success a -> a
632       | Error e -> raise (Err.Exc e)
633     let run_exn = run
634     (*
635     let zero () = Error Err.zero
636     let plus u v = match (u, v) with
637       | Success _, _ -> u
638       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
639        * otherwise, plus (Error _) v = v *)
640       | Error _, _ when v = zero -> u
641       (* combine errors *)
642       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
643       | Error _, _ -> v
644     *)
645   end
646   include Monad.Make(Base)
647   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
648   let throw e = Error e
649   let catch u handler = match u with
650     | Success _ -> u
651     | Error e -> handler e
652   module T(Wrapped : Monad.S) = struct
653     module Trans = struct
654       module Wrapped = Wrapped
655       type 'a m = 'a error Wrapped.m
656       type 'a result = 'a Wrapped.result
657       type 'a result_exn = 'a Wrapped.result_exn
658       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
659       let bind u f = Wrapped.bind u (fun t -> match t with
660         | Success a -> f a
661         | Error e -> Wrapped.unit (Error e))
662       (* TODO: should run refrain from failing? *)
663       let run u =
664         let w = Wrapped.bind u (fun t -> match t with
665           | Success a -> Wrapped.unit a
666           (* | _ -> Wrapped.fail () *)
667           | Error e -> raise (Err.Exc e))
668         in Wrapped.run w
669       let run_exn u =
670         let w = Wrapped.bind u (fun t -> match t with
671           | Success a -> Wrapped.unit a
672           (* | _ -> Wrapped.fail () *)
673           | Error e -> raise (Err.Exc e))
674         in Wrapped.run_exn w
675     end
676     include Monad.MakeT(Trans)
677     let throw e = Wrapped.unit (Error e)
678     let catch u handler = Wrapped.bind u (fun t -> match t with
679       | Success _ -> Wrapped.unit t
680       | Error e -> handler e)
681   end
682   module TP(Wrapped : Monad.P) = struct
683     module TransP = struct
684       include T(Wrapped)
685       let plus u v = Wrapped.plus u v
686       let zero () = elevate (Wrapped.zero ())
687     end
688     include TransP
689     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
690   end
691   module T2(Wrapped : Monad.S2) = struct
692     module Trans = struct
693       module Wrapped = Wrapped
694       type ('x,'a) m = ('x,'a error) Wrapped.m
695       type ('x,'a) result = ('x,'a) Wrapped.result
696       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
697       (* code repetition, ugh *)
698       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
699       let bind u f = Wrapped.bind u (fun t -> match t with
700         | Success a -> f a
701         | Error e -> Wrapped.unit (Error e))
702       let run u =
703         let w = Wrapped.bind u (fun t -> match t with
704           | Success a -> Wrapped.unit a
705           | Error e -> raise (Err.Exc e))
706         in Wrapped.run w
707       let run_exn u =
708         let w = Wrapped.bind u (fun t -> match t with
709           | Success a -> Wrapped.unit a
710           | Error e -> raise (Err.Exc e))
711         in Wrapped.run_exn w
712     end
713     include Monad.MakeT2(Trans)
714     let throw e = Wrapped.unit (Error e)
715     let catch u handler = Wrapped.bind u (fun t -> match t with
716       | Success _ -> Wrapped.unit t
717       | Error e -> handler e)
718   end
719 end
720
721 (* pre-define common instance of Error_monad *)
722 module Failure = Error_monad(struct
723   type err = string
724   exception Exc = Failure
725   (*
726   let zero = ""
727   let plus s1 s2 = s1 ^ "\n" ^ s2
728   *)
729 end)
730
731 (* must be parameterized on (struct type env = ... end) *)
732 module Reader_monad(Env : sig type env end) : sig
733   (* declare additional operations, while still hiding implementation of type m *)
734   type env = Env.env
735   type 'a result = env -> 'a
736   type 'a result_exn = env -> 'a
737   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
738   val ask : env m
739   val asks : (env -> 'a) -> 'a m
740   val local : (env -> env) -> 'a m -> 'a m
741   (* ReaderT transformer *)
742   module T : functor (Wrapped : Monad.S) -> sig
743     type 'a result = env -> 'a Wrapped.result
744     type 'a result_exn = env -> 'a Wrapped.result_exn
745     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
746     val elevate : 'a Wrapped.m -> 'a m
747     val ask : env m
748     val asks : (env -> 'a) -> 'a m
749     val local : (env -> env) -> 'a m -> 'a m
750   end
751   (* ReaderT transformer when wrapped monad has plus, zero *)
752   module TP : functor (Wrapped : Monad.P) -> sig
753     include module type of T(Wrapped)
754     include Monad.PLUS with type 'a m := 'a m
755   end
756   module T2 : functor (Wrapped : Monad.S2) -> sig
757     type ('x,'a) result = env -> ('x,'a) Wrapped.result
758     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
759     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
760     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
761     val ask : ('x,env) m
762     val asks : (env -> 'a) -> ('x,'a) m
763     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
764   end
765   module TP2 : functor (Wrapped : Monad.P2) -> sig
766     include module type of T2(Wrapped)
767     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
768   end
769 end = struct
770   type env = Env.env
771   module Base = struct
772     type 'a m = env -> 'a
773     type 'a result = env -> 'a
774     type 'a result_exn = env -> 'a
775     let unit a = fun e -> a
776     let bind u f = fun e -> let a = u e in let u' = f a in u' e
777     let run u = fun e -> u e
778     let run_exn = run
779   end
780   include Monad.Make(Base)
781   let ask = fun e -> e
782   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
783   let local modifier u = fun e -> u (modifier e)
784   module T(Wrapped : Monad.S) = struct
785     module Trans = struct
786       module Wrapped = Wrapped
787       type 'a m = env -> 'a Wrapped.m
788       type 'a result = env -> 'a Wrapped.result
789       type 'a result_exn = env -> 'a Wrapped.result_exn
790       let elevate w = fun e -> w
791       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
792       let run u = fun e -> Wrapped.run (u e)
793       let run_exn u = fun e -> Wrapped.run_exn (u e)
794     end
795     include Monad.MakeT(Trans)
796     let ask = fun e -> Wrapped.unit e
797     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
798     let local modifier u = fun e -> u (modifier e)
799   end
800   module TP(Wrapped : Monad.P) = struct
801     module TransP = struct
802       include T(Wrapped)
803       let plus u v = fun s -> Wrapped.plus (u s) (v s)
804       let zero () = elevate (Wrapped.zero ())
805       let asks selector = ask >>= (fun e ->
806         try unit (selector e)
807         with Not_found -> fun e -> Wrapped.zero ())
808     end
809     include TransP
810     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
811   end
812   module T2(Wrapped : Monad.S2) = struct
813     module Trans = struct
814       module Wrapped = Wrapped
815       type ('x,'a) m = env -> ('x,'a) Wrapped.m
816       type ('x,'a) result = env -> ('x,'a) Wrapped.result
817       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
818       (* code repetition, ugh *)
819       let elevate w = fun e -> w
820       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
821       let run u = fun e -> Wrapped.run (u e)
822       let run_exn u = fun e -> Wrapped.run_exn (u e)
823     end
824     include Monad.MakeT2(Trans)
825     let ask = fun e -> Wrapped.unit e
826     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
827     let local modifier u = fun e -> u (modifier e)
828   end
829   module TP2(Wrapped : Monad.P2) = struct
830     module TransP = struct
831       (* code repetition, ugh *)
832       include T2(Wrapped)
833       let plus u v = fun s -> Wrapped.plus (u s) (v s)
834       let zero () = elevate (Wrapped.zero ())
835       let asks selector = ask >>= (fun e ->
836         try unit (selector e)
837         with Not_found -> fun e -> Wrapped.zero ())
838     end
839     include TransP
840     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
841   end
842 end
843
844
845 (* must be parameterized on (struct type store = ... end) *)
846 module State_monad(Store : sig type store end) : sig
847   (* declare additional operations, while still hiding implementation of type m *)
848   type store = Store.store
849   type 'a result = store -> 'a * store
850   type 'a result_exn = store -> 'a
851   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
852   val get : store m
853   val gets : (store -> 'a) -> 'a m
854   val put : store -> unit m
855   val puts : (store -> store) -> unit m
856   (* StateT transformer *)
857   module T : functor (Wrapped : Monad.S) -> sig
858     type 'a result = store -> ('a * store) Wrapped.result
859     type 'a result_exn = store -> 'a Wrapped.result_exn
860     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
861     val elevate : 'a Wrapped.m -> 'a m
862     val get : store m
863     val gets : (store -> 'a) -> 'a m
864     val put : store -> unit m
865     val puts : (store -> store) -> unit m
866   end
867   (* StateT transformer when wrapped monad has plus, zero *)
868   module TP : functor (Wrapped : Monad.P) -> sig
869     include module type of T(Wrapped)
870     include Monad.PLUS with type 'a m := 'a m
871   end
872   module T2 : functor (Wrapped : Monad.S2) -> sig
873     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
874     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
875     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
876     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
877     val get : ('x,store) m
878     val gets : (store -> 'a) -> ('x,'a) m
879     val put : store -> ('x,unit) m
880     val puts : (store -> store) -> ('x,unit) m
881   end
882   module TP2 : functor (Wrapped : Monad.P2) -> sig
883     include module type of T2(Wrapped)
884     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
885   end
886 end = struct
887   type store = Store.store
888   module Base = struct
889     type 'a m = store -> 'a * store
890     type 'a result = store -> 'a * store
891     type 'a result_exn = store -> 'a
892     let unit a = fun s -> (a, s)
893     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
894     let run u = fun s -> (u s)
895     let run_exn u = fun s -> fst (u s)
896   end
897   include Monad.Make(Base)
898   let get = fun s -> (s, s)
899   let gets viewer = fun s -> (viewer s, s) (* may fail *)
900   let put s = fun _ -> ((), s)
901   let puts modifier = fun s -> ((), modifier s)
902   module T(Wrapped : Monad.S) = struct
903     module Trans = struct
904       module Wrapped = Wrapped
905       type 'a m = store -> ('a * store) Wrapped.m
906       type 'a result = store -> ('a * store) Wrapped.result
907       type 'a result_exn = store -> 'a Wrapped.result_exn
908       let elevate w = fun s ->
909         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
910       let bind u f = fun s ->
911         Wrapped.bind (u s) (fun (a, s') -> f a s')
912       let run u = fun s -> Wrapped.run (u s)
913       let run_exn u = fun s ->
914         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
915         in Wrapped.run_exn w
916     end
917     include Monad.MakeT(Trans)
918     let get = fun s -> Wrapped.unit (s, s)
919     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
920     let put s = fun _ -> Wrapped.unit ((), s)
921     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
922   end
923   module TP(Wrapped : Monad.P) = struct
924     module TransP = struct
925       include T(Wrapped)
926       let plus u v = fun s -> Wrapped.plus (u s) (v s)
927       let zero () = elevate (Wrapped.zero ())
928     end
929     let gets viewer = fun s ->
930       try Wrapped.unit (viewer s, s)
931       with Not_found -> Wrapped.zero ()
932     include TransP
933     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
934   end
935   module T2(Wrapped : Monad.S2) = struct
936     module Trans = struct
937       module Wrapped = Wrapped
938       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
939       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
940       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
941       (* code repetition, ugh *)
942       let elevate w = fun s ->
943         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
944       let bind u f = fun s ->
945         Wrapped.bind (u s) (fun (a, s') -> f a s')
946       let run u = fun s -> Wrapped.run (u s)
947       let run_exn u = fun s ->
948         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
949         in Wrapped.run_exn w
950     end
951     include Monad.MakeT2(Trans)
952     let get = fun s -> Wrapped.unit (s, s)
953     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
954     let put s = fun _ -> Wrapped.unit ((), s)
955     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
956   end
957   module TP2(Wrapped : Monad.P2) = struct
958     module TransP = struct
959       include T2(Wrapped)
960       let plus u v = fun s -> Wrapped.plus (u s) (v s)
961       let zero () = elevate (Wrapped.zero ())
962     end
963     let gets viewer = fun s ->
964       try Wrapped.unit (viewer s, s)
965       with Not_found -> Wrapped.zero ()
966     include TransP
967     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
968   end
969 end
970
971 (* State monad with different interface (structured store) *)
972 module Ref_monad(V : sig
973   type value
974 end) : sig
975   type ref
976   type value = V.value
977   type 'a result = 'a
978   type 'a result_exn = 'a
979   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
980   val newref : value -> ref m
981   val deref : ref -> value m
982   val change : ref -> value -> unit m
983   (* RefT transformer *)
984   module T : functor (Wrapped : Monad.S) -> sig
985     type 'a result = 'a Wrapped.result
986     type 'a result_exn = 'a Wrapped.result_exn
987     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
988     val elevate : 'a Wrapped.m -> 'a m
989     val newref : value -> ref m
990     val deref : ref -> value m
991     val change : ref -> value -> unit m
992   end
993   (* RefT transformer when wrapped monad has plus, zero *)
994   module TP : functor (Wrapped : Monad.P) -> sig
995     include module type of T(Wrapped)
996     include Monad.PLUS with type 'a m := 'a m
997   end
998   module T2 : functor (Wrapped : Monad.S2) -> sig
999     type ('x,'a) result = ('x,'a) Wrapped.result
1000     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1001     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1002     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1003     val newref : value -> ('x,ref) m
1004     val deref : ref -> ('x,value) m
1005     val change : ref -> value -> ('x,unit) m
1006   end
1007   module TP2 : functor (Wrapped : Monad.P2) -> sig
1008     include module type of T2(Wrapped)
1009     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1010   end
1011 end = struct
1012   type ref = int
1013   type value = V.value
1014   module D = Map.Make(struct type t = ref let compare = compare end)
1015   type dict = { next: ref; tree : value D.t }
1016   let empty = { next = 0; tree = D.empty }
1017   let alloc (value : value) (d : dict) =
1018     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
1019   let read (key : ref) (d : dict) =
1020     D.find key d.tree
1021   let write (key : ref) (value : value) (d : dict) =
1022     { next = d.next; tree = D.add key value d.tree }
1023   module Base = struct
1024     type 'a m = dict -> 'a * dict
1025     type 'a result = 'a
1026     type 'a result_exn = 'a
1027     let unit a = fun s -> (a, s)
1028     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
1029     let run u = fst (u empty)
1030     let run_exn = run
1031   end
1032   include Monad.Make(Base)
1033   let newref value = fun s -> alloc value s
1034   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
1035   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
1036   module T(Wrapped : Monad.S) = struct
1037     module Trans = struct
1038       module Wrapped = Wrapped
1039       type 'a m = dict -> ('a * dict) Wrapped.m
1040       type 'a result = 'a Wrapped.result
1041       type 'a result_exn = 'a Wrapped.result_exn
1042       let elevate w = fun s ->
1043         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1044       let bind u f = fun s ->
1045         Wrapped.bind (u s) (fun (a, s') -> f a s')
1046       let run u =
1047         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1048         in Wrapped.run w
1049       let run_exn u =
1050         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1051         in Wrapped.run_exn w
1052     end
1053     include Monad.MakeT(Trans)
1054     let newref value = fun s -> Wrapped.unit (alloc value s)
1055     let deref key = fun s -> Wrapped.unit (read key s, s)
1056     let change key value = fun s -> Wrapped.unit ((), write key value s)
1057   end
1058   module TP(Wrapped : Monad.P) = struct
1059     module TransP = struct
1060       include T(Wrapped)
1061       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1062       let zero () = elevate (Wrapped.zero ())
1063     end
1064     include TransP
1065     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1066   end
1067   module T2(Wrapped : Monad.S2) = struct
1068     module Trans = struct
1069       module Wrapped = Wrapped
1070       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
1071       type ('x,'a) result = ('x,'a) Wrapped.result
1072       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1073       (* code repetition, ugh *)
1074       let elevate w = fun s ->
1075         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1076       let bind u f = fun s ->
1077         Wrapped.bind (u s) (fun (a, s') -> f a s')
1078       let run u =
1079         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1080         in Wrapped.run w
1081       let run_exn u =
1082         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1083         in Wrapped.run_exn w
1084     end
1085     include Monad.MakeT2(Trans)
1086     let newref value = fun s -> Wrapped.unit (alloc value s)
1087     let deref key = fun s -> Wrapped.unit (read key s, s)
1088     let change key value = fun s -> Wrapped.unit ((), write key value s)
1089   end
1090   module TP2(Wrapped : Monad.P2) = struct
1091     module TransP = struct
1092       include T2(Wrapped)
1093       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1094       let zero () = elevate (Wrapped.zero ())
1095     end
1096     include TransP
1097     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1098   end
1099 end
1100
1101
1102 (* must be parameterized on (struct type log = ... end) *)
1103 module Writer_monad(Log : sig
1104   type log
1105   val zero : log
1106   val plus : log -> log -> log
1107 end) : sig
1108   (* declare additional operations, while still hiding implementation of type m *)
1109   type log = Log.log
1110   type 'a result = 'a * log
1111   type 'a result_exn = 'a * log
1112   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1113   val tell : log -> unit m
1114   val listen : 'a m -> ('a * log) m
1115   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
1116   (* val pass : ('a * (log -> log)) m -> 'a m *)
1117   val censor : (log -> log) -> 'a m -> 'a m
1118 end = struct
1119   type log = Log.log
1120   module Base = struct
1121     type 'a m = 'a * log
1122     type 'a result = 'a * log
1123     type 'a result_exn = 'a * log
1124     let unit a = (a, Log.zero)
1125     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
1126     let run u = u
1127     let run_exn = run
1128   end
1129   include Monad.Make(Base)
1130   let tell entries = ((), entries) (* add entries to log *)
1131   let listen (a, w) = ((a, w), w)
1132   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
1133   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
1134   let censor f u = pass (u >>= fun a -> unit (a, f))
1135 end
1136
1137 (* pre-define simple Writer *)
1138 module Writer1 = Writer_monad(struct
1139   type log = string
1140   let zero = ""
1141   let plus s1 s2 = s1 ^ "\n" ^ s2
1142 end)
1143
1144 (* slightly more efficient Writer *)
1145 module Writer2 = struct
1146   include Writer_monad(struct
1147     type log = string list
1148     let zero = []
1149     let plus w w' = Util.append w' w
1150   end)
1151   let tell_string s = tell [s]
1152   let tell entries = tell (Util.reverse entries)
1153   let run u = let (a, w) = run u in (a, Util.reverse w)
1154   let run_exn = run
1155 end
1156
1157
1158 module IO_monad : sig
1159   (* declare additional operation, while still hiding implementation of type m *)
1160   type 'a result = 'a
1161   type 'a result_exn = 'a
1162   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1163   val printf : ('a, unit, string, unit m) format4 -> 'a
1164   val print_string : string -> unit m
1165   val print_int : int -> unit m
1166   val print_hex : int -> unit m
1167   val print_bool : bool -> unit m
1168 end = struct
1169   module Base = struct
1170     type 'a m = { run : unit -> unit; value : 'a }
1171     type 'a result = 'a
1172     type 'a result_exn = 'a
1173     let unit a = { run = (fun () -> ()); value = a }
1174     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
1175      let fres = f a.value in
1176        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
1177     let run a = let () = a.run () in a.value
1178     let run_exn = run
1179   end
1180   include Monad.Make(Base)
1181   let printf fmt =
1182     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
1183   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
1184   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
1185   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
1186   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
1187 end
1188
1189 (*
1190 module Continuation_monad : sig
1191   (* expose only the implementation of type `('r,'a) result` *)
1192   type 'a m
1193   type 'a result = 'a m
1194   type 'a result_exn = 'a m
1195   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
1196   (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
1197   (* misses that the answer types of all the cont's must be the same *)
1198   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
1199   (* val reset : ('a,'a) m -> ('r,'a) m *)
1200   val reset : 'a m -> 'a m
1201   (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
1202   (* misses that the answer types of second and third continuations must be b *)
1203   val shift : (('a -> 'b m) -> 'b m) -> 'a m
1204   (* overwrite the run declaration in S, because I can't declare 'a result =
1205    * this polymorphic type (complains that 'r is unbound *)
1206   val runk : 'a m -> ('a -> 'r) -> 'r
1207   val run0 : 'a m -> 'a
1208 end = struct
1209   let id = fun i -> i
1210   module Base = struct
1211     (* 'r is result type of whole computation *)
1212     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
1213     type 'a result = 'a m
1214     type 'a result_exn = 'a m
1215     let unit a =
1216       let cont : 'r. ('a -> 'r) -> 'r =
1217         fun k -> k a
1218       in { cont }
1219     let bind u f =
1220       let cont : 'r. ('a -> 'r) -> 'r =
1221         fun k -> u.cont (fun a -> (f a).cont k)
1222       in { cont }
1223     let run (u : 'a m) : 'a result = u
1224     let run_exn (u : 'a m) : 'a result_exn = u
1225     let callcc f =
1226       let cont : 'r. ('a -> 'r) -> 'r =
1227           (* Can't figure out how to make the type polymorphic enough
1228            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
1229            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
1230            * with Obj.magic... which tells OCaml's type checker to
1231            * relax, the supplied value has whatever type the context
1232            * needs it to have. *)
1233           fun k ->
1234           let usek a = { cont = Obj.magic (fun _ -> k a) }
1235           in (f usek).cont k
1236       in { cont }
1237     let reset u = unit (u.cont id)
1238     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
1239       let cont = fun k ->
1240         (f (fun a -> unit (k a))).cont id
1241       in { cont = Obj.magic cont }
1242     let runk u k = (u.cont : ('a -> 'r) -> 'r) k
1243     let run0 u = runk u id
1244   end
1245   include Monad.Make(Base)
1246   let callcc = Base.callcc
1247   let reset = Base.reset
1248   let shift = Base.shift
1249   let runk = Base.runk
1250   let run0 = Base.run0
1251 end
1252  *)
1253
1254 (* This two-type parameter version works without Obj.magic *)
1255 module Continuation_monad : sig
1256   (* expose only the implementation of type `('r,'a) result` *)
1257   type ('r,'a) m
1258   type ('r,'a) result = ('r,'a) m
1259   type ('r,'a) result_exn = ('a -> 'r) -> 'r
1260   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
1261   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
1262   val reset : ('a,'a) m -> ('r,'a) m
1263   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
1264   (* val abort : ('a,'a) m -> ('a,'b) m *)
1265   val abort : 'a -> ('a,'b) m
1266   val run0 : ('a,'a) m -> 'a
1267 end = struct
1268   let id = fun i -> i
1269   module Base = struct
1270     (* 'r is result type of whole computation *)
1271     type ('r,'a) m = ('a -> 'r) -> 'r
1272     type ('r,'a) result = ('a -> 'r) -> 'r
1273     type ('r,'a) result_exn = ('r,'a) result
1274     let unit a = (fun k -> k a)
1275     let bind u f = (fun k -> (u) (fun a -> (f a) k))
1276     let run u k = (u) k
1277     let run_exn = run
1278   end
1279   include Monad.Make2(Base)
1280   let callcc f = (fun k ->
1281     let usek a = (fun _ -> k a)
1282     in (f usek) k)
1283   (*
1284   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
1285   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
1286   let callcc f = fun k -> f k k
1287   let throw k a = fun _ -> k a
1288   *)
1289
1290   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
1291    *
1292    *  reset :: (Monad m) => ContT a m a -> ContT r m a
1293    *  reset e = ContT $ \k -> runContT e return >>= k
1294    *
1295    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
1296    *  shift e = ContT $ \k ->
1297    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
1298   let reset u = unit ((u) id)
1299   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
1300   (* let abort a = shift (fun _ -> a) *)
1301   let abort a = shift (fun _ -> unit a)
1302   let run0 (u : ('a,'a) m) = (u) id
1303 end
1304
1305
1306 (*
1307  * Scheme:
1308  * (define (example n)
1309  *    (let ([u (let/cc k ; type int -> int pair
1310  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
1311  *                 (+ 1 (car v))))]) ; int
1312  *      (cons u 0))) ; int pair
1313  * ; (example 10) ~~> '(111 . 0)
1314  * ; (example -10) ~~> '(0 . 0)
1315  *
1316  * OCaml monads:
1317  * let example n : (int * int) =
1318  *   Continuation_monad.(let u = callcc (fun k ->
1319  *       (if n < 0 then k 0 else unit [n + 100])
1320  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
1321  *       >>= fun [x] -> unit (x + 1)
1322  *   )
1323  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1324  *   >>= fun x -> unit (x, 0)
1325  *   in run u)
1326  *
1327  *
1328  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1329  * let example1 () : int =
1330  *   Continuation_monad.(let v = reset (
1331  *       let u = shift (fun k -> unit (10 + 1))
1332  *       in u >>= fun x -> unit (100 + x)
1333  *     ) in let w = v >>= fun x -> unit (1000 + x)
1334  *     in run w)
1335  *
1336  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1337  * let example2 () =
1338  *   Continuation_monad.(let v = reset (
1339  *       let u = shift (fun k -> k (10 :: [1]))
1340  *       in u >>= fun x -> unit (100 :: x)
1341  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1342  *     in run w)
1343  *
1344  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1345  * let example3 () =
1346  *   Continuation_monad.(let v = reset (
1347  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1348  *       in u >>= fun x -> unit (100 :: x)
1349  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1350  *     in run w)
1351  *
1352  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1353  * (* not sure if this example can be typed without a sum-type *)
1354  *
1355  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1356  * let example5 () : int =
1357  *   Continuation_monad.(let v = reset (
1358  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1359  *       in u >>= fun x -> unit (10 + x)
1360  *     ) in let w = v >>= fun x -> unit (100 + x)
1361  *     in run w)
1362  *
1363  *)
1364
1365
1366 module Leaf_monad : sig
1367   (* We implement the type as `'a tree option` because it has a natural`plus`,
1368    * and the rest of the library expects that `plus` and `zero` will come together. *)
1369   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1370   type 'a result = 'a tree option
1371   type 'a result_exn = 'a tree
1372   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1373   include Monad.PLUS with type 'a m := 'a m
1374   (* LeafT transformer *)
1375   module T : functor (Wrapped : Monad.S) -> sig
1376     type 'a result = 'a tree option Wrapped.result
1377     type 'a result_exn = 'a tree Wrapped.result_exn
1378     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1379     include Monad.PLUS with type 'a m := 'a m
1380     val elevate : 'a Wrapped.m -> 'a m
1381     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1382     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1383     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1384   end
1385   module T2 : functor (Wrapped : Monad.S2) -> sig
1386     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1387     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1388     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1389     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1390     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1391     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1392   end
1393 end = struct
1394   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1395   (* uses supplied plus and zero to copy t to its image under f *)
1396   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1397       | None -> zero ()
1398       | Some ts -> let rec loop ts = (match ts with
1399                      | Leaf a -> f a
1400                      | Node (l, r) ->
1401                          (* recursive application of f may delete a branch *)
1402                          plus (loop l) (loop r)
1403                    ) in loop ts
1404   module Base = struct
1405     type 'a m = 'a tree option
1406     type 'a result = 'a tree option
1407     type 'a result_exn = 'a tree
1408     let unit a = Some (Leaf a)
1409     let zero () = None
1410     let plus u v = match (u, v) with
1411       | None, _ -> v
1412       | _, None -> u
1413       | Some us, Some vs -> Some (Node (us, vs))
1414     let bind u f = mapT f u zero plus
1415     let run u = u
1416     let run_exn u = match u with
1417       | None -> failwith "no values"
1418       (*
1419       | Some (Leaf a) -> a
1420       | many -> failwith "multiple values"
1421       *)
1422       | Some us -> us
1423   end
1424   include Monad.Make(Base)
1425   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1426   let base_plus = plus
1427   let base_lift = lift
1428   module T(Wrapped : Monad.S) = struct
1429     module Trans = struct
1430       let zero () = Wrapped.unit None
1431       let plus u v =
1432         Wrapped.bind u (fun us ->
1433         Wrapped.bind v (fun vs ->
1434         Wrapped.unit (base_plus us vs)))
1435       include Monad.MakeT(struct
1436         module Wrapped = Wrapped
1437         type 'a m = 'a tree option Wrapped.m
1438         type 'a result = 'a tree option Wrapped.result
1439         type 'a result_exn = 'a tree Wrapped.result_exn
1440         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1441         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1442         let run u = Wrapped.run u
1443         let run_exn u =
1444             let w = Wrapped.bind u (fun t -> match t with
1445               | None -> failwith "no values"
1446               | Some ts -> Wrapped.unit ts)
1447             in Wrapped.run_exn w
1448       end)
1449     end
1450     include Trans
1451     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1452     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1453     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1454   end
1455   module T2(Wrapped : Monad.S2) = struct
1456     module Trans = struct
1457       let zero () = Wrapped.unit None
1458       let plus u v =
1459         Wrapped.bind u (fun us ->
1460         Wrapped.bind v (fun vs ->
1461         Wrapped.unit (base_plus us vs)))
1462       include Monad.MakeT2(struct
1463         module Wrapped = Wrapped
1464         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1465         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1466         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1467         (* code repetition, ugh *)
1468         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1469         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1470         let run u = Wrapped.run u
1471         let run_exn u =
1472             let w = Wrapped.bind u (fun t -> match t with
1473               | None -> failwith "no values"
1474               | Some ts -> Wrapped.unit ts)
1475             in Wrapped.run_exn w
1476       end)
1477     end
1478     include Trans
1479     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1480     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1481   end
1482 end
1483
1484
1485 module L = List_monad;;
1486 module R = Reader_monad(struct type env = int -> int end);;
1487 module S = State_monad(struct type store = int end);;
1488 module T = Leaf_monad;;
1489 module LR = L.T(R);;
1490 module LS = L.T(S);;
1491 module TL = T.T(L);;
1492 module TR = T.T(R);;
1493 module TS = T.T(S);;
1494 module C = Continuation_monad
1495 module TC = T.T2(C);;
1496
1497
1498 print_endline "=== test Leaf(...).distribute ==================";;
1499
1500 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1501
1502 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1503 TS.run ts 0;;
1504 (*
1505 - : int T.tree option * S.store =
1506 (Some
1507   (T.Node
1508     (T.Node (T.Leaf 2, T.Leaf 3),
1509      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1510  5)
1511 *)
1512
1513 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1514 TS.run_exn ts2 0;;
1515 (*
1516 - : (int * S.store) T.tree option * S.store =
1517 (Some
1518   (T.Node
1519     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1520      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1521  5)
1522 *)
1523
1524 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1525 TR.run_exn tr (fun i -> i+i);;
1526 (*
1527 - : int T.tree option =
1528 Some
1529  (T.Node
1530    (T.Node (T.Leaf 4, T.Leaf 6),
1531     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1532 *)
1533
1534 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1535 TL.run_exn tl;;
1536 (*
1537 - : (int * int) TL.result =
1538 [Some
1539   (T.Node
1540     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1541      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1542 *)
1543
1544 let l2 = [1;2;3;4;5];;
1545 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1546
1547 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1548 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1549
1550 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1551 (*
1552 int T.tree option =
1553 Some
1554  (T.Node
1555    (T.Node (T.Leaf 10, T.Leaf 11),
1556     T.Node
1557      (T.Node
1558        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1559         T.Node (T.Leaf 40, T.Leaf 41)),
1560       T.Node (T.Leaf 50, T.Leaf 51))))
1561  *)
1562
1563 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1564 (*
1565 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1566 *)
1567
1568 print_endline "=== test Leaf(Continuation).distribute ==================";;
1569
1570 let id : 'z. 'z -> 'z = fun x -> x
1571
1572 let example n : (int * int) =
1573   Continuation_monad.(let u = callcc (fun k ->
1574       (if n < 0 then k 0 else unit [n + 100])
1575       (* all of the following is skipped by k 0; the end type int is k's input type *)
1576       >>= fun [x] -> unit (x + 1)
1577   )
1578   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1579   >>= fun x -> unit (x, 0)
1580   in run0 u)
1581
1582
1583 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1584 let example1 () : int =
1585   Continuation_monad.(let v = reset (
1586       let u = shift (fun k -> unit (10 + 1))
1587       in u >>= fun x -> unit (100 + x)
1588     ) in let w = v >>= fun x -> unit (1000 + x)
1589     in run0 w)
1590
1591 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1592 let example2 () =
1593   Continuation_monad.(let v = reset (
1594       let u = shift (fun k -> k (10 :: [1]))
1595       in u >>= fun x -> unit (100 :: x)
1596     ) in let w = v >>= fun x -> unit (1000 :: x)
1597     in run0 w)
1598
1599 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1600 let example3 () =
1601   Continuation_monad.(let v = reset (
1602       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1603       in u >>= fun x -> unit (100 :: x)
1604     ) in let w = v >>= fun x -> unit (1000 :: x)
1605     in run0 w)
1606
1607 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1608 (* not sure if this example can be typed without a sum-type *)
1609
1610 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1611 let example5 () : int =
1612   Continuation_monad.(let v = reset (
1613       let u = shift (fun k -> k 1 >>= k)
1614       in u >>= fun x -> unit (10 + x)
1615     ) in let w = v >>= fun x -> unit (100 + x)
1616     in run0 w)
1617
1618 ;;
1619
1620 print_endline "=== test bare Continuation ============";;
1621
1622 (1011, 1111, 1111, 121);;
1623 (example1(), example2(), example3(), example5());;
1624 ((111,0), (0,0));;
1625 (example ~+10, example ~-10);;
1626
1627 let testc df ic =
1628     C.run_exn TC.(run (distribute df t1)) ic;;
1629
1630
1631 (*
1632 (* do nothing *)
1633 let initial_continuation = fun t -> t in
1634 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1635 *)
1636 testc (C.unit) id;;
1637
1638 (*
1639 (* count leaves, using continuation *)
1640 let initial_continuation = fun t -> 0 in
1641 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1642 *)
1643
1644 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1645
1646 (*
1647 (* convert tree to list of leaves *)
1648 let initial_continuation = fun t -> [] in
1649 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1650 *)
1651
1652 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1653
1654 (*
1655 (* square each leaf using continuation *)
1656 let initial_continuation = fun t -> t in
1657 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1658 *)
1659
1660 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1661
1662
1663 (*
1664 (* replace leaves with list, using continuation *)
1665 let initial_continuation = fun t -> t in
1666 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1667 *)
1668
1669 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1670
1671 print_endline "=== pa_monad's Continuation Tests ============";;
1672
1673 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1674 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1675 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1676 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1677 (5, 27 = C.(run0 (
1678               let c = reset(abort 5 >>= fun y -> unit (y+6))
1679               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1680
1681 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1682
1683 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1684
1685 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1686
1687
1688 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1689