8b69ec1d1a0a909b62ae7353b3487e6c642169d5
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glasgow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources this is partly
55  * derived from)
56  *)
57
58 exception Undefined
59
60 (* Some library functions used below. *)
61 module Util = struct
62   let fold_right = List.fold_right
63   let map = List.map
64   let append = List.append
65   let reverse = List.rev
66   let concat = List.concat
67   let concat_map f lst = List.concat (List.map f lst)
68   (* let zip = List.combine *)
69   let unzip = List.split
70   let zip_with = List.map2
71   let replicate len fill =
72     let rec loop n accu =
73       if n == 0 then accu else loop (pred n) (fill :: accu)
74     in loop len []
75   (* Dirty hack to be a default polymorphic zero.
76    * To implement this cleanly, monads without a natural zero
77    * should always wrap themselves in an option layer (see Tree_monad). *)
78   let undef = Obj.magic (fun () -> raise Undefined)
79 end
80
81
82
83 (*
84  * This module contains factories that extend a base set of
85  * monadic definitions with a larger family of standard derived values.
86  *)
87
88 module Monad = struct
89   (*
90    * Signature extenders:
91    *   Make :: BASE -> S
92    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
93    *)
94
95
96   (* type of base definitions *)
97   module type BASE = sig
98     (* We make all monadic types doubly-parameterized so that they
99      * can layer nicely with Continuation, which needs the second
100      * type parameter. *)
101     type ('x,'a) m
102     type ('x,'a) result
103     type ('x,'a) result_exn
104     val unit : 'a -> ('x,'a) m
105     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
106     val run : ('x,'a) m -> ('x,'a) result
107     (* run_exn tries to provide a more ground-level result, but may fail *)
108     val run_exn : ('x,'a) m -> ('x,'a) result_exn
109     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
110      *     zero >>= f   ===  zero
111      *     plus zero u  ===  u
112      *     plus u zero  ===  u
113      * Additionally, they will obey one of the following laws:
114      *     (Catch)   plus (unit a) v  ===  unit a
115      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
116      * When no natural zero is available, use `let zero () = Util.undef`.
117      * The Make functor automatically detects for zero >>= ..., and 
118      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
119      *)
120     val zero : unit -> ('x,'a) m
121     (* zero has to be thunked to ensure results are always poly enough *)
122     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
123   end
124   module type S = sig
125     include BASE
126     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
127     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
128     val join : ('x,('x,'a) m) m -> ('x,'a) m
129     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
130     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
131     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
132     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
133     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
134     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
135     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
136     val sequence : ('x,'a) m list -> ('x,'a list) m
137     val sequence_ : ('x,'a) m list -> ('x,unit) m
138     val guard : bool -> ('x,unit) m
139     val sum : ('x,'a) m list -> ('x,'a) m
140   end
141
142   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
143     include B
144     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
145       if u == Util.undef then Util.undef
146       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
147     let plus u v =
148       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
149     let run u =
150       if u == Util.undef then raise Undefined else B.run u
151     let run_exn u =
152       if u == Util.undef then raise Undefined else B.run_exn u
153     let (>>=) = bind
154     (* expressions after >> will be evaluated before they're passed to
155      * bind, so you can't do `zero () >> assert false`
156      * this works though: `zero () >>= fun _ -> assert false`
157      *)
158     let (>>) u v = u >>= fun _ -> v
159     let lift f u = u >>= fun a -> unit (f a)
160     (* lift is called listM, fmap, and <$> in Haskell *)
161     let join uu = uu >>= fun u -> u
162     (* u >>= f === join (lift f u) *)
163     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
164     (* [f] <*> [x1,x2] = [f x1,f x2] *)
165     (* let apply u v = u >>= fun f -> lift f v *)
166     (* let apply = lift2 id *)
167     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
168     (* let lift f u === apply (unit f) u *)
169     (* let lift2 f u v = apply (lift f u) v *)
170     let (>=>) f g = fun a -> f a >>= g
171     let do_when test u = if test then u else unit ()
172     let do_unless test u = if test then unit () else u
173     (* A Haskell-like version works:
174          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
175      * but the recursive call is not in tail position so this can stack overflow. *)
176     let forever uthunk =
177         let z = zero () in
178         let id result = result in
179         let kcell = ref id in
180         let rec loop _ =
181             let result = uthunk (kcell := id) >>= chained
182             in !kcell result
183         and chained _ =
184             kcell := loop; z (* we use z only for its polymorphism *)
185         in loop z
186     (* Reimplementations of the preceding using a hand-rolled State or StateT
187 can also stack overflow. *)
188     let sequence ms =
189       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
190         Util.fold_right op ms (unit [])
191     let sequence_ ms =
192       Util.fold_right (>>) ms (unit ())
193
194     (* Haskell defines these other operations combining lists and monads.
195      * We don't, but notice that M.mapM == ListT(M).distribute
196      * There's also a parallel TreeT(M).distribute *)
197     (*
198     let mapM f alist = sequence (Util.map f alist)
199     let mapM_ f alist = sequence_ (Util.map f alist)
200     let rec filterM f lst = match lst with
201       | [] -> unit []
202       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
203     let forM alist f = mapM f alist
204     let forM_ alist f = mapM_ f alist
205     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
206     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
207     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
208     let rec foldM f z lst = match lst with
209       | [] -> unit z
210       | x::xs -> f z x >>= fun z' -> foldM f z' xs
211     let foldM_ f z xs = foldM f z xs >> unit ()
212     let replicateM n x = sequence (Util.replicate n x)
213     let replicateM_ n x = sequence_ (Util.replicate n x)
214     *)
215     let guard test = if test then B.unit () else zero ()
216     let sum ms = Util.fold_right plus ms (zero ())
217   end
218
219   (* Signatures for MonadT *)
220   module type BASET = sig
221     module Wrapped : S
222     type ('x,'a) m
223     type ('x,'a) result
224     type ('x,'a) result_exn
225     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
226     val run : ('x,'a) m -> ('x,'a) result
227     val run_exn : ('x,'a) m -> ('x,'a) result_exn
228     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
229     (* lift/elevate laws:
230      *     elevate (W.unit a) == unit a
231      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
232      *)
233     val zero : unit -> ('x,'a) m
234     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
235   end
236   module MakeT(T : BASET) = struct
237     include Make(struct
238         include T
239         let unit a = elevate (Wrapped.unit a)
240     end)
241     let elevate = T.elevate
242   end
243
244 end
245
246
247
248
249
250 module Identity_monad : sig
251   (* expose only the implementation of type `'a result` *)
252   type ('x,'a) result = 'a
253   type ('x,'a) result_exn = 'a
254   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
255 end = struct
256   module Base = struct
257     type ('x,'a) m = 'a
258     type ('x,'a) result = 'a
259     type ('x,'a) result_exn = 'a
260     let unit a = a
261     let bind a f = f a
262     let run a = a
263     let run_exn a = a
264     let zero () = Util.undef
265     let plus u v = u
266   end
267   include Monad.Make(Base)
268 end
269
270
271 module Maybe_monad : sig
272   (* expose only the implementation of type `'a result` *)
273   type ('x,'a) result = 'a option
274   type ('x,'a) result_exn = 'a
275   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
276   (* MaybeT transformer *)
277   module T : functor (Wrapped : Monad.S) -> sig
278     type ('x,'a) result = ('x,'a option) Wrapped.result
279     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
280     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
281     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
282   end
283 end = struct
284   module Base = struct
285     type ('x,'a) m = 'a option
286     type ('x,'a) result = 'a option
287     type ('x,'a) result_exn = 'a
288     let unit a = Some a
289     let bind u f = match u with Some a -> f a | None -> None
290     let run u = u
291     let run_exn u = match u with
292       | Some a -> a
293       | None -> failwith "no value"
294     let zero () = None
295     (* satisfies Catch *)
296     let plus u v = match u with None -> v | _ -> u
297   end
298   include Monad.Make(Base)
299   module T(Wrapped : Monad.S) = struct
300     module BaseT = struct
301       include Monad.MakeT(struct
302         module Wrapped = Wrapped
303         type ('x,'a) m = ('x,'a option) Wrapped.m
304         type ('x,'a) result = ('x,'a option) Wrapped.result
305         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
306         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
307         let bind u f = Wrapped.bind u (fun t -> match t with
308           | Some a -> f a
309           | None -> Wrapped.unit None)
310         let run u = Wrapped.run u
311         let run_exn u =
312           let w = Wrapped.bind u (fun t -> match t with
313             | Some a -> Wrapped.unit a
314             | None -> Wrapped.zero ()
315           ) in Wrapped.run_exn w
316         let zero () = Wrapped.unit None
317         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
318       end)
319     end
320     include BaseT
321   end
322 end
323
324
325 module List_monad : sig
326   (* declare additional operation, while still hiding implementation of type m *)
327   type ('x,'a) result = 'a list
328   type ('x,'a) result_exn = 'a
329   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
330   val permute : ('x,'a) m -> ('x,('x,'a) m) m
331   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
332   (* ListT transformer *)
333   module T : functor (Wrapped : Monad.S) -> sig
334     type ('x,'a) result = ('x,'a list) Wrapped.result
335     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
336     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
337     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
338     (* note that second argument is an 'a list, not the more abstract 'a m *)
339     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
340     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
341 (* TODO
342     val permute : 'a m -> 'a m m
343     val select : 'a m -> ('a * 'a m) m
344 *)
345     val expose : ('x,'a) m -> ('x,'a list) Wrapped.m
346   end
347 end = struct
348   module Base = struct
349    type ('x,'a) m = 'a list
350    type ('x,'a) result = 'a list
351    type ('x,'a) result_exn = 'a
352    let unit a = [a]
353    let bind u f = Util.concat_map f u
354    let run u = u
355    let run_exn u = match u with
356      | [] -> failwith "no values"
357      | [a] -> a
358      | many -> failwith "multiple values"
359    let zero () = []
360    (* satisfies Distrib *)
361    let plus = Util.append
362   end
363   include Monad.Make(Base)
364   (* let either u v = plus u v *)
365   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
366   let rec insert a u =
367     plus (unit (a :: u)) (match u with
368         | [] -> zero ()
369         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
370     )
371   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
372   let rec permute u = match u with
373       | [] -> unit []
374       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
375   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
376   let rec select u = match u with
377     | [] -> zero ()
378     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
379   module T(Wrapped : Monad.S) = struct
380     (* Wrapped.sequence ms  ===  
381          let plus1 u v =
382            Wrapped.bind u (fun x ->
383            Wrapped.bind v (fun xs ->
384            Wrapped.unit (x :: xs)))
385          in Util.fold_right plus1 ms (Wrapped.unit []) *)
386     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
387     let distribute f alist = Wrapped.sequence (Util.map f alist)
388
389     include Monad.MakeT(struct
390       module Wrapped = Wrapped
391       type ('x,'a) m = ('x,'a list) Wrapped.m
392       type ('x,'a) result = ('x,'a list) Wrapped.result
393       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
394       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
395       let bind u f =
396         Wrapped.bind u (fun ts ->
397         Wrapped.bind (distribute f ts) (fun tts ->
398         Wrapped.unit (Util.concat tts)))
399       let run u = Wrapped.run u
400       let run_exn u =
401         let w = Wrapped.bind u (fun ts -> match ts with
402           | [] -> Wrapped.zero ()
403           | [a] -> Wrapped.unit a
404           | many -> Wrapped.zero ()
405         ) in Wrapped.run_exn w
406       let zero () = Wrapped.unit []
407       let plus u v =
408         Wrapped.bind u (fun us ->
409         Wrapped.bind v (fun vs ->
410         Wrapped.unit (Base.plus us vs)))
411     end)
412 (*
413     let permute : 'a m -> 'a m m
414     let select : 'a m -> ('a * 'a m) m
415 *)
416     let expose u = u
417   end
418 end
419
420
421 (* must be parameterized on (struct type err = ... end) *)
422 module Error_monad(Err : sig
423   type err
424   exception Exc of err
425   (*
426   val zero : unit -> err
427   val plus : err -> err -> err
428   *)
429 end) : sig
430   (* declare additional operations, while still hiding implementation of type m *)
431   type err = Err.err
432   type 'a error = Error of err | Success of 'a
433   type ('x,'a) result = 'a error
434   type ('x,'a) result_exn = 'a
435   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
436   val throw : err -> ('x,'a) m
437   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
438   (* ErrorT transformer *)
439   module T : functor (Wrapped : Monad.S) -> sig
440     type ('x,'a) result = ('x,'a) Wrapped.result
441     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
442     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
443     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
444     val throw : err -> ('x,'a) m
445     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
446   end
447 end = struct
448   type err = Err.err
449   type 'a error = Error of err | Success of 'a
450   module Base = struct
451     type ('x,'a) m = 'a error
452     type ('x,'a) result = 'a error
453     type ('x,'a) result_exn = 'a
454     let unit a = Success a
455     let bind u f = match u with
456       | Success a -> f a
457       | Error e -> Error e (* input and output may be of different 'a types *)
458     let run u = u
459     let run_exn u = match u with
460       | Success a -> a
461       | Error e -> raise (Err.Exc e)
462     let zero () = Util.undef
463     (* satisfies Catch *)
464     let plus u v = match u with
465       | Success _ -> u
466       | Error _ -> if v == Util.undef then u else v
467   end
468   include Monad.Make(Base)
469   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
470   let throw e = Error e
471   let catch u handler = match u with
472     | Success _ -> u
473     | Error e -> handler e
474   module T(Wrapped : Monad.S) = struct
475     include Monad.MakeT(struct
476       module Wrapped = Wrapped
477       type ('x,'a) m = ('x,'a error) Wrapped.m
478       type ('x,'a) result = ('x,'a) Wrapped.result
479       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
480       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
481       let bind u f = Wrapped.bind u (fun t -> match t with
482         | Success a -> f a
483         | Error e -> Wrapped.unit (Error e))
484       let run u =
485         let w = Wrapped.bind u (fun t -> match t with
486           | Success a -> Wrapped.unit a
487           | Error e -> Wrapped.zero ()
488         ) in Wrapped.run w
489       let run_exn u =
490         let w = Wrapped.bind u (fun t -> match t with
491           | Success a -> Wrapped.unit a
492           | Error e -> raise (Err.Exc e))
493         in Wrapped.run_exn w
494       let plus u v = Wrapped.plus u v
495       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
496     end)
497     let throw e = Wrapped.unit (Error e)
498     let catch u handler = Wrapped.bind u (fun t -> match t with
499       | Success _ -> Wrapped.unit t
500       | Error e -> handler e)
501   end
502 end
503
504 (* pre-define common instance of Error_monad *)
505 module Failure = Error_monad(struct
506   type err = string
507   exception Exc = Failure
508   (*
509   let zero = ""
510   let plus s1 s2 = s1 ^ "\n" ^ s2
511   *)
512 end)
513
514
515 (* must be parameterized on (struct type env = ... end) *)
516 module Reader_monad(Env : sig type env end) : sig
517   (* declare additional operations, while still hiding implementation of type m *)
518   type env = Env.env
519   type ('x,'a) result = env -> 'a
520   type ('x,'a) result_exn = env -> 'a
521   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
522   val ask : ('x,env) m
523   val asks : (env -> 'a) -> ('x,'a) m
524   (* lookup i == `fun e -> e i` would assume env is a functional type *)
525   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
526   (* ReaderT transformer *)
527   module T : functor (Wrapped : Monad.S) -> sig
528     type ('x,'a) result = env -> ('x,'a) Wrapped.result
529     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
530     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
531     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
532     val ask : ('x,env) m
533     val asks : (env -> 'a) -> ('x,'a) m
534     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
535     val expose : ('x,'a) m -> env -> ('x,'a) Wrapped.m
536   end
537 end = struct
538   type env = Env.env
539   module Base = struct
540     type ('x,'a) m = env -> 'a
541     type ('x,'a) result = env -> 'a
542     type ('x,'a) result_exn = env -> 'a
543     let unit a = fun e -> a
544     let bind u f = fun e -> let a = u e in let u' = f a in u' e
545     let run u = fun e -> u e
546     let run_exn = run
547     let zero () = Util.undef
548     let plus u v = u
549   end
550   include Monad.Make(Base)
551   let ask = fun e -> e
552   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
553   let local modifier u = fun e -> u (modifier e)
554   module T(Wrapped : Monad.S) = struct
555     module BaseT = struct
556       module Wrapped = Wrapped
557       type ('x,'a) m = env -> ('x,'a) Wrapped.m
558       type ('x,'a) result = env -> ('x,'a) Wrapped.result
559       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
560       let elevate w = fun e -> w
561       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
562       let run u = fun e -> Wrapped.run (u e)
563       let run_exn u = fun e -> Wrapped.run_exn (u e)
564       (* satisfies Distrib *)
565       let plus u v = fun e -> Wrapped.plus (u e) (v e)
566       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
567     end
568     include Monad.MakeT(BaseT)
569     let ask = Wrapped.unit
570     let local modifier u = fun e -> u (modifier e)
571     let asks selector = ask >>= (fun e ->
572       try unit (selector e)
573       with Not_found -> fun e -> Wrapped.zero ())
574     let expose u = u
575   end
576 end
577
578
579 (* must be parameterized on (struct type store = ... end) *)
580 module State_monad(Store : sig type store end) : sig
581   (* declare additional operations, while still hiding implementation of type m *)
582   type store = Store.store
583   type ('x,'a) result =  store -> 'a * store
584   type ('x,'a) result_exn = store -> 'a
585   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
586   val get : ('x,store) m
587   val gets : (store -> 'a) -> ('x,'a) m
588   val put : store -> ('x,unit) m
589   val puts : (store -> store) -> ('x,unit) m
590   (* StateT transformer *)
591   module T : functor (Wrapped : Monad.S) -> sig
592     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
593     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
594     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
595     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
596     val get : ('x,store) m
597     val gets : (store -> 'a) -> ('x,'a) m
598     val put : store -> ('x,unit) m
599     val puts : (store -> store) -> ('x,unit) m
600     (* val passthru : ('x,'a) m -> (('x,'a * store) Wrapped.result * store -> 'b) -> ('x,'b) m *)
601     val expose : ('x,'a) m -> store -> ('x,'a * store) Wrapped.m
602   end
603 end = struct
604   type store = Store.store
605   module Base = struct
606     type ('x,'a) m =  store -> 'a * store
607     type ('x,'a) result =  store -> 'a * store
608     type ('x,'a) result_exn = store -> 'a
609     let unit a = fun s -> (a, s)
610     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
611     let run u = fun s -> (u s)
612     let run_exn u = fun s -> fst (u s)
613     let zero () = Util.undef
614     let plus u v = u
615   end
616   include Monad.Make(Base)
617   let get = fun s -> (s, s)
618   let gets viewer = fun s -> (viewer s, s) (* may fail *)
619   let put s = fun _ -> ((), s)
620   let puts modifier = fun s -> ((), modifier s)
621   module T(Wrapped : Monad.S) = struct
622     module BaseT = struct
623       module Wrapped = Wrapped
624       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
625       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
626       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
627       let elevate w = fun s ->
628         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
629       let bind u f = fun s ->
630         Wrapped.bind (u s) (fun (a, s') -> f a s')
631       let run u = fun s -> Wrapped.run (u s)
632       let run_exn u = fun s ->
633         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
634         in Wrapped.run_exn w
635       (* satisfies Distrib *)
636       let plus u v = fun s -> Wrapped.plus (u s) (v s)
637       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
638     end
639     include Monad.MakeT(BaseT)
640     let get = fun s -> Wrapped.unit (s, s)
641     let gets viewer = fun s ->
642       try Wrapped.unit (viewer s, s)
643       with Not_found -> Wrapped.zero ()
644     let put s = fun _ -> Wrapped.unit ((), s)
645     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
646     (* let passthru u f = fun s -> Wrapped.unit (f (Wrapped.run (u s), s), s) *)
647     let expose u = u
648   end
649 end
650
651
652 (* State monad with different interface (structured store) *)
653 module Ref_monad(V : sig
654   type value
655 end) : sig
656   type ref
657   type value = V.value
658   type ('x,'a) result = 'a
659   type ('x,'a) result_exn = 'a
660   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
661   val newref : value -> ('x,ref) m
662   val deref : ref -> ('x,value) m
663   val change : ref -> value -> ('x,unit) m
664   (* RefT transformer *)
665   module T : functor (Wrapped : Monad.S) -> sig
666     type ('x,'a) result = ('x,'a) Wrapped.result
667     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
668     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
669     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
670     val newref : value -> ('x,ref) m
671     val deref : ref -> ('x,value) m
672     val change : ref -> value -> ('x,unit) m
673   end
674 end = struct
675   type ref = int
676   type value = V.value
677   module D = Map.Make(struct type t = ref let compare = compare end)
678   type dict = { next: ref; tree : value D.t }
679   let empty = { next = 0; tree = D.empty }
680   let alloc (value : value) (d : dict) =
681     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
682   let read (key : ref) (d : dict) =
683     D.find key d.tree
684   let write (key : ref) (value : value) (d : dict) =
685     { next = d.next; tree = D.add key value d.tree }
686   module Base = struct
687     type ('x,'a) m = dict -> 'a * dict
688     type ('x,'a) result = 'a
689     type ('x,'a) result_exn = 'a
690     let unit a = fun s -> (a, s)
691     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
692     let run u = fst (u empty)
693     let run_exn = run
694     let zero () = Util.undef
695     let plus u v = u
696   end
697   include Monad.Make(Base)
698   let newref value = fun s -> alloc value s
699   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
700   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
701   module T(Wrapped : Monad.S) = struct
702     module BaseT = struct
703       module Wrapped = Wrapped
704       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
705       type ('x,'a) result = ('x,'a) Wrapped.result
706       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
707       let elevate w = fun s ->
708         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
709       let bind u f = fun s ->
710         Wrapped.bind (u s) (fun (a, s') -> f a s')
711       let run u =
712         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
713         in Wrapped.run w
714       let run_exn u =
715         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
716         in Wrapped.run_exn w
717       (* satisfies Distrib *)
718       let plus u v = fun s -> Wrapped.plus (u s) (v s)
719       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
720     end
721     include Monad.MakeT(BaseT)
722     let newref value = fun s -> Wrapped.unit (alloc value s)
723     let deref key = fun s -> Wrapped.unit (read key s, s)
724     let change key value = fun s -> Wrapped.unit ((), write key value s)
725   end
726 end
727
728
729 (* must be parameterized on (struct type log = ... end) *)
730 module Writer_monad(Log : sig
731   type log
732   val zero : log
733   val plus : log -> log -> log
734 end) : sig
735   (* declare additional operations, while still hiding implementation of type m *)
736   type log = Log.log
737   type ('x,'a) result = 'a * log
738   type ('x,'a) result_exn = 'a * log
739   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
740   val tell : log -> ('x,unit) m
741   val listen : ('x,'a) m -> ('x,'a * log) m
742   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
743   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
744   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
745   (* WriterT transformer *)
746   module T : functor (Wrapped : Monad.S) -> sig
747     type ('x,'a) result = ('x,'a * log) Wrapped.result
748     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
749     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
750     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
751     val tell : log -> ('x,unit) m
752     val listen : ('x,'a) m -> ('x,'a * log) m
753     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
754     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
755   end
756 end = struct
757   type log = Log.log
758   module Base = struct
759     type ('x,'a) m = 'a * log
760     type ('x,'a) result = 'a * log
761     type ('x,'a) result_exn = 'a * log
762     let unit a = (a, Log.zero)
763     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
764     let run u = u
765     let run_exn = run
766     let zero () = Util.undef
767     let plus u v = u
768   end
769   include Monad.Make(Base)
770   let tell entries = ((), entries) (* add entries to log *)
771   let listen (a, w) = ((a, w), w)
772   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
773   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
774   let censor f u = pass (u >>= fun a -> unit (a, f))
775   module T(Wrapped : Monad.S) = struct
776     module BaseT = struct
777       module Wrapped = Wrapped
778       type ('x,'a) m = ('x,'a * log) Wrapped.m
779       type ('x,'a) result = ('x,'a * log) Wrapped.result
780       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
781       let elevate w =
782         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
783       let bind u f =
784         Wrapped.bind u (fun (a, w) ->
785         Wrapped.bind (f a) (fun (b, w') ->
786         Wrapped.unit (b, Log.plus w w')))
787       let zero () = elevate (Wrapped.zero ())
788       let plus u v = Wrapped.plus u v
789       let run u = Wrapped.run u
790       let run_exn u = Wrapped.run_exn u
791     end
792     include Monad.MakeT(BaseT)
793     let tell entries = Wrapped.unit ((), entries)
794     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
795     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
796     (* rest are derived in same way as before *)
797     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
798     let censor f u = pass (u >>= fun a -> unit (a, f))
799   end
800 end
801
802 (* pre-define simple Writer *)
803 module Writer1 = Writer_monad(struct
804   type log = string
805   let zero = ""
806   let plus s1 s2 = s1 ^ "\n" ^ s2
807 end)
808
809 (* slightly more efficient Writer *)
810 module Writer2 = struct
811   include Writer_monad(struct
812     type log = string list
813     let zero = []
814     let plus w w' = Util.append w' w
815   end)
816   let tell_string s = tell [s]
817   let tell entries = tell (Util.reverse entries)
818   let run u = let (a, w) = run u in (a, Util.reverse w)
819   let run_exn = run
820 end
821
822
823 (* TODO needs a T *)
824 module IO_monad : sig
825   (* declare additional operation, while still hiding implementation of type m *)
826   type ('x,'a) result = 'a
827   type ('x,'a) result_exn = 'a
828   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
829   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
830   val print_string : string -> ('x,unit) m
831   val print_int : int -> ('x,unit) m
832   val print_hex : int -> ('x,unit) m
833   val print_bool : bool -> ('x,unit) m
834 end = struct
835   module Base = struct
836     type ('x,'a) m = { run : unit -> unit; value : 'a }
837     type ('x,'a) result = 'a
838     type ('x,'a) result_exn = 'a
839     let unit a = { run = (fun () -> ()); value = a }
840     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
841      let fres = f a.value in
842        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
843     let run a = let () = a.run () in a.value
844     let run_exn = run
845     let zero () = Util.undef
846     let plus u v = u
847   end
848   include Monad.Make(Base)
849   let printf fmt =
850     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
851   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
852   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
853   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
854   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
855 end
856
857
858 module Continuation_monad : sig
859   (* expose only the implementation of type `('r,'a) result` *)
860   type ('r,'a) m
861   type ('r,'a) result = ('r,'a) m
862   type ('r,'a) result_exn = ('a -> 'r) -> 'r
863   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
864   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
865   val reset : ('a,'a) m -> ('r,'a) m
866   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
867   (* val abort : ('a,'a) m -> ('a,'b) m *)
868   val abort : 'a -> ('a,'b) m
869   val run0 : ('a,'a) m -> 'a
870   (* ContinuationT transformer *)
871   module T : functor (Wrapped : Monad.S) -> sig
872     type ('r,'a) m
873     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
874     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
875     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
876     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
877     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
878     (* TODO: reset,shift,abort,run0 *)
879   end
880 end = struct
881   let id = fun i -> i
882   module Base = struct
883     (* 'r is result type of whole computation *)
884     type ('r,'a) m = ('a -> 'r) -> 'r
885     type ('r,'a) result = ('a -> 'r) -> 'r
886     type ('r,'a) result_exn = ('r,'a) result
887     let unit a = (fun k -> k a)
888     let bind u f = (fun k -> (u) (fun a -> (f a) k))
889     let run u k = (u) k
890     let run_exn = run
891     let zero () = Util.undef
892     let plus u v = u
893   end
894   include Monad.Make(Base)
895   let callcc f = (fun k ->
896     let usek a = (fun _ -> k a)
897     in (f usek) k)
898   (*
899   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
900   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
901   let callcc f = fun k -> f k k
902   let throw k a = fun _ -> k a
903   *)
904
905   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
906    *
907    *  reset :: (Monad m) => ContT a m a -> ContT r m a
908    *  reset e = ContT $ \k -> runContT e return >>= k
909    *
910    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
911    *  shift e = ContT $ \k ->
912    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
913   let reset u = unit ((u) id)
914   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
915   (* let abort a = shift (fun _ -> a) *)
916   let abort a = shift (fun _ -> unit a)
917   let run0 (u : ('a,'a) m) = (u) id
918   module T(Wrapped : Monad.S) = struct
919     module BaseT = struct
920       module Wrapped = Wrapped
921       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
922       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
923       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
924       let elevate w = fun k -> Wrapped.bind w k
925       let bind u f = fun k -> u (fun a -> f a k)
926       let run u k = Wrapped.run (u k)
927       let run_exn u k = Wrapped.run_exn (u k)
928       let zero () = Util.undef
929       let plus u v = u
930     end
931     include Monad.MakeT(BaseT)
932     let callcc f = (fun k ->
933       let usek a = (fun _ -> k a)
934       in (f usek) k)
935   end
936 end
937
938
939 (*
940  * Scheme:
941  * (define (example n)
942  *    (let ([u (let/cc k ; type int -> int pair
943  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
944  *                 (+ 1 (car v))))]) ; int
945  *      (cons u 0))) ; int pair
946  * ; (example 10) ~~> '(111 . 0)
947  * ; (example -10) ~~> '(0 . 0)
948  *
949  * OCaml monads:
950  * let example n : (int * int) =
951  *   Continuation_monad.(let u = callcc (fun k ->
952  *       (if n < 0 then k 0 else unit [n + 100])
953  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
954  *       >>= fun [x] -> unit (x + 1)
955  *   )
956  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
957  *   >>= fun x -> unit (x, 0)
958  *   in run u)
959  *
960  *)
961
962
963 module Tree_monad : sig
964   (* We implement the type as `'a tree option` because it has a natural`plus`,
965    * and the rest of the library expects that `plus` and `zero` will come together. *)
966   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
967   type ('x,'a) result = 'a tree option
968   type ('x,'a) result_exn = 'a tree
969   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
970   (* TreeT transformer *)
971   module T : functor (Wrapped : Monad.S) -> sig
972     type ('x,'a) result = ('x,'a tree option) Wrapped.result
973     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
974     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
975     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
976     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
977     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
978     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
979     val expose : ('x,'a) m -> ('x,'a tree option) Wrapped.m
980   end
981 end = struct
982   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
983   (* uses supplied plus and zero to copy t to its image under f *)
984   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
985       | None -> zero ()
986       | Some ts -> let rec loop ts = (match ts with
987                      | Leaf a -> f a
988                      | Node (l, r) ->
989                          (* recursive application of f may delete a branch *)
990                          plus (loop l) (loop r)
991                    ) in loop ts
992   module Base = struct
993     type ('x,'a) m = 'a tree option
994     type ('x,'a) result = 'a tree option
995     type ('x,'a) result_exn = 'a tree
996     let unit a = Some (Leaf a)
997     let zero () = None
998     (* satisfies Distrib *)
999     let plus u v = match (u, v) with
1000       | None, _ -> v
1001       | _, None -> u
1002       | Some us, Some vs -> Some (Node (us, vs))
1003     let bind u f = mapT f u zero plus
1004     let run u = u
1005     let run_exn u = match u with
1006       | None -> failwith "no values"
1007       (*
1008       | Some (Leaf a) -> a
1009       | many -> failwith "multiple values"
1010       *)
1011       | Some us -> us
1012   end
1013   include Monad.Make(Base)
1014   module T(Wrapped : Monad.S) = struct
1015     module BaseT = struct
1016       include Monad.MakeT(struct
1017         module Wrapped = Wrapped
1018         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1019         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1020         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1021         let zero () = Wrapped.unit None
1022         let plus u v =
1023           Wrapped.bind u (fun us ->
1024           Wrapped.bind v (fun vs ->
1025           Wrapped.unit (Base.plus us vs)))
1026         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1027         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1028         let run u = Wrapped.run u
1029         let run_exn u =
1030             let w = Wrapped.bind u (fun t -> match t with
1031               | None -> Wrapped.zero ()
1032               | Some ts -> Wrapped.unit ts
1033             ) in Wrapped.run_exn w
1034       end)
1035     end
1036     include BaseT
1037     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1038     let expose u = u
1039   end
1040 end;;
1041
1042