0b6af2ed38f013fd492d727257d8b23f1041a67b
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73   (*
74    * Signature extenders:
75    *   Make :: BASE -> S
76    *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
77    *                                           which merges into S
78    *                                           (P is merged sig)
79    *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
80    *
81    *   Make2 :: BASE2 -> S2
82    *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
83    *   to wrap double-typed inner monads:
84    *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
85    *
86    *)
87
88
89
90   (* type of base definitions *)
91   module type BASE = sig
92     (* The only constraints we impose here on how the monadic type
93      * is implemented is that it have a single type parameter 'a. *)
94     type 'a m
95     type 'a result
96     type 'a result_exn
97     val unit : 'a -> 'a m
98     val bind : 'a m -> ('a -> 'b m) -> 'b m
99     val run : 'a m -> 'a result
100     (* run_exn tries to provide a more ground-level result, but may fail *)
101     val run_exn : 'a m -> 'a result_exn
102   end
103   module type S = sig
104     include BASE
105     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
106     val (>>) : 'a m -> 'b m -> 'b m
107     val join : ('a m) m -> 'a m
108     val apply : ('a -> 'b) m -> 'a m -> 'b m
109     val lift : ('a -> 'b) -> 'a m -> 'b m
110     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
111     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
112     val do_when :  bool -> unit m -> unit m
113     val do_unless :  bool -> unit m -> unit m
114     val forever : (unit -> 'a m) -> 'b m
115     val sequence : 'a m list -> 'a list m
116     val sequence_ : 'a m list -> unit m
117   end
118
119   (* Standard, single-type-parameter monads. *)
120   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
121     include B
122     let (>>=) = bind
123     let (>>) u v = u >>= fun _ -> v
124     let lift f u = u >>= fun a -> unit (f a)
125     (* lift is called listM, fmap, and <$> in Haskell *)
126     let join uu = uu >>= fun u -> u
127     (* u >>= f === join (lift f u) *)
128     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
129     (* [f] <*> [x1,x2] = [f x1,f x2] *)
130     (* let apply u v = u >>= fun f -> lift f v *)
131     (* let apply = lift2 id *)
132     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
133     (* let lift f u === apply (unit f) u *)
134     (* let lift2 f u v = apply (lift f u) v *)
135     let (>=>) f g = fun a -> f a >>= g
136     let do_when test u = if test then u else unit ()
137     let do_unless test u = if test then unit () else u
138     let forever uthunk =
139         let rec loop () = uthunk () >>= fun _ -> loop ()
140         in loop ()
141     let sequence ms =
142       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
143         Util.fold_right op ms (unit [])
144     let sequence_ ms =
145       Util.fold_right (>>) ms (unit ())
146
147     (* Haskell defines these other operations combining lists and monads.
148      * We don't, but notice that M.mapM == ListT(M).distribute
149      * There's also a parallel TreeT(M).distribute *)
150     (*
151     let mapM f alist = sequence (Util.map f alist)
152     let mapM_ f alist = sequence_ (Util.map f alist)
153     let rec filterM f lst = match lst with
154       | [] -> unit []
155       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
156     let forM alist f = mapM f alist
157     let forM_ alist f = mapM_ f alist
158     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
159     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
160     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
161     let rec foldM f z lst = match lst with
162       | [] -> unit z
163       | x::xs -> f z x >>= fun z' -> foldM f z' xs
164     let foldM_ f z xs = foldM f z xs >> unit ()
165     let replicateM n x = sequence (Util.replicate n x)
166     let replicateM_ n x = sequence_ (Util.replicate n x)
167     *)
168   end
169
170   (* Single-type-parameter monads that also define `plus` and `zero`
171    * operations. These obey the following laws:
172    *     zero >>= f   ===  zero
173    *     plus zero u  ===  u
174    *     plus u zero  ===  u
175    * Additionally, these monads will obey one of the following laws:
176    *     (Catch)   plus (unit a) v  ===  unit a
177    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
178    *)
179   module type PLUSBASE = sig
180     include BASE
181     val zero : unit -> 'a m
182     val plus : 'a m -> 'a m -> 'a m
183   end
184   module type PLUS = sig
185     type 'a m
186     val zero : unit -> 'a m
187     val plus : 'a m -> 'a m -> 'a m
188     val guard : bool -> unit m
189     val sum : 'a m list -> 'a m
190   end
191   (* MakeCatch and MakeDistrib have the same implementation; we just declare
192    * them twice to document which laws the client code is promising to honor. *)
193   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
194     type 'a m = 'a B.m
195     let zero = B.zero
196     let plus = B.plus
197     let guard test = if test then B.unit () else zero ()
198     let sum ms = Util.fold_right plus ms (zero ())
199   end
200   module MakeDistrib = MakeCatch
201
202   (* Signatures for MonadT *)
203   (* sig for Wrapped that include S and PLUS *)
204   module type P = sig
205     include S
206     include PLUS with type 'a m := 'a m
207   end
208   module type TRANS = sig
209     module Wrapped : S
210     type 'a m
211     type 'a result
212     type 'a result_exn
213     val bind : 'a m -> ('a -> 'b m) -> 'b m
214     val run : 'a m -> 'a result
215     val run_exn : 'a m -> 'a result_exn
216     val elevate : 'a Wrapped.m -> 'a m
217     (* lift/elevate laws:
218      *     elevate (W.unit a) == unit a
219      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
220      *)
221   end
222   module MakeT(T : TRANS) = struct
223     include Make(struct
224         include T
225         let unit a = elevate (Wrapped.unit a)
226     end)
227     let elevate = T.elevate
228   end
229
230
231   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
232   module type BASE2 = sig
233     type ('x,'a) m
234     type ('x,'a) result
235     type ('x,'a) result_exn
236     val unit : 'a -> ('x,'a) m
237     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
238     val run : ('x,'a) m -> ('x,'a) result
239     val run_exn : ('x,'a) m -> ('x,'a) result_exn
240   end
241   module type S2 = sig
242     include BASE2
243     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
244     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
245     val join : ('x,('x,'a) m) m -> ('x,'a) m
246     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
247     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
248     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
249     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
250     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
251     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
252     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
253     val sequence : ('x,'a) m list -> ('x,'a list) m
254     val sequence_ : ('x,'a) m list -> ('x,unit) m
255   end
256   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
257     (* code repetition, ugh *)
258     include B
259     let (>>=) = bind
260     let (>>) u v = u >>= fun _ -> v
261     let lift f u = u >>= fun a -> unit (f a)
262     let join uu = uu >>= fun u -> u
263     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
264     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
265     let (>=>) f g = fun a -> f a >>= g
266     let do_when test u = if test then u else unit ()
267     let do_unless test u = if test then unit () else u
268     let forever uthunk =
269         let rec loop () = uthunk () >>= fun _ -> loop ()
270         in loop ()
271     let sequence ms =
272       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
273         Util.fold_right op ms (unit [])
274     let sequence_ ms =
275       Util.fold_right (>>) ms (unit ())
276   end
277
278   module type PLUSBASE2 = sig
279     include BASE2
280     val zero : unit -> ('x,'a) m
281     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
282   end
283   module type PLUS2 = sig
284     type ('x,'a) m
285     val zero : unit -> ('x,'a) m
286     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
287     val guard : bool -> ('x,unit) m
288     val sum : ('x,'a) m list -> ('x,'a) m
289   end
290   module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
291     type ('x,'a) m = ('x,'a) B.m
292     (* code repetition, ugh *)
293     let zero = B.zero
294     let plus = B.plus
295     let guard test = if test then B.unit () else zero ()
296     let sum ms = Util.fold_right plus ms (zero ())
297   end
298   module MakeDistrib2 = MakeCatch2
299
300   (* Signatures for MonadT *)
301   (* sig for Wrapped that include S and PLUS *)
302   module type P2 = sig
303     include S2
304     include PLUS2 with type ('x,'a) m := ('x,'a) m
305   end
306   module type TRANS2 = sig
307     module Wrapped : S2
308     type ('x,'a) m
309     type ('x,'a) result
310     type ('x,'a) result_exn
311     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
312     val run : ('x,'a) m -> ('x,'a) result
313     val run_exn : ('x,'a) m -> ('x,'a) result_exn
314     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
315   end
316   module MakeT2(T : TRANS2) = struct
317     (* code repetition, ugh *)
318     include Make2(struct
319         include T
320         let unit a = elevate (Wrapped.unit a)
321     end)
322     let elevate = T.elevate
323   end
324
325 end
326
327
328
329
330
331 module Identity_monad : sig
332   (* expose only the implementation of type `'a result` *)
333   type 'a result = 'a
334   type 'a result_exn = 'a
335   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
336 end = struct
337   module Base = struct
338     type 'a m = 'a
339     type 'a result = 'a
340     type 'a result_exn = 'a
341     let unit a = a
342     let bind a f = f a
343     let run a = a
344     let run_exn a = a
345   end
346   include Monad.Make(Base)
347 end
348
349
350 module Maybe_monad : sig
351   (* expose only the implementation of type `'a result` *)
352   type 'a result = 'a option
353   type 'a result_exn = 'a
354   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
355   include Monad.PLUS with type 'a m := 'a m
356   (* MaybeT transformer *)
357   module T : functor (Wrapped : Monad.S) -> sig
358     type 'a result = 'a option Wrapped.result
359     type 'a result_exn = 'a Wrapped.result_exn
360     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
361     include Monad.PLUS with type 'a m := 'a m
362     val elevate : 'a Wrapped.m -> 'a m
363   end
364   module T2 : functor (Wrapped : Monad.S2) -> sig
365     type ('x,'a) result = ('x,'a option) Wrapped.result
366     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
367     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
368     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
369     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
370   end
371 end = struct
372   module Base = struct
373    type 'a m = 'a option
374    type 'a result = 'a option
375    type 'a result_exn = 'a
376    let unit a = Some a
377    let bind u f = match u with Some a -> f a | None -> None
378    let run u = u
379    let run_exn u = match u with
380      | Some a -> a
381      | None -> failwith "no value"
382    let zero () = None
383    let plus u v = match u with None -> v | _ -> u
384   end
385   include Monad.Make(Base)
386   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
387   module T(Wrapped : Monad.S) = struct
388     module Trans = struct
389       include Monad.MakeT(struct
390         module Wrapped = Wrapped
391         type 'a m = 'a option Wrapped.m
392         type 'a result = 'a option Wrapped.result
393         type 'a result_exn = 'a Wrapped.result_exn
394         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
395         let bind u f = Wrapped.bind u (fun t -> match t with
396           | Some a -> f a
397           | None -> Wrapped.unit None)
398         let run u = Wrapped.run u
399         let run_exn u =
400           let w = Wrapped.bind u (fun t -> match t with
401             | Some a -> Wrapped.unit a
402             | None -> failwith "no value")
403           in Wrapped.run_exn w
404       end)
405       let zero () = Wrapped.unit None
406       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
407     end
408     include Trans
409     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
410   end
411   module T2(Wrapped : Monad.S2) = struct
412     module Trans = struct
413       include Monad.MakeT2(struct
414         module Wrapped = Wrapped
415         type ('x,'a) m = ('x,'a option) Wrapped.m
416         type ('x,'a) result = ('x,'a option) Wrapped.result
417         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
418         (* code repetition, ugh *)
419         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
420         let bind u f = Wrapped.bind u (fun t -> match t with
421           | Some a -> f a
422           | None -> Wrapped.unit None)
423         let run u = Wrapped.run u
424         let run_exn u =
425           let w = Wrapped.bind u (fun t -> match t with
426             | Some a -> Wrapped.unit a
427             | None -> failwith "no value")
428           in Wrapped.run_exn w
429       end)
430       let zero () = Wrapped.unit None
431       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
432     end
433     include Trans
434     include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
435   end
436 end
437
438
439 module List_monad : sig
440   (* declare additional operation, while still hiding implementation of type m *)
441   type 'a result = 'a list
442   type 'a result_exn = 'a
443   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
444   include Monad.PLUS with type 'a m := 'a m
445   val permute : 'a m -> 'a m m
446   val select : 'a m -> ('a * 'a m) m
447   (* ListT transformer *)
448   module T : functor (Wrapped : Monad.S) -> sig
449     type 'a result = 'a list Wrapped.result
450     type 'a result_exn = 'a Wrapped.result_exn
451     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
452     include Monad.PLUS with type 'a m := 'a m
453     val elevate : 'a Wrapped.m -> 'a m
454     (* note that second argument is an 'a list, not the more abstract 'a m *)
455     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
456     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
457 (* TODO
458     val permute : 'a m -> 'a m m
459     val select : 'a m -> ('a * 'a m) m
460 *)
461   end
462   module T2 : functor (Wrapped : Monad.S2) -> sig
463     type ('x,'a) result = ('x,'a list) Wrapped.result
464     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
465     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
466     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
467     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
468     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
469   end
470 end = struct
471   module Base = struct
472    type 'a m = 'a list
473    type 'a result = 'a list
474    type 'a result_exn = 'a
475    let unit a = [a]
476    let bind u f = Util.concat_map f u
477    let run u = u
478    let run_exn u = match u with
479      | [] -> failwith "no values"
480      | [a] -> a
481      | many -> failwith "multiple values"
482    let zero () = []
483    let plus = Util.append
484   end
485   include Monad.Make(Base)
486   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
487   (* let either u v = plus u v *)
488   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
489   let rec insert a u =
490     plus (unit (a :: u)) (match u with
491         | [] -> zero ()
492         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
493     )
494   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
495   let rec permute u = match u with
496       | [] -> unit []
497       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
498   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
499   let rec select u = match u with
500     | [] -> zero ()
501     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
502   let base_plus = plus
503   module T(Wrapped : Monad.S) = struct
504     module Trans = struct
505       (* Wrapped.sequence ms  ===  
506            let plus1 u v =
507              Wrapped.bind u (fun x ->
508              Wrapped.bind v (fun xs ->
509              Wrapped.unit (x :: xs)))
510            in Util.fold_right plus1 ms (Wrapped.unit []) *)
511       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
512       let distribute f alist = Wrapped.sequence (Util.map f alist)
513       include Monad.MakeT(struct
514         module Wrapped = Wrapped
515         type 'a m = 'a list Wrapped.m
516         type 'a result = 'a list Wrapped.result
517         type 'a result_exn = 'a Wrapped.result_exn
518         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
519         let bind u f =
520           Wrapped.bind u (fun ts ->
521           Wrapped.bind (distribute f ts) (fun tts ->
522           Wrapped.unit (Util.concat tts)))
523         let run u = Wrapped.run u
524         let run_exn u =
525           let w = Wrapped.bind u (fun ts -> match ts with
526             | [] -> failwith "no values"
527             | [a] -> Wrapped.unit a
528             | many -> failwith "multiple values"
529           ) in Wrapped.run_exn w
530       end)
531       let zero () = Wrapped.unit []
532       let plus u v =
533         Wrapped.bind u (fun us ->
534         Wrapped.bind v (fun vs ->
535         Wrapped.unit (base_plus us vs)))
536     end
537     include Trans
538     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
539 (*
540     let permute : 'a m -> 'a m m
541     let select : 'a m -> ('a * 'a m) m
542 *)
543   end
544   module T2(Wrapped : Monad.S2) = struct
545     module Trans = struct
546       let distribute f alist = Wrapped.sequence (Util.map f alist)
547       include Monad.MakeT2(struct
548         module Wrapped = Wrapped
549         type ('x,'a) m = ('x,'a list) Wrapped.m
550         type ('x,'a) result = ('x,'a list) Wrapped.result
551         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
552         (* code repetition, ugh *)
553         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
554         let bind u f =
555           Wrapped.bind u (fun ts ->
556           Wrapped.bind (distribute f ts) (fun tts ->
557           Wrapped.unit (Util.concat tts)))
558         let run u = Wrapped.run u
559         let run_exn u =
560           let w = Wrapped.bind u (fun ts -> match ts with
561             | [] -> failwith "no values"
562             | [a] -> Wrapped.unit a
563             | many -> failwith "multiple values"
564           ) in Wrapped.run_exn w
565       end)
566       let zero () = Wrapped.unit []
567       let plus u v =
568         Wrapped.bind u (fun us ->
569         Wrapped.bind v (fun vs ->
570         Wrapped.unit (base_plus us vs)))
571     end
572     include Trans
573     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
574   end
575 end
576
577
578 (* must be parameterized on (struct type err = ... end) *)
579 module Error_monad(Err : sig
580   type err
581   exception Exc of err
582   (*
583   val zero : unit -> err
584   val plus : err -> err -> err
585   *)
586 end) : sig
587   (* declare additional operations, while still hiding implementation of type m *)
588   type err = Err.err
589   type 'a error = Error of err | Success of 'a
590   type 'a result = 'a error
591   type 'a result_exn = 'a
592   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
593   (* include Monad.PLUS with type 'a m := 'a m *)
594   val throw : err -> 'a m
595   val catch : 'a m -> (err -> 'a m) -> 'a m
596   (* ErrorT transformer *)
597   module T : functor (Wrapped : Monad.S) -> sig
598     type 'a result = 'a error Wrapped.result
599     type 'a result_exn = 'a Wrapped.result_exn
600     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
601     val elevate : 'a Wrapped.m -> 'a m
602     val throw : err -> 'a m
603     val catch : 'a m -> (err -> 'a m) -> 'a m
604   end
605   (* ErrorT transformer when wrapped monad has plus, zero *)
606   module TP : functor (Wrapped : Monad.P) -> sig
607     type 'a result = 'a Wrapped.result
608     type 'a result_exn = 'a Wrapped.result_exn
609     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
610     val elevate : 'a Wrapped.m -> 'a m
611     val throw : err -> 'a m
612     val catch : 'a m -> (err -> 'a m) -> 'a m
613     include Monad.PLUS with type 'a m := 'a m
614   end
615   module T2 : functor (Wrapped : Monad.S2) -> sig
616     type ('x,'a) result = ('x,'a error) Wrapped.result
617     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
618     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
619     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
620     val throw : err -> ('x,'a) m
621     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
622   end
623   module TP2 : functor (Wrapped : Monad.P2) -> sig
624     type ('x,'a) result = ('x,'a) Wrapped.result
625     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
626     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
627     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
628     val throw : err -> ('x,'a) m
629     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
630     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
631   end
632 end = struct
633   type err = Err.err
634   type 'a error = Error of err | Success of 'a
635   module Base = struct
636     type 'a m = 'a error
637     type 'a result = 'a error
638     type 'a result_exn = 'a
639     let unit a = Success a
640     let bind u f = match u with
641       | Success a -> f a
642       | Error e -> Error e (* input and output may be of different 'a types *)
643     let run u = u
644     let run_exn u = match u with
645       | Success a -> a
646       | Error e -> raise (Err.Exc e)
647     (*
648     let zero () = Error Err.zero
649     let plus u v = match (u, v) with
650       | Success _, _ -> u
651       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
652        * otherwise, plus (Error _) v = v *)
653       | Error _, _ when v = zero -> u
654       (* combine errors *)
655       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
656       | Error _, _ -> v
657     *)
658   end
659   include Monad.Make(Base)
660   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
661   let throw e = Error e
662   let catch u handler = match u with
663     | Success _ -> u
664     | Error e -> handler e
665   module T(Wrapped : Monad.S) = struct
666     module Trans = struct
667       module Wrapped = Wrapped
668       type 'a m = 'a error Wrapped.m
669       type 'a result = 'a error Wrapped.result
670       type 'a result_exn = 'a Wrapped.result_exn
671       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
672       let bind u f = Wrapped.bind u (fun t -> match t with
673         | Success a -> f a
674         | Error e -> Wrapped.unit (Error e))
675       let run u = Wrapped.run u
676       let run_exn u =
677         let w = Wrapped.bind u (fun t -> match t with
678           | Success a -> Wrapped.unit a
679           (* | _ -> Wrapped.fail () *)
680           | Error e -> raise (Err.Exc e))
681         in Wrapped.run_exn w
682     end
683     include Monad.MakeT(Trans)
684     let throw e = Wrapped.unit (Error e)
685     let catch u handler = Wrapped.bind u (fun t -> match t with
686       | Success _ -> Wrapped.unit t
687       | Error e -> handler e)
688   end
689   module TP(Wrapped : Monad.P) = struct
690     (* code repetition, ugh *)
691     module TransP = struct
692       include Monad.MakeT(struct
693         module Wrapped = Wrapped
694         type 'a m = 'a error Wrapped.m
695         type 'a result = 'a Wrapped.result
696         type 'a result_exn = 'a Wrapped.result_exn
697         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
698         let bind u f = Wrapped.bind u (fun t -> match t with
699           | Success a -> f a
700           | Error e -> Wrapped.unit (Error e))
701         let run u =
702           let w = Wrapped.bind u (fun t -> match t with
703             | Success a -> Wrapped.unit a
704             | Error e -> Wrapped.zero ())
705           in Wrapped.run w
706         let run_exn u =
707           let w = Wrapped.bind u (fun t -> match t with
708             | Success a -> Wrapped.unit a
709             (* | _ -> Wrapped.fail () *)
710             | Error e -> raise (Err.Exc e))
711           in Wrapped.run_exn w
712       end)
713       let throw e = Wrapped.unit (Error e)
714       let catch u handler = Wrapped.bind u (fun t -> match t with
715         | Success _ -> Wrapped.unit t
716         | Error e -> handler e)
717       let plus u v = Wrapped.plus u v
718       let zero () = elevate (Wrapped.zero ())
719     end
720     include TransP
721     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
722   end
723   module T2(Wrapped : Monad.S2) = struct
724     module Trans = struct
725       module Wrapped = Wrapped
726       type ('x,'a) m = ('x,'a error) Wrapped.m
727       type ('x,'a) result = ('x,'a error) Wrapped.result
728       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
729       (* code repetition, ugh *)
730       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
731       let bind u f = Wrapped.bind u (fun t -> match t with
732         | Success a -> f a
733         | Error e -> Wrapped.unit (Error e))
734       let run u = Wrapped.run u
735       let run_exn u =
736         let w = Wrapped.bind u (fun t -> match t with
737           | Success a -> Wrapped.unit a
738           | Error e -> raise (Err.Exc e))
739         in Wrapped.run_exn w
740     end
741     include Monad.MakeT2(Trans)
742     let throw e = Wrapped.unit (Error e)
743     let catch u handler = Wrapped.bind u (fun t -> match t with
744       | Success _ -> Wrapped.unit t
745       | Error e -> handler e)
746   end
747   module TP2(Wrapped : Monad.P2) = struct
748     (* code repetition, ugh *)
749     module TransP = struct
750       include Monad.MakeT2(struct
751         module Wrapped = Wrapped
752         type ('x,'a) m = ('x,'a error) Wrapped.m
753         type ('x,'a) result = ('x,'a) Wrapped.result
754         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
755         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
756         let bind u f = Wrapped.bind u (fun t -> match t with
757           | Success a -> f a
758           | Error e -> Wrapped.unit (Error e))
759         let run u =
760           let w = Wrapped.bind u (fun t -> match t with
761             | Success a -> Wrapped.unit a
762             | Error e -> Wrapped.zero ())
763           in Wrapped.run w
764         let run_exn u =
765           let w = Wrapped.bind u (fun t -> match t with
766             | Success a -> Wrapped.unit a
767             (* | _ -> Wrapped.fail () *)
768             | Error e -> raise (Err.Exc e))
769           in Wrapped.run_exn w
770       end)
771       let throw e = Wrapped.unit (Error e)
772       let catch u handler = Wrapped.bind u (fun t -> match t with
773         | Success _ -> Wrapped.unit t
774         | Error e -> handler e)
775       let plus u v = Wrapped.plus u v
776       let zero () = elevate (Wrapped.zero ())
777     end
778     include TransP
779     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
780   end
781 end
782
783 (* pre-define common instance of Error_monad *)
784 module Failure = Error_monad(struct
785   type err = string
786   exception Exc = Failure
787   (*
788   let zero = ""
789   let plus s1 s2 = s1 ^ "\n" ^ s2
790   *)
791 end)
792
793 (*
794 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
795 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
796 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
797 - : int LE.result = Failure.Error "bye"
798 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
799 Exception: Failure "bye".
800 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
801 Exception: Failure "bye".
802
803 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
804 - : int Failure.error * S.store = (Failure.Error "bye", 1)
805 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
806 - : (int * S.store) Failure.result = Failure.Error "bye"
807 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
808 Exception: Failure "bye".
809 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
810 Exception: Failure "bye".
811  *)
812
813
814 (* must be parameterized on (struct type env = ... end) *)
815 module Reader_monad(Env : sig type env end) : sig
816   (* declare additional operations, while still hiding implementation of type m *)
817   type env = Env.env
818   type 'a result = env -> 'a
819   type 'a result_exn = env -> 'a
820   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
821   val ask : env m
822   val asks : (env -> 'a) -> 'a m
823   val local : (env -> env) -> 'a m -> 'a m
824   (* ReaderT transformer *)
825   module T : functor (Wrapped : Monad.S) -> sig
826     type 'a result = env -> 'a Wrapped.result
827     type 'a result_exn = env -> 'a Wrapped.result_exn
828     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
829     val elevate : 'a Wrapped.m -> 'a m
830     val ask : env m
831     val asks : (env -> 'a) -> 'a m
832     val local : (env -> env) -> 'a m -> 'a m
833   end
834   (* ReaderT transformer when wrapped monad has plus, zero *)
835   module TP : functor (Wrapped : Monad.P) -> sig
836     include module type of T(Wrapped)
837     include Monad.PLUS with type 'a m := 'a m
838   end
839   module T2 : functor (Wrapped : Monad.S2) -> sig
840     type ('x,'a) result = env -> ('x,'a) Wrapped.result
841     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
842     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
843     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
844     val ask : ('x,env) m
845     val asks : (env -> 'a) -> ('x,'a) m
846     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
847   end
848   module TP2 : functor (Wrapped : Monad.P2) -> sig
849     include module type of T2(Wrapped)
850     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
851   end
852 end = struct
853   type env = Env.env
854   module Base = struct
855     type 'a m = env -> 'a
856     type 'a result = env -> 'a
857     type 'a result_exn = env -> 'a
858     let unit a = fun e -> a
859     let bind u f = fun e -> let a = u e in let u' = f a in u' e
860     let run u = fun e -> u e
861     let run_exn = run
862   end
863   include Monad.Make(Base)
864   let ask = fun e -> e
865   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
866   let local modifier u = fun e -> u (modifier e)
867   module T(Wrapped : Monad.S) = struct
868     module Trans = struct
869       module Wrapped = Wrapped
870       type 'a m = env -> 'a Wrapped.m
871       type 'a result = env -> 'a Wrapped.result
872       type 'a result_exn = env -> 'a Wrapped.result_exn
873       let elevate w = fun e -> w
874       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
875       let run u = fun e -> Wrapped.run (u e)
876       let run_exn u = fun e -> Wrapped.run_exn (u e)
877     end
878     include Monad.MakeT(Trans)
879     let ask = fun e -> Wrapped.unit e
880     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
881     let local modifier u = fun e -> u (modifier e)
882   end
883   module TP(Wrapped : Monad.P) = struct
884     module TransP = struct
885       include T(Wrapped)
886       let plus u v = fun s -> Wrapped.plus (u s) (v s)
887       let zero () = elevate (Wrapped.zero ())
888       let asks selector = ask >>= (fun e ->
889         try unit (selector e)
890         with Not_found -> fun e -> Wrapped.zero ())
891     end
892     include TransP
893     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
894   end
895   module T2(Wrapped : Monad.S2) = struct
896     module Trans = struct
897       module Wrapped = Wrapped
898       type ('x,'a) m = env -> ('x,'a) Wrapped.m
899       type ('x,'a) result = env -> ('x,'a) Wrapped.result
900       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
901       (* code repetition, ugh *)
902       let elevate w = fun e -> w
903       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
904       let run u = fun e -> Wrapped.run (u e)
905       let run_exn u = fun e -> Wrapped.run_exn (u e)
906     end
907     include Monad.MakeT2(Trans)
908     let ask = fun e -> Wrapped.unit e
909     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
910     let local modifier u = fun e -> u (modifier e)
911   end
912   module TP2(Wrapped : Monad.P2) = struct
913     module TransP = struct
914       include T2(Wrapped)
915       (* code repetition, ugh *)
916       let plus u v = fun s -> Wrapped.plus (u s) (v s)
917       let zero () = elevate (Wrapped.zero ())
918       let asks selector = ask >>= (fun e ->
919         try unit (selector e)
920         with Not_found -> fun e -> Wrapped.zero ())
921     end
922     include TransP
923     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
924   end
925 end
926
927
928 (* must be parameterized on (struct type store = ... end) *)
929 module State_monad(Store : sig type store end) : sig
930   (* declare additional operations, while still hiding implementation of type m *)
931   type store = Store.store
932   type 'a result = store -> 'a * store
933   type 'a result_exn = store -> 'a
934   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
935   val get : store m
936   val gets : (store -> 'a) -> 'a m
937   val put : store -> unit m
938   val puts : (store -> store) -> unit m
939   (* StateT transformer *)
940   module T : functor (Wrapped : Monad.S) -> sig
941     type 'a result = store -> ('a * store) Wrapped.result
942     type 'a result_exn = store -> 'a Wrapped.result_exn
943     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
944     val elevate : 'a Wrapped.m -> 'a m
945     val get : store m
946     val gets : (store -> 'a) -> 'a m
947     val put : store -> unit m
948     val puts : (store -> store) -> unit m
949   end
950   (* StateT transformer when wrapped monad has plus, zero *)
951   module TP : functor (Wrapped : Monad.P) -> sig
952     include module type of T(Wrapped)
953     include Monad.PLUS with type 'a m := 'a m
954   end
955   module T2 : functor (Wrapped : Monad.S2) -> sig
956     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
957     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
958     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
959     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
960     val get : ('x,store) m
961     val gets : (store -> 'a) -> ('x,'a) m
962     val put : store -> ('x,unit) m
963     val puts : (store -> store) -> ('x,unit) m
964   end
965   module TP2 : functor (Wrapped : Monad.P2) -> sig
966     include module type of T2(Wrapped)
967     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
968   end
969 end = struct
970   type store = Store.store
971   module Base = struct
972     type 'a m = store -> 'a * store
973     type 'a result = store -> 'a * store
974     type 'a result_exn = store -> 'a
975     let unit a = fun s -> (a, s)
976     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
977     let run u = fun s -> (u s)
978     let run_exn u = fun s -> fst (u s)
979   end
980   include Monad.Make(Base)
981   let get = fun s -> (s, s)
982   let gets viewer = fun s -> (viewer s, s) (* may fail *)
983   let put s = fun _ -> ((), s)
984   let puts modifier = fun s -> ((), modifier s)
985   module T(Wrapped : Monad.S) = struct
986     module Trans = struct
987       module Wrapped = Wrapped
988       type 'a m = store -> ('a * store) Wrapped.m
989       type 'a result = store -> ('a * store) Wrapped.result
990       type 'a result_exn = store -> 'a Wrapped.result_exn
991       let elevate w = fun s ->
992         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
993       let bind u f = fun s ->
994         Wrapped.bind (u s) (fun (a, s') -> f a s')
995       let run u = fun s -> Wrapped.run (u s)
996       let run_exn u = fun s ->
997         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
998         in Wrapped.run_exn w
999     end
1000     include Monad.MakeT(Trans)
1001     let get = fun s -> Wrapped.unit (s, s)
1002     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
1003     let put s = fun _ -> Wrapped.unit ((), s)
1004     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
1005   end
1006   module TP(Wrapped : Monad.P) = struct
1007     module TransP = struct
1008       include T(Wrapped)
1009       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1010       let zero () = elevate (Wrapped.zero ())
1011     end
1012     let gets viewer = fun s ->
1013       try Wrapped.unit (viewer s, s)
1014       with Not_found -> Wrapped.zero ()
1015     include TransP
1016     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1017   end
1018   module T2(Wrapped : Monad.S2) = struct
1019     module Trans = struct
1020       module Wrapped = Wrapped
1021       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
1022       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
1023       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
1024       (* code repetition, ugh *)
1025       let elevate w = fun s ->
1026         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1027       let bind u f = fun s ->
1028         Wrapped.bind (u s) (fun (a, s') -> f a s')
1029       let run u = fun s -> Wrapped.run (u s)
1030       let run_exn u = fun s ->
1031         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
1032         in Wrapped.run_exn w
1033     end
1034     include Monad.MakeT2(Trans)
1035     let get = fun s -> Wrapped.unit (s, s)
1036     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
1037     let put s = fun _ -> Wrapped.unit ((), s)
1038     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
1039   end
1040   module TP2(Wrapped : Monad.P2) = struct
1041     module TransP = struct
1042       include T2(Wrapped)
1043       (* code repetition, ugh *)
1044       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1045       let zero () = elevate (Wrapped.zero ())
1046     end
1047     let gets viewer = fun s ->
1048       try Wrapped.unit (viewer s, s)
1049       with Not_found -> Wrapped.zero ()
1050     include TransP
1051     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1052   end
1053 end
1054
1055 (* State monad with different interface (structured store) *)
1056 module Ref_monad(V : sig
1057   type value
1058 end) : sig
1059   type ref
1060   type value = V.value
1061   type 'a result = 'a
1062   type 'a result_exn = 'a
1063   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1064   val newref : value -> ref m
1065   val deref : ref -> value m
1066   val change : ref -> value -> unit m
1067   (* RefT transformer *)
1068   module T : functor (Wrapped : Monad.S) -> sig
1069     type 'a result = 'a Wrapped.result
1070     type 'a result_exn = 'a Wrapped.result_exn
1071     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1072     val elevate : 'a Wrapped.m -> 'a m
1073     val newref : value -> ref m
1074     val deref : ref -> value m
1075     val change : ref -> value -> unit m
1076   end
1077   (* RefT transformer when wrapped monad has plus, zero *)
1078   module TP : functor (Wrapped : Monad.P) -> sig
1079     include module type of T(Wrapped)
1080     include Monad.PLUS with type 'a m := 'a m
1081   end
1082   module T2 : functor (Wrapped : Monad.S2) -> sig
1083     type ('x,'a) result = ('x,'a) Wrapped.result
1084     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1085     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1086     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1087     val newref : value -> ('x,ref) m
1088     val deref : ref -> ('x,value) m
1089     val change : ref -> value -> ('x,unit) m
1090   end
1091   module TP2 : functor (Wrapped : Monad.P2) -> sig
1092     include module type of T2(Wrapped)
1093     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1094   end
1095 end = struct
1096   type ref = int
1097   type value = V.value
1098   module D = Map.Make(struct type t = ref let compare = compare end)
1099   type dict = { next: ref; tree : value D.t }
1100   let empty = { next = 0; tree = D.empty }
1101   let alloc (value : value) (d : dict) =
1102     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
1103   let read (key : ref) (d : dict) =
1104     D.find key d.tree
1105   let write (key : ref) (value : value) (d : dict) =
1106     { next = d.next; tree = D.add key value d.tree }
1107   module Base = struct
1108     type 'a m = dict -> 'a * dict
1109     type 'a result = 'a
1110     type 'a result_exn = 'a
1111     let unit a = fun s -> (a, s)
1112     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
1113     let run u = fst (u empty)
1114     let run_exn = run
1115   end
1116   include Monad.Make(Base)
1117   let newref value = fun s -> alloc value s
1118   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
1119   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
1120   module T(Wrapped : Monad.S) = struct
1121     module Trans = struct
1122       module Wrapped = Wrapped
1123       type 'a m = dict -> ('a * dict) Wrapped.m
1124       type 'a result = 'a Wrapped.result
1125       type 'a result_exn = 'a Wrapped.result_exn
1126       let elevate w = fun s ->
1127         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1128       let bind u f = fun s ->
1129         Wrapped.bind (u s) (fun (a, s') -> f a s')
1130       let run u =
1131         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1132         in Wrapped.run w
1133       let run_exn u =
1134         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1135         in Wrapped.run_exn w
1136     end
1137     include Monad.MakeT(Trans)
1138     let newref value = fun s -> Wrapped.unit (alloc value s)
1139     let deref key = fun s -> Wrapped.unit (read key s, s)
1140     let change key value = fun s -> Wrapped.unit ((), write key value s)
1141   end
1142   module TP(Wrapped : Monad.P) = struct
1143     module TransP = struct
1144       include T(Wrapped)
1145       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1146       let zero () = elevate (Wrapped.zero ())
1147     end
1148     include TransP
1149     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
1150   end
1151   module T2(Wrapped : Monad.S2) = struct
1152     module Trans = struct
1153       module Wrapped = Wrapped
1154       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
1155       type ('x,'a) result = ('x,'a) Wrapped.result
1156       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
1157       (* code repetition, ugh *)
1158       let elevate w = fun s ->
1159         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
1160       let bind u f = fun s ->
1161         Wrapped.bind (u s) (fun (a, s') -> f a s')
1162       let run u =
1163         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1164         in Wrapped.run w
1165       let run_exn u =
1166         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
1167         in Wrapped.run_exn w
1168     end
1169     include Monad.MakeT2(Trans)
1170     let newref value = fun s -> Wrapped.unit (alloc value s)
1171     let deref key = fun s -> Wrapped.unit (read key s, s)
1172     let change key value = fun s -> Wrapped.unit ((), write key value s)
1173   end
1174   module TP2(Wrapped : Monad.P2) = struct
1175     module TransP = struct
1176       include T2(Wrapped)
1177       (* code repetition, ugh *)
1178       let plus u v = fun s -> Wrapped.plus (u s) (v s)
1179       let zero () = elevate (Wrapped.zero ())
1180     end
1181     include TransP
1182     include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1183   end
1184 end
1185
1186
1187 (* must be parameterized on (struct type log = ... end) *)
1188 module Writer_monad(Log : sig
1189   type log
1190   val zero : log
1191   val plus : log -> log -> log
1192 end) : sig
1193   (* declare additional operations, while still hiding implementation of type m *)
1194   type log = Log.log
1195   type 'a result = 'a * log
1196   type 'a result_exn = 'a * log
1197   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1198   val tell : log -> unit m
1199   val listen : 'a m -> ('a * log) m
1200   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
1201   (* val pass : ('a * (log -> log)) m -> 'a m *)
1202   val censor : (log -> log) -> 'a m -> 'a m
1203 end = struct
1204   type log = Log.log
1205   module Base = struct
1206     type 'a m = 'a * log
1207     type 'a result = 'a * log
1208     type 'a result_exn = 'a * log
1209     let unit a = (a, Log.zero)
1210     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
1211     let run u = u
1212     let run_exn = run
1213   end
1214   include Monad.Make(Base)
1215   let tell entries = ((), entries) (* add entries to log *)
1216   let listen (a, w) = ((a, w), w)
1217   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
1218   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
1219   let censor f u = pass (u >>= fun a -> unit (a, f))
1220 end
1221
1222 (* pre-define simple Writer *)
1223 module Writer1 = Writer_monad(struct
1224   type log = string
1225   let zero = ""
1226   let plus s1 s2 = s1 ^ "\n" ^ s2
1227 end)
1228
1229 (* slightly more efficient Writer *)
1230 module Writer2 = struct
1231   include Writer_monad(struct
1232     type log = string list
1233     let zero = []
1234     let plus w w' = Util.append w' w
1235   end)
1236   let tell_string s = tell [s]
1237   let tell entries = tell (Util.reverse entries)
1238   let run u = let (a, w) = run u in (a, Util.reverse w)
1239   let run_exn = run
1240 end
1241
1242
1243 module IO_monad : sig
1244   (* declare additional operation, while still hiding implementation of type m *)
1245   type 'a result = 'a
1246   type 'a result_exn = 'a
1247   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1248   val printf : ('a, unit, string, unit m) format4 -> 'a
1249   val print_string : string -> unit m
1250   val print_int : int -> unit m
1251   val print_hex : int -> unit m
1252   val print_bool : bool -> unit m
1253 end = struct
1254   module Base = struct
1255     type 'a m = { run : unit -> unit; value : 'a }
1256     type 'a result = 'a
1257     type 'a result_exn = 'a
1258     let unit a = { run = (fun () -> ()); value = a }
1259     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
1260      let fres = f a.value in
1261        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
1262     let run a = let () = a.run () in a.value
1263     let run_exn = run
1264   end
1265   include Monad.Make(Base)
1266   let printf fmt =
1267     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
1268   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
1269   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
1270   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
1271   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
1272 end
1273
1274 (*
1275 module Continuation_monad : sig
1276   (* expose only the implementation of type `('r,'a) result` *)
1277   type 'a m
1278   type 'a result = 'a m
1279   type 'a result_exn = 'a m
1280   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
1281   (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
1282   (* misses that the answer types of all the cont's must be the same *)
1283   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
1284   (* val reset : ('a,'a) m -> ('r,'a) m *)
1285   val reset : 'a m -> 'a m
1286   (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
1287   (* misses that the answer types of second and third continuations must be b *)
1288   val shift : (('a -> 'b m) -> 'b m) -> 'a m
1289   (* overwrite the run declaration in S, because I can't declare 'a result =
1290    * this polymorphic type (complains that 'r is unbound *)
1291   val runk : 'a m -> ('a -> 'r) -> 'r
1292   val run0 : 'a m -> 'a
1293 end = struct
1294   let id = fun i -> i
1295   module Base = struct
1296     (* 'r is result type of whole computation *)
1297     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
1298     type 'a result = 'a m
1299     type 'a result_exn = 'a m
1300     let unit a =
1301       let cont : 'r. ('a -> 'r) -> 'r =
1302         fun k -> k a
1303       in { cont }
1304     let bind u f =
1305       let cont : 'r. ('a -> 'r) -> 'r =
1306         fun k -> u.cont (fun a -> (f a).cont k)
1307       in { cont }
1308     let run (u : 'a m) : 'a result = u
1309     let run_exn (u : 'a m) : 'a result_exn = u
1310     let callcc f =
1311       let cont : 'r. ('a -> 'r) -> 'r =
1312           (* Can't figure out how to make the type polymorphic enough
1313            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
1314            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
1315            * with Obj.magic... which tells OCaml's type checker to
1316            * relax, the supplied value has whatever type the context
1317            * needs it to have. *)
1318           fun k ->
1319           let usek a = { cont = Obj.magic (fun _ -> k a) }
1320           in (f usek).cont k
1321       in { cont }
1322     let reset u = unit (u.cont id)
1323     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
1324       let cont = fun k ->
1325         (f (fun a -> unit (k a))).cont id
1326       in { cont = Obj.magic cont }
1327     let runk u k = (u.cont : ('a -> 'r) -> 'r) k
1328     let run0 u = runk u id
1329   end
1330   include Monad.Make(Base)
1331   let callcc = Base.callcc
1332   let reset = Base.reset
1333   let shift = Base.shift
1334   let runk = Base.runk
1335   let run0 = Base.run0
1336 end
1337  *)
1338
1339 (* This two-type parameter version works without Obj.magic *)
1340 module Continuation_monad : sig
1341   (* expose only the implementation of type `('r,'a) result` *)
1342   type ('r,'a) m
1343   type ('r,'a) result = ('r,'a) m
1344   type ('r,'a) result_exn = ('a -> 'r) -> 'r
1345   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
1346   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
1347   val reset : ('a,'a) m -> ('r,'a) m
1348   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
1349   (* val abort : ('a,'a) m -> ('a,'b) m *)
1350   val abort : 'a -> ('a,'b) m
1351   val run0 : ('a,'a) m -> 'a
1352 end = struct
1353   let id = fun i -> i
1354   module Base = struct
1355     (* 'r is result type of whole computation *)
1356     type ('r,'a) m = ('a -> 'r) -> 'r
1357     type ('r,'a) result = ('a -> 'r) -> 'r
1358     type ('r,'a) result_exn = ('r,'a) result
1359     let unit a = (fun k -> k a)
1360     let bind u f = (fun k -> (u) (fun a -> (f a) k))
1361     let run u k = (u) k
1362     let run_exn = run
1363   end
1364   include Monad.Make2(Base)
1365   let callcc f = (fun k ->
1366     let usek a = (fun _ -> k a)
1367     in (f usek) k)
1368   (*
1369   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
1370   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
1371   let callcc f = fun k -> f k k
1372   let throw k a = fun _ -> k a
1373   *)
1374
1375   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
1376    *
1377    *  reset :: (Monad m) => ContT a m a -> ContT r m a
1378    *  reset e = ContT $ \k -> runContT e return >>= k
1379    *
1380    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
1381    *  shift e = ContT $ \k ->
1382    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
1383   let reset u = unit ((u) id)
1384   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
1385   (* let abort a = shift (fun _ -> a) *)
1386   let abort a = shift (fun _ -> unit a)
1387   let run0 (u : ('a,'a) m) = (u) id
1388 end
1389
1390
1391 (*
1392  * Scheme:
1393  * (define (example n)
1394  *    (let ([u (let/cc k ; type int -> int pair
1395  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
1396  *                 (+ 1 (car v))))]) ; int
1397  *      (cons u 0))) ; int pair
1398  * ; (example 10) ~~> '(111 . 0)
1399  * ; (example -10) ~~> '(0 . 0)
1400  *
1401  * OCaml monads:
1402  * let example n : (int * int) =
1403  *   Continuation_monad.(let u = callcc (fun k ->
1404  *       (if n < 0 then k 0 else unit [n + 100])
1405  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
1406  *       >>= fun [x] -> unit (x + 1)
1407  *   )
1408  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1409  *   >>= fun x -> unit (x, 0)
1410  *   in run u)
1411  *
1412  *
1413  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1414  * let example1 () : int =
1415  *   Continuation_monad.(let v = reset (
1416  *       let u = shift (fun k -> unit (10 + 1))
1417  *       in u >>= fun x -> unit (100 + x)
1418  *     ) in let w = v >>= fun x -> unit (1000 + x)
1419  *     in run w)
1420  *
1421  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1422  * let example2 () =
1423  *   Continuation_monad.(let v = reset (
1424  *       let u = shift (fun k -> k (10 :: [1]))
1425  *       in u >>= fun x -> unit (100 :: x)
1426  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1427  *     in run w)
1428  *
1429  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1430  * let example3 () =
1431  *   Continuation_monad.(let v = reset (
1432  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1433  *       in u >>= fun x -> unit (100 :: x)
1434  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1435  *     in run w)
1436  *
1437  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1438  * (* not sure if this example can be typed without a sum-type *)
1439  *
1440  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1441  * let example5 () : int =
1442  *   Continuation_monad.(let v = reset (
1443  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1444  *       in u >>= fun x -> unit (10 + x)
1445  *     ) in let w = v >>= fun x -> unit (100 + x)
1446  *     in run w)
1447  *
1448  *)
1449
1450
1451 module Leaf_monad : sig
1452   (* We implement the type as `'a tree option` because it has a natural`plus`,
1453    * and the rest of the library expects that `plus` and `zero` will come together. *)
1454   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1455   type 'a result = 'a tree option
1456   type 'a result_exn = 'a tree
1457   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1458   include Monad.PLUS with type 'a m := 'a m
1459   (* LeafT transformer *)
1460   module T : functor (Wrapped : Monad.S) -> sig
1461     type 'a result = 'a tree option Wrapped.result
1462     type 'a result_exn = 'a tree Wrapped.result_exn
1463     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1464     include Monad.PLUS with type 'a m := 'a m
1465     val elevate : 'a Wrapped.m -> 'a m
1466     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1467     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1468     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1469   end
1470   module T2 : functor (Wrapped : Monad.S2) -> sig
1471     type ('x,'a) result = ('x,'a tree option) Wrapped.result
1472     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1473     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
1474     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
1475     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
1476     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
1477   end
1478 end = struct
1479   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1480   (* uses supplied plus and zero to copy t to its image under f *)
1481   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1482       | None -> zero ()
1483       | Some ts -> let rec loop ts = (match ts with
1484                      | Leaf a -> f a
1485                      | Node (l, r) ->
1486                          (* recursive application of f may delete a branch *)
1487                          plus (loop l) (loop r)
1488                    ) in loop ts
1489   module Base = struct
1490     type 'a m = 'a tree option
1491     type 'a result = 'a tree option
1492     type 'a result_exn = 'a tree
1493     let unit a = Some (Leaf a)
1494     let zero () = None
1495     let plus u v = match (u, v) with
1496       | None, _ -> v
1497       | _, None -> u
1498       | Some us, Some vs -> Some (Node (us, vs))
1499     let bind u f = mapT f u zero plus
1500     let run u = u
1501     let run_exn u = match u with
1502       | None -> failwith "no values"
1503       (*
1504       | Some (Leaf a) -> a
1505       | many -> failwith "multiple values"
1506       *)
1507       | Some us -> us
1508   end
1509   include Monad.Make(Base)
1510   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1511   let base_plus = plus
1512   let base_lift = lift
1513   module T(Wrapped : Monad.S) = struct
1514     module Trans = struct
1515       let zero () = Wrapped.unit None
1516       let plus u v =
1517         Wrapped.bind u (fun us ->
1518         Wrapped.bind v (fun vs ->
1519         Wrapped.unit (base_plus us vs)))
1520       include Monad.MakeT(struct
1521         module Wrapped = Wrapped
1522         type 'a m = 'a tree option Wrapped.m
1523         type 'a result = 'a tree option Wrapped.result
1524         type 'a result_exn = 'a tree Wrapped.result_exn
1525         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1526         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1527         let run u = Wrapped.run u
1528         let run_exn u =
1529             let w = Wrapped.bind u (fun t -> match t with
1530               | None -> failwith "no values"
1531               | Some ts -> Wrapped.unit ts)
1532             in Wrapped.run_exn w
1533       end)
1534     end
1535     include Trans
1536     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1537     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1538     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1539   end
1540   module T2(Wrapped : Monad.S2) = struct
1541     module Trans = struct
1542       let zero () = Wrapped.unit None
1543       let plus u v =
1544         Wrapped.bind u (fun us ->
1545         Wrapped.bind v (fun vs ->
1546         Wrapped.unit (base_plus us vs)))
1547       include Monad.MakeT2(struct
1548         module Wrapped = Wrapped
1549         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1550         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1551         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1552         (* code repetition, ugh *)
1553         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1554         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1555         let run u = Wrapped.run u
1556         let run_exn u =
1557             let w = Wrapped.bind u (fun t -> match t with
1558               | None -> failwith "no values"
1559               | Some ts -> Wrapped.unit ts)
1560             in Wrapped.run_exn w
1561       end)
1562     end
1563     include Trans
1564     include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
1565     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1566   end
1567 end
1568
1569
1570 module L = List_monad;;
1571 module R = Reader_monad(struct type env = int -> int end);;
1572 module S = State_monad(struct type store = int end);;
1573 module T = Leaf_monad;;
1574 module LR = L.T(R);;
1575 module LS = L.T(S);;
1576 module TL = T.T(L);;
1577 module TR = T.T(R);;
1578 module TS = T.T(S);;
1579 module C = Continuation_monad
1580 module TC = T.T2(C);;
1581
1582
1583 print_endline "=== test Leaf(...).distribute ==================";;
1584
1585 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1586
1587 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1588 TS.run ts 0;;
1589 (*
1590 - : int T.tree option * S.store =
1591 (Some
1592   (T.Node
1593     (T.Node (T.Leaf 2, T.Leaf 3),
1594      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1595  5)
1596 *)
1597
1598 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1599 TS.run_exn ts2 0;;
1600 (*
1601 - : (int * S.store) T.tree option * S.store =
1602 (Some
1603   (T.Node
1604     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1605      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1606  5)
1607 *)
1608
1609 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1610 TR.run_exn tr (fun i -> i+i);;
1611 (*
1612 - : int T.tree option =
1613 Some
1614  (T.Node
1615    (T.Node (T.Leaf 4, T.Leaf 6),
1616     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1617 *)
1618
1619 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1620 TL.run_exn tl;;
1621 (*
1622 - : (int * int) TL.result =
1623 [Some
1624   (T.Node
1625     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1626      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1627 *)
1628
1629 let l2 = [1;2;3;4;5];;
1630 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1631
1632 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1633 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1634
1635 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1636 (*
1637 int T.tree option =
1638 Some
1639  (T.Node
1640    (T.Node (T.Leaf 10, T.Leaf 11),
1641     T.Node
1642      (T.Node
1643        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1644         T.Node (T.Leaf 40, T.Leaf 41)),
1645       T.Node (T.Leaf 50, T.Leaf 51))))
1646  *)
1647
1648 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1649 (*
1650 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1651 *)
1652
1653 print_endline "=== test Leaf(Continuation).distribute ==================";;
1654
1655 let id : 'z. 'z -> 'z = fun x -> x
1656
1657 let example n : (int * int) =
1658   Continuation_monad.(let u = callcc (fun k ->
1659       (if n < 0 then k 0 else unit [n + 100])
1660       (* all of the following is skipped by k 0; the end type int is k's input type *)
1661       >>= fun [x] -> unit (x + 1)
1662   )
1663   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1664   >>= fun x -> unit (x, 0)
1665   in run0 u)
1666
1667
1668 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1669 let example1 () : int =
1670   Continuation_monad.(let v = reset (
1671       let u = shift (fun k -> unit (10 + 1))
1672       in u >>= fun x -> unit (100 + x)
1673     ) in let w = v >>= fun x -> unit (1000 + x)
1674     in run0 w)
1675
1676 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1677 let example2 () =
1678   Continuation_monad.(let v = reset (
1679       let u = shift (fun k -> k (10 :: [1]))
1680       in u >>= fun x -> unit (100 :: x)
1681     ) in let w = v >>= fun x -> unit (1000 :: x)
1682     in run0 w)
1683
1684 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1685 let example3 () =
1686   Continuation_monad.(let v = reset (
1687       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1688       in u >>= fun x -> unit (100 :: x)
1689     ) in let w = v >>= fun x -> unit (1000 :: x)
1690     in run0 w)
1691
1692 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1693 (* not sure if this example can be typed without a sum-type *)
1694
1695 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1696 let example5 () : int =
1697   Continuation_monad.(let v = reset (
1698       let u = shift (fun k -> k 1 >>= k)
1699       in u >>= fun x -> unit (10 + x)
1700     ) in let w = v >>= fun x -> unit (100 + x)
1701     in run0 w)
1702
1703 ;;
1704
1705 print_endline "=== test bare Continuation ============";;
1706
1707 (1011, 1111, 1111, 121);;
1708 (example1(), example2(), example3(), example5());;
1709 ((111,0), (0,0));;
1710 (example ~+10, example ~-10);;
1711
1712 let testc df ic =
1713     C.run_exn TC.(run (distribute df t1)) ic;;
1714
1715
1716 (*
1717 (* do nothing *)
1718 let initial_continuation = fun t -> t in
1719 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1720 *)
1721 testc (C.unit) id;;
1722
1723 (*
1724 (* count leaves, using continuation *)
1725 let initial_continuation = fun t -> 0 in
1726 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1727 *)
1728
1729 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1730
1731 (*
1732 (* convert tree to list of leaves *)
1733 let initial_continuation = fun t -> [] in
1734 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1735 *)
1736
1737 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1738
1739 (*
1740 (* square each leaf using continuation *)
1741 let initial_continuation = fun t -> t in
1742 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1743 *)
1744
1745 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1746
1747
1748 (*
1749 (* replace leaves with list, using continuation *)
1750 let initial_continuation = fun t -> t in
1751 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1752 *)
1753
1754 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1755
1756 print_endline "=== pa_monad's Continuation Tests ============";;
1757
1758 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1759 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1760 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1761 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1762 (5, 27 = C.(run0 (
1763               let c = reset(abort 5 >>= fun y -> unit (y+6))
1764               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1765
1766 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1767
1768 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1769
1770 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1771
1772
1773 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1774