monads.ml tweak
[lambda.git] / code / monads.ml
1 (*
2  * monads.ml
3  *
4  * Relies on features introduced in OCaml 3.12
5  *
6  * This library uses parameterized modules, see tree_monadize.ml for
7  * more examples and explanation.
8  *
9  * Some comparisons with the Haskell monadic libraries, which we mostly follow:
10  * In Haskell, the Reader 'a monadic type would be defined something like this:
11  *     newtype Reader a = Reader { runReader :: env -> a }
12  * (For simplicity, I'm suppressing the fact that Reader is also parameterized
13  * on the type of env.)
14  * This creates a type wrapper around `env -> a`, so that Haskell will
15  * distinguish between values that have been specifically designated as
16  * being of type `Reader a`, and common-garden values of type `env -> a`.
17  * To lift an aribtrary expression E of type `env -> a` into an `Reader a`,
18  * you do this:
19  *     Reader { runReader = E }
20  * or use any of the following equivalent shorthands:
21  *     Reader (E)
22  *     Reader $ E
23  * To drop an expression R of type `Reader a` back into an `env -> a`, you do
24  * one of these:
25  *     runReader (R)
26  *     runReader $ R
27  * The `newtype` in the type declaration ensures that Haskell does this all
28  * efficiently: though it regards E and R as type-distinct, their underlying
29  * machine implementation is identical and doesn't need to be transformed when
30  * lifting/dropping from one type to the other.
31  *
32  * Now, you _could_ also declare monads as record types in OCaml, too, _but_
33  * doing so would introduce an extra level of machine representation, and
34  * lifting/dropping from the one type to the other wouldn't be free like it is
35  * in Haskell.
36  *
37  * This library encapsulates the monadic types in another way: by
38  * making their implementations private. The interpreter won't let
39  * let you freely interchange the `'a Reader_monad.m`s defined below
40  * with `Reader_monad.env -> 'a`. The code in this library can see that
41  * those are equivalent, but code outside the library can't. Instead, you'll 
42  * have to use operations like `run` to convert the abstract monadic types
43  * to types whose internals you have free access to.
44  *
45  * Acknowledgements: This is largely based on the mtl library distributed
46  * with the Glasgow Haskell Compiler. I've also been helped in
47  * various ways by posts and direct feedback from Oleg Kiselyov and
48  * Chung-chieh Shan. The following were also useful:
49  * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
50  * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
51  * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
52  * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
53  *
54  * Licensing: MIT (if that's compatible with the ghc sources this is partly
55  * derived from)
56  *)
57
58
59 (* Some library functions used below. *)
60
61 exception Undefined
62
63 module Util = struct
64   let fold_right = List.fold_right
65   let map = List.map
66   let append = List.append
67   let reverse = List.rev
68   let concat = List.concat
69   let concat_map f lst = List.concat (List.map f lst)
70   (* let zip = List.combine *)
71   let unzip = List.split
72   let zip_with = List.map2
73   let replicate len fill =
74     let rec loop n accu =
75       if n == 0 then accu else loop (pred n) (fill :: accu)
76     in loop len []
77   (* Dirty hack to be a default polymorphic zero.
78    * To implement this cleanly, monads without a natural zero
79    * should always wrap themselves in an option layer (see Tree_monad). *)
80   let undef = Obj.magic (fun () -> raise Undefined)
81 end
82
83 (*
84  * This module contains factories that extend a base set of
85  * monadic definitions with a larger family of standard derived values.
86  *)
87
88 module Monad = struct
89
90   (*
91    * Signature extenders:
92    *   Make :: BASE -> S
93    *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
94    *)
95
96
97   (* type of base definitions *)
98   module type BASE = sig
99     (* We make all monadic types doubly-parameterized so that they
100      * can layer nicely with Continuation, which needs the second
101      * type parameter. *)
102     type ('x,'a) m
103     type ('x,'a) result
104     type ('x,'a) result_exn
105     val unit : 'a -> ('x,'a) m
106     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
107     val run : ('x,'a) m -> ('x,'a) result
108     (* run_exn tries to provide a more ground-level result, but may fail *)
109     val run_exn : ('x,'a) m -> ('x,'a) result_exn
110     (* To simplify the library, we require every monad to supply a plus and zero. These obey the following laws:
111      *     zero >>= f   ===  zero
112      *     plus zero u  ===  u
113      *     plus u zero  ===  u
114      * Additionally, they will obey one of the following laws:
115      *     (Catch)   plus (unit a) v  ===  unit a
116      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
117      * When no natural zero is available, use `let zero () = Util.undef`.
118      * The Make functor automatically detects for zero >>= ..., and 
119      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
120      *)
121     val zero : unit -> ('x,'a) m
122     (* zero has to be thunked to ensure results are always poly enough *)
123     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
124   end
125   module type S = sig
126     include BASE
127     val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
128     val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m
129     val join : ('x,('x,'a) m) m -> ('x,'a) m
130     val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m
131     val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m
132     val lift2 :  ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m
133     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
134     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
135     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
136     val forever : (unit -> ('x,'a) m) -> ('x,'b) m
137     val sequence : ('x,'a) m list -> ('x,'a list) m
138     val sequence_ : ('x,'a) m list -> ('x,unit) m
139     val guard : bool -> ('x,unit) m
140     val sum : ('x,'a) m list -> ('x,'a) m
141   end
142
143   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
144     include B
145     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
146       if u == Util.undef then Util.undef
147       else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
148     let plus u v =
149       if u == Util.undef then v else if v == Util.undef then u else B.plus u v
150     let run u =
151       if u == Util.undef then raise Undefined else B.run u
152     let run_exn u =
153       if u == Util.undef then raise Undefined else B.run_exn u
154     let (>>=) = bind
155     (* expressions after >> will be evaluated before they're passed to
156      * bind, so you can't do `zero () >> assert false`
157      * this works though: `zero () >>= fun _ -> assert false`
158      *)
159     let (>>) u v = u >>= fun _ -> v
160     let lift f u = u >>= fun a -> unit (f a)
161     (* lift is called listM, fmap, and <$> in Haskell *)
162     let join uu = uu >>= fun u -> u
163     (* u >>= f === join (lift f u) *)
164     let apply u v = u >>= fun f -> v >>= fun a -> unit (f a)
165     (* [f] <*> [x1,x2] = [f x1,f x2] *)
166     (* let apply u v = u >>= fun f -> lift f v *)
167     (* let apply = lift2 id *)
168     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
169     (* let lift f u === apply (unit f) u *)
170     (* let lift2 f u v = apply (lift f u) v *)
171     let (>=>) f g = fun a -> f a >>= g
172     let do_when test u = if test then u else unit ()
173     let do_unless test u = if test then unit () else u
174     (* A Haskell-like version works:
175          let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
176      * but the recursive call is not in tail position so this can stack overflow. *)
177     let forever uthunk =
178         let z = zero () in
179         let id result = result in
180         let kcell = ref id in
181         let rec loop _ =
182             let result = uthunk (kcell := id) >>= chained
183             in !kcell result
184         and chained _ =
185             kcell := loop; z (* we use z only for its polymorphism *)
186         in loop z
187     (* Reimplementations of the preceding using a hand-rolled State or StateT
188 can also stack overflow. *)
189     let sequence ms =
190       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
191         Util.fold_right op ms (unit [])
192     let sequence_ ms =
193       Util.fold_right (>>) ms (unit ())
194
195     (* Haskell defines these other operations combining lists and monads.
196      * We don't, but notice that M.mapM == ListT(M).distribute
197      * There's also a parallel TreeT(M).distribute *)
198     (*
199     let mapM f alist = sequence (Util.map f alist)
200     let mapM_ f alist = sequence_ (Util.map f alist)
201     let rec filterM f lst = match lst with
202       | [] -> unit []
203       | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys)
204     let forM alist f = mapM f alist
205     let forM_ alist f = mapM_ f alist
206     let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x)
207     let zip_withM f xs ys = sequence (Util.zip_with f xs ys)
208     let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys)
209     let rec foldM f z lst = match lst with
210       | [] -> unit z
211       | x::xs -> f z x >>= fun z' -> foldM f z' xs
212     let foldM_ f z xs = foldM f z xs >> unit ()
213     let replicateM n x = sequence (Util.replicate n x)
214     let replicateM_ n x = sequence_ (Util.replicate n x)
215     *)
216     let guard test = if test then B.unit () else zero ()
217     let sum ms = Util.fold_right plus ms (zero ())
218   end
219
220   (* Signatures for MonadT *)
221   module type BASET = sig
222     module Wrapped : S
223     type ('x,'a) m
224     type ('x,'a) result
225     type ('x,'a) result_exn
226     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
227     val run : ('x,'a) m -> ('x,'a) result
228     val run_exn : ('x,'a) m -> ('x,'a) result_exn
229     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
230     (* lift/elevate laws:
231      *     elevate (W.unit a) == unit a
232      *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
233      *)
234     val zero : unit -> ('x,'a) m
235     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
236   end
237   module MakeT(T : BASET) = struct
238     include Make(struct
239         include T
240         let unit a = elevate (Wrapped.unit a)
241     end)
242     let elevate = T.elevate
243   end
244
245 end
246
247
248
249
250
251 module Identity_monad : sig
252   (* expose only the implementation of type `'a result` *)
253   type ('x,'a) result = 'a
254   type ('x,'a) result_exn = 'a
255   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
256 end = struct
257   module Base = struct
258     type ('x,'a) m = 'a
259     type ('x,'a) result = 'a
260     type ('x,'a) result_exn = 'a
261     let unit a = a
262     let bind a f = f a
263     let run a = a
264     let run_exn a = a
265     let zero () = Util.undef
266     let plus u v = u
267   end
268   include Monad.Make(Base)
269 end
270
271
272 module Maybe_monad : sig
273   (* expose only the implementation of type `'a result` *)
274   type ('x,'a) result = 'a option
275   type ('x,'a) result_exn = 'a
276   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
277   (* MaybeT transformer *)
278   module T : functor (Wrapped : Monad.S) -> sig
279     type ('x,'a) result = ('x,'a option) Wrapped.result
280     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
281     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
282     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
283   end
284 end = struct
285   module Base = struct
286     type ('x,'a) m = 'a option
287     type ('x,'a) result = 'a option
288     type ('x,'a) result_exn = 'a
289     let unit a = Some a
290     let bind u f = match u with Some a -> f a | None -> None
291     let run u = u
292     let run_exn u = match u with
293       | Some a -> a
294       | None -> failwith "no value"
295     let zero () = None
296     (* satisfies Catch *)
297     let plus u v = match u with None -> v | _ -> u
298   end
299   include Monad.Make(Base)
300   module T(Wrapped : Monad.S) = struct
301     module BaseT = struct
302       include Monad.MakeT(struct
303         module Wrapped = Wrapped
304         type ('x,'a) m = ('x,'a option) Wrapped.m
305         type ('x,'a) result = ('x,'a option) Wrapped.result
306         type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
307         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
308         let bind u f = Wrapped.bind u (fun t -> match t with
309           | Some a -> f a
310           | None -> Wrapped.unit None)
311         let run u = Wrapped.run u
312         let run_exn u =
313           let w = Wrapped.bind u (fun t -> match t with
314             | Some a -> Wrapped.unit a
315             | None -> Wrapped.zero ()
316           ) in Wrapped.run_exn w
317         let zero () = Wrapped.unit None
318         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
319       end)
320     end
321     include BaseT
322   end
323 end
324
325
326 module List_monad : sig
327   (* declare additional operation, while still hiding implementation of type m *)
328   type ('x,'a) result = 'a list
329   type ('x,'a) result_exn = 'a
330   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
331   val permute : ('x,'a) m -> ('x,('x,'a) m) m
332   val select : ('x,'a) m -> ('x,'a * ('x,'a) m) m
333   (* ListT transformer *)
334   module T : functor (Wrapped : Monad.S) -> sig
335     type ('x,'a) result = ('x,'a list) Wrapped.result
336     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
337     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
338     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
339     (* note that second argument is an 'a list, not the more abstract 'a m *)
340     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
341     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
342 (* TODO
343     val permute : 'a m -> 'a m m
344     val select : 'a m -> ('a * 'a m) m
345 *)
346     val expose : ('x,'a) m -> ('x,'a list) Wrapped.m
347   end
348 end = struct
349   module Base = struct
350    type ('x,'a) m = 'a list
351    type ('x,'a) result = 'a list
352    type ('x,'a) result_exn = 'a
353    let unit a = [a]
354    let bind u f = Util.concat_map f u
355    let run u = u
356    let run_exn u = match u with
357      | [] -> failwith "no values"
358      | [a] -> a
359      | many -> failwith "multiple values"
360    let zero () = []
361    (* satisfies Distrib *)
362    let plus = Util.append
363   end
364   include Monad.Make(Base)
365   (* let either u v = plus u v *)
366   (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *)
367   let rec insert a u =
368     plus (unit (a :: u)) (match u with
369         | [] -> zero ()
370         | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v)
371     )
372   (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *)
373   let rec permute u = match u with
374       | [] -> unit []
375       | x :: xs -> (permute xs) >>= (fun v -> insert x v)
376   (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *)
377   let rec select u = match u with
378     | [] -> zero ()
379     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
380   module T(Wrapped : Monad.S) = struct
381     (* Wrapped.sequence ms  ===  
382          let plus1 u v =
383            Wrapped.bind u (fun x ->
384            Wrapped.bind v (fun xs ->
385            Wrapped.unit (x :: xs)))
386          in Util.fold_right plus1 ms (Wrapped.unit []) *)
387     (* distribute  ===  Wrapped.mapM; copies alist to its image under f *)
388     let distribute f alist = Wrapped.sequence (Util.map f alist)
389
390     include Monad.MakeT(struct
391       module Wrapped = Wrapped
392       type ('x,'a) m = ('x,'a list) Wrapped.m
393       type ('x,'a) result = ('x,'a list) Wrapped.result
394       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
395       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
396       let bind u f =
397         Wrapped.bind u (fun ts ->
398         Wrapped.bind (distribute f ts) (fun tts ->
399         Wrapped.unit (Util.concat tts)))
400       let run u = Wrapped.run u
401       let run_exn u =
402         let w = Wrapped.bind u (fun ts -> match ts with
403           | [] -> Wrapped.zero ()
404           | [a] -> Wrapped.unit a
405           | many -> Wrapped.zero ()
406         ) in Wrapped.run_exn w
407       let zero () = Wrapped.unit []
408       let plus u v =
409         Wrapped.bind u (fun us ->
410         Wrapped.bind v (fun vs ->
411         Wrapped.unit (Base.plus us vs)))
412     end)
413 (*
414     let permute : 'a m -> 'a m m
415     let select : 'a m -> ('a * 'a m) m
416 *)
417     let expose u = u
418   end
419 end
420
421
422 (* must be parameterized on (struct type err = ... end) *)
423 module Error_monad(Err : sig
424   type err
425   exception Exc of err
426   (*
427   val zero : unit -> err
428   val plus : err -> err -> err
429   *)
430 end) : sig
431   (* declare additional operations, while still hiding implementation of type m *)
432   type err = Err.err
433   type 'a error = Error of err | Success of 'a
434   type ('x,'a) result = 'a error
435   type ('x,'a) result_exn = 'a
436   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
437   val throw : err -> ('x,'a) m
438   val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
439   (* ErrorT transformer *)
440   module T : functor (Wrapped : Monad.S) -> sig
441     type ('x,'a) result = ('x,'a) Wrapped.result
442     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
443     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
444     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
445     val throw : err -> ('x,'a) m
446     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
447   end
448 end = struct
449   type err = Err.err
450   type 'a error = Error of err | Success of 'a
451   module Base = struct
452     type ('x,'a) m = 'a error
453     type ('x,'a) result = 'a error
454     type ('x,'a) result_exn = 'a
455     let unit a = Success a
456     let bind u f = match u with
457       | Success a -> f a
458       | Error e -> Error e (* input and output may be of different 'a types *)
459     let run u = u
460     let run_exn u = match u with
461       | Success a -> a
462       | Error e -> raise (Err.Exc e)
463     let zero () = Util.undef
464     (* satisfies Catch *)
465     let plus u v = match u with
466       | Success _ -> u
467       | Error _ -> if v == Util.undef then u else v
468   end
469   include Monad.Make(Base)
470   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
471   let throw e = Error e
472   let catch u handler = match u with
473     | Success _ -> u
474     | Error e -> handler e
475   module T(Wrapped : Monad.S) = struct
476     include Monad.MakeT(struct
477       module Wrapped = Wrapped
478       type ('x,'a) m = ('x,'a error) Wrapped.m
479       type ('x,'a) result = ('x,'a) Wrapped.result
480       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
481       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
482       let bind u f = Wrapped.bind u (fun t -> match t with
483         | Success a -> f a
484         | Error e -> Wrapped.unit (Error e))
485       let run u =
486         let w = Wrapped.bind u (fun t -> match t with
487           | Success a -> Wrapped.unit a
488           | Error e -> Wrapped.zero ()
489         ) in Wrapped.run w
490       let run_exn u =
491         let w = Wrapped.bind u (fun t -> match t with
492           | Success a -> Wrapped.unit a
493           | Error e -> raise (Err.Exc e))
494         in Wrapped.run_exn w
495       let plus u v = Wrapped.plus u v
496       let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
497     end)
498     let throw e = Wrapped.unit (Error e)
499     let catch u handler = Wrapped.bind u (fun t -> match t with
500       | Success _ -> Wrapped.unit t
501       | Error e -> handler e)
502   end
503 end
504
505 (* pre-define common instance of Error_monad *)
506 module Failure = Error_monad(struct
507   type err = string
508   exception Exc = Failure
509   (*
510   let zero = ""
511   let plus s1 s2 = s1 ^ "\n" ^ s2
512   *)
513 end)
514
515
516 (* must be parameterized on (struct type env = ... end) *)
517 module Reader_monad(Env : sig type env end) : sig
518   (* declare additional operations, while still hiding implementation of type m *)
519   type env = Env.env
520   type ('x,'a) result = env -> 'a
521   type ('x,'a) result_exn = env -> 'a
522   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
523   val ask : ('x,env) m
524   val asks : (env -> 'a) -> ('x,'a) m
525   (* lookup i == `fun e -> e i` would assume env is a functional type *)
526   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
527   (* ReaderT transformer *)
528   module T : functor (Wrapped : Monad.S) -> sig
529     type ('x,'a) result = env -> ('x,'a) Wrapped.result
530     type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
531     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
532     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
533     val ask : ('x,env) m
534     val asks : (env -> 'a) -> ('x,'a) m
535     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
536     val expose : ('x,'a) m -> env -> ('x,'a) Wrapped.m
537   end
538 end = struct
539   type env = Env.env
540   module Base = struct
541     type ('x,'a) m = env -> 'a
542     type ('x,'a) result = env -> 'a
543     type ('x,'a) result_exn = env -> 'a
544     let unit a = fun e -> a
545     let bind u f = fun e -> let a = u e in let u' = f a in u' e
546     let run u = fun e -> u e
547     let run_exn = run
548     let zero () = Util.undef
549     let plus u v = u
550   end
551   include Monad.Make(Base)
552   let ask = fun e -> e
553   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
554   let local modifier u = fun e -> u (modifier e)
555   module T(Wrapped : Monad.S) = struct
556     module BaseT = struct
557       module Wrapped = Wrapped
558       type ('x,'a) m = env -> ('x,'a) Wrapped.m
559       type ('x,'a) result = env -> ('x,'a) Wrapped.result
560       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
561       let elevate w = fun e -> w
562       let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
563       let run u = fun e -> Wrapped.run (u e)
564       let run_exn u = fun e -> Wrapped.run_exn (u e)
565       (* satisfies Distrib *)
566       let plus u v = fun e -> Wrapped.plus (u e) (v e)
567       let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
568     end
569     include Monad.MakeT(BaseT)
570     let ask = Wrapped.unit
571     let local modifier u = fun e -> u (modifier e)
572     let asks selector = ask >>= (fun e ->
573       try unit (selector e)
574       with Not_found -> fun e -> Wrapped.zero ())
575     let expose u = u
576   end
577 end
578
579
580 (* must be parameterized on (struct type store = ... end) *)
581 module State_monad(Store : sig type store end) : sig
582   (* declare additional operations, while still hiding implementation of type m *)
583   type store = Store.store
584   type ('x,'a) result =  store -> 'a * store
585   type ('x,'a) result_exn = store -> 'a
586   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
587   val get : ('x,store) m
588   val gets : (store -> 'a) -> ('x,'a) m
589   val put : store -> ('x,unit) m
590   val puts : (store -> store) -> ('x,unit) m
591   (* StateT transformer *)
592   module T : functor (Wrapped : Monad.S) -> sig
593     type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
594     type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
595     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
596     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
597     val get : ('x,store) m
598     val gets : (store -> 'a) -> ('x,'a) m
599     val put : store -> ('x,unit) m
600     val puts : (store -> store) -> ('x,unit) m
601     (* val passthru : ('x,'a) m -> (('x,'a * store) Wrapped.result * store -> 'b) -> ('x,'b) m *)
602     val expose : ('x,'a) m -> store -> ('x,'a * store) Wrapped.m
603   end
604 end = struct
605   type store = Store.store
606   module Base = struct
607     type ('x,'a) m =  store -> 'a * store
608     type ('x,'a) result =  store -> 'a * store
609     type ('x,'a) result_exn = store -> 'a
610     let unit a = fun s -> (a, s)
611     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
612     let run u = fun s -> (u s)
613     let run_exn u = fun s -> fst (u s)
614     let zero () = Util.undef
615     let plus u v = u
616   end
617   include Monad.Make(Base)
618   let get = fun s -> (s, s)
619   let gets viewer = fun s -> (viewer s, s) (* may fail *)
620   let put s = fun _ -> ((), s)
621   let puts modifier = fun s -> ((), modifier s)
622   module T(Wrapped : Monad.S) = struct
623     module BaseT = struct
624       module Wrapped = Wrapped
625       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
626       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
627       type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
628       let elevate w = fun s ->
629         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
630       let bind u f = fun s ->
631         Wrapped.bind (u s) (fun (a, s') -> f a s')
632       let run u = fun s -> Wrapped.run (u s)
633       let run_exn u = fun s ->
634         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
635         in Wrapped.run_exn w
636       (* satisfies Distrib *)
637       let plus u v = fun s -> Wrapped.plus (u s) (v s)
638       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
639     end
640     include Monad.MakeT(BaseT)
641     let get = fun s -> Wrapped.unit (s, s)
642     let gets viewer = fun s ->
643       try Wrapped.unit (viewer s, s)
644       with Not_found -> Wrapped.zero ()
645     let put s = fun _ -> Wrapped.unit ((), s)
646     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
647     (* let passthru u f = fun s -> Wrapped.unit (f (Wrapped.run (u s), s), s) *)
648     let expose u = u
649   end
650 end
651
652
653 (* State monad with different interface (structured store) *)
654 module Ref_monad(V : sig
655   type value
656 end) : sig
657   type ref
658   type value = V.value
659   type ('x,'a) result = 'a
660   type ('x,'a) result_exn = 'a
661   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
662   val newref : value -> ('x,ref) m
663   val deref : ref -> ('x,value) m
664   val change : ref -> value -> ('x,unit) m
665   (* RefT transformer *)
666   module T : functor (Wrapped : Monad.S) -> sig
667     type ('x,'a) result = ('x,'a) Wrapped.result
668     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
669     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
670     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
671     val newref : value -> ('x,ref) m
672     val deref : ref -> ('x,value) m
673     val change : ref -> value -> ('x,unit) m
674   end
675 end = struct
676   type ref = int
677   type value = V.value
678   module D = Map.Make(struct type t = ref let compare = compare end)
679   type dict = { next: ref; tree : value D.t }
680   let empty = { next = 0; tree = D.empty }
681   let alloc (value : value) (d : dict) =
682     (d.next, { next = succ d.next; tree = D.add d.next value d.tree })
683   let read (key : ref) (d : dict) =
684     D.find key d.tree
685   let write (key : ref) (value : value) (d : dict) =
686     { next = d.next; tree = D.add key value d.tree }
687   module Base = struct
688     type ('x,'a) m = dict -> 'a * dict
689     type ('x,'a) result = 'a
690     type ('x,'a) result_exn = 'a
691     let unit a = fun s -> (a, s)
692     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
693     let run u = fst (u empty)
694     let run_exn = run
695     let zero () = Util.undef
696     let plus u v = u
697   end
698   include Monad.Make(Base)
699   let newref value = fun s -> alloc value s
700   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
701   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
702   module T(Wrapped : Monad.S) = struct
703     module BaseT = struct
704       module Wrapped = Wrapped
705       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
706       type ('x,'a) result = ('x,'a) Wrapped.result
707       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
708       let elevate w = fun s ->
709         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
710       let bind u f = fun s ->
711         Wrapped.bind (u s) (fun (a, s') -> f a s')
712       let run u =
713         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
714         in Wrapped.run w
715       let run_exn u =
716         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
717         in Wrapped.run_exn w
718       (* satisfies Distrib *)
719       let plus u v = fun s -> Wrapped.plus (u s) (v s)
720       let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
721     end
722     include Monad.MakeT(BaseT)
723     let newref value = fun s -> Wrapped.unit (alloc value s)
724     let deref key = fun s -> Wrapped.unit (read key s, s)
725     let change key value = fun s -> Wrapped.unit ((), write key value s)
726   end
727 end
728
729
730 (* must be parameterized on (struct type log = ... end) *)
731 module Writer_monad(Log : sig
732   type log
733   val zero : log
734   val plus : log -> log -> log
735 end) : sig
736   (* declare additional operations, while still hiding implementation of type m *)
737   type log = Log.log
738   type ('x,'a) result = 'a * log
739   type ('x,'a) result_exn = 'a * log
740   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
741   val tell : log -> ('x,unit) m
742   val listen : ('x,'a) m -> ('x,'a * log) m
743   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
744   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
745   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
746   (* WriterT transformer *)
747   module T : functor (Wrapped : Monad.S) -> sig
748     type ('x,'a) result = ('x,'a * log) Wrapped.result
749     type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
750     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
751     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
752     val tell : log -> ('x,unit) m
753     val listen : ('x,'a) m -> ('x,'a * log) m
754     val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
755     val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
756   end
757 end = struct
758   type log = Log.log
759   module Base = struct
760     type ('x,'a) m = 'a * log
761     type ('x,'a) result = 'a * log
762     type ('x,'a) result_exn = 'a * log
763     let unit a = (a, Log.zero)
764     let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
765     let run u = u
766     let run_exn = run
767     let zero () = Util.undef
768     let plus u v = u
769   end
770   include Monad.Make(Base)
771   let tell entries = ((), entries) (* add entries to log *)
772   let listen (a, w) = ((a, w), w)
773   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
774   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
775   let censor f u = pass (u >>= fun a -> unit (a, f))
776   module T(Wrapped : Monad.S) = struct
777     module BaseT = struct
778       module Wrapped = Wrapped
779       type ('x,'a) m = ('x,'a * log) Wrapped.m
780       type ('x,'a) result = ('x,'a * log) Wrapped.result
781       type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
782       let elevate w =
783         Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
784       let bind u f =
785         Wrapped.bind u (fun (a, w) ->
786         Wrapped.bind (f a) (fun (b, w') ->
787         Wrapped.unit (b, Log.plus w w')))
788       let zero () = elevate (Wrapped.zero ())
789       let plus u v = Wrapped.plus u v
790       let run u = Wrapped.run u
791       let run_exn u = Wrapped.run_exn u
792     end
793     include Monad.MakeT(BaseT)
794     let tell entries = Wrapped.unit ((), entries)
795     let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
796     let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
797     (* rest are derived in same way as before *)
798     let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
799     let censor f u = pass (u >>= fun a -> unit (a, f))
800   end
801 end
802
803 (* pre-define simple Writer *)
804 module Writer1 = Writer_monad(struct
805   type log = string
806   let zero = ""
807   let plus s1 s2 = s1 ^ "\n" ^ s2
808 end)
809
810 (* slightly more efficient Writer *)
811 module Writer2 = struct
812   include Writer_monad(struct
813     type log = string list
814     let zero = []
815     let plus w w' = Util.append w' w
816   end)
817   let tell_string s = tell [s]
818   let tell entries = tell (Util.reverse entries)
819   let run u = let (a, w) = run u in (a, Util.reverse w)
820   let run_exn = run
821 end
822
823
824 (* TODO needs a T *)
825 module IO_monad : sig
826   (* declare additional operation, while still hiding implementation of type m *)
827   type ('x,'a) result = 'a
828   type ('x,'a) result_exn = 'a
829   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
830   val printf : ('a, unit, string, ('x,unit) m) format4 -> 'a
831   val print_string : string -> ('x,unit) m
832   val print_int : int -> ('x,unit) m
833   val print_hex : int -> ('x,unit) m
834   val print_bool : bool -> ('x,unit) m
835 end = struct
836   module Base = struct
837     type ('x,'a) m = { run : unit -> unit; value : 'a }
838     type ('x,'a) result = 'a
839     type ('x,'a) result_exn = 'a
840     let unit a = { run = (fun () -> ()); value = a }
841     let bind (a : ('x,'a) m) (f: 'a -> ('x,'b) m) : ('x,'b) m =
842      let fres = f a.value in
843        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
844     let run a = let () = a.run () in a.value
845     let run_exn = run
846     let zero () = Util.undef
847     let plus u v = u
848   end
849   include Monad.Make(Base)
850   let printf fmt =
851     Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt
852   let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () }
853   let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () }
854   let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () }
855   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
856 end
857
858
859 module Continuation_monad : sig
860   (* expose only the implementation of type `('r,'a) result` *)
861   type ('r,'a) m
862   type ('r,'a) result = ('r,'a) m
863   type ('r,'a) result_exn = ('a -> 'r) -> 'r
864   include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
865   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
866   val reset : ('a,'a) m -> ('r,'a) m
867   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
868   (* val abort : ('a,'a) m -> ('a,'b) m *)
869   val abort : 'a -> ('a,'b) m
870   val run0 : ('a,'a) m -> 'a
871   (* ContinuationT transformer *)
872   module T : functor (Wrapped : Monad.S) -> sig
873     type ('r,'a) m
874     type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
875     type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
876     include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
877     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
878     val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
879     (* TODO: reset,shift,abort,run0 *)
880   end
881 end = struct
882   let id = fun i -> i
883   module Base = struct
884     (* 'r is result type of whole computation *)
885     type ('r,'a) m = ('a -> 'r) -> 'r
886     type ('r,'a) result = ('a -> 'r) -> 'r
887     type ('r,'a) result_exn = ('r,'a) result
888     let unit a = (fun k -> k a)
889     let bind u f = (fun k -> (u) (fun a -> (f a) k))
890     let run u k = (u) k
891     let run_exn = run
892     let zero () = Util.undef
893     let plus u v = u
894   end
895   include Monad.Make(Base)
896   let callcc f = (fun k ->
897     let usek a = (fun _ -> k a)
898     in (f usek) k)
899   (*
900   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
901   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
902   let callcc f = fun k -> f k k
903   let throw k a = fun _ -> k a
904   *)
905
906   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
907    *
908    *  reset :: (Monad m) => ContT a m a -> ContT r m a
909    *  reset e = ContT $ \k -> runContT e return >>= k
910    *
911    *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
912    *  shift e = ContT $ \k ->
913    *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
914   let reset u = unit ((u) id)
915   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
916   (* let abort a = shift (fun _ -> a) *)
917   let abort a = shift (fun _ -> unit a)
918   let run0 (u : ('a,'a) m) = (u) id
919   module T(Wrapped : Monad.S) = struct
920     module BaseT = struct
921       module Wrapped = Wrapped
922       type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
923       type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
924       type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
925       let elevate w = fun k -> Wrapped.bind w k
926       let bind u f = fun k -> u (fun a -> f a k)
927       let run u k = Wrapped.run (u k)
928       let run_exn u k = Wrapped.run_exn (u k)
929       let zero () = Util.undef
930       let plus u v = u
931     end
932     include Monad.MakeT(BaseT)
933     let callcc f = (fun k ->
934       let usek a = (fun _ -> k a)
935       in (f usek) k)
936   end
937 end
938
939
940 (*
941  * Scheme:
942  * (define (example n)
943  *    (let ([u (let/cc k ; type int -> int pair
944  *               (let ([v (if (< n 0) (k 0) (list (+ n 100)))])
945  *                 (+ 1 (car v))))]) ; int
946  *      (cons u 0))) ; int pair
947  * ; (example 10) ~~> '(111 . 0)
948  * ; (example -10) ~~> '(0 . 0)
949  *
950  * OCaml monads:
951  * let example n : (int * int) =
952  *   Continuation_monad.(let u = callcc (fun k ->
953  *       (if n < 0 then k 0 else unit [n + 100])
954  *       (* all of the following is skipped by k 0; the end type int is k's input type *)
955  *       >>= fun [x] -> unit (x + 1)
956  *   )
957  *   (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
958  *   >>= fun x -> unit (x, 0)
959  *   in run u)
960  *
961  *)
962
963
964 module Tree_monad : sig
965   (* We implement the type as `'a tree option` because it has a natural`plus`,
966    * and the rest of the library expects that `plus` and `zero` will come together. *)
967   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
968   type ('x,'a) result = 'a tree option
969   type ('x,'a) result_exn = 'a tree
970   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
971   (* TreeT transformer *)
972   module T : functor (Wrapped : Monad.S) -> sig
973     type ('x,'a) result = ('x,'a tree option) Wrapped.result
974     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
975     include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
976     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
977     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
978     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
979     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
980     val expose : ('x,'a) m -> ('x,'a tree option) Wrapped.m
981   end
982 end = struct
983   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
984   (* uses supplied plus and zero to copy t to its image under f *)
985   let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with
986       | None -> zero ()
987       | Some ts -> let rec loop ts = (match ts with
988                      | Leaf a -> f a
989                      | Node (l, r) ->
990                          (* recursive application of f may delete a branch *)
991                          plus (loop l) (loop r)
992                    ) in loop ts
993   module Base = struct
994     type ('x,'a) m = 'a tree option
995     type ('x,'a) result = 'a tree option
996     type ('x,'a) result_exn = 'a tree
997     let unit a = Some (Leaf a)
998     let zero () = None
999     (* satisfies Distrib *)
1000     let plus u v = match (u, v) with
1001       | None, _ -> v
1002       | _, None -> u
1003       | Some us, Some vs -> Some (Node (us, vs))
1004     let bind u f = mapT f u zero plus
1005     let run u = u
1006     let run_exn u = match u with
1007       | None -> failwith "no values"
1008       (*
1009       | Some (Leaf a) -> a
1010       | many -> failwith "multiple values"
1011       *)
1012       | Some us -> us
1013   end
1014   include Monad.Make(Base)
1015   module T(Wrapped : Monad.S) = struct
1016     module BaseT = struct
1017       include Monad.MakeT(struct
1018         module Wrapped = Wrapped
1019         type ('x,'a) m = ('x,'a tree option) Wrapped.m
1020         type ('x,'a) result = ('x,'a tree option) Wrapped.result
1021         type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
1022         let zero () = Wrapped.unit None
1023         let plus u v =
1024           Wrapped.bind u (fun us ->
1025           Wrapped.bind v (fun vs ->
1026           Wrapped.unit (Base.plus us vs)))
1027         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
1028         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
1029         let run u = Wrapped.run u
1030         let run_exn u =
1031             let w = Wrapped.bind u (fun t -> match t with
1032               | None -> Wrapped.zero ()
1033               | Some ts -> Wrapped.unit ts
1034             ) in Wrapped.run_exn w
1035       end)
1036     end
1037     include BaseT
1038     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
1039     let expose u = u
1040   end
1041
1042 end;;
1043
1044