index,new_stuff
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glasgow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources this is partly
55  * derived from)
56  *)
57
58 exception Undefined
59
60 (* Some library functions used below. *)
61 module Util = struct
62   let fold_right = List.fold_right
63   let map = List.map
64   let append = List.append
65   let reverse = List.rev
66   let concat = List.concat
67   let concat_map f lst = List.concat (List.map f lst)
68   (* let zip = List.combine *)
69   let unzip = List.split
70   let zip_with = List.map2
71   let replicate len fill =
72     let rec loop n accu =
73       if n == 0 then accu else loop (pred n) (fill :: accu)
74     in loop len []
75   (* Dirty hack to be a default polymorphic zero.
76    * To implement this cleanly, monads without a natural zero
77    * should always wrap themselves in an option layer (see Tree_monad). *)
78   let undef = Obj.magic (fun () -> raise Undefined)
79 end
80
81
82
83 (*
84  * This module contains factories that extend a base set of
85  * monadic definitions with a larger family of standard derived values.
86  *)
87
88 module Monad = struct
89   (*
90    * Signature extenders:
91    *   Make :: BASE -> S
92    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
93    *)
94
95
96   (* type of base definitions *)
97   module type BASE = sig
98     (* We make all monadic types doubly-parameterized so that they
99      * can layer nicely with Continuation, which needs the second
100      * type parameter. *)
101     type ('x,'a) m
102     type ('x,'a) result
103     type ('x,'a) result_exn
104     val unit : 'a -> ('x,'a) m
105     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
106     val run : ('x,'a) m -> ('x,'a) result
107     (* run_exn tries to provide a more ground-level result, but may fail *)
108     val run_exn : ('x,'a) m -> ('x,'a) result_exn
109     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
110      *     zero >>= f   ===  zero
111      *     plus zero u  ===  u
112      *     plus u zero  ===  u
113      * Additionally, they will obey one of the following laws:
114      *     (Catch)   plus (unit a) v  ===  unit a
115      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
116      * When no natural zero is available, use `let zero () = Util.undef`.
117      * The Make functor automatically detects for zero >>= ..., and 
118      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
119      *)
120     val zero : unit -> ('x,'a) m
121     (* zero has to be thunked to ensure results are always poly enough *)
122     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
123   end
124   module type S = sig
125     include BASE
126     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
127     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
128     val join : ('x,('x,'a) m) m -> ('x,'a) m
129     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
130     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
131     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
132     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
133     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
134     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
135     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
136     val sequence : ('x,'a) m list -> ('x,'a list) m
137     val sequence_ : ('x,'a) m list -> ('x,unit) m
138     val guard : bool -> ('x,unit) m
139     val sum : ('x,'a) m list -> ('x,'a) m
140   end
141
142   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
143     include B
144     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
145       if u == Util.undef then Util.undef
146       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
147     let plus u v =
148       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
149     let run u =
150       if u == Util.undef then raise Undefined else B.run u
151     let run_exn u =
152       if u == Util.undef then raise Undefined else B.run_exn u
153     let (>>=) = bind
154     (* expressions after >> will be evaluated before they're passed to
155      * bind, so you can't do `zero () >> assert false`
156      * this works though: `zero () >>= fun _ -> assert false`
157      *)
158     let (>>) u v = u >>= fun _ -> v
159     let lift f u = u >>= fun a -> unit (f a)
160     (* lift is called listM, fmap, and <$> in Haskell *)
161     let join uu = uu >>= fun u -> u
162     (* u >>= f === join (lift f u) *)
163     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
164     (* [f] <*> [x1,x2] = [f x1,f x2] *)
165     (* let apply u v = u >>= fun f -> lift f v *)
166     (* let apply = lift2 id *)
167     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
168     (* let lift f u === apply (unit f) u *)
169     (* let lift2 f u v = apply (lift f u) v *)
170     let (>=>) f g = fun a -> f a >>= g
171     let do_when test u = if test then u else unit ()
172     let do_unless test u = if test then unit () else u
173     (* A Haskell-like version works:
174          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
175      * but the recursive call is not in tail position so this can stack overflow. *)
176     let forever uthunk =
177         let z = zero () in
178         let id result = result in
179         let kcell = ref id in
180         let rec loop _ =
181             let result = uthunk (kcell := id) >>= chained
182             in !kcell result
183         and chained _ =
184             kcell := loop; z (* we use z only for its polymorphism *)
185         in loop z
186     (* Reimplementations of the preceding using a hand-rolled State or StateT
187 can also stack overflow. *)
188     let sequence ms =
189       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
190         Util.fold_right op ms (unit [])
191     let sequence_ ms =
192       Util.fold_right (>>) ms (unit ())
193
194     (* Haskell defines these other operations combining lists and monads.
195      * We don't, but notice that M.mapM == ListT(M).distribute
196      * There's also a parallel TreeT(M).distribute *)
197     (*
198     let mapM f alist = sequence (Util.map f alist)
199     let mapM_ f alist = sequence_ (Util.map f alist)
200     let rec filterM f lst = match lst with
201       | [] -> unit []
202       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
203     let forM alist f = mapM f alist
204     let forM_ alist f = mapM_ f alist
205     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
206     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
207     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
208     let rec foldM f z lst = match lst with
209       | [] -> unit z
210       | x::xs -> f z x >>= fun z' -> foldM f z' xs
211     let foldM_ f z xs = foldM f z xs >> unit ()
212     let replicateM n x = sequence (Util.replicate n x)
213     let replicateM_ n x = sequence_ (Util.replicate n x)
214     *)
215     let guard test = if test then B.unit () else zero ()
216     let sum ms = Util.fold_right plus ms (zero ())
217   end
218
219   (* Signatures for MonadT *)
220   module type BASET = sig
221     module Wrapped : S
222     type ('x,'a) m
223     type ('x,'a) result
224     type ('x,'a) result_exn
225     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
226     val run : ('x,'a) m -> ('x,'a) result
227     val run_exn : ('x,'a) m -> ('x,'a) result_exn
228     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
229     (* lift/elevate laws:
230      *     elevate (W.unit a) == unit a
231      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
232      *)
233     val zero : unit -> ('x,'a) m
234     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
235   end
236   module MakeT(T : BASET) = struct
237     include Make(struct
238         include T
239         let unit a = elevate (Wrapped.unit a)
240     end)
241     let elevate = T.elevate
242   end
243
244 end
245
246
247
248
249
250 module Identity_monad : sig
251   (* expose only the implementation of type `'a result` *)
252   type ('x,'a) result = 'a
253   type ('x,'a) result_exn = 'a
254   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
255 end = struct
256   module Base = struct
257     type ('x,'a) m = 'a
258     type ('x,'a) result = 'a
259     type ('x,'a) result_exn = 'a
260     let unit a = a
261     let bind a f = f a
262     let run a = a
263     let run_exn a = a
264     let zero () = Util.undef
265     let plus u v = u
266   end
267   include Monad.Make(Base)
268 end
269
270
271 module Maybe_monad : sig
272   (* expose only the implementation of type `'a result` *)
273   type ('x,'a) result = 'a option
274   type ('x,'a) result_exn = 'a
275   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
276   (* MaybeT transformer *)
277   module T : functor (Wrapped : Monad.S) -> sig
278     type ('x,'a) result = ('x,'a option) Wrapped.result
279     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
280     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
281     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
282   end
283 end = struct
284   module Base = struct
285     type ('x,'a) m = 'a option
286     type ('x,'a) result = 'a option
287     type ('x,'a) result_exn = 'a
288     let unit a = Some a
289     let bind u f = match u with Some a -> f a | None -> None
290     let run u = u
291     let run_exn u = match u with
292       | Some a -> a
293       | None -> failwith "no value"
294     let zero () = None
295     (* satisfies Catch *)
296     let plus u v = match u with None -> v | _ -> u
297   end
298   include Monad.Make(Base)
299   module T(Wrapped : Monad.S) = struct
300     module BaseT = struct
301       include Monad.MakeT(struct
302         module Wrapped = Wrapped
303         type ('x,'a) m = ('x,'a option) Wrapped.m
304         type ('x,'a) result = ('x,'a option) Wrapped.result
305         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
306         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
307         let bind u f = Wrapped.bind u (fun t -> match t with
308           | Some a -> f a
309           | None -> Wrapped.unit None)
310         let run u = Wrapped.run u
311         let run_exn u =
312           let w = Wrapped.bind u (fun t -> match t with
313             | Some a -> Wrapped.unit a
314             | None -> Wrapped.zero ()
315           ) in Wrapped.run_exn w
316         let zero () = Wrapped.unit None
317         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
318       end)
319     end
320     include BaseT
321   end
322 end
323
324
325 module List_monad : sig
326   (* declare additional operation, while still hiding implementation of type m *)
327   type ('x,'a) result = 'a list
328   type ('x,'a) result_exn = 'a
329   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
330   val permute : ('x,'a) m -> ('x,('x,'a) m) m
331   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
332   (* ListT transformer *)
333   module T : functor (Wrapped : Monad.S) -> sig
334     type ('x,'a) result = ('x,'a list) Wrapped.result
335     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
336     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
337     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
338     (* note that second argument is an 'a list, not the more abstract 'a m *)
339     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
340     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
341 (* TODO
342     val permute : 'a m -> 'a m m
343     val select : 'a m -> ('a * 'a m) m
344 *)
345   end
346 end = struct
347   module Base = struct
348    type ('x,'a) m = 'a list
349    type ('x,'a) result = 'a list
350    type ('x,'a) result_exn = 'a
351    let unit a = [a]
352    let bind u f = Util.concat_map f u
353    let run u = u
354    let run_exn u = match u with
355      | [] -> failwith "no values"
356      | [a] -> a
357      | many -> failwith "multiple values"
358    let zero () = []
359    (* satisfies Distrib *)
360    let plus = Util.append
361   end
362   include Monad.Make(Base)
363   (* let either u v = plus u v *)
364   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
365   let rec insert a u =
366     plus (unit (a :: u)) (match u with
367         | [] -> zero ()
368         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
369     )
370   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
371   let rec permute u = match u with
372       | [] -> unit []
373       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
374   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
375   let rec select u = match u with
376     | [] -> zero ()
377     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
378   module T(Wrapped : Monad.S) = struct
379     (* Wrapped.sequence ms  ===  
380          let plus1 u v =
381            Wrapped.bind u (fun x ->
382            Wrapped.bind v (fun xs ->
383            Wrapped.unit (x :: xs)))
384          in Util.fold_right plus1 ms (Wrapped.unit []) *)
385     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
386     let distribute f alist = Wrapped.sequence (Util.map f alist)
387
388     include Monad.MakeT(struct
389       module Wrapped = Wrapped
390       type ('x,'a) m = ('x,'a list) Wrapped.m
391       type ('x,'a) result = ('x,'a list) Wrapped.result
392       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
393       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
394       let bind u f =
395         Wrapped.bind u (fun ts ->
396         Wrapped.bind (distribute f ts) (fun tts ->
397         Wrapped.unit (Util.concat tts)))
398       let run u = Wrapped.run u
399       let run_exn u =
400         let w = Wrapped.bind u (fun ts -> match ts with
401           | [] -> Wrapped.zero ()
402           | [a] -> Wrapped.unit a
403           | many -> Wrapped.zero ()
404         ) in Wrapped.run_exn w
405       let zero () = Wrapped.unit []
406       let plus u v =
407         Wrapped.bind u (fun us ->
408         Wrapped.bind v (fun vs ->
409         Wrapped.unit (Base.plus us vs)))
410     end)
411 (*
412     let permute : 'a m -> 'a m m
413     let select : 'a m -> ('a * 'a m) m
414 *)
415   end
416 end
417
418
419 (* must be parameterized on (struct type err = ... end) *)
420 module Error_monad(Err : sig
421   type err
422   exception Exc of err
423   (*
424   val zero : unit -> err
425   val plus : err -> err -> err
426   *)
427 end) : sig
428   (* declare additional operations, while still hiding implementation of type m *)
429   type err = Err.err
430   type 'a error = Error of err | Success of 'a
431   type ('x,'a) result = 'a error
432   type ('x,'a) result_exn = 'a
433   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
434   val throw : err -> ('x,'a) m
435   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
436   (* ErrorT transformer *)
437   module T : functor (Wrapped : Monad.S) -> sig
438     type ('x,'a) result = ('x,'a) Wrapped.result
439     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
440     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
441     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
442     val throw : err -> ('x,'a) m
443     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
444   end
445 end = struct
446   type err = Err.err
447   type 'a error = Error of err | Success of 'a
448   module Base = struct
449     type ('x,'a) m = 'a error
450     type ('x,'a) result = 'a error
451     type ('x,'a) result_exn = 'a
452     let unit a = Success a
453     let bind u f = match u with
454       | Success a -> f a
455       | Error e -> Error e (* input and output may be of different 'a types *)
456     let run u = u
457     let run_exn u = match u with
458       | Success a -> a
459       | Error e -> raise (Err.Exc e)
460     let zero () = Util.undef
461     (* satisfies Catch *)
462     let plus u v = match u with
463       | Success _ -> u
464       | Error _ -> if v == Util.undef then u else v
465   end
466   include Monad.Make(Base)
467   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
468   let throw e = Error e
469   let catch u handler = match u with
470     | Success _ -> u
471     | Error e -> handler e
472   module T(Wrapped : Monad.S) = struct
473     include Monad.MakeT(struct
474       module Wrapped = Wrapped
475       type ('x,'a) m = ('x,'a error) Wrapped.m
476       type ('x,'a) result = ('x,'a) Wrapped.result
477       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
478       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
479       let bind u f = Wrapped.bind u (fun t -> match t with
480         | Success a -> f a
481         | Error e -> Wrapped.unit (Error e))
482       let run u =
483         let w = Wrapped.bind u (fun t -> match t with
484           | Success a -> Wrapped.unit a
485           | Error e -> Wrapped.zero ()
486         ) in Wrapped.run w
487       let run_exn u =
488         let w = Wrapped.bind u (fun t -> match t with
489           | Success a -> Wrapped.unit a
490           | Error e -> raise (Err.Exc e))
491         in Wrapped.run_exn w
492       let plus u v = Wrapped.plus u v
493       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
494     end)
495     let throw e = Wrapped.unit (Error e)
496     let catch u handler = Wrapped.bind u (fun t -> match t with
497       | Success _ -> Wrapped.unit t
498       | Error e -> handler e)
499   end
500 end
501
502 (* pre-define common instance of Error_monad *)
503 module Failure = Error_monad(struct
504   type err = string
505   exception Exc = Failure
506   (*
507   let zero = ""
508   let plus s1 s2 = s1 ^ "\n" ^ s2
509   *)
510 end)
511
512
513 (* must be parameterized on (struct type env = ... end) *)
514 module Reader_monad(Env : sig type env end) : sig
515   (* declare additional operations, while still hiding implementation of type m *)
516   type env = Env.env
517   type ('x,'a) result = env -> 'a
518   type ('x,'a) result_exn = env -> 'a
519   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
520   val ask : ('x,env) m
521   val asks : (env -> 'a) -> ('x,'a) m
522   (* lookup i == `fun e -> e i` would assume env is a functional type *)
523   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
524   (* ReaderT transformer *)
525   module T : functor (Wrapped : Monad.S) -> sig
526     type ('x,'a) result = env -> ('x,'a) Wrapped.result
527     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
528     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
529     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
530     val ask : ('x,env) m
531     val asks : (env -> 'a) -> ('x,'a) m
532     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
533   end
534 end = struct
535   type env = Env.env
536   module Base = struct
537     type ('x,'a) m = env -> 'a
538     type ('x,'a) result = env -> 'a
539     type ('x,'a) result_exn = env -> 'a
540     let unit a = fun e -> a
541     let bind u f = fun e -> let a = u e in let u' = f a in u' e
542     let run u = fun e -> u e
543     let run_exn = run
544     let zero () = Util.undef
545     let plus u v = u
546   end
547   include Monad.Make(Base)
548   let ask = fun e -> e
549   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
550   let local modifier u = fun e -> u (modifier e)
551   module T(Wrapped : Monad.S) = struct
552     module BaseT = struct
553       module Wrapped = Wrapped
554       type ('x,'a) m = env -> ('x,'a) Wrapped.m
555       type ('x,'a) result = env -> ('x,'a) Wrapped.result
556       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
557       let elevate w = fun e -> w
558       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
559       let run u = fun e -> Wrapped.run (u e)
560       let run_exn u = fun e -> Wrapped.run_exn (u e)
561       (* satisfies Distrib *)
562       let plus u v = fun e -> Wrapped.plus (u e) (v e)
563       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
564     end
565     include Monad.MakeT(BaseT)
566     let ask = Wrapped.unit
567     let local modifier u = fun e -> u (modifier e)
568     let asks selector = ask >>= (fun e ->
569       try unit (selector e)
570       with Not_found -> fun e -> Wrapped.zero ())
571   end
572 end
573
574
575 (* must be parameterized on (struct type store = ... end) *)
576 module State_monad(Store : sig type store end) : sig
577   (* declare additional operations, while still hiding implementation of type m *)
578   type store = Store.store
579   type ('x,'a) result =  store -> 'a * store
580   type ('x,'a) result_exn = store -> 'a
581   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
582   val get : ('x,store) m
583   val gets : (store -> 'a) -> ('x,'a) m
584   val put : store -> ('x,unit) m
585   val puts : (store -> store) -> ('x,unit) m
586   (* StateT transformer *)
587   module T : functor (Wrapped : Monad.S) -> sig
588     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
589     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
590     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
591     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
592     val get : ('x,store) m
593     val gets : (store -> 'a) -> ('x,'a) m
594     val put : store -> ('x,unit) m
595     val puts : (store -> store) -> ('x,unit) m
596   end
597 end = struct
598   type store = Store.store
599   module Base = struct
600     type ('x,'a) m =  store -> 'a * store
601     type ('x,'a) result =  store -> 'a * store
602     type ('x,'a) result_exn = store -> 'a
603     let unit a = fun s -> (a, s)
604     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
605     let run u = fun s -> (u s)
606     let run_exn u = fun s -> fst (u s)
607     let zero () = Util.undef
608     let plus u v = u
609   end
610   include Monad.Make(Base)
611   let get = fun s -> (s, s)
612   let gets viewer = fun s -> (viewer s, s) (* may fail *)
613   let put s = fun _ -> ((), s)
614   let puts modifier = fun s -> ((), modifier s)
615   module T(Wrapped : Monad.S) = struct
616     module BaseT = struct
617       module Wrapped = Wrapped
618       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
619       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
620       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
621       let elevate w = fun s ->
622         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
623       let bind u f = fun s ->
624         Wrapped.bind (u s) (fun (a, s') -> f a s')
625       let run u = fun s -> Wrapped.run (u s)
626       let run_exn u = fun s ->
627         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
628         in Wrapped.run_exn w
629       (* satisfies Distrib *)
630       let plus u v = fun s -> Wrapped.plus (u s) (v s)
631       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
632     end
633     include Monad.MakeT(BaseT)
634     let get = fun s -> Wrapped.unit (s, s)
635     let gets viewer = fun s ->
636       try Wrapped.unit (viewer s, s)
637       with Not_found -> Wrapped.zero ()
638     let put s = fun _ -> Wrapped.unit ((), s)
639     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
640   end
641 end
642
643
644 (* State monad with different interface (structured store) *)
645 module Ref_monad(V : sig
646   type value
647 end) : sig
648   type ref
649   type value = V.value
650   type ('x,'a) result = 'a
651   type ('x,'a) result_exn = 'a
652   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
653   val newref : value -> ('x,ref) m
654   val deref : ref -> ('x,value) m
655   val change : ref -> value -> ('x,unit) m
656   (* RefT transformer *)
657   module T : functor (Wrapped : Monad.S) -> sig
658     type ('x,'a) result = ('x,'a) Wrapped.result
659     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
660     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
661     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
662     val newref : value -> ('x,ref) m
663     val deref : ref -> ('x,value) m
664     val change : ref -> value -> ('x,unit) m
665   end
666 end = struct
667   type ref = int
668   type value = V.value
669   module D = Map.Make(struct type t = ref let compare = compare end)
670   type dict = { next: ref; tree : value D.t }
671   let empty = { next = 0; tree = D.empty }
672   let alloc (value : value) (d : dict) =
673     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
674   let read (key : ref) (d : dict) =
675     D.find key d.tree
676   let write (key : ref) (value : value) (d : dict) =
677     { next = d.next; tree = D.add key value d.tree }
678   module Base = struct
679     type ('x,'a) m = dict -> 'a * dict
680     type ('x,'a) result = 'a
681     type ('x,'a) result_exn = 'a
682     let unit a = fun s -> (a, s)
683     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
684     let run u = fst (u empty)
685     let run_exn = run
686     let zero () = Util.undef
687     let plus u v = u
688   end
689   include Monad.Make(Base)
690   let newref value = fun s -> alloc value s
691   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
692   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
693   module T(Wrapped : Monad.S) = struct
694     module BaseT = struct
695       module Wrapped = Wrapped
696       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
697       type ('x,'a) result = ('x,'a) Wrapped.result
698       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
699       let elevate w = fun s ->
700         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
701       let bind u f = fun s ->
702         Wrapped.bind (u s) (fun (a, s') -> f a s')
703       let run u =
704         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
705         in Wrapped.run w
706       let run_exn u =
707         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
708         in Wrapped.run_exn w
709       (* satisfies Distrib *)
710       let plus u v = fun s -> Wrapped.plus (u s) (v s)
711       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
712     end
713     include Monad.MakeT(BaseT)
714     let newref value = fun s -> Wrapped.unit (alloc value s)
715     let deref key = fun s -> Wrapped.unit (read key s, s)
716     let change key value = fun s -> Wrapped.unit ((), write key value s)
717   end
718 end
719
720
721 (* must be parameterized on (struct type log = ... end) *)
722 module Writer_monad(Log : sig
723   type log
724   val zero : log
725   val plus : log -> log -> log
726 end) : sig
727   (* declare additional operations, while still hiding implementation of type m *)
728   type log = Log.log
729   type ('x,'a) result = 'a * log
730   type ('x,'a) result_exn = 'a * log
731   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
732   val tell : log -> ('x,unit) m
733   val listen : ('x,'a) m -> ('x,'a * log) m
734   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
735   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
736   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
737   (* WriterT transformer *)
738   module T : functor (Wrapped : Monad.S) -> sig
739     type ('x,'a) result = ('x,'a * log) Wrapped.result
740     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
741     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
742     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
743     val tell : log -> ('x,unit) m
744     val listen : ('x,'a) m -> ('x,'a * log) m
745     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
746     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
747   end
748 end = struct
749   type log = Log.log
750   module Base = struct
751     type ('x,'a) m = 'a * log
752     type ('x,'a) result = 'a * log
753     type ('x,'a) result_exn = 'a * log
754     let unit a = (a, Log.zero)
755     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
756     let run u = u
757     let run_exn = run
758     let zero () = Util.undef
759     let plus u v = u
760   end
761   include Monad.Make(Base)
762   let tell entries = ((), entries) (* add entries to log *)
763   let listen (a, w) = ((a, w), w)
764   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
765   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
766   let censor f u = pass (u >>= fun a -> unit (a, f))
767   module T(Wrapped : Monad.S) = struct
768     module BaseT = struct
769       module Wrapped = Wrapped
770       type ('x,'a) m = ('x,'a * log) Wrapped.m
771       type ('x,'a) result = ('x,'a * log) Wrapped.result
772       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
773       let elevate w =
774         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
775       let bind u f =
776         Wrapped.bind u (fun (a, w) ->
777         Wrapped.bind (f a) (fun (b, w') ->
778         Wrapped.unit (b, Log.plus w w')))
779       let zero () = elevate (Wrapped.zero ())
780       let plus u v = Wrapped.plus u v
781       let run u = Wrapped.run u
782       let run_exn u = Wrapped.run_exn u
783     end
784     include Monad.MakeT(BaseT)
785     let tell entries = Wrapped.unit ((), entries)
786     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
787     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
788     (* rest are derived in same way as before *)
789     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
790     let censor f u = pass (u >>= fun a -> unit (a, f))
791   end
792 end
793
794 (* pre-define simple Writer *)
795 module Writer1 = Writer_monad(struct
796   type log = string
797   let zero = ""
798   let plus s1 s2 = s1 ^ "\n" ^ s2
799 end)
800
801 (* slightly more efficient Writer *)
802 module Writer2 = struct
803   include Writer_monad(struct
804     type log = string list
805     let zero = []
806     let plus w w' = Util.append w' w
807   end)
808   let tell_string s = tell [s]
809   let tell entries = tell (Util.reverse entries)
810   let run u = let (a, w) = run u in (a, Util.reverse w)
811   let run_exn = run
812 end
813
814
815 (* TODO needs a T *)
816 module IO_monad : sig
817   (* declare additional operation, while still hiding implementation of type m *)
818   type ('x,'a) result = 'a
819   type ('x,'a) result_exn = 'a
820   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
821   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
822   val print_string : string -> ('x,unit) m
823   val print_int : int -> ('x,unit) m
824   val print_hex : int -> ('x,unit) m
825   val print_bool : bool -> ('x,unit) m
826 end = struct
827   module Base = struct
828     type ('x,'a) m = { run : unit -> unit; value : 'a }
829     type ('x,'a) result = 'a
830     type ('x,'a) result_exn = 'a
831     let unit a = { run = (fun () -> ()); value = a }
832     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
833      let fres = f a.value in
834        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
835     let run a = let () = a.run () in a.value
836     let run_exn = run
837     let zero () = Util.undef
838     let plus u v = u
839   end
840   include Monad.Make(Base)
841   let printf fmt =
842     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
843   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
844   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
845   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
846   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
847 end
848
849
850 module Continuation_monad : sig
851   (* expose only the implementation of type `('r,'a) result` *)
852   type ('r,'a) m
853   type ('r,'a) result = ('r,'a) m
854   type ('r,'a) result_exn = ('a -> 'r) -> 'r
855   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
856   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
857   val reset : ('a,'a) m -> ('r,'a) m
858   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
859   (* val abort : ('a,'a) m -> ('a,'b) m *)
860   val abort : 'a -> ('a,'b) m
861   val run0 : ('a,'a) m -> 'a
862   (* ContinuationT transformer *)
863   module T : functor (Wrapped : Monad.S) -> sig
864     type ('r,'a) m
865     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
866     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
867     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
868     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
869     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
870     (* TODO: reset,shift,abort,run0 *)
871   end
872 end = struct
873   let id = fun i -> i
874   module Base = struct
875     (* 'r is result type of whole computation *)
876     type ('r,'a) m = ('a -> 'r) -> 'r
877     type ('r,'a) result = ('a -> 'r) -> 'r
878     type ('r,'a) result_exn = ('r,'a) result
879     let unit a = (fun k -> k a)
880     let bind u f = (fun k -> (u) (fun a -> (f a) k))
881     let run u k = (u) k
882     let run_exn = run
883     let zero () = Util.undef
884     let plus u v = u
885   end
886   include Monad.Make(Base)
887   let callcc f = (fun k ->
888     let usek a = (fun _ -> k a)
889     in (f usek) k)
890   (*
891   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
892   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
893   let callcc f = fun k -> f k k
894   let throw k a = fun _ -> k a
895   *)
896
897   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
898    *
899    *  reset :: (Monad m) => ContT a m a -> ContT r m a
900    *  reset e = ContT $ \k -> runContT e return >>= k
901    *
902    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
903    *  shift e = ContT $ \k ->
904    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
905   let reset u = unit ((u) id)
906   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
907   (* let abort a = shift (fun _ -> a) *)
908   let abort a = shift (fun _ -> unit a)
909   let run0 (u : ('a,'a) m) = (u) id
910   module T(Wrapped : Monad.S) = struct
911     module BaseT = struct
912       module Wrapped = Wrapped
913       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
914       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
915       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
916       let elevate w = fun k -> Wrapped.bind w k
917       let bind u f = fun k -> u (fun a -> f a k)
918       let run u k = Wrapped.run (u k)
919       let run_exn u k = Wrapped.run_exn (u k)
920       let zero () = Util.undef
921       let plus u v = u
922     end
923     include Monad.MakeT(BaseT)
924     let callcc f = (fun k ->
925       let usek a = (fun _ -> k a)
926       in (f usek) k)
927   end
928 end
929
930
931 (*
932  * Scheme:
933  * (define (example n)
934  *    (let ([u (let/cc k ; type int -> int pair
935  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
936  *                 (+ 1 (car v))))]) ; int
937  *      (cons u 0))) ; int pair
938  * ; (example 10) ~~> '(111 . 0)
939  * ; (example -10) ~~> '(0 . 0)
940  *
941  * OCaml monads:
942  * let example n : (int * int) =
943  *   Continuation_monad.(let u = callcc (fun k ->
944  *       (if n < 0 then k 0 else unit [n + 100])
945  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
946  *       >>= fun [x] -> unit (x + 1)
947  *   )
948  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
949  *   >>= fun x -> unit (x, 0)
950  *   in run u)
951  *
952  *)
953
954
955 module Tree_monad : sig
956   (* We implement the type as `'a tree option` because it has a natural`plus`,
957    * and the rest of the library expects that `plus` and `zero` will come together. *)
958   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
959   type ('x,'a) result = 'a tree option
960   type ('x,'a) result_exn = 'a tree
961   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
962   (* TreeT transformer *)
963   module T : functor (Wrapped : Monad.S) -> sig
964     type ('x,'a) result = ('x,'a tree option) Wrapped.result
965     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
966     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
967     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
968     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
969     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
970     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
971   end
972 end = struct
973   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
974   (* uses supplied plus and zero to copy t to its image under f *)
975   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
976       | None -> zero ()
977       | Some ts -> let rec loop ts = (match ts with
978                      | Leaf a -> f a
979                      | Node (l, r) ->
980                          (* recursive application of f may delete a branch *)
981                          plus (loop l) (loop r)
982                    ) in loop ts
983   module Base = struct
984     type ('x,'a) m = 'a tree option
985     type ('x,'a) result = 'a tree option
986     type ('x,'a) result_exn = 'a tree
987     let unit a = Some (Leaf a)
988     let zero () = None
989     (* satisfies Distrib *)
990     let plus u v = match (u, v) with
991       | None, _ -> v
992       | _, None -> u
993       | Some us, Some vs -> Some (Node (us, vs))
994     let bind u f = mapT f u zero plus
995     let run u = u
996     let run_exn u = match u with
997       | None -> failwith "no values"
998       (*
999       | Some (Leaf a) -> a
1000       | many -> failwith "multiple values"
1001       *)
1002       | Some us -> us
1003   end
1004   include Monad.Make(Base)
1005   module T(Wrapped : Monad.S) = struct
1006     module BaseT = struct
1007       include Monad.MakeT(struct
1008         module Wrapped = Wrapped
1009         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1010         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1011         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1012         let zero () = Wrapped.unit None
1013         let plus u v =
1014           Wrapped.bind u (fun us ->
1015           Wrapped.bind v (fun vs ->
1016           Wrapped.unit (Base.plus us vs)))
1017         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1018         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1019         let run u = Wrapped.run u
1020         let run_exn u =
1021             let w = Wrapped.bind u (fun t -> match t with
1022               | None -> Wrapped.zero ()
1023               | Some ts -> Wrapped.unit ts
1024             ) in Wrapped.run_exn w
1025       end)
1026     end
1027     include BaseT
1028     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1029   end
1030 end;;
1031
1032