4db605c89987577377754af50b3a3632bc85c2c7
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73   (*
74    * Signature extenders:
75    *   Make :: BASE -> S
76    *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
77    *                                           which merges into S
78    *                                           (P is merged sig)
79    *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
80    *
81    *   Make2 :: BASE2 -> S2
82    *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
83    *   to wrap double-typed inner monads:
84    *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
85    *
86    *)
87
88
89
90   (* type of base definitions *)
91   module type BASE = sig
92     (* The only constraints we impose here on how the monadic type
93      * is implemented is that it have a single type parameter 'a. *)
94     type 'a m
95     type 'a result
96     type 'a result_exn
97     val unit : 'a -> 'a m
98     val bind : 'a m -> ('a -> 'b m) -> 'b m
99     val run : 'a m -> 'a result
100     (* run_exn tries to provide a more ground-level result, but may fail *)
101     val run_exn : 'a m -> 'a result_exn
102   end
103   module type S = sig
104     include BASE
105     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
106     val (>>) : 'a m -> 'b m -> 'b m
107     val join : ('a m) m -> 'a m
108     val apply : ('a -> 'b) m -> 'a m -> 'b m
109     val lift : ('a -> 'b) -> 'a m -> 'b m
110     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
111     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
112     val do_when :  bool -> unit m -> unit m
113     val do_unless :  bool -> unit m -> unit m
114     val forever : (unit -> 'a m) -> 'b m
115     val sequence : 'a m list -> 'a list m
116     val sequence_ : 'a m list -> unit m
117   end
118
119   (* Standard, single-type-parameter monads. *)
120   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
121     include B
122     let (>>=) = bind
123     let (>>) u v = u >>= fun _ -> v
124     let lift f u = u >>= fun a -> unit (f a)
125     (* lift is called listM, fmap, and <$> in Haskell *)
126     let join uu = uu >>= fun u -> u
127     (* u >>= f === join (lift f u) *)
128     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
129     (* [f] <*> [x1,x2] = [f x1,f x2] *)
130     (* let apply u v = u >>= fun f -> lift f v *)
131     (* let apply = lift2 id *)
132     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
133     (* let lift f u === apply (unit f) u *)
134     (* let lift2 f u v = apply (lift f u) v *)
135     let (>=>) f g = fun a -> f a >>= g
136     let do_when test u = if test then u else unit ()
137     let do_unless test u = if test then unit () else u
138     let forever uthunk =
139         let rec loop () = uthunk () >>= fun _ -> loop ()
140         in loop ()
141     let sequence ms =
142       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
143         Util.fold_right op ms (unit [])
144     let sequence_ ms =
145       Util.fold_right (>>) ms (unit ())
146
147     (* Haskell defines these other operations combining lists and monads.
148      * We don't, but notice that M.mapM == ListT(M).distribute
149      * There's also a parallel TreeT(M).distribute *)
150     (*
151     let mapM f alist = sequence (Util.map f alist)
152     let mapM_ f alist = sequence_ (Util.map f alist)
153     let rec filterM f lst = match lst with
154       | [] -> unit []
155       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
156     let forM alist f = mapM f alist
157     let forM_ alist f = mapM_ f alist
158     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
159     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
160     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
161     let rec foldM f z lst = match lst with
162       | [] -> unit z
163       | x::xs -> f z x >>= fun z' -> foldM f z' xs
164     let foldM_ f z xs = foldM f z xs >> unit ()
165     let replicateM n x = sequence (Util.replicate n x)
166     let replicateM_ n x = sequence_ (Util.replicate n x)
167     *)
168   end
169
170   (* Single-type-parameter monads that also define `plus` and `zero`
171    * operations. These obey the following laws:
172    *     zero >>= f   ===  zero
173    *     plus zero u  ===  u
174    *     plus u zero  ===  u
175    * Additionally, these monads will obey one of the following laws:
176    *     (Catch)   plus (unit a) v  ===  unit a
177    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
178    *)
179   module type PLUSBASE = sig
180     include BASE
181     val zero : unit -> 'a m
182     val plus : 'a m -> 'a m -> 'a m
183   end
184   module type PLUS = sig
185     type 'a m
186     val zero : unit -> 'a m
187     val plus : 'a m -> 'a m -> 'a m
188     val guard : bool -> unit m
189     val sum : 'a m list -> 'a m
190   end
191   (* MakeCatch and MakeDistrib have the same implementation; we just declare
192    * them twice to document which laws the client code is promising to honor. *)
193   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
194     type 'a m = 'a B.m
195     let zero = B.zero
196     let plus = B.plus
197     let guard test = if test then B.unit () else zero ()
198     let sum ms = Util.fold_right plus ms (zero ())
199   end
200   module MakeDistrib = MakeCatch
201
202   (* Signatures for MonadT *)
203   (* sig for Wrapped that include S and PLUS *)
204   module type P = sig
205     include S
206     include PLUS with type 'a m := 'a m
207   end
208   module type TRANS = sig
209     module Wrapped : S
210     type 'a m
211     type 'a result
212     type 'a result_exn
213     val bind : 'a m -> ('a -> 'b m) -> 'b m
214     val run : 'a m -> 'a result
215     val run_exn : 'a m -> 'a result_exn
216     val elevate : 'a Wrapped.m -> 'a m
217     (* lift/elevate laws:
218      *     elevate (W.unit a) == unit a
219      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
220      *)
221   end
222   module MakeT(T : TRANS) = struct
223     include Make(struct
224         include T
225         let unit a = elevate (Wrapped.unit a)
226     end)
227     let elevate = T.elevate
228   end
229
230
231   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
232   module type BASE2 = sig
233     type ('x,'a) m
234     type ('x,'a) result
235     type ('x,'a) result_exn
236     val unit : 'a -> ('x,'a) m
237     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
238     val run : ('x,'a) m -> ('x,'a) result
239     val run_exn : ('x,'a) m -> ('x,'a) result_exn
240   end
241   module type S2 = sig
242     include BASE2
243     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
244     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
245     val join : ('x,('x,'a) m) m -> ('x,'a) m
246     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
247     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
248     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
249     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
250     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
251     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
252     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
253     val sequence : ('x,'a) m list -> ('x,'a list) m
254     val sequence_ : ('x,'a) m list -> ('x,unit) m
255   end
256   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
257     (* code repetition, ugh *)
258     include B
259     let (>>=) = bind
260     let (>>) u v = u >>= fun _ -> v
261     let lift f u = u >>= fun a -> unit (f a)
262     let join uu = uu >>= fun u -> u
263     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
264     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
265     let (>=>) f g = fun a -> f a >>= g
266     let do_when test u = if test then u else unit ()
267     let do_unless test u = if test then unit () else u
268     let forever uthunk =
269         let rec loop () = uthunk () >>= fun _ -> loop ()
270         in loop ()
271     let sequence ms =
272       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
273         Util.fold_right op ms (unit [])
274     let sequence_ ms =
275       Util.fold_right (>>) ms (unit ())
276   end
277
278   module type PLUSBASE2 = sig
279     include BASE2
280     val zero : unit -> ('x,'a) m
281     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
282   end
283   module type PLUS2 = sig
284     type ('x,'a) m
285     val zero : unit -> ('x,'a) m
286     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
287     val guard : bool -> ('x,unit) m
288     val sum : ('x,'a) m list -> ('x,'a) m
289   end
290   module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
291     type ('x,'a) m = ('x,'a) B.m
292     (* code repetition, ugh *)
293     let zero = B.zero
294     let plus = B.plus
295     let guard test = if test then B.unit () else zero ()
296     let sum ms = Util.fold_right plus ms (zero ())
297   end
298   module MakeDistrib2 = MakeCatch2
299
300   (* Signatures for MonadT *)
301   (* sig for Wrapped that include S and PLUS *)
302   module type P2 = sig
303     include S2
304     include PLUS2 with type ('x,'a) m := ('x,'a) m
305   end
306   module type TRANS2 = sig
307     module Wrapped : S2
308     type ('x,'a) m
309     type ('x,'a) result
310     type ('x,'a) result_exn
311     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
312     val run : ('x,'a) m -> ('x,'a) result
313     val run_exn : ('x,'a) m -> ('x,'a) result_exn
314     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
315   end
316   module MakeT2(T : TRANS2) = struct
317     (* code repetition, ugh *)
318     include Make2(struct
319         include T
320         let unit a = elevate (Wrapped.unit a)
321     end)
322     let elevate = T.elevate
323   end
324
325 end
326
327
328
329
330
331 module Identity_monad : sig
332   (* expose only the implementation of type `'a result` *)
333   type 'a result = 'a
334   type 'a result_exn = 'a
335   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
336 end = struct
337   module Base = struct
338     type 'a m = 'a
339     type 'a result = 'a
340     type 'a result_exn = 'a
341     let unit a = a
342     let bind a f = f a
343     let run a = a
344     let run_exn a = a
345   end
346   include Monad.Make(Base)
347 end
348
349
350 module Maybe_monad : sig
351   (* expose only the implementation of type `'a result` *)
352   type 'a result = 'a option
353   type 'a result_exn = 'a
354   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
355   include Monad.PLUS with type 'a m := 'a m
356   (* MaybeT transformer *)
357   module T : functor (Wrapped : Monad.S) -> sig
358     type 'a result = 'a option Wrapped.result
359     type 'a result_exn = 'a Wrapped.result_exn
360     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
361     include Monad.PLUS with type 'a m := 'a m
362     val elevate : 'a Wrapped.m -> 'a m
363   end
364   module T2 : functor (Wrapped : Monad.S2) -> sig
365     type ('x,'a) result = ('x,'a option) Wrapped.result
366     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
367     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
368     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
369     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
370   end
371 end = struct
372   module Base = struct
373    type 'a m = 'a option
374    type 'a result = 'a option
375    type 'a result_exn = 'a
376    let unit a = Some a
377    let bind u f = match u with Some a -> f a | None -> None
378    let run u = u
379    let run_exn u = match u with
380      | Some a -> a
381      | None -> failwith "no value"
382    let zero () = None
383    let plus u v = match u with None -> v | _ -> u
384   end
385   include Monad.Make(Base)
386   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
387   module T(Wrapped : Monad.S) = struct
388     module Trans = struct
389       include Monad.MakeT(struct
390         module Wrapped = Wrapped
391         type 'a m = 'a option Wrapped.m
392         type 'a result = 'a option Wrapped.result
393         type 'a result_exn = 'a Wrapped.result_exn
394         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
395         let bind u f = Wrapped.bind u (fun t -> match t with
396           | Some a -> f a
397           | None -> Wrapped.unit None)
398         let run u = Wrapped.run u
399         let run_exn u =
400           let w = Wrapped.bind u (fun t -> match t with
401             | Some a -> Wrapped.unit a
402             | None -> failwith "no value")
403           in Wrapped.run_exn w
404       end)
405       let zero () = Wrapped.unit None
406       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
407     end
408     include Trans
409     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
410   end
411   module T2(Wrapped : Monad.S2) = struct
412     module Trans = struct
413       include Monad.MakeT2(struct
414         module Wrapped = Wrapped
415         type ('x,'a) m = ('x,'a option) Wrapped.m
416         type ('x,'a) result = ('x,'a option) Wrapped.result
417         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
418         (* code repetition, ugh *)
419         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
420         let bind u f = Wrapped.bind u (fun t -> match t with
421           | Some a -> f a
422           | None -> Wrapped.unit None)
423         let run u = Wrapped.run u
424         let run_exn u =
425           let w = Wrapped.bind u (fun t -> match t with
426             | Some a -> Wrapped.unit a
427             | None -> failwith "no value")
428           in Wrapped.run_exn w
429       end)
430       let zero () = Wrapped.unit None
431       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
432     end
433     include Trans
434     include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
435   end
436 end
437
438
439 module List_monad : sig
440   (* declare additional operation, while still hiding implementation of type m *)
441   type 'a result = 'a list
442   type 'a result_exn = 'a
443   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
444   include Monad.PLUS with type 'a m := 'a m
445   val permute : 'a m -> 'a m m
446   val select : 'a m -> ('a * 'a m) m
447   (* ListT transformer *)
448   module T : functor (Wrapped : Monad.S) -> sig
449     type 'a result = 'a list Wrapped.result
450     type 'a result_exn = 'a Wrapped.result_exn
451     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
452     include Monad.PLUS with type 'a m := 'a m
453     val elevate : 'a Wrapped.m -> 'a m
454     (* note that second argument is an 'a list, not the more abstract 'a m *)
455     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
456     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
457 (* TODO
458     val permute : 'a m -> 'a m m
459     val select : 'a m -> ('a * 'a m) m
460 *)
461   end
462   module T2 : functor (Wrapped : Monad.S2) -> sig
463     type ('x,'a) result = ('x,'a list) Wrapped.result
464     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
465     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
466     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
467     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
468     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
469   end
470 end = struct
471   module Base = struct
472    type 'a m = 'a list
473    type 'a result = 'a list
474    type 'a result_exn = 'a
475    let unit a = [a]
476    let bind u f = Util.concat_map f u
477    let run u = u
478    let run_exn u = match u with
479      | [] -> failwith "no values"
480      | [a] -> a
481      | many -> failwith "multiple values"
482    let zero () = []
483    let plus = Util.append
484   end
485   include Monad.Make(Base)
486   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
487   (* let either u v = plus u v *)
488   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
489   let rec insert a u =
490     plus (unit (a :: u)) (match u with
491         | [] -> zero ()
492         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
493     )
494   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
495   let rec permute u = match u with
496       | [] -> unit []
497       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
498   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
499   let rec select u = match u with
500     | [] -> zero ()
501     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
502   let base_plus = plus
503   module T(Wrapped : Monad.S) = struct
504     module Trans = struct
505       (* Wrapped.sequence ms  ===  
506            let plus1 u v =
507              Wrapped.bind u (fun x ->
508              Wrapped.bind v (fun xs ->
509              Wrapped.unit (x :: xs)))
510            in Util.fold_right plus1 ms (Wrapped.unit []) *)
511       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
512       let distribute f alist = Wrapped.sequence (Util.map f alist)
513       include Monad.MakeT(struct
514         module Wrapped = Wrapped
515         type 'a m = 'a list Wrapped.m
516         type 'a result = 'a list Wrapped.result
517         type 'a result_exn = 'a Wrapped.result_exn
518         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
519         let bind u f =
520           Wrapped.bind u (fun ts ->
521           Wrapped.bind (distribute f ts) (fun tts ->
522           Wrapped.unit (Util.concat tts)))
523         let run u = Wrapped.run u
524         let run_exn u =
525           let w = Wrapped.bind u (fun ts -> match ts with
526             | [] -> failwith "no values"
527             | [a] -> Wrapped.unit a
528             | many -> failwith "multiple values"
529           ) in Wrapped.run_exn w
530       end)
531       let zero () = Wrapped.unit []
532       let plus u v =
533         Wrapped.bind u (fun us ->
534         Wrapped.bind v (fun vs ->
535         Wrapped.unit (base_plus us vs)))
536     end
537     include Trans
538     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
539 (*
540     let permute : 'a m -> 'a m m
541     let select : 'a m -> ('a * 'a m) m
542 *)
543   end
544   module T2(Wrapped : Monad.S2) = struct
545     module Trans = struct
546       let distribute f alist = Wrapped.sequence (Util.map f alist)
547       include Monad.MakeT2(struct
548         module Wrapped = Wrapped
549         type ('x,'a) m = ('x,'a list) Wrapped.m
550         type ('x,'a) result = ('x,'a list) Wrapped.result
551         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
552         (* code repetition, ugh *)
553         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
554         let bind u f =
555           Wrapped.bind u (fun ts ->
556           Wrapped.bind (distribute f ts) (fun tts ->
557           Wrapped.unit (Util.concat tts)))
558         let run u = Wrapped.run u
559         let run_exn u =
560           let w = Wrapped.bind u (fun ts -> match ts with
561             | [] -> failwith "no values"
562             | [a] -> Wrapped.unit a
563             | many -> failwith "multiple values"
564           ) in Wrapped.run_exn w
565       end)
566       let zero () = Wrapped.unit []
567       let plus u v =
568         Wrapped.bind u (fun us ->
569         Wrapped.bind v (fun vs ->
570         Wrapped.unit (base_plus us vs)))
571     end
572     include Trans
573     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
574   end
575 end
576
577
578 (* must be parameterized on (struct type err = ... end) *)
579 module Error_monad(Err : sig
580   type err
581   exception Exc of err
582   (*
583   val zero : unit -> err
584   val plus : err -> err -> err
585   *)
586 end) : sig
587   (* declare additional operations, while still hiding implementation of type m *)
588   type err = Err.err
589   type 'a error = Error of err | Success of 'a
590   type 'a result = 'a error
591   type 'a result_exn = 'a
592   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
593   (* include Monad.PLUS with type 'a m := 'a m *)
594   val throw : err -> 'a m
595   val catch : 'a m -> (err -> 'a m) -> 'a m
596   (* ErrorT transformer *)
597   module T : functor (Wrapped : Monad.S) -> sig
598     type 'a result = 'a error Wrapped.result
599     type 'a result_exn = 'a Wrapped.result_exn
600     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
601     val elevate : 'a Wrapped.m -> 'a m
602     val throw : err -> 'a m
603     val catch : 'a m -> (err -> 'a m) -> 'a m
604   end
605   (* ErrorT transformer when wrapped monad has plus, zero *)
606   module TP : functor (Wrapped : Monad.P) -> sig
607     include module type of T(Wrapped)
608     include Monad.PLUS with type 'a m := 'a m
609   end
610   module T2 : functor (Wrapped : Monad.S2) -> sig
611     type ('x,'a) result = ('x,'a error) Wrapped.result
612     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
613     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
614     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
615     val throw : err -> ('x,'a) m
616     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
617   end
618   module TP2 : functor (Wrapped : Monad.P2) -> sig
619     include module type of T2(Wrapped)
620     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
621   end
622 end = struct
623   type err = Err.err
624   type 'a error = Error of err | Success of 'a
625   module Base = struct
626     type 'a m = 'a error
627     type 'a result = 'a error
628     type 'a result_exn = 'a
629     let unit a = Success a
630     let bind u f = match u with
631       | Success a -> f a
632       | Error e -> Error e (* input and output may be of different 'a types *)
633     let run u = u
634     let run_exn u = match u with
635       | Success a -> a
636       | Error e -> raise (Err.Exc e)
637     (*
638     let zero () = Error Err.zero
639     let plus u v = match (u, v) with
640       | Success _, _ -> u
641       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
642        * otherwise, plus (Error _) v = v *)
643       | Error _, _ when v = zero -> u
644       (* combine errors *)
645       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
646       | Error _, _ -> v
647     *)
648   end
649   include Monad.Make(Base)
650   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
651   let throw e = Error e
652   let catch u handler = match u with
653     | Success _ -> u
654     | Error e -> handler e
655   module T(Wrapped : Monad.S) = struct
656     module Trans = struct
657       module Wrapped = Wrapped
658       type 'a m = 'a error Wrapped.m
659       type 'a result = 'a error Wrapped.result
660       type 'a result_exn = 'a Wrapped.result_exn
661       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
662       let bind u f = Wrapped.bind u (fun t -> match t with
663         | Success a -> f a
664         | Error e -> Wrapped.unit (Error e))
665       let run u = Wrapped.run u
666       let run_exn u =
667         let w = Wrapped.bind u (fun t -> match t with
668           | Success a -> Wrapped.unit a
669           (* | _ -> Wrapped.fail () *)
670           | Error e -> raise (Err.Exc e))
671         in Wrapped.run_exn w
672     end
673     include Monad.MakeT(Trans)
674     let throw e = Wrapped.unit (Error e)
675     let catch u handler = Wrapped.bind u (fun t -> match t with
676       | Success _ -> Wrapped.unit t
677       | Error e -> handler e)
678   end
679   module TP(Wrapped : Monad.P) = struct
680     module TransP = struct
681       include T(Wrapped)
682       let plus u v = Wrapped.plus u v
683       let zero () = elevate (Wrapped.zero ())
684     end
685     include TransP
686     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
687   end
688   module T2(Wrapped : Monad.S2) = struct
689     module Trans = struct
690       module Wrapped = Wrapped
691       type ('x,'a) m = ('x,'a error) Wrapped.m
692       type ('x,'a) result = ('x,'a error) Wrapped.result
693       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
694       (* code repetition, ugh *)
695       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
696       let bind u f = Wrapped.bind u (fun t -> match t with
697         | Success a -> f a
698         | Error e -> Wrapped.unit (Error e))
699       let run u = Wrapped.run u
700       let run_exn u =
701         let w = Wrapped.bind u (fun t -> match t with
702           | Success a -> Wrapped.unit a
703           | Error e -> raise (Err.Exc e))
704         in Wrapped.run_exn w
705     end
706     include Monad.MakeT2(Trans)
707     let throw e = Wrapped.unit (Error e)
708     let catch u handler = Wrapped.bind u (fun t -> match t with
709       | Success _ -> Wrapped.unit t
710       | Error e -> handler e)
711   end
712   module TP2(Wrapped : Monad.P2) = struct
713     module TransP = struct
714       include T2(Wrapped)
715       (* code repetition, ugh *)
716       let plus u v = Wrapped.plus u v
717       let zero () = elevate (Wrapped.zero ())
718     end
719     include TransP
720     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
721   end
722 end
723
724 (* pre-define common instance of Error_monad *)
725 module Failure = Error_monad(struct
726   type err = string
727   exception Exc = Failure
728   (*
729   let zero = ""
730   let plus s1 s2 = s1 ^ "\n" ^ s2
731   *)
732 end)
733
734 (*
735 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
736 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
737 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
738 - : int LE.result = Failure.Error "bye"
739 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
740 Exception: Failure "bye".
741 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
742 Exception: Failure "bye".
743
744 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
745 - : int Failure.error * S.store = (Failure.Error "bye", 1)
746 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
747 - : (int * S.store) Failure.result = Failure.Error "bye"
748 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
749 Exception: Failure "bye".
750 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
751 Exception: Failure "bye".
752  *)
753
754
755 (* must be parameterized on (struct type env = ... end) *)
756 module Reader_monad(Env : sig type env end) : sig
757   (* declare additional operations, while still hiding implementation of type m *)
758   type env = Env.env
759   type 'a result = env -> 'a
760   type 'a result_exn = env -> 'a
761   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
762   val ask : env m
763   val asks : (env -> 'a) -> 'a m
764   val local : (env -> env) -> 'a m -> 'a m
765   (* ReaderT transformer *)
766   module T : functor (Wrapped : Monad.S) -> sig
767     type 'a result = env -> 'a Wrapped.result
768     type 'a result_exn = env -> 'a Wrapped.result_exn
769     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
770     val elevate : 'a Wrapped.m -> 'a m
771     val ask : env m
772     val asks : (env -> 'a) -> 'a m
773     val local : (env -> env) -> 'a m -> 'a m
774   end
775   (* ReaderT transformer when wrapped monad has plus, zero *)
776   module TP : functor (Wrapped : Monad.P) -> sig
777     include module type of T(Wrapped)
778     include Monad.PLUS with type 'a m := 'a m
779   end
780   module T2 : functor (Wrapped : Monad.S2) -> sig
781     type ('x,'a) result = env -> ('x,'a) Wrapped.result
782     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
783     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
784     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
785     val ask : ('x,env) m
786     val asks : (env -> 'a) -> ('x,'a) m
787     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
788   end
789   module TP2 : functor (Wrapped : Monad.P2) -> sig
790     include module type of T2(Wrapped)
791     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
792   end
793 end = struct
794   type env = Env.env
795   module Base = struct
796     type 'a m = env -> 'a
797     type 'a result = env -> 'a
798     type 'a result_exn = env -> 'a
799     let unit a = fun e -> a
800     let bind u f = fun e -> let a = u e in let u' = f a in u' e
801     let run u = fun e -> u e
802     let run_exn = run
803   end
804   include Monad.Make(Base)
805   let ask = fun e -> e
806   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
807   let local modifier u = fun e -> u (modifier e)
808   module T(Wrapped : Monad.S) = struct
809     module Trans = struct
810       module Wrapped = Wrapped
811       type 'a m = env -> 'a Wrapped.m
812       type 'a result = env -> 'a Wrapped.result
813       type 'a result_exn = env -> 'a Wrapped.result_exn
814       let elevate w = fun e -> w
815       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
816       let run u = fun e -> Wrapped.run (u e)
817       let run_exn u = fun e -> Wrapped.run_exn (u e)
818     end
819     include Monad.MakeT(Trans)
820     let ask = fun e -> Wrapped.unit e
821     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
822     let local modifier u = fun e -> u (modifier e)
823   end
824   module TP(Wrapped : Monad.P) = struct
825     module TransP = struct
826       include T(Wrapped)
827       let plus u v = fun s -> Wrapped.plus (u s) (v s)
828       let zero () = elevate (Wrapped.zero ())
829       let asks selector = ask >>= (fun e ->
830         try unit (selector e)
831         with Not_found -> fun e -> Wrapped.zero ())
832     end
833     include TransP
834     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
835   end
836   module T2(Wrapped : Monad.S2) = struct
837     module Trans = struct
838       module Wrapped = Wrapped
839       type ('x,'a) m = env -> ('x,'a) Wrapped.m
840       type ('x,'a) result = env -> ('x,'a) Wrapped.result
841       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
842       (* code repetition, ugh *)
843       let elevate w = fun e -> w
844       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
845       let run u = fun e -> Wrapped.run (u e)
846       let run_exn u = fun e -> Wrapped.run_exn (u e)
847     end
848     include Monad.MakeT2(Trans)
849     let ask = fun e -> Wrapped.unit e
850     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
851     let local modifier u = fun e -> u (modifier e)
852   end
853   module TP2(Wrapped : Monad.P2) = struct
854     module TransP = struct
855       include T2(Wrapped)
856       (* code repetition, ugh *)
857       let plus u v = fun s -> Wrapped.plus (u s) (v s)
858       let zero () = elevate (Wrapped.zero ())
859       let asks selector = ask >>= (fun e ->
860         try unit (selector e)
861         with Not_found -> fun e -> Wrapped.zero ())
862     end
863     include TransP
864     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
865   end
866 end
867
868
869 (* must be parameterized on (struct type store = ... end) *)
870 module State_monad(Store : sig type store end) : sig
871   (* declare additional operations, while still hiding implementation of type m *)
872   type store = Store.store
873   type 'a result = store -> 'a * store
874   type 'a result_exn = store -> 'a
875   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
876   val get : store m
877   val gets : (store -> 'a) -> 'a m
878   val put : store -> unit m
879   val puts : (store -> store) -> unit m
880   (* StateT transformer *)
881   module T : functor (Wrapped : Monad.S) -> sig
882     type 'a result = store -> ('a * store) Wrapped.result
883     type 'a result_exn = store -> 'a Wrapped.result_exn
884     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
885     val elevate : 'a Wrapped.m -> 'a m
886     val get : store m
887     val gets : (store -> 'a) -> 'a m
888     val put : store -> unit m
889     val puts : (store -> store) -> unit m
890   end
891   (* StateT transformer when wrapped monad has plus, zero *)
892   module TP : functor (Wrapped : Monad.P) -> sig
893     include module type of T(Wrapped)
894     include Monad.PLUS with type 'a m := 'a m
895   end
896   module T2 : functor (Wrapped : Monad.S2) -> sig
897     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
898     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
899     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
900     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
901     val get : ('x,store) m
902     val gets : (store -> 'a) -> ('x,'a) m
903     val put : store -> ('x,unit) m
904     val puts : (store -> store) -> ('x,unit) m
905   end
906   module TP2 : functor (Wrapped : Monad.P2) -> sig
907     include module type of T2(Wrapped)
908     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
909   end
910 end = struct
911   type store = Store.store
912   module Base = struct
913     type 'a m = store -> 'a * store
914     type 'a result = store -> 'a * store
915     type 'a result_exn = store -> 'a
916     let unit a = fun s -> (a, s)
917     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
918     let run u = fun s -> (u s)
919     let run_exn u = fun s -> fst (u s)
920   end
921   include Monad.Make(Base)
922   let get = fun s -> (s, s)
923   let gets viewer = fun s -> (viewer s, s) (* may fail *)
924   let put s = fun _ -> ((), s)
925   let puts modifier = fun s -> ((), modifier s)
926   module T(Wrapped : Monad.S) = struct
927     module Trans = struct
928       module Wrapped = Wrapped
929       type 'a m = store -> ('a * store) Wrapped.m
930       type 'a result = store -> ('a * store) Wrapped.result
931       type 'a result_exn = store -> 'a Wrapped.result_exn
932       let elevate w = fun s ->
933         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
934       let bind u f = fun s ->
935         Wrapped.bind (u s) (fun (a, s') -> f a s')
936       let run u = fun s -> Wrapped.run (u s)
937       let run_exn u = fun s ->
938         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
939         in Wrapped.run_exn w
940     end
941     include Monad.MakeT(Trans)
942     let get = fun s -> Wrapped.unit (s, s)
943     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
944     let put s = fun _ -> Wrapped.unit ((), s)
945     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
946   end
947   module TP(Wrapped : Monad.P) = struct
948     module TransP = struct
949       include T(Wrapped)
950       let plus u v = fun s -> Wrapped.plus (u s) (v s)
951       let zero () = elevate (Wrapped.zero ())
952     end
953     let gets viewer = fun s ->
954       try Wrapped.unit (viewer s, s)
955       with Not_found -> Wrapped.zero ()
956     include TransP
957     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
958   end
959   module T2(Wrapped : Monad.S2) = struct
960     module Trans = struct
961       module Wrapped = Wrapped
962       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
963       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
964       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
965       (* code repetition, ugh *)
966       let elevate w = fun s ->
967         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
968       let bind u f = fun s ->
969         Wrapped.bind (u s) (fun (a, s') -> f a s')
970       let run u = fun s -> Wrapped.run (u s)
971       let run_exn u = fun s ->
972         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
973         in Wrapped.run_exn w
974     end
975     include Monad.MakeT2(Trans)
976     let get = fun s -> Wrapped.unit (s, s)
977     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
978     let put s = fun _ -> Wrapped.unit ((), s)
979     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
980   end
981   module TP2(Wrapped : Monad.P2) = struct
982     module TransP = struct
983       include T2(Wrapped)
984       (* code repetition, ugh *)
985       let plus u v = fun s -> Wrapped.plus (u s) (v s)
986       let zero () = elevate (Wrapped.zero ())
987     end
988     let gets viewer = fun s ->
989       try Wrapped.unit (viewer s, s)
990       with Not_found -> Wrapped.zero ()
991     include TransP
992     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
993   end
994 end
995
996 (* State monad with different interface (structured store) *)
997 module Ref_monad(V : sig
998   type value
999 end) : sig
1000   type ref
1001   type value = V.value
1002   type 'a result = 'a
1003   type 'a result_exn = 'a
1004   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1005   val newref : value -> ref m
1006   val deref : ref -> value m
1007   val change : ref -> value -> unit m
1008   (* RefT transformer *)
1009   module T : functor (Wrapped : Monad.S) -> sig
1010     type 'a result = 'a Wrapped.result
1011     type 'a result_exn = 'a Wrapped.result_exn
1012     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1013     val elevate : 'a Wrapped.m -> 'a m
1014     val newref : value -> ref m
1015     val deref : ref -> value m
1016     val change : ref -> value -> unit m
1017   end
1018   (* RefT transformer when wrapped monad has plus, zero *)
1019   module TP : functor (Wrapped : Monad.P) -> sig
1020     include module type of T(Wrapped)
1021     include Monad.PLUS with type 'a m := 'a m
1022   end
1023   module T2 : functor (Wrapped : Monad.S2) -> sig
1024     type ('x,'a) result = ('x,'a) Wrapped.result
1025     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1026     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1027     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1028     val newref : value -> ('x,ref) m
1029     val deref : ref -> ('x,value) m
1030     val change : ref -> value -> ('x,unit) m
1031   end
1032   module TP2 : functor (Wrapped : Monad.P2) -> sig
1033     include module type of T2(Wrapped)
1034     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1035   end
1036 end = struct
1037   type ref = int
1038   type value = V.value
1039   module D = Map.Make(struct type t = ref let compare = compare end)
1040   type dict = { next: ref; tree : value D.t }
1041   let empty = { next = 0; tree = D.empty }
1042   let alloc (value : value) (d : dict) =
1043     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
1044   let read (key : ref) (d : dict) =
1045     D.find key d.tree
1046   let write (key : ref) (value : value) (d : dict) =
1047     { next = d.next; tree = D.add key value d.tree }
1048   module Base = struct
1049     type 'a m = dict -> 'a * dict
1050     type 'a result = 'a
1051     type 'a result_exn = 'a
1052     let unit a = fun s -> (a, s)
1053     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
1054     let run u = fst (u empty)
1055     let run_exn = run
1056   end
1057   include Monad.Make(Base)
1058   let newref value = fun s -> alloc value s
1059   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
1060   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
1061   module T(Wrapped : Monad.S) = struct
1062     module Trans = struct
1063       module Wrapped = Wrapped
1064       type 'a m = dict -> ('a * dict) Wrapped.m
1065       type 'a result = 'a Wrapped.result
1066       type 'a result_exn = 'a Wrapped.result_exn
1067       let elevate w = fun s ->
1068         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1069       let bind u f = fun s ->
1070         Wrapped.bind (u s) (fun (a, s') -> f a s')
1071       let run u =
1072         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1073         in Wrapped.run w
1074       let run_exn u =
1075         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1076         in Wrapped.run_exn w
1077     end
1078     include Monad.MakeT(Trans)
1079     let newref value = fun s -> Wrapped.unit (alloc value s)
1080     let deref key = fun s -> Wrapped.unit (read key s, s)
1081     let change key value = fun s -> Wrapped.unit ((), write key value s)
1082   end
1083   module TP(Wrapped : Monad.P) = struct
1084     module TransP = struct
1085       include T(Wrapped)
1086       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1087       let zero () = elevate (Wrapped.zero ())
1088     end
1089     include TransP
1090     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1091   end
1092   module T2(Wrapped : Monad.S2) = struct
1093     module Trans = struct
1094       module Wrapped = Wrapped
1095       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
1096       type ('x,'a) result = ('x,'a) Wrapped.result
1097       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1098       (* code repetition, ugh *)
1099       let elevate w = fun s ->
1100         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1101       let bind u f = fun s ->
1102         Wrapped.bind (u s) (fun (a, s') -> f a s')
1103       let run u =
1104         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1105         in Wrapped.run w
1106       let run_exn u =
1107         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1108         in Wrapped.run_exn w
1109     end
1110     include Monad.MakeT2(Trans)
1111     let newref value = fun s -> Wrapped.unit (alloc value s)
1112     let deref key = fun s -> Wrapped.unit (read key s, s)
1113     let change key value = fun s -> Wrapped.unit ((), write key value s)
1114   end
1115   module TP2(Wrapped : Monad.P2) = struct
1116     module TransP = struct
1117       include T2(Wrapped)
1118       (* code repetition, ugh *)
1119       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1120       let zero () = elevate (Wrapped.zero ())
1121     end
1122     include TransP
1123     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1124   end
1125 end
1126
1127
1128 (* must be parameterized on (struct type log = ... end) *)
1129 module Writer_monad(Log : sig
1130   type log
1131   val zero : log
1132   val plus : log -> log -> log
1133 end) : sig
1134   (* declare additional operations, while still hiding implementation of type m *)
1135   type log = Log.log
1136   type 'a result = 'a * log
1137   type 'a result_exn = 'a * log
1138   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1139   val tell : log -> unit m
1140   val listen : 'a m -> ('a * log) m
1141   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
1142   (* val pass : ('a * (log -> log)) m -> 'a m *)
1143   val censor : (log -> log) -> 'a m -> 'a m
1144 end = struct
1145   type log = Log.log
1146   module Base = struct
1147     type 'a m = 'a * log
1148     type 'a result = 'a * log
1149     type 'a result_exn = 'a * log
1150     let unit a = (a, Log.zero)
1151     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
1152     let run u = u
1153     let run_exn = run
1154   end
1155   include Monad.Make(Base)
1156   let tell entries = ((), entries) (* add entries to log *)
1157   let listen (a, w) = ((a, w), w)
1158   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
1159   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
1160   let censor f u = pass (u >>= fun a -> unit (a, f))
1161 end
1162
1163 (* pre-define simple Writer *)
1164 module Writer1 = Writer_monad(struct
1165   type log = string
1166   let zero = ""
1167   let plus s1 s2 = s1 ^ "\n" ^ s2
1168 end)
1169
1170 (* slightly more efficient Writer *)
1171 module Writer2 = struct
1172   include Writer_monad(struct
1173     type log = string list
1174     let zero = []
1175     let plus w w' = Util.append w' w
1176   end)
1177   let tell_string s = tell [s]
1178   let tell entries = tell (Util.reverse entries)
1179   let run u = let (a, w) = run u in (a, Util.reverse w)
1180   let run_exn = run
1181 end
1182
1183
1184 module IO_monad : sig
1185   (* declare additional operation, while still hiding implementation of type m *)
1186   type 'a result = 'a
1187   type 'a result_exn = 'a
1188   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1189   val printf : ('a, unit, string, unit m) format4 -> 'a
1190   val print_string : string -> unit m
1191   val print_int : int -> unit m
1192   val print_hex : int -> unit m
1193   val print_bool : bool -> unit m
1194 end = struct
1195   module Base = struct
1196     type 'a m = { run : unit -> unit; value : 'a }
1197     type 'a result = 'a
1198     type 'a result_exn = 'a
1199     let unit a = { run = (fun () -> ()); value = a }
1200     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
1201      let fres = f a.value in
1202        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
1203     let run a = let () = a.run () in a.value
1204     let run_exn = run
1205   end
1206   include Monad.Make(Base)
1207   let printf fmt =
1208     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
1209   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
1210   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
1211   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
1212   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
1213 end
1214
1215 (*
1216 module Continuation_monad : sig
1217   (* expose only the implementation of type `('r,'a) result` *)
1218   type 'a m
1219   type 'a result = 'a m
1220   type 'a result_exn = 'a m
1221   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
1222   (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
1223   (* misses that the answer types of all the cont's must be the same *)
1224   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
1225   (* val reset : ('a,'a) m -> ('r,'a) m *)
1226   val reset : 'a m -> 'a m
1227   (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
1228   (* misses that the answer types of second and third continuations must be b *)
1229   val shift : (('a -> 'b m) -> 'b m) -> 'a m
1230   (* overwrite the run declaration in S, because I can't declare 'a result =
1231    * this polymorphic type (complains that 'r is unbound *)
1232   val runk : 'a m -> ('a -> 'r) -> 'r
1233   val run0 : 'a m -> 'a
1234 end = struct
1235   let id = fun i -> i
1236   module Base = struct
1237     (* 'r is result type of whole computation *)
1238     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
1239     type 'a result = 'a m
1240     type 'a result_exn = 'a m
1241     let unit a =
1242       let cont : 'r. ('a -> 'r) -> 'r =
1243         fun k -> k a
1244       in { cont }
1245     let bind u f =
1246       let cont : 'r. ('a -> 'r) -> 'r =
1247         fun k -> u.cont (fun a -> (f a).cont k)
1248       in { cont }
1249     let run (u : 'a m) : 'a result = u
1250     let run_exn (u : 'a m) : 'a result_exn = u
1251     let callcc f =
1252       let cont : 'r. ('a -> 'r) -> 'r =
1253           (* Can't figure out how to make the type polymorphic enough
1254            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
1255            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
1256            * with Obj.magic... which tells OCaml's type checker to
1257            * relax, the supplied value has whatever type the context
1258            * needs it to have. *)
1259           fun k ->
1260           let usek a = { cont = Obj.magic (fun _ -> k a) }
1261           in (f usek).cont k
1262       in { cont }
1263     let reset u = unit (u.cont id)
1264     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
1265       let cont = fun k ->
1266         (f (fun a -> unit (k a))).cont id
1267       in { cont = Obj.magic cont }
1268     let runk u k = (u.cont : ('a -> 'r) -> 'r) k
1269     let run0 u = runk u id
1270   end
1271   include Monad.Make(Base)
1272   let callcc = Base.callcc
1273   let reset = Base.reset
1274   let shift = Base.shift
1275   let runk = Base.runk
1276   let run0 = Base.run0
1277 end
1278  *)
1279
1280 (* This two-type parameter version works without Obj.magic *)
1281 module Continuation_monad : sig
1282   (* expose only the implementation of type `('r,'a) result` *)
1283   type ('r,'a) m
1284   type ('r,'a) result = ('r,'a) m
1285   type ('r,'a) result_exn = ('a -> 'r) -> 'r
1286   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
1287   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
1288   val reset : ('a,'a) m -> ('r,'a) m
1289   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
1290   (* val abort : ('a,'a) m -> ('a,'b) m *)
1291   val abort : 'a -> ('a,'b) m
1292   val run0 : ('a,'a) m -> 'a
1293 end = struct
1294   let id = fun i -> i
1295   module Base = struct
1296     (* 'r is result type of whole computation *)
1297     type ('r,'a) m = ('a -> 'r) -> 'r
1298     type ('r,'a) result = ('a -> 'r) -> 'r
1299     type ('r,'a) result_exn = ('r,'a) result
1300     let unit a = (fun k -> k a)
1301     let bind u f = (fun k -> (u) (fun a -> (f a) k))
1302     let run u k = (u) k
1303     let run_exn = run
1304   end
1305   include Monad.Make2(Base)
1306   let callcc f = (fun k ->
1307     let usek a = (fun _ -> k a)
1308     in (f usek) k)
1309   (*
1310   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
1311   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
1312   let callcc f = fun k -> f k k
1313   let throw k a = fun _ -> k a
1314   *)
1315
1316   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
1317    *
1318    *  reset :: (Monad m) => ContT a m a -> ContT r m a
1319    *  reset e = ContT $ \k -> runContT e return >>= k
1320    *
1321    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
1322    *  shift e = ContT $ \k ->
1323    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
1324   let reset u = unit ((u) id)
1325   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
1326   (* let abort a = shift (fun _ -> a) *)
1327   let abort a = shift (fun _ -> unit a)
1328   let run0 (u : ('a,'a) m) = (u) id
1329 end
1330
1331
1332 (*
1333  * Scheme:
1334  * (define (example n)
1335  *    (let ([u (let/cc k ; type int -> int pair
1336  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
1337  *                 (+ 1 (car v))))]) ; int
1338  *      (cons u 0))) ; int pair
1339  * ; (example 10) ~~> '(111 . 0)
1340  * ; (example -10) ~~> '(0 . 0)
1341  *
1342  * OCaml monads:
1343  * let example n : (int * int) =
1344  *   Continuation_monad.(let u = callcc (fun k ->
1345  *       (if n < 0 then k 0 else unit [n + 100])
1346  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
1347  *       >>= fun [x] -> unit (x + 1)
1348  *   )
1349  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1350  *   >>= fun x -> unit (x, 0)
1351  *   in run u)
1352  *
1353  *
1354  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1355  * let example1 () : int =
1356  *   Continuation_monad.(let v = reset (
1357  *       let u = shift (fun k -> unit (10 + 1))
1358  *       in u >>= fun x -> unit (100 + x)
1359  *     ) in let w = v >>= fun x -> unit (1000 + x)
1360  *     in run w)
1361  *
1362  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1363  * let example2 () =
1364  *   Continuation_monad.(let v = reset (
1365  *       let u = shift (fun k -> k (10 :: [1]))
1366  *       in u >>= fun x -> unit (100 :: x)
1367  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1368  *     in run w)
1369  *
1370  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1371  * let example3 () =
1372  *   Continuation_monad.(let v = reset (
1373  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1374  *       in u >>= fun x -> unit (100 :: x)
1375  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1376  *     in run w)
1377  *
1378  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1379  * (* not sure if this example can be typed without a sum-type *)
1380  *
1381  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1382  * let example5 () : int =
1383  *   Continuation_monad.(let v = reset (
1384  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1385  *       in u >>= fun x -> unit (10 + x)
1386  *     ) in let w = v >>= fun x -> unit (100 + x)
1387  *     in run w)
1388  *
1389  *)
1390
1391
1392 module Leaf_monad : sig
1393   (* We implement the type as `'a tree option` because it has a natural`plus`,
1394    * and the rest of the library expects that `plus` and `zero` will come together. *)
1395   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1396   type 'a result = 'a tree option
1397   type 'a result_exn = 'a tree
1398   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1399   include Monad.PLUS with type 'a m := 'a m
1400   (* LeafT transformer *)
1401   module T : functor (Wrapped : Monad.S) -> sig
1402     type 'a result = 'a tree option Wrapped.result
1403     type 'a result_exn = 'a tree Wrapped.result_exn
1404     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1405     include Monad.PLUS with type 'a m := 'a m
1406     val elevate : 'a Wrapped.m -> 'a m
1407     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1408     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1409     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1410   end
1411   module T2 : functor (Wrapped : Monad.S2) -> sig
1412     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1413     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1414     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1415     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1416     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1417     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1418   end
1419 end = struct
1420   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1421   (* uses supplied plus and zero to copy t to its image under f *)
1422   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1423       | None -> zero ()
1424       | Some ts -> let rec loop ts = (match ts with
1425                      | Leaf a -> f a
1426                      | Node (l, r) ->
1427                          (* recursive application of f may delete a branch *)
1428                          plus (loop l) (loop r)
1429                    ) in loop ts
1430   module Base = struct
1431     type 'a m = 'a tree option
1432     type 'a result = 'a tree option
1433     type 'a result_exn = 'a tree
1434     let unit a = Some (Leaf a)
1435     let zero () = None
1436     let plus u v = match (u, v) with
1437       | None, _ -> v
1438       | _, None -> u
1439       | Some us, Some vs -> Some (Node (us, vs))
1440     let bind u f = mapT f u zero plus
1441     let run u = u
1442     let run_exn u = match u with
1443       | None -> failwith "no values"
1444       (*
1445       | Some (Leaf a) -> a
1446       | many -> failwith "multiple values"
1447       *)
1448       | Some us -> us
1449   end
1450   include Monad.Make(Base)
1451   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1452   let base_plus = plus
1453   let base_lift = lift
1454   module T(Wrapped : Monad.S) = struct
1455     module Trans = struct
1456       let zero () = Wrapped.unit None
1457       let plus u v =
1458         Wrapped.bind u (fun us ->
1459         Wrapped.bind v (fun vs ->
1460         Wrapped.unit (base_plus us vs)))
1461       include Monad.MakeT(struct
1462         module Wrapped = Wrapped
1463         type 'a m = 'a tree option Wrapped.m
1464         type 'a result = 'a tree option Wrapped.result
1465         type 'a result_exn = 'a tree Wrapped.result_exn
1466         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1467         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1468         let run u = Wrapped.run u
1469         let run_exn u =
1470             let w = Wrapped.bind u (fun t -> match t with
1471               | None -> failwith "no values"
1472               | Some ts -> Wrapped.unit ts)
1473             in Wrapped.run_exn w
1474       end)
1475     end
1476     include Trans
1477     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1478     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1479     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1480   end
1481   module T2(Wrapped : Monad.S2) = struct
1482     module Trans = struct
1483       let zero () = Wrapped.unit None
1484       let plus u v =
1485         Wrapped.bind u (fun us ->
1486         Wrapped.bind v (fun vs ->
1487         Wrapped.unit (base_plus us vs)))
1488       include Monad.MakeT2(struct
1489         module Wrapped = Wrapped
1490         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1491         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1492         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1493         (* code repetition, ugh *)
1494         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1495         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1496         let run u = Wrapped.run u
1497         let run_exn u =
1498             let w = Wrapped.bind u (fun t -> match t with
1499               | None -> failwith "no values"
1500               | Some ts -> Wrapped.unit ts)
1501             in Wrapped.run_exn w
1502       end)
1503     end
1504     include Trans
1505     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1506     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1507   end
1508 end
1509
1510
1511 module L = List_monad;;
1512 module R = Reader_monad(struct type env = int -> int end);;
1513 module S = State_monad(struct type store = int end);;
1514 module T = Leaf_monad;;
1515 module LR = L.T(R);;
1516 module LS = L.T(S);;
1517 module TL = T.T(L);;
1518 module TR = T.T(R);;
1519 module TS = T.T(S);;
1520 module C = Continuation_monad
1521 module TC = T.T2(C);;
1522
1523
1524 print_endline "=== test Leaf(...).distribute ==================";;
1525
1526 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1527
1528 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1529 TS.run ts 0;;
1530 (*
1531 - : int T.tree option * S.store =
1532 (Some
1533   (T.Node
1534     (T.Node (T.Leaf 2, T.Leaf 3),
1535      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1536  5)
1537 *)
1538
1539 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1540 TS.run_exn ts2 0;;
1541 (*
1542 - : (int * S.store) T.tree option * S.store =
1543 (Some
1544   (T.Node
1545     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1546      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1547  5)
1548 *)
1549
1550 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1551 TR.run_exn tr (fun i -> i+i);;
1552 (*
1553 - : int T.tree option =
1554 Some
1555  (T.Node
1556    (T.Node (T.Leaf 4, T.Leaf 6),
1557     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1558 *)
1559
1560 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1561 TL.run_exn tl;;
1562 (*
1563 - : (int * int) TL.result =
1564 [Some
1565   (T.Node
1566     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1567      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1568 *)
1569
1570 let l2 = [1;2;3;4;5];;
1571 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1572
1573 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1574 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1575
1576 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1577 (*
1578 int T.tree option =
1579 Some
1580  (T.Node
1581    (T.Node (T.Leaf 10, T.Leaf 11),
1582     T.Node
1583      (T.Node
1584        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1585         T.Node (T.Leaf 40, T.Leaf 41)),
1586       T.Node (T.Leaf 50, T.Leaf 51))))
1587  *)
1588
1589 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1590 (*
1591 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1592 *)
1593
1594 print_endline "=== test Leaf(Continuation).distribute ==================";;
1595
1596 let id : 'z. 'z -> 'z = fun x -> x
1597
1598 let example n : (int * int) =
1599   Continuation_monad.(let u = callcc (fun k ->
1600       (if n < 0 then k 0 else unit [n + 100])
1601       (* all of the following is skipped by k 0; the end type int is k's input type *)
1602       >>= fun [x] -> unit (x + 1)
1603   )
1604   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1605   >>= fun x -> unit (x, 0)
1606   in run0 u)
1607
1608
1609 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1610 let example1 () : int =
1611   Continuation_monad.(let v = reset (
1612       let u = shift (fun k -> unit (10 + 1))
1613       in u >>= fun x -> unit (100 + x)
1614     ) in let w = v >>= fun x -> unit (1000 + x)
1615     in run0 w)
1616
1617 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1618 let example2 () =
1619   Continuation_monad.(let v = reset (
1620       let u = shift (fun k -> k (10 :: [1]))
1621       in u >>= fun x -> unit (100 :: x)
1622     ) in let w = v >>= fun x -> unit (1000 :: x)
1623     in run0 w)
1624
1625 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1626 let example3 () =
1627   Continuation_monad.(let v = reset (
1628       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1629       in u >>= fun x -> unit (100 :: x)
1630     ) in let w = v >>= fun x -> unit (1000 :: x)
1631     in run0 w)
1632
1633 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1634 (* not sure if this example can be typed without a sum-type *)
1635
1636 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1637 let example5 () : int =
1638   Continuation_monad.(let v = reset (
1639       let u = shift (fun k -> k 1 >>= k)
1640       in u >>= fun x -> unit (10 + x)
1641     ) in let w = v >>= fun x -> unit (100 + x)
1642     in run0 w)
1643
1644 ;;
1645
1646 print_endline "=== test bare Continuation ============";;
1647
1648 (1011, 1111, 1111, 121);;
1649 (example1(), example2(), example3(), example5());;
1650 ((111,0), (0,0));;
1651 (example ~+10, example ~-10);;
1652
1653 let testc df ic =
1654     C.run_exn TC.(run (distribute df t1)) ic;;
1655
1656
1657 (*
1658 (* do nothing *)
1659 let initial_continuation = fun t -> t in
1660 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1661 *)
1662 testc (C.unit) id;;
1663
1664 (*
1665 (* count leaves, using continuation *)
1666 let initial_continuation = fun t -> 0 in
1667 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1668 *)
1669
1670 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1671
1672 (*
1673 (* convert tree to list of leaves *)
1674 let initial_continuation = fun t -> [] in
1675 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1676 *)
1677
1678 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1679
1680 (*
1681 (* square each leaf using continuation *)
1682 let initial_continuation = fun t -> t in
1683 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1684 *)
1685
1686 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1687
1688
1689 (*
1690 (* replace leaves with list, using continuation *)
1691 let initial_continuation = fun t -> t in
1692 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1693 *)
1694
1695 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1696
1697 print_endline "=== pa_monad's Continuation Tests ============";;
1698
1699 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1700 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1701 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1702 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1703 (5, 27 = C.(run0 (
1704               let c = reset(abort 5 >>= fun y -> unit (y+6))
1705               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1706
1707 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1708
1709 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1710
1711 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1712
1713
1714 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1715