push monads library
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  *)
46
47
48 (* Some library functions used below. *)
49 module Util = struct
50   let fold_right = List.fold_right
51   let map = List.map
52   let append = List.append
53   let reverse = List.rev
54   let concat = List.concat
55   let concat_map f lst = List.concat (List.map f lst)
56   (* let zip = List.combine *)
57   let unzip = List.split
58   let zip_with = List.map2
59   let replicate len fill =
60     let rec loop n accu =
61       if n == 0 then accu else loop (pred n) (fill :: accu)
62     in loop len []
63 end
64
65
66
67 (*
68  * This module contains factories that extend a base set of
69  * monadic definitions with a larger family of standard derived values.
70  *)
71
72 module Monad = struct
73
74   (* type of base definitions *)
75   module type BASE = sig
76     (*
77      * The only constraints we impose here on how the monadic type
78      * is implemented is that it have a single type parameter 'a.
79      *)
80     type 'a m
81     val unit : 'a -> 'a m
82     val bind : 'a m -> ('a -> 'b m) -> 'b m
83     type 'a result
84     val run : 'a m -> 'a result
85     (* run_exn tries to provide a more ground-level result, but may fail *)
86     type 'a result_exn
87     val run_exn : 'a m -> 'a result_exn
88   end
89   module type S = sig
90     include BASE
91     val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
92     val (>>) : 'a m -> 'b m -> 'b m
93     val join : ('a m) m -> 'a m
94     val apply : ('a -> 'b) m -> 'a m -> 'b m
95     val lift : ('a -> 'b) -> 'a m -> 'b m
96     val lift2 :  ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
97     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
98     val do_when :  bool -> unit m -> unit m
99     val do_unless :  bool -> unit m -> unit m
100     val forever : 'a m -> 'b m
101     val sequence : 'a m list -> 'a list m
102     val sequence_ : 'a m list -> unit m
103   end
104
105   (* Standard, single-type-parameter monads. *)
106   module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct
107     include B
108     let (>>=) = bind
109     let (>>) u v = u >>= fun _ -> v
110     let lift f u = u >>= fun a -> unit (f a)
111     (* lift is called listM, fmap, and <$> in Haskell *)
112     let join uu = uu >>= fun u -> u
113     (* u >>= f === join (lift f u) *)
114     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
115     (* [f] <*> [x1,x2] = [f x1,f x2] *)
116     (* let apply u v = u >>= fun f -> lift f v *)
117     (* let apply = lift2 id *)
118     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
119     (* let lift f u === apply (unit f) u *)
120     (* let lift2 f u v = apply (lift f u) v *)
121
122     let (>=>) f g = fun a -> f a >>= g
123     let do_when test u = if test then u else unit ()
124     let do_unless test u = if test then unit () else u
125     let rec forever u = u >> forever u
126     let sequence ms =
127       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
128         Util.fold_right op ms (unit [])
129     let sequence_ ms =
130       Util.fold_right (>>) ms (unit ())
131
132     (* Haskell defines these other operations combining lists and monads.
133      * We don't, but notice that M.mapM == ListT(M).distribute
134      * There's also a parallel TreeT(M).distribute *)
135     (*
136     let mapM f alist = sequence (Util.map f alist)
137     let mapM_ f alist = sequence_ (Util.map f alist)
138     let rec filterM f lst = match lst with
139       | [] -> unit []
140       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
141     let forM alist f = mapM f alist
142     let forM_ alist f = mapM_ f alist
143     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
144     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
145     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
146     let rec foldM f z lst = match lst with
147       | [] -> unit z
148       | x::xs -> f z x >>= fun z' -> foldM f z' xs
149     let foldM_ f z xs = foldM f z xs >> unit ()
150     let replicateM n x = sequence (Util.replicate n x)
151     let replicateM_ n x = sequence_ (Util.replicate n x)
152     *)
153   end
154
155   (* Single-type-parameter monads that also define `plus` and `zero`
156    * operations. These obey the following laws:
157    *     zero >>= f   ===  zero
158    *     plus zero u  ===  u
159    *     plus u zero  ===  u
160    * Additionally, these monads will obey one of the following laws:
161    *     (Catch)   plus (unit a) v  ===  unit a
162    *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
163    *)
164   module type PLUSBASE = sig
165     include BASE
166     val zero : unit -> 'a m
167     val plus : 'a m -> 'a m -> 'a m
168   end
169   module type PLUS = sig
170     type 'a m
171     val zero : unit -> 'a m
172     val plus : 'a m -> 'a m -> 'a m
173     val guard : bool -> unit m
174     val sum : 'a m list -> 'a m
175   end
176   (* MakeCatch and MakeDistrib have the same implementation; we just declare
177    * them twice to document which laws the client code is promising to honor. *)
178   module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct
179     type 'a m = 'a B.m
180     let zero = B.zero
181     let plus = B.plus
182     let guard test = if test then B.unit () else zero ()
183     let sum ms = Util.fold_right plus ms (zero ())
184   end
185   module MakeDistrib = MakeCatch
186
187   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
188   module type BASE2 = sig
189     type ('x,'a) m
190     val unit : 'a -> ('x,'a) m
191     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
192     type ('x,'a) result
193     val run : ('x,'a) m -> ('x,'a) result
194     type ('x,'a) result_exn
195     val run_exn : ('x,'a) m -> ('x,'a) result
196   end
197   module type S2 = sig
198     include BASE2
199     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
200     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
201     val join : ('x,('x,'a) m) m -> ('x,'a) m
202     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
203     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
204     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
205     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
206     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
207     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
208     val forever : ('x,'a) m -> ('x,'b) m
209     val sequence : ('x,'a) m list -> ('x,'a list) m
210     val sequence_ : ('x,'a) m list -> ('x,unit) m
211   end
212   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
213     include B
214     let (>>=) = bind
215     let (>>) u v = u >>= fun _ -> v
216     let lift f u = u >>= fun a -> unit (f a)
217     let join uu = uu >>= fun u -> u
218     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
219     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
220     let (>=>) f g = fun a -> f a >>= g
221     let do_when test u = if test then u else unit ()
222     let do_unless test u = if test then unit () else u
223     let rec forever u = u >> forever u
224     let sequence ms =
225       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
226         Util.fold_right op ms (unit [])
227     let sequence_ ms =
228       Util.fold_right (>>) ms (unit ())
229   end
230
231   (* Signatures for MonadT *)
232   module type W = sig
233     include S
234   end
235   module type WP = sig
236     include W
237     val zero : unit -> 'a m
238     val plus : 'a m -> 'a m -> 'a m
239   end
240   module type TRANS = sig
241     type 'a m
242     val bind : 'a m -> ('a -> 'b m) -> 'b m
243     module Wrapped : W
244     type 'a result
245     val run : 'a m -> 'a result
246     type 'a result_exn
247     val run_exn : 'a m -> 'a result_exn
248     val elevate : 'a Wrapped.m -> 'a m
249     (* lift/elevate laws:
250      *     elevate (W.unit a) == unit a
251      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
252      *)
253   end
254   module MakeT(T : TRANS) = struct
255     include Make(struct
256         include T
257         let unit a = elevate (Wrapped.unit a)
258     end)
259     let elevate = T.elevate
260   end
261
262 end
263
264
265
266
267
268 module Identity_monad : sig
269   (* expose only the implementation of type `'a result` *)
270   type 'a result = 'a
271   type 'a result_exn = 'a
272   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
273 end = struct
274   module Base = struct
275     type 'a m = 'a
276     let unit a = a
277     let bind a f = f a
278     type 'a result = 'a
279     let run a = a
280     type 'a result_exn = 'a
281     let run_exn a = a
282   end
283   include Monad.Make(Base)
284 end
285
286
287 module Maybe_monad : sig
288   (* expose only the implementation of type `'a result` *)
289   type 'a result = 'a option
290   type 'a result_exn = 'a
291   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
292   include Monad.PLUS with type 'a m := 'a m
293   (* MaybeT transformer *)
294   module T : functor (Wrapped : Monad.W) -> sig
295     type 'a result = 'a option Wrapped.result
296     type 'a result_exn = 'a Wrapped.result_exn
297     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
298     include Monad.PLUS with type 'a m := 'a m
299     val elevate : 'a Wrapped.m -> 'a m
300   end
301 end = struct
302   module Base = struct
303    type 'a m = 'a option
304    let unit a = Some a
305    let bind u f = match u with Some a -> f a | None -> None
306    type 'a result = 'a option
307    let run u = u
308    type 'a result_exn = 'a
309    let run_exn u = match u with
310      | Some a -> a
311      | None -> failwith "no value"
312    let zero () = None
313    let plus u v = match u with None -> v | _ -> u
314   end
315   include Monad.Make(Base)
316   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
317   module T(Wrapped : Monad.W) = struct
318     module Trans = struct
319       include Monad.MakeT(struct
320         module Wrapped = Wrapped
321         type 'a m = 'a option Wrapped.m
322         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
323         let bind u f = Wrapped.bind u (fun t -> match t with
324           | Some a -> f a
325           | None -> Wrapped.unit None)
326         type 'a result = 'a option Wrapped.result
327         let run u = Wrapped.run u
328         type 'a result_exn = 'a Wrapped.result_exn
329         let run_exn u =
330           let w = Wrapped.bind u (fun t -> match t with
331             | Some a -> Wrapped.unit a
332             | None -> failwith "no value")
333           in Wrapped.run_exn w
334       end)
335       let zero () = Wrapped.unit None
336       let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
337     end
338     include Trans
339     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
340   end
341 end
342
343
344 module List_monad : sig
345   (* declare additional operation, while still hiding implementation of type m *)
346   type 'a result = 'a list
347   type 'a result_exn = 'a
348   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
349   include Monad.PLUS with type 'a m := 'a m
350   val permute : 'a m -> 'a m m
351   val select : 'a m -> ('a * 'a m) m
352   (* ListT transformer *)
353   module T : functor (Wrapped : Monad.W) -> sig
354     type 'a result = 'a list Wrapped.result
355     type 'a result_exn = 'a Wrapped.result_exn
356     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
357     include Monad.PLUS with type 'a m := 'a m
358     val elevate : 'a Wrapped.m -> 'a m
359     (* note that second argument is an 'a list, not the more abstract 'a m *)
360     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
361     val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m
362 (* TODO
363     val permute : 'a m -> 'a m m
364     val select : 'a m -> ('a * 'a m) m
365 *)
366   end
367 end = struct
368   module Base = struct
369    type 'a m = 'a list
370    let unit a = [a]
371    let bind u f = Util.concat_map f u
372    type 'a result = 'a list
373    let run u = u
374    type 'a result_exn = 'a
375    let run_exn u = match u with
376      | [] -> failwith "no values"
377      | [a] -> a
378      | many -> failwith "multiple values"
379    let zero () = []
380    let plus = Util.append
381   end
382   include Monad.Make(Base)
383   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
384   (* let either u v = plus u v *)
385   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
386   let rec insert a u =
387     plus (unit (a :: u)) (match u with
388         | [] -> zero ()
389         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
390     )
391   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
392   let rec permute u = match u with
393       | [] -> unit []
394       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
395   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
396   let rec select u = match u with
397     | [] -> zero ()
398     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
399   let base_plus = plus
400   module T(Wrapped : Monad.W) = struct
401     module Trans = struct
402       let zero () = Wrapped.unit []
403       let plus u v =
404         Wrapped.bind u (fun us ->
405         Wrapped.bind v (fun vs ->
406         Wrapped.unit (base_plus us vs)))
407       (* Wrapped.sequence ms  ===  
408            let plus1 u v =
409              Wrapped.bind u (fun x ->
410              Wrapped.bind v (fun xs ->
411              Wrapped.unit (x :: xs)))
412            in Util.fold_right plus1 ms (Wrapped.unit []) *)
413       (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
414       let distribute f alist = Wrapped.sequence (Util.map f alist)
415       include Monad.MakeT(struct
416         module Wrapped = Wrapped
417         type 'a m = 'a list Wrapped.m
418         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
419         let bind u f =
420           Wrapped.bind u (fun ts ->
421           Wrapped.bind (distribute f ts) (fun tts ->
422           Wrapped.unit (Util.concat tts)))
423         type 'a result = 'a list Wrapped.result
424         let run u = Wrapped.run u
425         type 'a result_exn = 'a Wrapped.result_exn
426         let run_exn u =
427           let w = Wrapped.bind u (fun ts -> match ts with
428             | [] -> failwith "no values"
429             | [a] -> Wrapped.unit a
430             | many -> failwith "multiple values"
431           ) in Wrapped.run_exn w
432       end)
433     end
434     include Trans
435     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
436 (*
437     let permute : 'a m -> 'a m m
438     let select : 'a m -> ('a * 'a m) m
439 *)
440   end
441 end
442
443
444 (* must be parameterized on (struct type err = ... end) *)
445 module Error_monad(Err : sig
446   type err
447   exception Exc of err
448   (*
449   val zero : unit -> err
450   val plus : err -> err -> err
451   *)
452 end) : sig
453   (* declare additional operations, while still hiding implementation of type m *)
454   type err = Err.err
455   type 'a error = Error of err | Success of 'a
456   type 'a result = 'a
457   type 'a result_exn = 'a
458   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
459   (* include Monad.PLUS with type 'a m := 'a m *)
460   val throw : err -> 'a m
461   val catch : 'a m -> (err -> 'a m) -> 'a m
462   (* ErrorT transformer *)
463   module T : functor (Wrapped : Monad.W) -> sig
464     type 'a result = 'a Wrapped.result
465     type 'a result_exn = 'a Wrapped.result_exn
466     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
467     val elevate : 'a Wrapped.m -> 'a m
468     val throw : err -> 'a m
469     val catch : 'a m -> (err -> 'a m) -> 'a m
470   end
471 end = struct
472   type err = Err.err
473   type 'a error = Error of err | Success of 'a
474   module Base = struct
475     type 'a m = 'a error
476     let unit a = Success a
477     let bind u f = match u with
478       | Success a -> f a
479       | Error e -> Error e (* input and output may be of different 'a types *)
480     type 'a result = 'a
481     (* TODO: should run refrain from failing? *)
482     let run u = match u with
483       | Success a -> a
484       | Error e -> raise (Err.Exc e)
485     type 'a result_exn = 'a
486     let run_exn = run
487     (*
488     let zero () = Error Err.zero
489     let plus u v = match (u, v) with
490       | Success _, _ -> u
491       (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
492        * otherwise, plus (Error _) v = v *)
493       | Error _, _ when v = zero -> u
494       (* combine errors *)
495       | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
496       | Error _, _ -> v
497     *)
498   end
499   include Monad.Make(Base)
500   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
501   let throw e = Error e
502   let catch u handler = match u with
503     | Success _ -> u
504     | Error e -> handler e
505   module T(Wrapped : Monad.W) = struct
506     module Trans = struct
507       module Wrapped = Wrapped
508       type 'a m = 'a Base.m Wrapped.m
509       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
510       let bind u f = Wrapped.bind u (fun t -> match t with
511         | Success a -> f a
512         | Error e -> Wrapped.unit (Error e))
513       type 'a result = 'a Wrapped.result
514       (* TODO: should run refrain from failing? *)
515       let run u =
516         let w = Wrapped.bind u (fun t -> match t with
517           | Success a -> Wrapped.unit a
518           (* | _ -> Wrapped.fail () *)
519           | Error e -> raise (Err.Exc e))
520         in Wrapped.run w
521       type 'a result_exn = 'a Wrapped.result_exn
522       let run_exn u =
523         let w = Wrapped.bind u (fun t -> match t with
524           | Success a -> Wrapped.unit a
525           (* | _ -> Wrapped.fail () *)
526           | Error e -> raise (Err.Exc e))
527         in Wrapped.run_exn w
528     end
529     include Monad.MakeT(Trans)
530     let throw e = Wrapped.unit (Error e)
531     let catch u handler = Wrapped.bind u (fun t -> match t with
532       | Success _ -> Wrapped.unit t
533       | Error e -> handler e)
534   end
535 end
536
537 (* pre-define common instance of Error_monad *)
538 module Failure = Error_monad(struct
539   type err = string
540   exception Exc = Failure
541   (*
542   let zero = ""
543   let plus s1 s2 = s1 ^ "\n" ^ s2
544   *)
545 end)
546
547 (* must be parameterized on (struct type env = ... end) *)
548 module Reader_monad(Env : sig type env end) : sig
549   (* declare additional operations, while still hiding implementation of type m *)
550   type env = Env.env
551   type 'a result = env -> 'a
552   type 'a result_exn = env -> 'a
553   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
554   val ask : env m
555   val asks : (env -> 'a) -> 'a m
556   val local : (env -> env) -> 'a m -> 'a m
557   (* ReaderT transformer *)
558   module T : functor (Wrapped : Monad.W) -> sig
559     type 'a result = env -> 'a Wrapped.result
560     type 'a result_exn = env -> 'a Wrapped.result_exn
561     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
562     val elevate : 'a Wrapped.m -> 'a m
563     val ask : env m
564     val asks : (env -> 'a) -> 'a m
565     val local : (env -> env) -> 'a m -> 'a m
566   end
567   (* ReaderT transformer when wrapped monad has plus, zero *)
568   module TP : functor (Wrapped : Monad.WP) -> sig
569     include module type of T(Wrapped)
570     include Monad.PLUS with type 'a m := 'a m
571   end
572 end = struct
573   type env = Env.env
574   module Base = struct
575     type 'a m = env -> 'a
576     let unit a = fun e -> a
577     let bind u f = fun e -> let a = u e in let u' = f a in u' e
578     type 'a result = env -> 'a
579     let run u = fun e -> u e
580     type 'a result_exn = env -> 'a
581     let run_exn = run
582   end
583   include Monad.Make(Base)
584   let ask = fun e -> e
585   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
586   let local modifier u = fun e -> u (modifier e)
587   module T(Wrapped : Monad.W) = struct
588     module Trans = struct
589       module Wrapped = Wrapped
590       type 'a m = env -> 'a Wrapped.m
591       let elevate w = fun e -> w
592       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
593       type 'a result = env -> 'a Wrapped.result
594       let run u = fun e -> Wrapped.run (u e)
595       type 'a result_exn = env -> 'a Wrapped.result_exn
596       let run_exn u = fun e -> Wrapped.run_exn (u e)
597     end
598     include Monad.MakeT(Trans)
599     let ask = fun e -> Wrapped.unit e
600     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
601     let local modifier u = fun e -> u (modifier e)
602   end
603   module TP(Wrapped : Monad.WP) = struct
604     module TransP = struct
605       include T(Wrapped)
606       let plus u v = fun s -> Wrapped.plus (u s) (v s)
607       let zero () = elevate (Wrapped.zero ())
608       let asks selector = ask >>= (fun e ->
609         try unit (selector e)
610         with Not_found -> fun e -> Wrapped.zero ())
611     end
612     include TransP
613     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
614   end
615 end
616
617
618 (* must be parameterized on (struct type store = ... end) *)
619 module State_monad(Store : sig type store end) : sig
620   (* declare additional operations, while still hiding implementation of type m *)
621   type store = Store.store
622   type 'a result = store -> 'a * store
623   type 'a result_exn = store -> 'a
624   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
625   val get : store m
626   val gets : (store -> 'a) -> 'a m
627   val put : store -> unit m
628   val puts : (store -> store) -> unit m
629   (* StateT transformer *)
630   module T : functor (Wrapped : Monad.W) -> sig
631     type 'a result = store -> ('a * store) Wrapped.result
632     type 'a result_exn = store -> 'a Wrapped.result_exn
633     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
634     val elevate : 'a Wrapped.m -> 'a m
635     val get : store m
636     val gets : (store -> 'a) -> 'a m
637     val put : store -> unit m
638     val puts : (store -> store) -> unit m
639   end
640   (* StateT transformer when wrapped monad has plus, zero *)
641   module TP : functor (Wrapped : Monad.WP) -> sig
642     include module type of T(Wrapped)
643     include Monad.PLUS with type 'a m := 'a m
644   end
645 end = struct
646   type store = Store.store
647   module Base = struct
648     type 'a m = store -> 'a * store
649     let unit a = fun s -> (a, s)
650     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
651     type 'a result = store -> 'a * store
652     let run u = fun s -> (u s)
653     type 'a result_exn = store -> 'a
654     let run_exn u = fun s -> fst (u s)
655   end
656   include Monad.Make(Base)
657   let get = fun s -> (s, s)
658   let gets viewer = fun s -> (viewer s, s) (* may fail *)
659   let put s = fun _ -> ((), s)
660   let puts modifier = fun s -> ((), modifier s)
661   module T(Wrapped : Monad.W) = struct
662     module Trans = struct
663       module Wrapped = Wrapped
664       type 'a m = store -> ('a * store) Wrapped.m
665       let elevate w = fun s ->
666         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
667       let bind u f = fun s ->
668         Wrapped.bind (u s) (fun (a, s') -> f a s')
669       type 'a result = store -> ('a * store) Wrapped.result
670       let run u = fun s -> Wrapped.run (u s)
671       type 'a result_exn = store -> 'a Wrapped.result_exn
672       let run_exn u = fun s ->
673         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
674         in Wrapped.run_exn w
675     end
676     include Monad.MakeT(Trans)
677     let get = fun s -> Wrapped.unit (s, s)
678     let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
679     let put s = fun _ -> Wrapped.unit ((), s)
680     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
681   end
682   module TP(Wrapped : Monad.WP) = struct
683     module TransP = struct
684       include T(Wrapped)
685       let plus u v = fun s -> Wrapped.plus (u s) (v s)
686       let zero () = elevate (Wrapped.zero ())
687     end
688     let gets viewer = fun s ->
689       try Wrapped.unit (viewer s, s)
690       with Not_found -> Wrapped.zero ()
691     include TransP
692     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
693   end
694 end
695
696 (* State monad with different interface (structured store) *)
697 module Ref_monad(V : sig
698   type value
699 end) : sig
700   type ref
701   type value = V.value
702   type 'a result = 'a
703   type 'a result_exn = 'a
704   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
705   val newref : value -> ref m
706   val deref : ref -> value m
707   val change : ref -> value -> unit m
708   (* RefT transformer *)
709   module T : functor (Wrapped : Monad.W) -> sig
710     type 'a result = 'a Wrapped.result
711     type 'a result_exn = 'a Wrapped.result_exn
712     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
713     val elevate : 'a Wrapped.m -> 'a m
714     val newref : value -> ref m
715     val deref : ref -> value m
716     val change : ref -> value -> unit m
717   end
718   (* RefT transformer when wrapped monad has plus, zero *)
719   module TP : functor (Wrapped : Monad.WP) -> sig
720     include module type of T(Wrapped)
721     include Monad.PLUS with type 'a m := 'a m
722   end
723 end = struct
724   type ref = int
725   type value = V.value
726   module D = Map.Make(struct type t = ref let compare = compare end)
727   type dict = { next: ref; tree : value D.t }
728   let empty = { next = 0; tree = D.empty }
729   let alloc (value : value) (d : dict) =
730     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
731   let read (key : ref) (d : dict) =
732     D.find key d.tree
733   let write (key : ref) (value : value) (d : dict) =
734     { next = d.next; tree = D.add key value d.tree }
735   module Base = struct
736     type 'a m = dict -> 'a * dict
737     let unit a = fun s -> (a, s)
738     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
739     type 'a result = 'a
740     let run u = fst (u empty)
741     type 'a result_exn = 'a
742     let run_exn = run
743   end
744   include Monad.Make(Base)
745   let newref value = fun s -> alloc value s
746   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
747   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
748   module T(Wrapped : Monad.W) = struct
749     module Trans = struct
750       module Wrapped = Wrapped
751       type 'a m = dict -> ('a * dict) Wrapped.m
752       let elevate w = fun s ->
753         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
754       let bind u f = fun s ->
755         Wrapped.bind (u s) (fun (a, s') -> f a s')
756       type 'a result = 'a Wrapped.result
757       let run u =
758         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
759         in Wrapped.run w
760       type 'a result_exn = 'a Wrapped.result_exn
761       let run_exn u =
762         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
763         in Wrapped.run_exn w
764     end
765     include Monad.MakeT(Trans)
766     let newref value = fun s -> Wrapped.unit (alloc value s)
767     let deref key = fun s -> Wrapped.unit (read key s, s)
768     let change key value = fun s -> Wrapped.unit ((), write key value s)
769   end
770   module TP(Wrapped : Monad.WP) = struct
771     module TransP = struct
772       include T(Wrapped)
773       let plus u v = fun s -> Wrapped.plus (u s) (v s)
774       let zero () = elevate (Wrapped.zero ())
775     end
776     include TransP
777     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
778   end
779 end
780
781
782 (* must be parameterized on (struct type log = ... end) *)
783 module Writer_monad(Log : sig
784   type log
785   val zero : log
786   val plus : log -> log -> log
787 end) : sig
788   (* declare additional operations, while still hiding implementation of type m *)
789   type log = Log.log
790   type 'a result = 'a * log
791   type 'a result_exn = 'a * log
792   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
793   val tell : log -> unit m
794   val listen : 'a m -> ('a * log) m
795   val listens : (log -> 'b) -> 'a m -> ('a * 'b) m
796   (* val pass : ('a * (log -> log)) m -> 'a m *)
797   val censor : (log -> log) -> 'a m -> 'a m
798 end = struct
799   type log = Log.log
800   module Base = struct
801     type 'a m = 'a * log
802     let unit a = (a, Log.zero)
803     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
804     type 'a result = 'a * log
805     let run u = u
806     type 'a result_exn = 'a * log
807     let run_exn = run
808   end
809   include Monad.Make(Base)
810   let tell entries = ((), entries) (* add entries to log *)
811   let listen (a, w) = ((a, w), w)
812   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
813   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
814   let censor f u = pass (u >>= fun a -> unit (a, f))
815 end
816
817 (* pre-define simple Writer *)
818 module Writer1 = Writer_monad(struct
819   type log = string
820   let zero = ""
821   let plus s1 s2 = s1 ^ "\n" ^ s2
822 end)
823
824 (* slightly more efficient Writer *)
825 module Writer2 = struct
826   include Writer_monad(struct
827     type log = string list
828     let zero = []
829     let plus w w' = Util.append w' w
830   end)
831   let tell_string s = tell [s]
832   let tell entries = tell (Util.reverse entries)
833   let run u = let (a, w) = run u in (a, Util.reverse w)
834   let run_exn = run
835 end
836
837
838 module IO_monad : sig
839   (* declare additional operation, while still hiding implementation of type m *)
840   type 'a result = 'a
841   type 'a result_exn = 'a
842   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
843   val printf : ('a, unit, string, unit m) format4 -> 'a
844   val print_string : string -> unit m
845   val print_int : int -> unit m
846   val print_hex : int -> unit m
847   val print_bool : bool -> unit m
848 end = struct
849   module Base = struct
850     type 'a m = { run : unit -> unit; value : 'a }
851     let unit a = { run = (fun () -> ()); value = a }
852     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
853      let fres = f a.value in
854        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
855     type 'a result = 'a
856     let run a = let () = a.run () in a.value
857     type 'a result_exn = 'a
858     let run_exn = run
859   end
860   include Monad.Make(Base)
861   let printf fmt =
862     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
863   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
864   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
865   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
866   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
867 end
868
869 module Continuation_monad : sig
870   (* expose only the implementation of type `('r,'a) result` *)
871   type 'a m
872   type 'a result = 'a m
873   type 'a result_exn = 'a m
874   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
875   (* misses that the answer types of all the cont's must be the same *)
876   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
877   val reset : 'a m -> 'a m
878   (* misses that the answer types of second and third continuations must be b *)
879   val shift : (('a -> 'b m) -> 'b m) -> 'a m
880   (* overwrite the run declaration in S, because I can't declare 'a result =
881    * this polymorphic type (complains that 'r is unbound *)
882   val runk : 'a m -> ('a -> 'r) -> 'r
883   val run0 : 'a m -> 'a
884 end = struct
885   let id = fun i -> i
886   module Base = struct
887     (* 'r is result type of whole computation *)
888     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
889     let unit a =
890       let cont : 'r. ('a -> 'r) -> 'r =
891         fun k -> k a
892       in { cont }
893     let bind u f =
894       let cont : 'r. ('a -> 'r) -> 'r =
895         fun k -> u.cont (fun a -> (f a).cont k)
896       in { cont }
897     type 'a result = 'a m
898     let run (u : 'a m) : 'a result = u
899     type 'a result_exn = 'a m
900     let run_exn (u : 'a m) : 'a result_exn = u
901     let callcc f =
902       let cont : 'r. ('a -> 'r) -> 'r = fun k ->
903           (* Can't figure out how to make the type polymorphic enough
904            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
905            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
906            * with Obj.magic... which tells OCaml's type checker to
907            * relax, the supplied value has whatever type the context
908            * needs it to have. *)
909           let usek a = { cont = Obj.magic (fun _ -> k a) }
910           in (f usek).cont k
911       in { cont }
912     let reset u = unit (u.cont id)
913     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
914       let cont =
915         fun k -> (f (fun a -> unit (k a))).cont id
916       in { cont = Obj.magic cont }
917     let runk u k = u.cont k
918     let run0 u = u.cont id
919   end
920   include Monad.Make(Base)
921   let callcc = Base.callcc
922   let reset = Base.reset
923   let shift = Base.shift
924   let runk = Base.runk
925   let run0 = Base.run0
926 end
927
928 (*
929 (* This two-type parameter version works without Obj.magic *)
930
931 module Continuation_monad2 : sig
932   (* expose only the implementation of type `('r,'a) result` *)
933   type ('r,'a) result = ('a -> 'r) -> 'r
934   type ('r,'a) result_exn = ('a -> 'r) -> 'r
935   include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn
936   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
937   val reset : ('a,'a) m -> ('r,'a) m
938   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
939
940 end = struct
941   let id = fun i -> i
942   module Base = struct
943     (* 'r is result type of whole computation *)
944     type ('r,'a) m = ('a -> 'r) -> 'r
945     let unit a = fun k -> k a
946     let bind u f = fun k -> u (fun a -> (f a) k)
947     type ('r,'a) result = ('a -> 'r) -> 'r
948     let run u = u
949     type ('r,'a) result_exn = ('a -> 'r) -> 'r
950     let run_exn = run
951   end
952   include Monad.Make2(Base)
953   let callcc f = fun k ->
954     let usek a = fun _ -> k a
955     in f usek k
956   (*
957   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
958   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
959   let callcc f = fun k -> f k k
960   let throw k a = fun _ -> k a
961   *)
962   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *)
963   let reset u = unit (u id)
964   let shift u = fun k -> u (fun a -> unit (k a)) id
965 end
966  *)
967
968
969 (*
970  * Scheme:
971  * (define (example n)
972  *    (let ([u (let/cc k ; type int -> int pair
973  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
974  *                 (+ 1 (car v))))]) ; int
975  *      (cons u 0))) ; int pair
976  * ; (example 10) ~~> '(111 . 0)
977  * ; (example -10) ~~> '(0 . 0)
978  *
979  * OCaml monads:
980  * let example n : (int * int) =
981  *   Continuation_monad.(let u = callcc (fun k ->
982  *       (if n < 0 then k 0 else unit [n + 100])
983  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
984  *       >>= fun [x] -> unit (x + 1)
985  *   )
986  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
987  *   >>= fun x -> unit (x, 0)
988  *   in run u)
989  *
990  *
991  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
992  * let example1 () : int =
993  *   Continuation_monad.(let v = reset (
994  *       let u = shift (fun k -> unit (10 + 1))
995  *       in u >>= fun x -> unit (100 + x)
996  *     ) in let w = v >>= fun x -> unit (1000 + x)
997  *     in run w)
998  *
999  * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1000  * let example2 () =
1001  *   Continuation_monad.(let v = reset (
1002  *       let u = shift (fun k -> k (10 :: [1]))
1003  *       in u >>= fun x -> unit (100 :: x)
1004  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1005  *     in run w)
1006  *
1007  * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1008  * let example3 () =
1009  *   Continuation_monad.(let v = reset (
1010  *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1011  *       in u >>= fun x -> unit (100 :: x)
1012  *     ) in let w = v >>= fun x -> unit (1000 :: x)
1013  *     in run w)
1014  *
1015  * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1016  * (* not sure if this example can be typed without a sum-type *)
1017  *
1018  * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1019  * let example5 () : int =
1020  *   Continuation_monad.(let v = reset (
1021  *       let u = shift (fun k -> k 1 >>= fun x -> k x)
1022  *       in u >>= fun x -> unit (10 + x)
1023  *     ) in let w = v >>= fun x -> unit (100 + x)
1024  *     in run w)
1025  *
1026  *)
1027
1028
1029 module Leaf_monad : sig
1030   (* We implement the type as `'a tree option` because it has a natural`plus`,
1031    * and the rest of the library expects that `plus` and `zero` will come together. *)
1032   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1033   type 'a result = 'a tree option
1034   type 'a result_exn = 'a tree
1035   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1036   include Monad.PLUS with type 'a m := 'a m
1037   (* LeafT transformer *)
1038   module T : functor (Wrapped : Monad.W) -> sig
1039     type 'a result = 'a tree option Wrapped.result
1040     type 'a result_exn = 'a tree Wrapped.result_exn
1041     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
1042     include Monad.PLUS with type 'a m := 'a m
1043     val elevate : 'a Wrapped.m -> 'a m
1044     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
1045     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
1046     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
1047   end
1048 end = struct
1049   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
1050   (* uses supplied plus and zero to copy t to its image under f *)
1051   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
1052       | None -> zero ()
1053       | Some ts -> let rec loop ts = (match ts with
1054                      | Leaf a -> f a
1055                      | Node (l, r) ->
1056                          (* recursive application of f may delete a branch *)
1057                          plus (loop l) (loop r)
1058                    ) in loop ts
1059   module Base = struct
1060     type 'a m = 'a tree option
1061     let unit a = Some (Leaf a)
1062     let zero () = None
1063     let plus u v = match (u, v) with
1064       | None, _ -> v
1065       | _, None -> u
1066       | Some us, Some vs -> Some (Node (us, vs))
1067     let bind u f = mapT f u zero plus
1068     type 'a result = 'a tree option
1069     let run u = u
1070     type 'a result_exn = 'a tree
1071     let run_exn u = match u with
1072       | None -> failwith "no values"
1073       (*
1074       | Some (Leaf a) -> a
1075       | many -> failwith "multiple values"
1076       *)
1077       | Some us -> us
1078   end
1079   include Monad.Make(Base)
1080   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
1081   let base_plus = plus
1082   let base_lift = lift
1083   module T(Wrapped : Monad.W) = struct
1084     module Trans = struct
1085       let zero () = Wrapped.unit None
1086       let plus u v =
1087         Wrapped.bind u (fun us ->
1088         Wrapped.bind v (fun vs ->
1089         Wrapped.unit (base_plus us vs)))
1090       include Monad.MakeT(struct
1091         module Wrapped = Wrapped
1092         type 'a m = 'a Base.m Wrapped.m
1093         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1094         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1095         type 'a result = 'a tree option Wrapped.result
1096         let run u = Wrapped.run u
1097         type 'a result_exn = 'a tree Wrapped.result_exn
1098         let run_exn u =
1099             let w = Wrapped.bind u (fun t -> match t with
1100               | None -> failwith "no values"
1101               | Some ts -> Wrapped.unit ts)
1102             in Wrapped.run_exn w
1103       end)
1104     end
1105     include Trans
1106     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
1107     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
1108     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1109   end
1110 end
1111
1112
1113 module L = List_monad;;
1114 module R = Reader_monad(struct type env = int -> int end);;
1115 module S = State_monad(struct type store = int end);;
1116 module T = Leaf_monad;;
1117 module LR = L.T(R);;
1118 module LS = L.T(S);;
1119 module TL = T.T(L);;
1120 module TR = T.T(R);;
1121 module TS = T.T(S);;
1122
1123 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
1124
1125 (*
1126 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
1127 TS.run ts 0;;
1128 (*
1129 - : int T.tree option * S.store =
1130 (Some
1131   (T.Node
1132     (T.Node (T.Leaf 2, T.Leaf 3),
1133      T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
1134  5)
1135 *)
1136
1137 let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
1138 TS.run_exn ts2 0;;
1139 (*
1140 - : (int * S.store) T.tree option * S.store =
1141 (Some
1142   (T.Node
1143     (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
1144      T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
1145  5)
1146 *)
1147
1148 let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
1149 TR.run_exn tr (fun i -> i+i);;
1150 (*
1151 - : int T.tree option =
1152 Some
1153  (T.Node
1154    (T.Node (T.Leaf 4, T.Leaf 6),
1155     T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
1156 *)
1157
1158 let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
1159 TL.run_exn tl;;
1160 (*
1161 - : (int * int) TL.result =
1162 [Some
1163   (T.Node
1164     (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
1165      T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
1166 *)
1167
1168 let l2 = [1;2;3;4;5];;
1169 let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
1170
1171 LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1172 (* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
1173
1174 TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
1175 (*
1176 int T.tree option =
1177 Some
1178  (T.Node
1179    (T.Node (T.Leaf 10, T.Leaf 11),
1180     T.Node
1181      (T.Node
1182        (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
1183         T.Node (T.Leaf 40, T.Leaf 41)),
1184       T.Node (T.Leaf 50, T.Leaf 51))))
1185  *)
1186
1187 LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
1188 (*
1189 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
1190 *)
1191
1192 *)
1193
1194 let id : 'z. 'z -> 'z = fun x -> x
1195
1196 let example n : (int * int) =
1197   Continuation_monad.(let u = callcc (fun k ->
1198       (if n < 0 then k 0 else unit [n + 100])
1199       (* all of the following is skipped by k 0; the end type int is k's input type *)
1200       >>= fun [x] -> unit (x + 1)
1201   )
1202   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
1203   >>= fun x -> unit (x, 0)
1204   in run0 u)
1205
1206
1207 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
1208 let example1 () : int =
1209   Continuation_monad.(let v = reset (
1210       let u = shift (fun k -> unit (10 + 1))
1211       in u >>= fun x -> unit (100 + x)
1212     ) in let w = v >>= fun x -> unit (1000 + x)
1213     in run0 w)
1214
1215 (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
1216 let example2 () =
1217   Continuation_monad.(let v = reset (
1218       let u = shift (fun k -> k (10 :: [1]))
1219       in u >>= fun x -> unit (100 :: x)
1220     ) in let w = v >>= fun x -> unit (1000 :: x)
1221     in run0 w)
1222
1223 (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
1224 let example3 () =
1225   Continuation_monad.(let v = reset (
1226       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
1227       in u >>= fun x -> unit (100 :: x)
1228     ) in let w = v >>= fun x -> unit (1000 :: x)
1229     in run0 w)
1230
1231 (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
1232 (* not sure if this example can be typed without a sum-type *)
1233
1234 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
1235 let example5 () : int =
1236   Continuation_monad.(let v = reset (
1237       let u = shift (fun k -> k 1 >>= fun x -> k x)
1238       in u >>= fun x -> unit (10 + x)
1239     ) in let w = v >>= fun x -> unit (100 + x)
1240     in run0 w)
1241
1242
1243 ;;
1244
1245 (1011, 1111, 1111, 121);;
1246 (example1(), example2(), example3(), example5());;
1247 ((111,0), (0,0));;
1248 (example ~+10, example ~-10);;
1249
1250 module C = Continuation_monad
1251 module TC = T.T(C)
1252
1253 let testc df ic =
1254     C.runk TC.(run_exn (distribute df t1)) ic;;
1255
1256
1257 (*
1258 (* do nothing *)
1259 let initial_continuation = fun t -> t in
1260 TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
1261 *)
1262 testc (C.unit) id;;
1263
1264 (*
1265 (* count leaves, using continuation *)
1266 let initial_continuation = fun t -> 0 in
1267 TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
1268 *)
1269
1270 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
1271
1272 (*
1273 (* convert tree to list of leaves *)
1274 let initial_continuation = fun t -> [] in
1275 TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
1276 *)
1277
1278 testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
1279
1280 (*
1281 (* square each leaf using continuation *)
1282 let initial_continuation = fun t -> t in
1283 TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
1284 *)
1285
1286 testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
1287
1288
1289 (*
1290 (* replace leaves with list, using continuation *)
1291 let initial_continuation = fun t -> t in
1292 TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
1293 *)
1294
1295 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
1296