tweak monads-lib
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47 exception Undefined
48
49 (* Some library functions used below. *)
50 module Util = struct
51   let fold_right = List.fold_right
52   let map = List.map
53   let append = List.append
54   let reverse = List.rev
55   let concat = List.concat
56   let concat_map f lst = List.concat (List.map f lst)
57   (* let zip = List.combine *)
58   let unzip = List.split
59   let zip_with = List.map2
60   let replicate len fill =
61     let rec loop n accu =
62       if n == 0 then accu else loop (pred n) (fill :: accu)
63     in loop len []
64   (* Dirty hack to be a default polymorphic zero.
65    * To implement this cleanly, monads without a natural zero
66    * should always wrap themselves in an option layer (see Leaf_monad). *)
67   let undef = Obj.magic (fun () -> raise Undefined)
68 end
69
70
71
72 (*
73  * This module contains factories that extend a base set of
74  * monadic definitions with a larger family of standard derived values.
75  *)
76
77 module Monad = struct
78   (*
79    * Signature extenders:
80    *   Make :: BASE -> S
81    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
82    *)
83
84
85   (* type of base definitions *)
86   module type BASE = sig
87     (* We make all monadic types doubly-parameterized so that they
88      * can layer nicely with Continuation, which needs the second
89      * type parameter. *)
90     type ('x,'a) m
91     type ('x,'a) result
92     type ('x,'a) result_exn
93     val unit : 'a -> ('x,'a) m
94     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
95     val run : ('x,'a) m -> ('x,'a) result
96     (* run_exn tries to provide a more ground-level result, but may fail *)
97     val run_exn : ('x,'a) m -> ('x,'a) result_exn
98     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
99      *     zero >>= f   ===  zero
100      *     plus zero u  ===  u
101      *     plus u zero  ===  u
102      * Additionally, they will obey one of the following laws:
103      *     (Catch)   plus (unit a) v  ===  unit a
104      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
105      * When no natural zero is available, use `let zero () = Util.undef`.
106      * The Make functor automatically detects for zero >>= ..., and 
107      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
108      *)
109     val zero : unit -> ('x,'a) m
110     (* zero has to be thunked to ensure results are always poly enough *)
111     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
112   end
113   module type S = sig
114     include BASE
115     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
116     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
117     val join : ('x,('x,'a) m) m -> ('x,'a) m
118     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
119     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
120     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
121     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
122     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
123     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
124     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
125     val sequence : ('x,'a) m list -> ('x,'a list) m
126     val sequence_ : ('x,'a) m list -> ('x,unit) m
127     val guard : bool -> ('x,unit) m
128     val sum : ('x,'a) m list -> ('x,'a) m
129   end
130
131   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
132     include B
133     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
134       if u == Util.undef then Util.undef
135       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
136     let plus u v =
137       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
138     let run u =
139       if u == Util.undef then raise Undefined else B.run u
140     let run_exn u =
141       if u == Util.undef then raise Undefined else B.run_exn u
142     let (>>=) = bind
143     let (>>) u v = u >>= fun _ -> v
144     let lift f u = u >>= fun a -> unit (f a)
145     (* lift is called listM, fmap, and <$> in Haskell *)
146     let join uu = uu >>= fun u -> u
147     (* u >>= f === join (lift f u) *)
148     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
149     (* [f] <*> [x1,x2] = [f x1,f x2] *)
150     (* let apply u v = u >>= fun f -> lift f v *)
151     (* let apply = lift2 id *)
152     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
153     (* let lift f u === apply (unit f) u *)
154     (* let lift2 f u v = apply (lift f u) v *)
155     let (>=>) f g = fun a -> f a >>= g
156     let do_when test u = if test then u else unit ()
157     let do_unless test u = if test then unit () else u
158     (* not in tail position, will Stack overflow *)
159     let forever uthunk =
160         let rec loop () = uthunk () >>= fun _ -> loop ()
161         in loop ()
162     let sequence ms =
163       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
164         Util.fold_right op ms (unit [])
165     let sequence_ ms =
166       Util.fold_right (>>) ms (unit ())
167
168     (* Haskell defines these other operations combining lists and monads.
169      * We don't, but notice that M.mapM == ListT(M).distribute
170      * There's also a parallel TreeT(M).distribute *)
171     (*
172     let mapM f alist = sequence (Util.map f alist)
173     let mapM_ f alist = sequence_ (Util.map f alist)
174     let rec filterM f lst = match lst with
175       | [] -> unit []
176       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
177     let forM alist f = mapM f alist
178     let forM_ alist f = mapM_ f alist
179     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
180     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
181     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
182     let rec foldM f z lst = match lst with
183       | [] -> unit z
184       | x::xs -> f z x >>= fun z' -> foldM f z' xs
185     let foldM_ f z xs = foldM f z xs >> unit ()
186     let replicateM n x = sequence (Util.replicate n x)
187     let replicateM_ n x = sequence_ (Util.replicate n x)
188     *)
189     let guard test = if test then B.unit () else zero ()
190     let sum ms = Util.fold_right plus ms (zero ())
191   end
192
193   (* Signatures for MonadT *)
194   module type BASET = sig
195     module Wrapped : S
196     type ('x,'a) m
197     type ('x,'a) result
198     type ('x,'a) result_exn
199     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
200     val run : ('x,'a) m -> ('x,'a) result
201     val run_exn : ('x,'a) m -> ('x,'a) result_exn
202     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
203     (* lift/elevate laws:
204      *     elevate (W.unit a) == unit a
205      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
206      *)
207     val zero : unit -> ('x,'a) m
208     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
209   end
210   module MakeT(T : BASET) = struct
211     include Make(struct
212         include T
213         let unit a = elevate (Wrapped.unit a)
214     end)
215     let elevate = T.elevate
216   end
217
218 end
219
220
221
222
223
224 module Identity_monad : sig
225   (* expose only the implementation of type `'a result` *)
226   type ('x,'a) result = 'a
227   type ('x,'a) result_exn = 'a
228   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
229 end = struct
230   module Base = struct
231     type ('x,'a) m = 'a
232     type ('x,'a) result = 'a
233     type ('x,'a) result_exn = 'a
234     let unit a = a
235     let bind a f = f a
236     let run a = a
237     let run_exn a = a
238     let zero () = Util.undef
239     let plus u v = u
240   end
241   include Monad.Make(Base)
242 end
243
244
245 module Maybe_monad : sig
246   (* expose only the implementation of type `'a result` *)
247   type ('x,'a) result = 'a option
248   type ('x,'a) result_exn = 'a
249   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
250   (* MaybeT transformer *)
251   module T : functor (Wrapped : Monad.S) -> sig
252     type ('x,'a) result = ('x,'a option) Wrapped.result
253     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
254     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
255     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
256   end
257 end = struct
258   module Base = struct
259     type ('x,'a) m = 'a option
260     type ('x,'a) result = 'a option
261     type ('x,'a) result_exn = 'a
262     let unit a = Some a
263     let bind u f = match u with Some a -> f a | None -> None
264     let run u = u
265     let run_exn u = match u with
266       | Some a -> a
267       | None -> failwith "no value"
268     let zero () = None
269     (* satisfies Catch *)
270     let plus u v = match u with None -> v | _ -> u
271   end
272   include Monad.Make(Base)
273   module T(Wrapped : Monad.S) = struct
274     module BaseT = struct
275       include Monad.MakeT(struct
276         module Wrapped = Wrapped
277         type ('x,'a) m = ('x,'a option) Wrapped.m
278         type ('x,'a) result = ('x,'a option) Wrapped.result
279         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
280         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
281         let bind u f = Wrapped.bind u (fun t -> match t with
282           | Some a -> f a
283           | None -> Wrapped.unit None)
284         let run u = Wrapped.run u
285         let run_exn u =
286           let w = Wrapped.bind u (fun t -> match t with
287             | Some a -> Wrapped.unit a
288             | None -> Wrapped.zero ()
289           ) in Wrapped.run_exn w
290         let zero () = Wrapped.unit None
291         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
292       end)
293     end
294     include BaseT
295   end
296 end
297
298
299 module List_monad : sig
300   (* declare additional operation, while still hiding implementation of type m *)
301   type ('x,'a) result = 'a list
302   type ('x,'a) result_exn = 'a
303   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
304   val permute : ('x,'a) m -> ('x,('x,'a) m) m
305   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
306   (* ListT transformer *)
307   module T : functor (Wrapped : Monad.S) -> sig
308     type ('x,'a) result = ('x,'a list) Wrapped.result
309     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
310     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
311     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
312     (* note that second argument is an 'a list, not the more abstract 'a m *)
313     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
314     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
315 (* TODO
316     val permute : 'a m -> 'a m m
317     val select : 'a m -> ('a * 'a m) m
318 *)
319   end
320 end = struct
321   module Base = struct
322    type ('x,'a) m = 'a list
323    type ('x,'a) result = 'a list
324    type ('x,'a) result_exn = 'a
325    let unit a = [a]
326    let bind u f = Util.concat_map f u
327    let run u = u
328    let run_exn u = match u with
329      | [] -> failwith "no values"
330      | [a] -> a
331      | many -> failwith "multiple values"
332    let zero () = []
333    (* satisfies Distrib *)
334    let plus = Util.append
335   end
336   include Monad.Make(Base)
337   (* let either u v = plus u v *)
338   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
339   let rec insert a u =
340     plus (unit (a :: u)) (match u with
341         | [] -> zero ()
342         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
343     )
344   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
345   let rec permute u = match u with
346       | [] -> unit []
347       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
348   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
349   let rec select u = match u with
350     | [] -> zero ()
351     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
352   module T(Wrapped : Monad.S) = struct
353     (* Wrapped.sequence ms  ===  
354          let plus1 u v =
355            Wrapped.bind u (fun x ->
356            Wrapped.bind v (fun xs ->
357            Wrapped.unit (x :: xs)))
358          in Util.fold_right plus1 ms (Wrapped.unit []) *)
359     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
360     let distribute f alist = Wrapped.sequence (Util.map f alist)
361
362     include Monad.MakeT(struct
363       module Wrapped = Wrapped
364       type ('x,'a) m = ('x,'a list) Wrapped.m
365       type ('x,'a) result = ('x,'a list) Wrapped.result
366       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
367       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
368       let bind u f =
369         Wrapped.bind u (fun ts ->
370         Wrapped.bind (distribute f ts) (fun tts ->
371         Wrapped.unit (Util.concat tts)))
372       let run u = Wrapped.run u
373       let run_exn u =
374         let w = Wrapped.bind u (fun ts -> match ts with
375           | [] -> Wrapped.zero ()
376           | [a] -> Wrapped.unit a
377           | many -> Wrapped.zero ()
378         ) in Wrapped.run_exn w
379       let zero () = Wrapped.unit []
380       let plus u v =
381         Wrapped.bind u (fun us ->
382         Wrapped.bind v (fun vs ->
383         Wrapped.unit (Base.plus us vs)))
384     end)
385 (*
386     let permute : 'a m -> 'a m m
387     let select : 'a m -> ('a * 'a m) m
388 *)
389   end
390 end
391
392
393 (* must be parameterized on (struct type err = ... end) *)
394 module Error_monad(Err : sig
395   type err
396   exception Exc of err
397   (*
398   val zero : unit -> err
399   val plus : err -> err -> err
400   *)
401 end) : sig
402   (* declare additional operations, while still hiding implementation of type m *)
403   type err = Err.err
404   type 'a error = Error of err | Success of 'a
405   type ('x,'a) result = 'a error
406   type ('x,'a) result_exn = 'a
407   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
408   val throw : err -> ('x,'a) m
409   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
410   (* ErrorT transformer *)
411   module T : functor (Wrapped : Monad.S) -> sig
412     type ('x,'a) result = ('x,'a) Wrapped.result
413     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
414     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
415     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
416     val throw : err -> ('x,'a) m
417     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
418   end
419 end = struct
420   type err = Err.err
421   type 'a error = Error of err | Success of 'a
422   module Base = struct
423     type ('x,'a) m = 'a error
424     type ('x,'a) result = 'a error
425     type ('x,'a) result_exn = 'a
426     let unit a = Success a
427     let bind u f = match u with
428       | Success a -> f a
429       | Error e -> Error e (* input and output may be of different 'a types *)
430     let run u = u
431     let run_exn u = match u with
432       | Success a -> a
433       | Error e -> raise (Err.Exc e)
434     let zero () = Util.undef
435     let plus u v = u
436     (*
437     let zero () = Error Err.zero
438     let plus u v = match (u, v) with
439       | Success _, _ -> u
440       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
441        * otherwise, plus (Error _) v = v *)
442       | Error _, _ when v = zero -> u
443       (* combine errors *)
444       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
445       | Error _, _ -> v
446     *)
447   end
448   include Monad.Make(Base)
449   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
450   let throw e = Error e
451   let catch u handler = match u with
452     | Success _ -> u
453     | Error e -> handler e
454   module T(Wrapped : Monad.S) = struct
455     include Monad.MakeT(struct
456       module Wrapped = Wrapped
457       type ('x,'a) m = ('x,'a error) Wrapped.m
458       type ('x,'a) result = ('x,'a) Wrapped.result
459       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
460       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
461       let bind u f = Wrapped.bind u (fun t -> match t with
462         | Success a -> f a
463         | Error e -> Wrapped.unit (Error e))
464       let run u =
465         let w = Wrapped.bind u (fun t -> match t with
466           | Success a -> Wrapped.unit a
467           | Error e -> Wrapped.zero ()
468         ) in Wrapped.run w
469       let run_exn u =
470         let w = Wrapped.bind u (fun t -> match t with
471           | Success a -> Wrapped.unit a
472           | Error e -> raise (Err.Exc e))
473         in Wrapped.run_exn w
474       let plus u v = Wrapped.plus u v
475       let zero () = elevate (Wrapped.zero ())
476     end)
477     let throw e = Wrapped.unit (Error e)
478     let catch u handler = Wrapped.bind u (fun t -> match t with
479       | Success _ -> Wrapped.unit t
480       | Error e -> handler e)
481   end
482 end
483
484 (* pre-define common instance of Error_monad *)
485 module Failure = Error_monad(struct
486   type err = string
487   exception Exc = Failure
488   (*
489   let zero = ""
490   let plus s1 s2 = s1 ^ "\n" ^ s2
491   *)
492 end)
493
494 (*
495 # EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
496 - : int EL.result = [Failure.Error "bye"; Failure.Success 30]
497 # LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
498 - : int LE.result = Failure.Error "bye"
499 # EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
500 Exception: Failure "bye".
501 # LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
502 Exception: Failure "bye".
503
504 # ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
505 - : int Failure.error * S.store = (Failure.Error "bye", 1)
506 # SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
507 - : (int * S.store) Failure.result = Failure.Error "bye"
508 # ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
509 Exception: Failure "bye".
510 # SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
511 Exception: Failure "bye".
512  *)
513
514
515 (* must be parameterized on (struct type env = ... end) *)
516 module Reader_monad(Env : sig type env end) : sig
517   (* declare additional operations, while still hiding implementation of type m *)
518   type env = Env.env
519   type ('x,'a) result = env -> 'a
520   type ('x,'a) result_exn = env -> 'a
521   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
522   val ask : ('x,env) m
523   val asks : (env -> 'a) -> ('x,'a) m
524   (* lookup i == `fun e -> e i` would assume env is a functional type *)
525   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
526   (* ReaderT transformer *)
527   module T : functor (Wrapped : Monad.S) -> sig
528     type ('x,'a) result = env -> ('x,'a) Wrapped.result
529     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
530     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
531     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
532     val ask : ('x,env) m
533     val asks : (env -> 'a) -> ('x,'a) m
534     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
535   end
536 end = struct
537   type env = Env.env
538   module Base = struct
539     type ('x,'a) m = env -> 'a
540     type ('x,'a) result = env -> 'a
541     type ('x,'a) result_exn = env -> 'a
542     let unit a = fun e -> a
543     let bind u f = fun e -> let a = u e in let u' = f a in u' e
544     let run u = fun e -> u e
545     let run_exn = run
546     let zero () = Util.undef
547     let plus u v = u
548   end
549   include Monad.Make(Base)
550   let ask = fun e -> e
551   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
552   let local modifier u = fun e -> u (modifier e)
553   module T(Wrapped : Monad.S) = struct
554     module BaseT = struct
555       module Wrapped = Wrapped
556       type ('x,'a) m = env -> ('x,'a) Wrapped.m
557       type ('x,'a) result = env -> ('x,'a) Wrapped.result
558       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
559       let elevate w = fun e -> w
560       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
561       let run u = fun e -> Wrapped.run (u e)
562       let run_exn u = fun e -> Wrapped.run_exn (u e)
563       (* satisfies Distrib *)
564       let plus u v = fun s -> Wrapped.plus (u s) (v s)
565       let zero () = elevate (Wrapped.zero ())
566     end
567     include Monad.MakeT(BaseT)
568     let ask = fun e -> Wrapped.unit e
569     let local modifier u = fun e -> u (modifier e)
570     let asks selector = ask >>= (fun e ->
571       try unit (selector e)
572       with Not_found -> fun e -> Wrapped.zero ())
573   end
574 end
575
576
577 (* must be parameterized on (struct type store = ... end) *)
578 module State_monad(Store : sig type store end) : sig
579   (* declare additional operations, while still hiding implementation of type m *)
580   type store = Store.store
581   type ('x,'a) result =  store -> 'a * store
582   type ('x,'a) result_exn = store -> 'a
583   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
584   val get : ('x,store) m
585   val gets : (store -> 'a) -> ('x,'a) m
586   val put : store -> ('x,unit) m
587   val puts : (store -> store) -> ('x,unit) m
588   (* StateT transformer *)
589   module T : functor (Wrapped : Monad.S) -> sig
590     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
591     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
592     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
593     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
594     val get : ('x,store) m
595     val gets : (store -> 'a) -> ('x,'a) m
596     val put : store -> ('x,unit) m
597     val puts : (store -> store) -> ('x,unit) m
598   end
599 end = struct
600   type store = Store.store
601   module Base = struct
602     type ('x,'a) m =  store -> 'a * store
603     type ('x,'a) result =  store -> 'a * store
604     type ('x,'a) result_exn = store -> 'a
605     let unit a = fun s -> (a, s)
606     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
607     let run u = fun s -> (u s)
608     let run_exn u = fun s -> fst (u s)
609     let zero () = Util.undef
610     let plus u v = u
611   end
612   include Monad.Make(Base)
613   let get = fun s -> (s, s)
614   let gets viewer = fun s -> (viewer s, s) (* may fail *)
615   let put s = fun _ -> ((), s)
616   let puts modifier = fun s -> ((), modifier s)
617   module T(Wrapped : Monad.S) = struct
618     module BaseT = struct
619       module Wrapped = Wrapped
620       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
621       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
622       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
623       let elevate w = fun s ->
624         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
625       let bind u f = fun s ->
626         Wrapped.bind (u s) (fun (a, s') -> f a s')
627       let run u = fun s -> Wrapped.run (u s)
628       let run_exn u = fun s ->
629         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
630         in Wrapped.run_exn w
631       (* satisfies Distrib *)
632       let plus u v = fun s -> Wrapped.plus (u s) (v s)
633       let zero () = elevate (Wrapped.zero ())
634     end
635     include Monad.MakeT(BaseT)
636     let get = fun s -> Wrapped.unit (s, s)
637     let gets viewer = fun s ->
638       try Wrapped.unit (viewer s, s)
639       with Not_found -> Wrapped.zero ()
640     let put s = fun _ -> Wrapped.unit ((), s)
641     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
642   end
643 end
644
645 (* State monad with different interface (structured store) *)
646 module Ref_monad(V : sig
647   type value
648 end) : sig
649   type ref
650   type value = V.value
651   type ('x,'a) result = 'a
652   type ('x,'a) result_exn = 'a
653   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
654   val newref : value -> ('x,ref) m
655   val deref : ref -> ('x,value) m
656   val change : ref -> value -> ('x,unit) m
657   (* RefT transformer *)
658   module T : functor (Wrapped : Monad.S) -> sig
659     type ('x,'a) result = ('x,'a) Wrapped.result
660     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
661     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
662     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
663     val newref : value -> ('x,ref) m
664     val deref : ref -> ('x,value) m
665     val change : ref -> value -> ('x,unit) m
666   end
667 end = struct
668   type ref = int
669   type value = V.value
670   module D = Map.Make(struct type t = ref let compare = compare end)
671   type dict = { next: ref; tree : value D.t }
672   let empty = { next = 0; tree = D.empty }
673   let alloc (value : value) (d : dict) =
674     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
675   let read (key : ref) (d : dict) =
676     D.find key d.tree
677   let write (key : ref) (value : value) (d : dict) =
678     { next = d.next; tree = D.add key value d.tree }
679   module Base = struct
680     type ('x,'a) m = dict -> 'a * dict
681     type ('x,'a) result = 'a
682     type ('x,'a) result_exn = 'a
683     let unit a = fun s -> (a, s)
684     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
685     let run u = fst (u empty)
686     let run_exn = run
687     let zero () = Util.undef
688     let plus u v = u
689   end
690   include Monad.Make(Base)
691   let newref value = fun s -> alloc value s
692   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
693   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
694   module T(Wrapped : Monad.S) = struct
695     module BaseT = struct
696       module Wrapped = Wrapped
697       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
698       type ('x,'a) result = ('x,'a) Wrapped.result
699       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
700       let elevate w = fun s ->
701         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
702       let bind u f = fun s ->
703         Wrapped.bind (u s) (fun (a, s') -> f a s')
704       let run u =
705         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
706         in Wrapped.run w
707       let run_exn u =
708         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
709         in Wrapped.run_exn w
710       (* satisfies Distrib *)
711       let plus u v = fun s -> Wrapped.plus (u s) (v s)
712       let zero () = elevate (Wrapped.zero ())
713     end
714     include Monad.MakeT(BaseT)
715     let newref value = fun s -> Wrapped.unit (alloc value s)
716     let deref key = fun s -> Wrapped.unit (read key s, s)
717     let change key value = fun s -> Wrapped.unit ((), write key value s)
718   end
719 end
720
721 (* TODO needs a T *)
722 (* must be parameterized on (struct type log = ... end) *)
723 module Writer_monad(Log : sig
724   type log
725   val zero : log
726   val plus : log -> log -> log
727 end) : sig
728   (* declare additional operations, while still hiding implementation of type m *)
729   type log = Log.log
730   type ('x,'a) result = 'a * log
731   type ('x,'a) result_exn = 'a * log
732   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
733   val tell : log -> ('x,unit) m
734   val listen : ('x,'a) m -> ('x,'a * log) m
735   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
736   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
737   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
738 end = struct
739   type log = Log.log
740   module Base = struct
741     type ('x,'a) m = 'a * log
742     type ('x,'a) result = 'a * log
743     type ('x,'a) result_exn = 'a * log
744     let unit a = (a, Log.zero)
745     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
746     let run u = u
747     let run_exn = run
748     let zero () = Util.undef
749     let plus u v = u
750   end
751   include Monad.Make(Base)
752   let tell entries = ((), entries) (* add entries to log *)
753   let listen (a, w) = ((a, w), w)
754   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
755   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
756   let censor f u = pass (u >>= fun a -> unit (a, f))
757 end
758
759 (* pre-define simple Writer *)
760 module Writer1 = Writer_monad(struct
761   type log = string
762   let zero = ""
763   let plus s1 s2 = s1 ^ "\n" ^ s2
764 end)
765
766 (* slightly more efficient Writer *)
767 module Writer2 = struct
768   include Writer_monad(struct
769     type log = string list
770     let zero = []
771     let plus w w' = Util.append w' w
772   end)
773   let tell_string s = tell [s]
774   let tell entries = tell (Util.reverse entries)
775   let run u = let (a, w) = run u in (a, Util.reverse w)
776   let run_exn = run
777 end
778
779
780 (* TODO needs a T *)
781 module IO_monad : sig
782   (* declare additional operation, while still hiding implementation of type m *)
783   type ('x,'a) result = 'a
784   type ('x,'a) result_exn = 'a
785   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
786   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
787   val print_string : string -> ('x,unit) m
788   val print_int : int -> ('x,unit) m
789   val print_hex : int -> ('x,unit) m
790   val print_bool : bool -> ('x,unit) m
791 end = struct
792   module Base = struct
793     type ('x,'a) m = { run : unit -> unit; value : 'a }
794     type ('x,'a) result = 'a
795     type ('x,'a) result_exn = 'a
796     let unit a = { run = (fun () -> ()); value = a }
797     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
798      let fres = f a.value in
799        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
800     let run a = let () = a.run () in a.value
801     let run_exn = run
802     let zero () = Util.undef
803     let plus u v = u
804   end
805   include Monad.Make(Base)
806   let printf fmt =
807     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
808   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
809   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
810   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
811   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
812 end
813
814
815 (* TODO needs a T *)
816 module Continuation_monad : sig
817   (* expose only the implementation of type `('r,'a) result` *)
818   type ('r,'a) m
819   type ('r,'a) result = ('r,'a) m
820   type ('r,'a) result_exn = ('a -> 'r) -> 'r
821   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
822   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
823   val reset : ('a,'a) m -> ('r,'a) m
824   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
825   (* val abort : ('a,'a) m -> ('a,'b) m *)
826   val abort : 'a -> ('a,'b) m
827   val run0 : ('a,'a) m -> 'a
828 end = struct
829   let id = fun i -> i
830   module Base = struct
831     (* 'r is result type of whole computation *)
832     type ('r,'a) m = ('a -> 'r) -> 'r
833     type ('r,'a) result = ('a -> 'r) -> 'r
834     type ('r,'a) result_exn = ('r,'a) result
835     let unit a = (fun k -> k a)
836     let bind u f = (fun k -> (u) (fun a -> (f a) k))
837     let run u k = (u) k
838     let run_exn = run
839     let zero () = Util.undef
840     let plus u v = u
841   end
842   include Monad.Make(Base)
843   let callcc f = (fun k ->
844     let usek a = (fun _ -> k a)
845     in (f usek) k)
846   (*
847   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
848   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
849   let callcc f = fun k -> f k k
850   let throw k a = fun _ -> k a
851   *)
852
853   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
854    *
855    *  reset :: (Monad m) => ContT a m a -> ContT r m a
856    *  reset e = ContT $ \k -> runContT e return >>= k
857    *
858    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
859    *  shift e = ContT $ \k ->
860    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
861   let reset u = unit ((u) id)
862   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
863   (* let abort a = shift (fun _ -> a) *)
864   let abort a = shift (fun _ -> unit a)
865   let run0 (u : ('a,'a) m) = (u) id
866 end
867
868
869 (*
870  * Scheme:
871  * (define (example n)
872  *    (let ([u (let/cc k ; type int -> int pair
873  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
874  *                 (+ 1 (car v))))]) ; int
875  *      (cons u 0))) ; int pair
876  * ; (example 10) ~~> '(111 . 0)
877  * ; (example -10) ~~> '(0 . 0)
878  *
879  * OCaml monads:
880  * let example n : (int * int) =
881  *   Continuation_monad.(let u = callcc (fun k ->
882  *       (if n < 0 then k 0 else unit [n + 100])
883  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
884  *       >>= fun [x] -> unit (x + 1)
885  *   )
886  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
887  *   >>= fun x -> unit (x, 0)
888  *   in run u)
889  *
890  *
891  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
892  * let example1 () : int =
893  *   Continuation_monad.(let v = reset (
894  *       let u = shift (fun k -> unit (10 + 1))
895  *       in u >>= fun x -> unit (100 + x)
896  *     ) in let w = v >>= fun x -> unit (1000 + x)
897  *     in run w)
898  *
899  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
900  * let example2 () =
901  *   Continuation_monad.(let v = reset (
902  *       let u = shift (fun k -> k (10 :: [1]))
903  *       in u >>= fun x -> unit (100 :: x)
904  *     ) in let w = v >>= fun x -> unit (1000 :: x)
905  *     in run w)
906  *
907  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
908  * let example3 () =
909  *   Continuation_monad.(let v = reset (
910  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
911  *       in u >>= fun x -> unit (100 :: x)
912  *     ) in let w = v >>= fun x -> unit (1000 :: x)
913  *     in run w)
914  *
915  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
916  * (* not sure if this example can be typed without a sum-type *)
917  *
918  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
919  * let example5 () : int =
920  *   Continuation_monad.(let v = reset (
921  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
922  *       in u >>= fun x -> unit (10 + x)
923  *     ) in let w = v >>= fun x -> unit (100 + x)
924  *     in run w)
925  *
926  *)
927
928
929 module Leaf_monad : sig
930   (* We implement the type as `'a tree option` because it has a natural`plus`,
931    * and the rest of the library expects that `plus` and `zero` will come together. *)
932   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
933   type ('x,'a) result = 'a tree option
934   type ('x,'a) result_exn = 'a tree
935   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
936   (* LeafT transformer *)
937   module T : functor (Wrapped : Monad.S) -> sig
938     type ('x,'a) result = ('x,'a tree option) Wrapped.result
939     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
940     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
941     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
942     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
943     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
944     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
945   end
946 end = struct
947   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
948   (* uses supplied plus and zero to copy t to its image under f *)
949   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
950       | None -> zero ()
951       | Some ts -> let rec loop ts = (match ts with
952                      | Leaf a -> f a
953                      | Node (l, r) ->
954                          (* recursive application of f may delete a branch *)
955                          plus (loop l) (loop r)
956                    ) in loop ts
957   module Base = struct
958     type ('x,'a) m = 'a tree option
959     type ('x,'a) result = 'a tree option
960     type ('x,'a) result_exn = 'a tree
961     let unit a = Some (Leaf a)
962     let zero () = None
963     (* satisfies Distrib *)
964     let plus u v = match (u, v) with
965       | None, _ -> v
966       | _, None -> u
967       | Some us, Some vs -> Some (Node (us, vs))
968     let bind u f = mapT f u zero plus
969     let run u = u
970     let run_exn u = match u with
971       | None -> failwith "no values"
972       (*
973       | Some (Leaf a) -> a
974       | many -> failwith "multiple values"
975       *)
976       | Some us -> us
977   end
978   include Monad.Make(Base)
979   module T(Wrapped : Monad.S) = struct
980     module BaseT = struct
981       include Monad.MakeT(struct
982         module Wrapped = Wrapped
983         type ('x,'a) m = ('x,'a tree option) Wrapped.m
984         type ('x,'a) result = ('x,'a tree option) Wrapped.result
985         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
986         let zero () = Wrapped.unit None
987         let plus u v =
988           Wrapped.bind u (fun us ->
989           Wrapped.bind v (fun vs ->
990           Wrapped.unit (Base.plus us vs)))
991         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
992         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
993         let run u = Wrapped.run u
994         let run_exn u =
995             let w = Wrapped.bind u (fun t -> match t with
996               | None -> Wrapped.zero ()
997               | Some ts -> Wrapped.unit ts
998             ) in Wrapped.run_exn w
999       end)
1000     end
1001     include BaseT
1002     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1003   end
1004 end
1005
1006
1007 module L = List_monad;;
1008 module R = Reader_monad(struct type env = int -> int end);;
1009 module S = State_monad(struct type store = int end);;
1010 module T = Leaf_monad;;
1011 module LR = L.T(R);;
1012 module LS = L.T(S);;
1013 module TL = T.T(L);;
1014 module TR = T.T(R);;
1015 module TS = T.T(S);;
1016 module C = Continuation_monad
1017 module TC = T.T(C);;
1018
1019
1020 print_endline "=== test Leaf(...).distribute ==================";;
1021
1022 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1023
1024 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1025 TS.run ts 0;;
1026 (*
1027 - : int T.tree option * S.store =
1028 (Some
1029   (T.Node
1030     (T.Node (T.Leaf 2, T.Leaf 3),
1031      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1032  5)
1033 *)
1034
1035 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1036 TS.run_exn ts2 0;;
1037 (*
1038 - : (int * S.store) T.tree option * S.store =
1039 (Some
1040   (T.Node
1041     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1042      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1043  5)
1044 *)
1045
1046 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1047 TR.run_exn tr (fun i -> i+i);;
1048 (*
1049 - : int T.tree option =
1050 Some
1051  (T.Node
1052    (T.Node (T.Leaf 4, T.Leaf 6),
1053     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1054 *)
1055
1056 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1057 TL.run_exn tl;;
1058 (*
1059 - : (int * int) TL.result =
1060 [Some
1061   (T.Node
1062     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1063      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1064 *)
1065
1066 let l2 = [1;2;3;4;5];;
1067 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1068
1069 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1070 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1071
1072 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1073 (*
1074 int T.tree option =
1075 Some
1076  (T.Node
1077    (T.Node (T.Leaf 10, T.Leaf 11),
1078     T.Node
1079      (T.Node
1080        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1081         T.Node (T.Leaf 40, T.Leaf 41)),
1082       T.Node (T.Leaf 50, T.Leaf 51))))
1083  *)
1084
1085 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1086 (*
1087 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1088 *)
1089
1090 print_endline "=== test Leaf(Continuation).distribute ==================";;
1091
1092 let id : 'z. 'z -> 'z = fun x -> x
1093
1094 let example n : (int * int) =
1095   Continuation_monad.(let u = callcc (fun k ->
1096       (if n < 0 then k 0 else unit [n + 100])
1097       (* all of the following is skipped by k 0; the end type int is k's input type *)
1098       >>= fun [x] -> unit (x + 1)
1099   )
1100   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1101   >>= fun x -> unit (x, 0)
1102   in run0 u)
1103
1104
1105 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1106 let example1 () : int =
1107   Continuation_monad.(let v = reset (
1108       let u = shift (fun k -> unit (10 + 1))
1109       in u >>= fun x -> unit (100 + x)
1110     ) in let w = v >>= fun x -> unit (1000 + x)
1111     in run0 w)
1112
1113 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1114 let example2 () =
1115   Continuation_monad.(let v = reset (
1116       let u = shift (fun k -> k (10 :: [1]))
1117       in u >>= fun x -> unit (100 :: x)
1118     ) in let w = v >>= fun x -> unit (1000 :: x)
1119     in run0 w)
1120
1121 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1122 let example3 () =
1123   Continuation_monad.(let v = reset (
1124       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1125       in u >>= fun x -> unit (100 :: x)
1126     ) in let w = v >>= fun x -> unit (1000 :: x)
1127     in run0 w)
1128
1129 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1130 (* not sure if this example can be typed without a sum-type *)
1131
1132 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1133 let example5 () : int =
1134   Continuation_monad.(let v = reset (
1135       let u = shift (fun k -> k 1 >>= k)
1136       in u >>= fun x -> unit (10 + x)
1137     ) in let w = v >>= fun x -> unit (100 + x)
1138     in run0 w)
1139
1140 ;;
1141
1142 print_endline "=== test bare Continuation ============";;
1143
1144 (1011, 1111, 1111, 121);;
1145 (example1(), example2(), example3(), example5());;
1146 ((111,0), (0,0));;
1147 (example ~+10, example ~-10);;
1148
1149 let testc df ic =
1150     C.run_exn TC.(run (distribute df t1)) ic;;
1151
1152
1153 (*
1154 (* do nothing *)
1155 let initial_continuation = fun t -> t in
1156 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1157 *)
1158 testc (C.unit) id;;
1159
1160 (*
1161 (* count leaves, using continuation *)
1162 let initial_continuation = fun t -> 0 in
1163 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1164 *)
1165
1166 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1167
1168 (*
1169 (* convert tree to list of leaves *)
1170 let initial_continuation = fun t -> [] in
1171 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1172 *)
1173
1174 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1175
1176 (*
1177 (* square each leaf using continuation *)
1178 let initial_continuation = fun t -> t in
1179 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1180 *)
1181
1182 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1183
1184
1185 (*
1186 (* replace leaves with list, using continuation *)
1187 let initial_continuation = fun t -> t in
1188 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1189 *)
1190
1191 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1192
1193 print_endline "=== pa_monad's Continuation Tests ============";;
1194
1195 (1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
1196 (2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
1197 (3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
1198 (4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
1199 (5, 27 = C.(run0 (
1200               let c = reset(abort 5 >>= fun y -> unit (y+6))
1201               in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
1202
1203 (7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1204
1205 (8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
1206
1207 (12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
1208
1209
1210 (0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
1211