tweak monads-lib
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73   (*
74    * Signature extenders:
75    *   Make :: BASE -> S
76    *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
77    *                                           which merges into S
78    *                                           (P is merged sig)
79    *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
80    *
81    *   Make2 :: BASE2 -> S2
82    *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
83    *   to wrap double-typed inner monads:
84    *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
85    *
86    *)
87
88
89
90   (* type of base definitions *)
91   module type BASE = sig
92     (* The only constraints we impose here on how the monadic type
93      * is implemented is that it have a single type parameter 'a. *)
94     type 'a m
95     type 'a result
96     type 'a result_exn
97     val unit : 'a -> 'a m
98     val bind : 'a m -> ('a -> 'b m) -> 'b m
99     val run : 'a m -> 'a result
100     (* run_exn tries to provide a more ground-level result, but may fail *)
101     val run_exn : 'a m -> 'a result_exn
102   end
103   module type S = sig
104     include BASE
105     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
106     val (>>) : 'a m -> 'b m -> 'b m
107     val join : ('a m) m -> 'a m
108     val apply : ('a -> 'b) m -> 'a m -> 'b m
109     val lift : ('a -> 'b) -> 'a m -> 'b m
110     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
111     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
112     val do_when :  bool -> unit m -> unit m
113     val do_unless :  bool -> unit m -> unit m
114     val forever : (unit -> 'a m) -> 'b m
115     val sequence : 'a m list -> 'a list m
116     val sequence_ : 'a m list -> unit m
117   end
118
119   (* Standard, single-type-parameter monads. *)
120   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
121     include B
122     let (>>=) = bind
123     let (>>) u v = u >>= fun _ -> v
124     let lift f u = u >>= fun a -> unit (f a)
125     (* lift is called listM, fmap, and <$> in Haskell *)
126     let join uu = uu >>= fun u -> u
127     (* u >>= f === join (lift f u) *)
128     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
129     (* [f] <*> [x1,x2] = [f x1,f x2] *)
130     (* let apply u v = u >>= fun f -> lift f v *)
131     (* let apply = lift2 id *)
132     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
133     (* let lift f u === apply (unit f) u *)
134     (* let lift2 f u v = apply (lift f u) v *)
135     let (>=>) f g = fun a -> f a >>= g
136     let do_when test u = if test then u else unit ()
137     let do_unless test u = if test then unit () else u
138     let forever uthunk =
139         let rec loop () = uthunk () >>= fun _ -> loop ()
140         in loop ()
141     let sequence ms =
142       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
143         Util.fold_right op ms (unit [])
144     let sequence_ ms =
145       Util.fold_right (>>) ms (unit ())
146
147     (* Haskell defines these other operations combining lists and monads.
148      * We don't, but notice that M.mapM == ListT(M).distribute
149      * There's also a parallel TreeT(M).distribute *)
150     (*
151     let mapM f alist = sequence (Util.map f alist)
152     let mapM_ f alist = sequence_ (Util.map f alist)
153     let rec filterM f lst = match lst with
154       | [] -> unit []
155       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
156     let forM alist f = mapM f alist
157     let forM_ alist f = mapM_ f alist
158     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
159     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
160     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
161     let rec foldM f z lst = match lst with
162       | [] -> unit z
163       | x::xs -> f z x >>= fun z' -> foldM f z' xs
164     let foldM_ f z xs = foldM f z xs >> unit ()
165     let replicateM n x = sequence (Util.replicate n x)
166     let replicateM_ n x = sequence_ (Util.replicate n x)
167     *)
168   end
169
170   (* Single-type-parameter monads that also define `plus` and `zero`
171    * operations. These obey the following laws:
172    *     zero >>= f   ===  zero
173    *     plus zero u  ===  u
174    *     plus u zero  ===  u
175    * Additionally, these monads will obey one of the following laws:
176    *     (Catch)   plus (unit a) v  ===  unit a
177    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
178    *)
179   module type PLUSBASE = sig
180     include BASE
181     val zero : unit -> 'a m
182     val plus : 'a m -> 'a m -> 'a m
183   end
184   module type PLUS = sig
185     type 'a m
186     val zero : unit -> 'a m
187     val plus : 'a m -> 'a m -> 'a m
188     val guard : bool -> unit m
189     val sum : 'a m list -> 'a m
190   end
191   (* MakeCatch and MakeDistrib have the same implementation; we just declare
192    * them twice to document which laws the client code is promising to honor. *)
193   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
194     type 'a m = 'a B.m
195     let zero = B.zero
196     let plus = B.plus
197     let guard test = if test then B.unit () else zero ()
198     let sum ms = Util.fold_right plus ms (zero ())
199   end
200   module MakeDistrib = MakeCatch
201
202   (* Signatures for MonadT *)
203   (* sig for Wrapped that include S and PLUS *)
204   module type P = sig
205     include S
206     include PLUS with type 'a m := 'a m
207   end
208   module type TRANS = sig
209     module Wrapped : S
210     type 'a m
211     type 'a result
212     type 'a result_exn
213     val bind : 'a m -> ('a -> 'b m) -> 'b m
214     val run : 'a m -> 'a result
215     val run_exn : 'a m -> 'a result_exn
216     val elevate : 'a Wrapped.m -> 'a m
217     (* lift/elevate laws:
218      *     elevate (W.unit a) == unit a
219      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
220      *)
221   end
222   module MakeT(T : TRANS) = struct
223     include Make(struct
224         include T
225         let unit a = elevate (Wrapped.unit a)
226     end)
227     let elevate = T.elevate
228   end
229
230
231   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
232   module type BASE2 = sig
233     type ('x,'a) m
234     type ('x,'a) result
235     type ('x,'a) result_exn
236     val unit : 'a -> ('x,'a) m
237     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
238     val run : ('x,'a) m -> ('x,'a) result
239     val run_exn : ('x,'a) m -> ('x,'a) result_exn
240   end
241   module type S2 = sig
242     include BASE2
243     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
244     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
245     val join : ('x,('x,'a) m) m -> ('x,'a) m
246     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
247     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
248     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
249     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
250     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
251     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
252     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
253     val sequence : ('x,'a) m list -> ('x,'a list) m
254     val sequence_ : ('x,'a) m list -> ('x,unit) m
255   end
256   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
257     (* code repetition, ugh *)
258     include B
259     let (>>=) = bind
260     let (>>) u v = u >>= fun _ -> v
261     let lift f u = u >>= fun a -> unit (f a)
262     let join uu = uu >>= fun u -> u
263     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
264     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
265     let (>=>) f g = fun a -> f a >>= g
266     let do_when test u = if test then u else unit ()
267     let do_unless test u = if test then unit () else u
268     let forever uthunk =
269         let rec loop () = uthunk () >>= fun _ -> loop ()
270         in loop ()
271     let sequence ms =
272       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
273         Util.fold_right op ms (unit [])
274     let sequence_ ms =
275       Util.fold_right (>>) ms (unit ())
276   end
277
278   module type PLUSBASE2 = sig
279     include BASE2
280     val zero : unit -> ('x,'a) m
281     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
282   end
283   module type PLUS2 = sig
284     type ('x,'a) m
285     val zero : unit -> ('x,'a) m
286     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
287     val guard : bool -> ('x,unit) m
288     val sum : ('x,'a) m list -> ('x,'a) m
289   end
290   module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
291     type ('x,'a) m = ('x,'a) B.m
292     (* code repetition, ugh *)
293     let zero = B.zero
294     let plus = B.plus
295     let guard test = if test then B.unit () else zero ()
296     let sum ms = Util.fold_right plus ms (zero ())
297   end
298   module MakeDistrib2 = MakeCatch2
299
300   (* Signatures for MonadT *)
301   (* sig for Wrapped that include S and PLUS *)
302   module type P2 = sig
303     include S2
304     include PLUS2 with type ('x,'a) m := ('x,'a) m
305   end
306   module type TRANS2 = sig
307     module Wrapped : S2
308     type ('x,'a) m
309     type ('x,'a) result
310     type ('x,'a) result_exn
311     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
312     val run : ('x,'a) m -> ('x,'a) result
313     val run_exn : ('x,'a) m -> ('x,'a) result_exn
314     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
315   end
316   module MakeT2(T : TRANS2) = struct
317     (* code repetition, ugh *)
318     include Make2(struct
319         include T
320         let unit a = elevate (Wrapped.unit a)
321     end)
322     let elevate = T.elevate
323   end
324
325 end
326
327
328
329
330
331 module Identity_monad : sig
332   (* expose only the implementation of type `'a result` *)
333   type 'a result = 'a
334   type 'a result_exn = 'a
335   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
336 end = struct
337   module Base = struct
338     type 'a m = 'a
339     type 'a result = 'a
340     type 'a result_exn = 'a
341     let unit a = a
342     let bind a f = f a
343     let run a = a
344     let run_exn a = a
345   end
346   include Monad.Make(Base)
347 end
348
349
350 module Maybe_monad : sig
351   (* expose only the implementation of type `'a result` *)
352   type 'a result = 'a option
353   type 'a result_exn = 'a
354   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
355   include Monad.PLUS with type 'a m := 'a m
356   (* MaybeT transformer *)
357   module T : functor (Wrapped : Monad.S) -> sig
358     type 'a result = 'a option Wrapped.result
359     type 'a result_exn = 'a Wrapped.result_exn
360     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
361     include Monad.PLUS with type 'a m := 'a m
362     val elevate : 'a Wrapped.m -> 'a m
363   end
364   module T2 : functor (Wrapped : Monad.S2) -> sig
365     type ('x,'a) result = ('x,'a option) Wrapped.result
366     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
367     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
368     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
369     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
370   end
371 end = struct
372   module Base = struct
373    type 'a m = 'a option
374    type 'a result = 'a option
375    type 'a result_exn = 'a
376    let unit a = Some a
377    let bind u f = match u with Some a -> f a | None -> None
378    let run u = u
379    let run_exn u = match u with
380      | Some a -> a
381      | None -> failwith "no value"
382    let zero () = None
383    let plus u v = match u with None -> v | _ -> u
384   end
385   include Monad.Make(Base)
386   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
387   module T(Wrapped : Monad.S) = struct
388     module Trans = struct
389       include Monad.MakeT(struct
390         module Wrapped = Wrapped
391         type 'a m = 'a option Wrapped.m
392         type 'a result = 'a option Wrapped.result
393         type 'a result_exn = 'a Wrapped.result_exn
394         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
395         let bind u f = Wrapped.bind u (fun t -> match t with
396           | Some a -> f a
397           | None -> Wrapped.unit None)
398         let run u = Wrapped.run u
399         let run_exn u =
400           let w = Wrapped.bind u (fun t -> match t with
401             | Some a -> Wrapped.unit a
402             | None -> failwith "no value")
403           in Wrapped.run_exn w
404       end)
405       let zero () = Wrapped.unit None
406       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
407     end
408     include Trans
409     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
410   end
411   module T2(Wrapped : Monad.S2) = struct
412     module Trans = struct
413       include Monad.MakeT2(struct
414         module Wrapped = Wrapped
415         type ('x,'a) m = ('x,'a option) Wrapped.m
416         type ('x,'a) result = ('x,'a option) Wrapped.result
417         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
418         (* code repetition, ugh *)
419         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
420         let bind u f = Wrapped.bind u (fun t -> match t with
421           | Some a -> f a
422           | None -> Wrapped.unit None)
423         let run u = Wrapped.run u
424         let run_exn u =
425           let w = Wrapped.bind u (fun t -> match t with
426             | Some a -> Wrapped.unit a
427             | None -> failwith "no value")
428           in Wrapped.run_exn w
429       end)
430       let zero () = Wrapped.unit None
431       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
432     end
433     include Trans
434     include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
435   end
436 end
437
438
439 module List_monad : sig
440   (* declare additional operation, while still hiding implementation of type m *)
441   type 'a result = 'a list
442   type 'a result_exn = 'a
443   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
444   include Monad.PLUS with type 'a m := 'a m
445   val permute : 'a m -> 'a m m
446   val select : 'a m -> ('a * 'a m) m
447   (* ListT transformer *)
448   module T : functor (Wrapped : Monad.S) -> sig
449     type 'a result = 'a list Wrapped.result
450     type 'a result_exn = 'a Wrapped.result_exn
451     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
452     include Monad.PLUS with type 'a m := 'a m
453     val elevate : 'a Wrapped.m -> 'a m
454     (* note that second argument is an 'a list, not the more abstract 'a m *)
455     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
456     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
457 (* TODO
458     val permute : 'a m -> 'a m m
459     val select : 'a m -> ('a * 'a m) m
460 *)
461   end
462   module T2 : functor (Wrapped : Monad.S2) -> sig
463     type ('x,'a) result = ('x,'a list) Wrapped.result
464     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
465     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
466     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
467     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
468     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
469   end
470 end = struct
471   module Base = struct
472    type 'a m = 'a list
473    type 'a result = 'a list
474    type 'a result_exn = 'a
475    let unit a = [a]
476    let bind u f = Util.concat_map f u
477    let run u = u
478    let run_exn u = match u with
479      | [] -> failwith "no values"
480      | [a] -> a
481      | many -> failwith "multiple values"
482    let zero () = []
483    let plus = Util.append
484   end
485   include Monad.Make(Base)
486   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
487   (* let either u v = plus u v *)
488   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
489   let rec insert a u =
490     plus (unit (a :: u)) (match u with
491         | [] -> zero ()
492         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
493     )
494   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
495   let rec permute u = match u with
496       | [] -> unit []
497       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
498   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
499   let rec select u = match u with
500     | [] -> zero ()
501     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
502   let base_plus = plus
503   module T(Wrapped : Monad.S) = struct
504     module Trans = struct
505       (* Wrapped.sequence ms  ===  
506            let plus1 u v =
507              Wrapped.bind u (fun x ->
508              Wrapped.bind v (fun xs ->
509              Wrapped.unit (x :: xs)))
510            in Util.fold_right plus1 ms (Wrapped.unit []) *)
511       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
512       let distribute f alist = Wrapped.sequence (Util.map f alist)
513       include Monad.MakeT(struct
514         module Wrapped = Wrapped
515         type 'a m = 'a list Wrapped.m
516         type 'a result = 'a list Wrapped.result
517         type 'a result_exn = 'a Wrapped.result_exn
518         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
519         let bind u f =
520           Wrapped.bind u (fun ts ->
521           Wrapped.bind (distribute f ts) (fun tts ->
522           Wrapped.unit (Util.concat tts)))
523         let run u = Wrapped.run u
524         let run_exn u =
525           let w = Wrapped.bind u (fun ts -> match ts with
526             | [] -> failwith "no values"
527             | [a] -> Wrapped.unit a
528             | many -> failwith "multiple values"
529           ) in Wrapped.run_exn w
530       end)
531       let zero () = Wrapped.unit []
532       let plus u v =
533         Wrapped.bind u (fun us ->
534         Wrapped.bind v (fun vs ->
535         Wrapped.unit (base_plus us vs)))
536     end
537     include Trans
538     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
539 (*
540     let permute : 'a m -> 'a m m
541     let select : 'a m -> ('a * 'a m) m
542 *)
543   end
544   module T2(Wrapped : Monad.S2) = struct
545     module Trans = struct
546       let distribute f alist = Wrapped.sequence (Util.map f alist)
547       include Monad.MakeT2(struct
548         module Wrapped = Wrapped
549         type ('x,'a) m = ('x,'a list) Wrapped.m
550         type ('x,'a) result = ('x,'a list) Wrapped.result
551         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
552         (* code repetition, ugh *)
553         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
554         let bind u f =
555           Wrapped.bind u (fun ts ->
556           Wrapped.bind (distribute f ts) (fun tts ->
557           Wrapped.unit (Util.concat tts)))
558         let run u = Wrapped.run u
559         let run_exn u =
560           let w = Wrapped.bind u (fun ts -> match ts with
561             | [] -> failwith "no values"
562             | [a] -> Wrapped.unit a
563             | many -> failwith "multiple values"
564           ) in Wrapped.run_exn w
565       end)
566       let zero () = Wrapped.unit []
567       let plus u v =
568         Wrapped.bind u (fun us ->
569         Wrapped.bind v (fun vs ->
570         Wrapped.unit (base_plus us vs)))
571     end
572     include Trans
573     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
574   end
575 end
576
577
578 (* must be parameterized on (struct type err = ... end) *)
579 module Error_monad(Err : sig
580   type err
581   exception Exc of err
582   (*
583   val zero : unit -> err
584   val plus : err -> err -> err
585   *)
586 end) : sig
587   (* declare additional operations, while still hiding implementation of type m *)
588   type err = Err.err
589   type 'a error = Error of err | Success of 'a
590   type 'a result = 'a error
591   type 'a result_exn = 'a
592   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
593   (* include Monad.PLUS with type 'a m := 'a m *)
594   val throw : err -> 'a m
595   val catch : 'a m -> (err -> 'a m) -> 'a m
596   (* ErrorT transformer *)
597   module T : functor (Wrapped : Monad.S) -> sig
598     type 'a result = 'a error Wrapped.result
599     type 'a result_exn = 'a Wrapped.result_exn
600     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
601     val elevate : 'a Wrapped.m -> 'a m
602     val throw : err -> 'a m
603     val catch : 'a m -> (err -> 'a m) -> 'a m
604   end
605   (* ErrorT transformer when wrapped monad has plus, zero *)
606   module TP : functor (Wrapped : Monad.P) -> sig
607     type 'a result = 'a Wrapped.result
608     type 'a result_exn = 'a Wrapped.result_exn
609     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
610     val elevate : 'a Wrapped.m -> 'a m
611     val throw : err -> 'a m
612     val catch : 'a m -> (err -> 'a m) -> 'a m
613     include Monad.PLUS with type 'a m := 'a m
614   end
615   module T2 : functor (Wrapped : Monad.S2) -> sig
616     type ('x,'a) result = ('x,'a error) Wrapped.result
617     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
618     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
619     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
620     val throw : err -> ('x,'a) m
621     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
622   end
623   module TP2 : functor (Wrapped : Monad.P2) -> sig
624     type ('x,'a) result = ('x,'a) Wrapped.result
625     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
626     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
627     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
628     val throw : err -> ('x,'a) m
629     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
630     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
631   end
632 end = struct
633   type err = Err.err
634   type 'a error = Error of err | Success of 'a
635   module Base = struct
636     type 'a m = 'a error
637     type 'a result = 'a error
638     type 'a result_exn = 'a
639     let unit a = Success a
640     let bind u f = match u with
641       | Success a -> f a
642       | Error e -> Error e (* input and output may be of different 'a types *)
643     let run u = u
644     let run_exn u = match u with
645       | Success a -> a
646       | Error e -> raise (Err.Exc e)
647     (*
648     let zero () = Error Err.zero
649     let plus u v = match (u, v) with
650       | Success _, _ -> u
651       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
652        * otherwise, plus (Error _) v = v *)
653       | Error _, _ when v = zero -> u
654       (* combine errors *)
655       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
656       | Error _, _ -> v
657     *)
658   end
659   include Monad.Make(Base)
660   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
661   let throw e = Error e
662   let catch u handler = match u with
663     | Success _ -> u
664     | Error e -> handler e
665   module T(Wrapped : Monad.S) = struct
666     module Trans = struct
667       module Wrapped = Wrapped
668       type 'a m = 'a error Wrapped.m
669       type 'a result = 'a error Wrapped.result
670       type 'a result_exn = 'a Wrapped.result_exn
671       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
672       let bind u f = Wrapped.bind u (fun t -> match t with
673         | Success a -> f a
674         | Error e -> Wrapped.unit (Error e))
675       let run u = Wrapped.run u
676       let run_exn u =
677         let w = Wrapped.bind u (fun t -> match t with
678           | Success a -> Wrapped.unit a
679           | Error e -> raise (Err.Exc e))
680         in Wrapped.run_exn w
681     end
682     include Monad.MakeT(Trans)
683     let throw e = Wrapped.unit (Error e)
684     let catch u handler = Wrapped.bind u (fun t -> match t with
685       | Success _ -> Wrapped.unit t
686       | Error e -> handler e)
687   end
688   module TP(Wrapped : Monad.P) = struct
689     (* code repetition, ugh *)
690     module TransP = struct
691       include Monad.MakeT(struct
692         module Wrapped = Wrapped
693         type 'a m = 'a error Wrapped.m
694         type 'a result = 'a Wrapped.result
695         type 'a result_exn = 'a Wrapped.result_exn
696         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
697         let bind u f = Wrapped.bind u (fun t -> match t with
698           | Success a -> f a
699           | Error e -> Wrapped.unit (Error e))
700         let run u =
701           let w = Wrapped.bind u (fun t -> match t with
702             | Success a -> Wrapped.unit a
703             | Error e -> Wrapped.zero ())
704           in Wrapped.run w
705         let run_exn u =
706           let w = Wrapped.bind u (fun t -> match t with
707             | Success a -> Wrapped.unit a
708             | Error e -> raise (Err.Exc e))
709           in Wrapped.run_exn w
710       end)
711       let throw e = Wrapped.unit (Error e)
712       let catch u handler = Wrapped.bind u (fun t -> match t with
713         | Success _ -> Wrapped.unit t
714         | Error e -> handler e)
715       let plus u v = Wrapped.plus u v
716       let zero () = elevate (Wrapped.zero ())
717     end
718     include TransP
719     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
720   end
721   module T2(Wrapped : Monad.S2) = struct
722     module Trans = struct
723       module Wrapped = Wrapped
724       type ('x,'a) m = ('x,'a error) Wrapped.m
725       type ('x,'a) result = ('x,'a error) Wrapped.result
726       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
727       (* code repetition, ugh *)
728       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
729       let bind u f = Wrapped.bind u (fun t -> match t with
730         | Success a -> f a
731         | Error e -> Wrapped.unit (Error e))
732       let run u = Wrapped.run u
733       let run_exn u =
734         let w = Wrapped.bind u (fun t -> match t with
735           | Success a -> Wrapped.unit a
736           | Error e -> raise (Err.Exc e))
737         in Wrapped.run_exn w
738     end
739     include Monad.MakeT2(Trans)
740     let throw e = Wrapped.unit (Error e)
741     let catch u handler = Wrapped.bind u (fun t -> match t with
742       | Success _ -> Wrapped.unit t
743       | Error e -> handler e)
744   end
745   module TP2(Wrapped : Monad.P2) = struct
746     (* code repetition, ugh *)
747     module TransP = struct
748       include Monad.MakeT2(struct
749         module Wrapped = Wrapped
750         type ('x,'a) m = ('x,'a error) Wrapped.m
751         type ('x,'a) result = ('x,'a) Wrapped.result
752         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
753         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
754         let bind u f = Wrapped.bind u (fun t -> match t with
755           | Success a -> f a
756           | Error e -> Wrapped.unit (Error e))
757         let run u =
758           let w = Wrapped.bind u (fun t -> match t with
759             | Success a -> Wrapped.unit a
760             | Error e -> Wrapped.zero ())
761           in Wrapped.run w
762         let run_exn u =
763           let w = Wrapped.bind u (fun t -> match t with
764             | Success a -> Wrapped.unit a
765             | Error e -> raise (Err.Exc e))
766           in Wrapped.run_exn w
767       end)
768       let throw e = Wrapped.unit (Error e)
769       let catch u handler = Wrapped.bind u (fun t -> match t with
770         | Success _ -> Wrapped.unit t
771         | Error e -> handler e)
772       let plus u v = Wrapped.plus u v
773       let zero () = elevate (Wrapped.zero ())
774     end
775     include TransP
776     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
777   end
778 end
779
780 (* pre-define common instance of Error_monad *)
781 module Failure = Error_monad(struct
782   type err = string
783   exception Exc = Failure
784   (*
785   let zero = ""
786   let plus s1 s2 = s1 ^ "\n" ^ s2
787   *)
788 end)
789
790 (*
791 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
792 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
793 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
794 - : int LE.result = Failure.Error "bye"
795 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
796 Exception: Failure "bye".
797 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
798 Exception: Failure "bye".
799
800 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
801 - : int Failure.error * S.store = (Failure.Error "bye", 1)
802 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
803 - : (int * S.store) Failure.result = Failure.Error "bye"
804 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
805 Exception: Failure "bye".
806 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
807 Exception: Failure "bye".
808  *)
809
810
811 (* must be parameterized on (struct type env = ... end) *)
812 module Reader_monad(Env : sig type env end) : sig
813   (* declare additional operations, while still hiding implementation of type m *)
814   type env = Env.env
815   type 'a result = env -> 'a
816   type 'a result_exn = env -> 'a
817   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
818   val ask : env m
819   val asks : (env -> 'a) -> 'a m
820   val local : (env -> env) -> 'a m -> 'a m
821   (* ReaderT transformer *)
822   module T : functor (Wrapped : Monad.S) -> sig
823     type 'a result = env -> 'a Wrapped.result
824     type 'a result_exn = env -> 'a Wrapped.result_exn
825     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
826     val elevate : 'a Wrapped.m -> 'a m
827     val ask : env m
828     val asks : (env -> 'a) -> 'a m
829     val local : (env -> env) -> 'a m -> 'a m
830   end
831   (* ReaderT transformer when wrapped monad has plus, zero *)
832   module TP : functor (Wrapped : Monad.P) -> sig
833     include module type of T(Wrapped)
834     include Monad.PLUS with type 'a m := 'a m
835   end
836   module T2 : functor (Wrapped : Monad.S2) -> sig
837     type ('x,'a) result = env -> ('x,'a) Wrapped.result
838     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
839     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
840     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
841     val ask : ('x,env) m
842     val asks : (env -> 'a) -> ('x,'a) m
843     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
844   end
845   module TP2 : functor (Wrapped : Monad.P2) -> sig
846     include module type of T2(Wrapped)
847     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
848   end
849 end = struct
850   type env = Env.env
851   module Base = struct
852     type 'a m = env -> 'a
853     type 'a result = env -> 'a
854     type 'a result_exn = env -> 'a
855     let unit a = fun e -> a
856     let bind u f = fun e -> let a = u e in let u' = f a in u' e
857     let run u = fun e -> u e
858     let run_exn = run
859   end
860   include Monad.Make(Base)
861   let ask = fun e -> e
862   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
863   let local modifier u = fun e -> u (modifier e)
864   module T(Wrapped : Monad.S) = struct
865     module Trans = struct
866       module Wrapped = Wrapped
867       type 'a m = env -> 'a Wrapped.m
868       type 'a result = env -> 'a Wrapped.result
869       type 'a result_exn = env -> 'a Wrapped.result_exn
870       let elevate w = fun e -> w
871       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
872       let run u = fun e -> Wrapped.run (u e)
873       let run_exn u = fun e -> Wrapped.run_exn (u e)
874     end
875     include Monad.MakeT(Trans)
876     let ask = fun e -> Wrapped.unit e
877     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
878     let local modifier u = fun e -> u (modifier e)
879   end
880   module TP(Wrapped : Monad.P) = struct
881     module TransP = struct
882       include T(Wrapped)
883       let plus u v = fun s -> Wrapped.plus (u s) (v s)
884       let zero () = elevate (Wrapped.zero ())
885       let asks selector = ask >>= (fun e ->
886         try unit (selector e)
887         with Not_found -> fun e -> Wrapped.zero ())
888     end
889     include TransP
890     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
891   end
892   module T2(Wrapped : Monad.S2) = struct
893     module Trans = struct
894       module Wrapped = Wrapped
895       type ('x,'a) m = env -> ('x,'a) Wrapped.m
896       type ('x,'a) result = env -> ('x,'a) Wrapped.result
897       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
898       (* code repetition, ugh *)
899       let elevate w = fun e -> w
900       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
901       let run u = fun e -> Wrapped.run (u e)
902       let run_exn u = fun e -> Wrapped.run_exn (u e)
903     end
904     include Monad.MakeT2(Trans)
905     let ask = fun e -> Wrapped.unit e
906     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
907     let local modifier u = fun e -> u (modifier e)
908   end
909   module TP2(Wrapped : Monad.P2) = struct
910     module TransP = struct
911       include T2(Wrapped)
912       (* code repetition, ugh *)
913       let plus u v = fun s -> Wrapped.plus (u s) (v s)
914       let zero () = elevate (Wrapped.zero ())
915       let asks selector = ask >>= (fun e ->
916         try unit (selector e)
917         with Not_found -> fun e -> Wrapped.zero ())
918     end
919     include TransP
920     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
921   end
922 end
923
924
925 (* must be parameterized on (struct type store = ... end) *)
926 module State_monad(Store : sig type store end) : sig
927   (* declare additional operations, while still hiding implementation of type m *)
928   type store = Store.store
929   type 'a result = store -> 'a * store
930   type 'a result_exn = store -> 'a
931   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
932   val get : store m
933   val gets : (store -> 'a) -> 'a m
934   val put : store -> unit m
935   val puts : (store -> store) -> unit m
936   (* StateT transformer *)
937   module T : functor (Wrapped : Monad.S) -> sig
938     type 'a result = store -> ('a * store) Wrapped.result
939     type 'a result_exn = store -> 'a Wrapped.result_exn
940     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
941     val elevate : 'a Wrapped.m -> 'a m
942     val get : store m
943     val gets : (store -> 'a) -> 'a m
944     val put : store -> unit m
945     val puts : (store -> store) -> unit m
946   end
947   (* StateT transformer when wrapped monad has plus, zero *)
948   module TP : functor (Wrapped : Monad.P) -> sig
949     include module type of T(Wrapped)
950     include Monad.PLUS with type 'a m := 'a m
951   end
952   module T2 : functor (Wrapped : Monad.S2) -> sig
953     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
954     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
955     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
956     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
957     val get : ('x,store) m
958     val gets : (store -> 'a) -> ('x,'a) m
959     val put : store -> ('x,unit) m
960     val puts : (store -> store) -> ('x,unit) m
961   end
962   module TP2 : functor (Wrapped : Monad.P2) -> sig
963     include module type of T2(Wrapped)
964     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
965   end
966 end = struct
967   type store = Store.store
968   module Base = struct
969     type 'a m = store -> 'a * store
970     type 'a result = store -> 'a * store
971     type 'a result_exn = store -> 'a
972     let unit a = fun s -> (a, s)
973     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
974     let run u = fun s -> (u s)
975     let run_exn u = fun s -> fst (u s)
976   end
977   include Monad.Make(Base)
978   let get = fun s -> (s, s)
979   let gets viewer = fun s -> (viewer s, s) (* may fail *)
980   let put s = fun _ -> ((), s)
981   let puts modifier = fun s -> ((), modifier s)
982   module T(Wrapped : Monad.S) = struct
983     module Trans = struct
984       module Wrapped = Wrapped
985       type 'a m = store -> ('a * store) Wrapped.m
986       type 'a result = store -> ('a * store) Wrapped.result
987       type 'a result_exn = store -> 'a Wrapped.result_exn
988       let elevate w = fun s ->
989         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
990       let bind u f = fun s ->
991         Wrapped.bind (u s) (fun (a, s') -> f a s')
992       let run u = fun s -> Wrapped.run (u s)
993       let run_exn u = fun s ->
994         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
995         in Wrapped.run_exn w
996     end
997     include Monad.MakeT(Trans)
998     let get = fun s -> Wrapped.unit (s, s)
999     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
1000     let put s = fun _ -> Wrapped.unit ((), s)
1001     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
1002   end
1003   module TP(Wrapped : Monad.P) = struct
1004     module TransP = struct
1005       include T(Wrapped)
1006       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1007       let zero () = elevate (Wrapped.zero ())
1008     end
1009     let gets viewer = fun s ->
1010       try Wrapped.unit (viewer s, s)
1011       with Not_found -> Wrapped.zero ()
1012     include TransP
1013     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1014   end
1015   module T2(Wrapped : Monad.S2) = struct
1016     module Trans = struct
1017       module Wrapped = Wrapped
1018       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
1019       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
1020       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
1021       (* code repetition, ugh *)
1022       let elevate w = fun s ->
1023         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1024       let bind u f = fun s ->
1025         Wrapped.bind (u s) (fun (a, s') -> f a s')
1026       let run u = fun s -> Wrapped.run (u s)
1027       let run_exn u = fun s ->
1028         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
1029         in Wrapped.run_exn w
1030     end
1031     include Monad.MakeT2(Trans)
1032     let get = fun s -> Wrapped.unit (s, s)
1033     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
1034     let put s = fun _ -> Wrapped.unit ((), s)
1035     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
1036   end
1037   module TP2(Wrapped : Monad.P2) = struct
1038     module TransP = struct
1039       include T2(Wrapped)
1040       (* code repetition, ugh *)
1041       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1042       let zero () = elevate (Wrapped.zero ())
1043     end
1044     let gets viewer = fun s ->
1045       try Wrapped.unit (viewer s, s)
1046       with Not_found -> Wrapped.zero ()
1047     include TransP
1048     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1049   end
1050 end
1051
1052 (* State monad with different interface (structured store) *)
1053 module Ref_monad(V : sig
1054   type value
1055 end) : sig
1056   type ref
1057   type value = V.value
1058   type 'a result = 'a
1059   type 'a result_exn = 'a
1060   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1061   val newref : value -> ref m
1062   val deref : ref -> value m
1063   val change : ref -> value -> unit m
1064   (* RefT transformer *)
1065   module T : functor (Wrapped : Monad.S) -> sig
1066     type 'a result = 'a Wrapped.result
1067     type 'a result_exn = 'a Wrapped.result_exn
1068     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1069     val elevate : 'a Wrapped.m -> 'a m
1070     val newref : value -> ref m
1071     val deref : ref -> value m
1072     val change : ref -> value -> unit m
1073   end
1074   (* RefT transformer when wrapped monad has plus, zero *)
1075   module TP : functor (Wrapped : Monad.P) -> sig
1076     include module type of T(Wrapped)
1077     include Monad.PLUS with type 'a m := 'a m
1078   end
1079   module T2 : functor (Wrapped : Monad.S2) -> sig
1080     type ('x,'a) result = ('x,'a) Wrapped.result
1081     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1082     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1083     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1084     val newref : value -> ('x,ref) m
1085     val deref : ref -> ('x,value) m
1086     val change : ref -> value -> ('x,unit) m
1087   end
1088   module TP2 : functor (Wrapped : Monad.P2) -> sig
1089     include module type of T2(Wrapped)
1090     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1091   end
1092 end = struct
1093   type ref = int
1094   type value = V.value
1095   module D = Map.Make(struct type t = ref let compare = compare end)
1096   type dict = { next: ref; tree : value D.t }
1097   let empty = { next = 0; tree = D.empty }
1098   let alloc (value : value) (d : dict) =
1099     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
1100   let read (key : ref) (d : dict) =
1101     D.find key d.tree
1102   let write (key : ref) (value : value) (d : dict) =
1103     { next = d.next; tree = D.add key value d.tree }
1104   module Base = struct
1105     type 'a m = dict -> 'a * dict
1106     type 'a result = 'a
1107     type 'a result_exn = 'a
1108     let unit a = fun s -> (a, s)
1109     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
1110     let run u = fst (u empty)
1111     let run_exn = run
1112   end
1113   include Monad.Make(Base)
1114   let newref value = fun s -> alloc value s
1115   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
1116   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
1117   module T(Wrapped : Monad.S) = struct
1118     module Trans = struct
1119       module Wrapped = Wrapped
1120       type 'a m = dict -> ('a * dict) Wrapped.m
1121       type 'a result = 'a Wrapped.result
1122       type 'a result_exn = 'a Wrapped.result_exn
1123       let elevate w = fun s ->
1124         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1125       let bind u f = fun s ->
1126         Wrapped.bind (u s) (fun (a, s') -> f a s')
1127       let run u =
1128         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1129         in Wrapped.run w
1130       let run_exn u =
1131         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1132         in Wrapped.run_exn w
1133     end
1134     include Monad.MakeT(Trans)
1135     let newref value = fun s -> Wrapped.unit (alloc value s)
1136     let deref key = fun s -> Wrapped.unit (read key s, s)
1137     let change key value = fun s -> Wrapped.unit ((), write key value s)
1138   end
1139   module TP(Wrapped : Monad.P) = struct
1140     module TransP = struct
1141       include T(Wrapped)
1142       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1143       let zero () = elevate (Wrapped.zero ())
1144     end
1145     include TransP
1146     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1147   end
1148   module T2(Wrapped : Monad.S2) = struct
1149     module Trans = struct
1150       module Wrapped = Wrapped
1151       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
1152       type ('x,'a) result = ('x,'a) Wrapped.result
1153       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1154       (* code repetition, ugh *)
1155       let elevate w = fun s ->
1156         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1157       let bind u f = fun s ->
1158         Wrapped.bind (u s) (fun (a, s') -> f a s')
1159       let run u =
1160         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1161         in Wrapped.run w
1162       let run_exn u =
1163         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1164         in Wrapped.run_exn w
1165     end
1166     include Monad.MakeT2(Trans)
1167     let newref value = fun s -> Wrapped.unit (alloc value s)
1168     let deref key = fun s -> Wrapped.unit (read key s, s)
1169     let change key value = fun s -> Wrapped.unit ((), write key value s)
1170   end
1171   module TP2(Wrapped : Monad.P2) = struct
1172     module TransP = struct
1173       include T2(Wrapped)
1174       (* code repetition, ugh *)
1175       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1176       let zero () = elevate (Wrapped.zero ())
1177     end
1178     include TransP
1179     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1180   end
1181 end
1182
1183
1184 (* must be parameterized on (struct type log = ... end) *)
1185 module Writer_monad(Log : sig
1186   type log
1187   val zero : log
1188   val plus : log -> log -> log
1189 end) : sig
1190   (* declare additional operations, while still hiding implementation of type m *)
1191   type log = Log.log
1192   type 'a result = 'a * log
1193   type 'a result_exn = 'a * log
1194   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1195   val tell : log -> unit m
1196   val listen : 'a m -> ('a * log) m
1197   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
1198   (* val pass : ('a * (log -> log)) m -> 'a m *)
1199   val censor : (log -> log) -> 'a m -> 'a m
1200 end = struct
1201   type log = Log.log
1202   module Base = struct
1203     type 'a m = 'a * log
1204     type 'a result = 'a * log
1205     type 'a result_exn = 'a * log
1206     let unit a = (a, Log.zero)
1207     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
1208     let run u = u
1209     let run_exn = run
1210   end
1211   include Monad.Make(Base)
1212   let tell entries = ((), entries) (* add entries to log *)
1213   let listen (a, w) = ((a, w), w)
1214   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
1215   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
1216   let censor f u = pass (u >>= fun a -> unit (a, f))
1217 end
1218
1219 (* pre-define simple Writer *)
1220 module Writer1 = Writer_monad(struct
1221   type log = string
1222   let zero = ""
1223   let plus s1 s2 = s1 ^ "\n" ^ s2
1224 end)
1225
1226 (* slightly more efficient Writer *)
1227 module Writer2 = struct
1228   include Writer_monad(struct
1229     type log = string list
1230     let zero = []
1231     let plus w w' = Util.append w' w
1232   end)
1233   let tell_string s = tell [s]
1234   let tell entries = tell (Util.reverse entries)
1235   let run u = let (a, w) = run u in (a, Util.reverse w)
1236   let run_exn = run
1237 end
1238
1239
1240 module IO_monad : sig
1241   (* declare additional operation, while still hiding implementation of type m *)
1242   type 'a result = 'a
1243   type 'a result_exn = 'a
1244   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1245   val printf : ('a, unit, string, unit m) format4 -> 'a
1246   val print_string : string -> unit m
1247   val print_int : int -> unit m
1248   val print_hex : int -> unit m
1249   val print_bool : bool -> unit m
1250 end = struct
1251   module Base = struct
1252     type 'a m = { run : unit -> unit; value : 'a }
1253     type 'a result = 'a
1254     type 'a result_exn = 'a
1255     let unit a = { run = (fun () -> ()); value = a }
1256     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
1257      let fres = f a.value in
1258        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
1259     let run a = let () = a.run () in a.value
1260     let run_exn = run
1261   end
1262   include Monad.Make(Base)
1263   let printf fmt =
1264     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
1265   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
1266   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
1267   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
1268   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
1269 end
1270
1271 (*
1272 module Continuation_monad : sig
1273   (* expose only the implementation of type `('r,'a) result` *)
1274   type 'a m
1275   type 'a result = 'a m
1276   type 'a result_exn = 'a m
1277   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
1278   (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
1279   (* misses that the answer types of all the cont's must be the same *)
1280   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
1281   (* val reset : ('a,'a) m -> ('r,'a) m *)
1282   val reset : 'a m -> 'a m
1283   (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
1284   (* misses that the answer types of second and third continuations must be b *)
1285   val shift : (('a -> 'b m) -> 'b m) -> 'a m
1286   (* overwrite the run declaration in S, because I can't declare 'a result =
1287    * this polymorphic type (complains that 'r is unbound *)
1288   val runk : 'a m -> ('a -> 'r) -> 'r
1289   val run0 : 'a m -> 'a
1290 end = struct
1291   let id = fun i -> i
1292   module Base = struct
1293     (* 'r is result type of whole computation *)
1294     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
1295     type 'a result = 'a m
1296     type 'a result_exn = 'a m
1297     let unit a =
1298       let cont : 'r. ('a -> 'r) -> 'r =
1299         fun k -> k a
1300       in { cont }
1301     let bind u f =
1302       let cont : 'r. ('a -> 'r) -> 'r =
1303         fun k -> u.cont (fun a -> (f a).cont k)
1304       in { cont }
1305     let run (u : 'a m) : 'a result = u
1306     let run_exn (u : 'a m) : 'a result_exn = u
1307     let callcc f =
1308       let cont : 'r. ('a -> 'r) -> 'r =
1309           (* Can't figure out how to make the type polymorphic enough
1310            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
1311            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
1312            * with Obj.magic... which tells OCaml's type checker to
1313            * relax, the supplied value has whatever type the context
1314            * needs it to have. *)
1315           fun k ->
1316           let usek a = { cont = Obj.magic (fun _ -> k a) }
1317           in (f usek).cont k
1318       in { cont }
1319     let reset u = unit (u.cont id)
1320     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
1321       let cont = fun k ->
1322         (f (fun a -> unit (k a))).cont id
1323       in { cont = Obj.magic cont }
1324     let runk u k = (u.cont : ('a -> 'r) -> 'r) k
1325     let run0 u = runk u id
1326   end
1327   include Monad.Make(Base)
1328   let callcc = Base.callcc
1329   let reset = Base.reset
1330   let shift = Base.shift
1331   let runk = Base.runk
1332   let run0 = Base.run0
1333 end
1334  *)
1335
1336 (* This two-type parameter version works without Obj.magic *)
1337 module Continuation_monad : sig
1338   (* expose only the implementation of type `('r,'a) result` *)
1339   type ('r,'a) m
1340   type ('r,'a) result = ('r,'a) m
1341   type ('r,'a) result_exn = ('a -> 'r) -> 'r
1342   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
1343   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
1344   val reset : ('a,'a) m -> ('r,'a) m
1345   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
1346   (* val abort : ('a,'a) m -> ('a,'b) m *)
1347   val abort : 'a -> ('a,'b) m
1348   val run0 : ('a,'a) m -> 'a
1349 end = struct
1350   let id = fun i -> i
1351   module Base = struct
1352     (* 'r is result type of whole computation *)
1353     type ('r,'a) m = ('a -> 'r) -> 'r
1354     type ('r,'a) result = ('a -> 'r) -> 'r
1355     type ('r,'a) result_exn = ('r,'a) result
1356     let unit a = (fun k -> k a)
1357     let bind u f = (fun k -> (u) (fun a -> (f a) k))
1358     let run u k = (u) k
1359     let run_exn = run
1360   end
1361   include Monad.Make2(Base)
1362   let callcc f = (fun k ->
1363     let usek a = (fun _ -> k a)
1364     in (f usek) k)
1365   (*
1366   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
1367   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
1368   let callcc f = fun k -> f k k
1369   let throw k a = fun _ -> k a
1370   *)
1371
1372   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
1373    *
1374    *  reset :: (Monad m) => ContT a m a -> ContT r m a
1375    *  reset e = ContT $ \k -> runContT e return >>= k
1376    *
1377    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
1378    *  shift e = ContT $ \k ->
1379    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
1380   let reset u = unit ((u) id)
1381   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
1382   (* let abort a = shift (fun _ -> a) *)
1383   let abort a = shift (fun _ -> unit a)
1384   let run0 (u : ('a,'a) m) = (u) id
1385 end
1386
1387
1388 (*
1389  * Scheme:
1390  * (define (example n)
1391  *    (let ([u (let/cc k ; type int -> int pair
1392  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
1393  *                 (+ 1 (car v))))]) ; int
1394  *      (cons u 0))) ; int pair
1395  * ; (example 10) ~~> '(111 . 0)
1396  * ; (example -10) ~~> '(0 . 0)
1397  *
1398  * OCaml monads:
1399  * let example n : (int * int) =
1400  *   Continuation_monad.(let u = callcc (fun k ->
1401  *       (if n < 0 then k 0 else unit [n + 100])
1402  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
1403  *       >>= fun [x] -> unit (x + 1)
1404  *   )
1405  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1406  *   >>= fun x -> unit (x, 0)
1407  *   in run u)
1408  *
1409  *
1410  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1411  * let example1 () : int =
1412  *   Continuation_monad.(let v = reset (
1413  *       let u = shift (fun k -> unit (10 + 1))
1414  *       in u >>= fun x -> unit (100 + x)
1415  *     ) in let w = v >>= fun x -> unit (1000 + x)
1416  *     in run w)
1417  *
1418  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1419  * let example2 () =
1420  *   Continuation_monad.(let v = reset (
1421  *       let u = shift (fun k -> k (10 :: [1]))
1422  *       in u >>= fun x -> unit (100 :: x)
1423  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1424  *     in run w)
1425  *
1426  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1427  * let example3 () =
1428  *   Continuation_monad.(let v = reset (
1429  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1430  *       in u >>= fun x -> unit (100 :: x)
1431  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1432  *     in run w)
1433  *
1434  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1435  * (* not sure if this example can be typed without a sum-type *)
1436  *
1437  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1438  * let example5 () : int =
1439  *   Continuation_monad.(let v = reset (
1440  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1441  *       in u >>= fun x -> unit (10 + x)
1442  *     ) in let w = v >>= fun x -> unit (100 + x)
1443  *     in run w)
1444  *
1445  *)
1446
1447
1448 module Leaf_monad : sig
1449   (* We implement the type as `'a tree option` because it has a natural`plus`,
1450    * and the rest of the library expects that `plus` and `zero` will come together. *)
1451   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1452   type 'a result = 'a tree option
1453   type 'a result_exn = 'a tree
1454   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1455   include Monad.PLUS with type 'a m := 'a m
1456   (* LeafT transformer *)
1457   module T : functor (Wrapped : Monad.S) -> sig
1458     type 'a result = 'a tree option Wrapped.result
1459     type 'a result_exn = 'a tree Wrapped.result_exn
1460     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1461     include Monad.PLUS with type 'a m := 'a m
1462     val elevate : 'a Wrapped.m -> 'a m
1463     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1464     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1465     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1466   end
1467   module T2 : functor (Wrapped : Monad.S2) -> sig
1468     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1469     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1470     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1471     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1472     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1473     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1474   end
1475 end = struct
1476   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1477   (* uses supplied plus and zero to copy t to its image under f *)
1478   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1479       | None -> zero ()
1480       | Some ts -> let rec loop ts = (match ts with
1481                      | Leaf a -> f a
1482                      | Node (l, r) ->
1483                          (* recursive application of f may delete a branch *)
1484                          plus (loop l) (loop r)
1485                    ) in loop ts
1486   module Base = struct
1487     type 'a m = 'a tree option
1488     type 'a result = 'a tree option
1489     type 'a result_exn = 'a tree
1490     let unit a = Some (Leaf a)
1491     let zero () = None
1492     let plus u v = match (u, v) with
1493       | None, _ -> v
1494       | _, None -> u
1495       | Some us, Some vs -> Some (Node (us, vs))
1496     let bind u f = mapT f u zero plus
1497     let run u = u
1498     let run_exn u = match u with
1499       | None -> failwith "no values"
1500       (*
1501       | Some (Leaf a) -> a
1502       | many -> failwith "multiple values"
1503       *)
1504       | Some us -> us
1505   end
1506   include Monad.Make(Base)
1507   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1508   let base_plus = plus
1509   let base_lift = lift
1510   module T(Wrapped : Monad.S) = struct
1511     module Trans = struct
1512       let zero () = Wrapped.unit None
1513       let plus u v =
1514         Wrapped.bind u (fun us ->
1515         Wrapped.bind v (fun vs ->
1516         Wrapped.unit (base_plus us vs)))
1517       include Monad.MakeT(struct
1518         module Wrapped = Wrapped
1519         type 'a m = 'a tree option Wrapped.m
1520         type 'a result = 'a tree option Wrapped.result
1521         type 'a result_exn = 'a tree Wrapped.result_exn
1522         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1523         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1524         let run u = Wrapped.run u
1525         let run_exn u =
1526             let w = Wrapped.bind u (fun t -> match t with
1527               | None -> failwith "no values"
1528               | Some ts -> Wrapped.unit ts)
1529             in Wrapped.run_exn w
1530       end)
1531     end
1532     include Trans
1533     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1534     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1535     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1536   end
1537   module T2(Wrapped : Monad.S2) = struct
1538     module Trans = struct
1539       let zero () = Wrapped.unit None
1540       let plus u v =
1541         Wrapped.bind u (fun us ->
1542         Wrapped.bind v (fun vs ->
1543         Wrapped.unit (base_plus us vs)))
1544       include Monad.MakeT2(struct
1545         module Wrapped = Wrapped
1546         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1547         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1548         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1549         (* code repetition, ugh *)
1550         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1551         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1552         let run u = Wrapped.run u
1553         let run_exn u =
1554             let w = Wrapped.bind u (fun t -> match t with
1555               | None -> failwith "no values"
1556               | Some ts -> Wrapped.unit ts)
1557             in Wrapped.run_exn w
1558       end)
1559     end
1560     include Trans
1561     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1562     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1563   end
1564 end
1565
1566
1567 module L = List_monad;;
1568 module R = Reader_monad(struct type env = int -> int end);;
1569 module S = State_monad(struct type store = int end);;
1570 module T = Leaf_monad;;
1571 module LR = L.T(R);;
1572 module LS = L.T(S);;
1573 module TL = T.T(L);;
1574 module TR = T.T(R);;
1575 module TS = T.T(S);;
1576 module C = Continuation_monad
1577 module TC = T.T2(C);;
1578
1579
1580 print_endline "=== test Leaf(...).distribute ==================";;
1581
1582 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1583
1584 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1585 TS.run ts 0;;
1586 (*
1587 - : int T.tree option * S.store =
1588 (Some
1589   (T.Node
1590     (T.Node (T.Leaf 2, T.Leaf 3),
1591      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1592  5)
1593 *)
1594
1595 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1596 TS.run_exn ts2 0;;
1597 (*
1598 - : (int * S.store) T.tree option * S.store =
1599 (Some
1600   (T.Node
1601     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1602      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1603  5)
1604 *)
1605
1606 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1607 TR.run_exn tr (fun i -> i+i);;
1608 (*
1609 - : int T.tree option =
1610 Some
1611  (T.Node
1612    (T.Node (T.Leaf 4, T.Leaf 6),
1613     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1614 *)
1615
1616 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1617 TL.run_exn tl;;
1618 (*
1619 - : (int * int) TL.result =
1620 [Some
1621   (T.Node
1622     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1623      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1624 *)
1625
1626 let l2 = [1;2;3;4;5];;
1627 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1628
1629 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1630 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1631
1632 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1633 (*
1634 int T.tree option =
1635 Some
1636  (T.Node
1637    (T.Node (T.Leaf 10, T.Leaf 11),
1638     T.Node
1639      (T.Node
1640        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1641         T.Node (T.Leaf 40, T.Leaf 41)),
1642       T.Node (T.Leaf 50, T.Leaf 51))))
1643  *)
1644
1645 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1646 (*
1647 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1648 *)
1649
1650 print_endline "=== test Leaf(Continuation).distribute ==================";;
1651
1652 let id : 'z. 'z -> 'z = fun x -> x
1653
1654 let example n : (int * int) =
1655   Continuation_monad.(let u = callcc (fun k ->
1656       (if n < 0 then k 0 else unit [n + 100])
1657       (* all of the following is skipped by k 0; the end type int is k's input type *)
1658       >>= fun [x] -> unit (x + 1)
1659   )
1660   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1661   >>= fun x -> unit (x, 0)
1662   in run0 u)
1663
1664
1665 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1666 let example1 () : int =
1667   Continuation_monad.(let v = reset (
1668       let u = shift (fun k -> unit (10 + 1))
1669       in u >>= fun x -> unit (100 + x)
1670     ) in let w = v >>= fun x -> unit (1000 + x)
1671     in run0 w)
1672
1673 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1674 let example2 () =
1675   Continuation_monad.(let v = reset (
1676       let u = shift (fun k -> k (10 :: [1]))
1677       in u >>= fun x -> unit (100 :: x)
1678     ) in let w = v >>= fun x -> unit (1000 :: x)
1679     in run0 w)
1680
1681 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1682 let example3 () =
1683   Continuation_monad.(let v = reset (
1684       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1685       in u >>= fun x -> unit (100 :: x)
1686     ) in let w = v >>= fun x -> unit (1000 :: x)
1687     in run0 w)
1688
1689 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1690 (* not sure if this example can be typed without a sum-type *)
1691
1692 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1693 let example5 () : int =
1694   Continuation_monad.(let v = reset (
1695       let u = shift (fun k -> k 1 >>= k)
1696       in u >>= fun x -> unit (10 + x)
1697     ) in let w = v >>= fun x -> unit (100 + x)
1698     in run0 w)
1699
1700 ;;
1701
1702 print_endline "=== test bare Continuation ============";;
1703
1704 (1011, 1111, 1111, 121);;
1705 (example1(), example2(), example3(), example5());;
1706 ((111,0), (0,0));;
1707 (example ~+10, example ~-10);;
1708
1709 let testc df ic =
1710     C.run_exn TC.(run (distribute df t1)) ic;;
1711
1712
1713 (*
1714 (* do nothing *)
1715 let initial_continuation = fun t -> t in
1716 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1717 *)
1718 testc (C.unit) id;;
1719
1720 (*
1721 (* count leaves, using continuation *)
1722 let initial_continuation = fun t -> 0 in
1723 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1724 *)
1725
1726 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1727
1728 (*
1729 (* convert tree to list of leaves *)
1730 let initial_continuation = fun t -> [] in
1731 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1732 *)
1733
1734 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1735
1736 (*
1737 (* square each leaf using continuation *)
1738 let initial_continuation = fun t -> t in
1739 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1740 *)
1741
1742 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1743
1744
1745 (*
1746 (* replace leaves with list, using continuation *)
1747 let initial_continuation = fun t -> t in
1748 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1749 *)
1750
1751 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1752
1753 print_endline "=== pa_monad's Continuation Tests ============";;
1754
1755 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1756 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1757 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1758 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1759 (5, 27 = C.(run0 (
1760               let c = reset(abort 5 >>= fun y -> unit (y+6))
1761               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1762
1763 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1764
1765 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1766
1767 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1768
1769
1770 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1771