X-Git-Url: http://lambda.jimpryor.net/git/gitweb.cgi?p=lambda.git;a=blobdiff_plain;f=code%2Fmonads.ml;h=b678a92c28a75c78e257e55a9979268477b09e3e;hp=34ad1cef5771d3ea8caf3c1dc7d1f27bfdc8756a;hb=6620761d57706f681ccf431c19d6d6fd77fa2942;hpb=ef6e90fdda821a1d071c93dc50587d1a0fd207b9 diff --git a/code/monads.ml b/code/monads.ml index 34ad1cef..b678a92c 100644 --- a/code/monads.ml +++ b/code/monads.ml @@ -42,6 +42,17 @@ * have to use operations like `run` to convert the abstract monadic types * to types whose internals you have free access to. * + * Acknowledgements: This is largely based on the mtl library distributed + * with the Glasgow Haskell Compiler. I've also been helped in + * various ways by posts and direct feedback from Oleg Kiselyov and + * Chung-chieh Shan. The following were also useful: + * - + * - Ken Shan "Monads for natural language semantics" + * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf + * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers + * + * Licensing: MIT (if that's compatible with the ghc sources this is partly + * derived from) *) exception Undefined @@ -63,7 +74,7 @@ module Util = struct in loop len [] (* Dirty hack to be a default polymorphic zero. * To implement this cleanly, monads without a natural zero - * should always wrap themselves in an option layer (see Leaf_monad). *) + * should always wrap themselves in an option layer (see Tree_monad). *) let undef = Obj.magic (fun () -> raise Undefined) end @@ -140,6 +151,10 @@ module Monad = struct let run_exn u = if u == Util.undef then raise Undefined else B.run_exn u let (>>=) = bind + (* expressions after >> will be evaluated before they're passed to + * bind, so you can't do `zero () >> assert false` + * this works though: `zero () >>= fun _ -> assert false` + *) let (>>) u v = u >>= fun _ -> v let lift f u = u >>= fun a -> unit (f a) (* lift is called listM, fmap, and <$> in Haskell *) @@ -155,10 +170,21 @@ module Monad = struct let (>=>) f g = fun a -> f a >>= g let do_when test u = if test then u else unit () let do_unless test u = if test then unit () else u - (* not in tail position, will Stack overflow *) + (* A Haskell-like version works: + let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk + * but the recursive call is not in tail position so this can stack overflow. *) let forever uthunk = - let rec loop () = uthunk () >>= fun _ -> loop () - in loop () + let z = zero () in + let id result = result in + let kcell = ref id in + let rec loop _ = + let result = uthunk (kcell := id) >>= chained + in !kcell result + and chained _ = + kcell := loop; z (* we use z only for its polymorphism *) + in loop z + (* Reimplementations of the preceding using a hand-rolled State or StateT +can also stack overflow. *) let sequence ms = let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in Util.fold_right op ms (unit []) @@ -432,18 +458,10 @@ end = struct | Success a -> a | Error e -> raise (Err.Exc e) let zero () = Util.undef - let plus u v = u - (* - let zero () = Error Err.zero - let plus u v = match (u, v) with - | Success _, _ -> u - (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _ - * otherwise, plus (Error _) v = v *) - | Error _, _ when v = zero -> u - (* combine errors *) - | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2) - | Error _, _ -> v - *) + (* satisfies Catch *) + let plus u v = match u with + | Success _ -> u + | Error _ -> if v == Util.undef then u else v end include Monad.Make(Base) (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *) @@ -472,7 +490,7 @@ end = struct | Error e -> raise (Err.Exc e)) in Wrapped.run_exn w let plus u v = Wrapped.plus u v - let zero () = elevate (Wrapped.zero ()) + let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *) end) let throw e = Wrapped.unit (Error e) let catch u handler = Wrapped.bind u (fun t -> match t with @@ -491,26 +509,6 @@ module Failure = Error_monad(struct *) end) -(* -# EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));; -- : int EL.result = [Failure.Error "bye"; Failure.Success 30] -# LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));; -- : int LE.result = Failure.Error "bye" -# EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));; -Exception: Failure "bye". -# LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));; -Exception: Failure "bye". - -# ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;; -- : int Failure.error * S.store = (Failure.Error "bye", 1) -# SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;; -- : (int * S.store) Failure.result = Failure.Error "bye" -# ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;; -Exception: Failure "bye". -# SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;; -Exception: Failure "bye". - *) - (* must be parameterized on (struct type env = ... end) *) module Reader_monad(Env : sig type env end) : sig @@ -557,15 +555,15 @@ end = struct type ('x,'a) result = env -> ('x,'a) Wrapped.result type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn let elevate w = fun e -> w - let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e) + let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e) let run u = fun e -> Wrapped.run (u e) let run_exn u = fun e -> Wrapped.run_exn (u e) (* satisfies Distrib *) - let plus u v = fun s -> Wrapped.plus (u s) (v s) - let zero () = elevate (Wrapped.zero ()) + let plus u v = fun e -> Wrapped.plus (u e) (v e) + let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *) end include Monad.MakeT(BaseT) - let ask = fun e -> Wrapped.unit e + let ask = Wrapped.unit let local modifier u = fun e -> u (modifier e) let asks selector = ask >>= (fun e -> try unit (selector e) @@ -630,7 +628,7 @@ end = struct in Wrapped.run_exn w (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) - let zero () = elevate (Wrapped.zero ()) + let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *) end include Monad.MakeT(BaseT) let get = fun s -> Wrapped.unit (s, s) @@ -642,6 +640,7 @@ end = struct end end + (* State monad with different interface (structured store) *) module Ref_monad(V : sig type value @@ -709,7 +708,7 @@ end = struct in Wrapped.run_exn w (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) - let zero () = elevate (Wrapped.zero ()) + let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *) end include Monad.MakeT(BaseT) let newref value = fun s -> Wrapped.unit (alloc value s) @@ -718,7 +717,7 @@ end = struct end end -(* TODO needs a T *) + (* must be parameterized on (struct type log = ... end) *) module Writer_monad(Log : sig type log @@ -735,6 +734,17 @@ end) : sig val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *) val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m + (* WriterT transformer *) + module T : functor (Wrapped : Monad.S) -> sig + type ('x,'a) result = ('x,'a * log) Wrapped.result + type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn + include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + val tell : log -> ('x,unit) m + val listen : ('x,'a) m -> ('x,'a * log) m + val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m + val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m + end end = struct type log = Log.log module Base = struct @@ -742,7 +752,7 @@ end = struct type ('x,'a) result = 'a * log type ('x,'a) result_exn = 'a * log let unit a = (a, Log.zero) - let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w') + let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w') let run u = u let run_exn = run let zero () = Util.undef @@ -754,6 +764,31 @@ end = struct let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *) let pass ((a, f), w) = (a, f w) (* usually use censor helper *) let censor f u = pass (u >>= fun a -> unit (a, f)) + module T(Wrapped : Monad.S) = struct + module BaseT = struct + module Wrapped = Wrapped + type ('x,'a) m = ('x,'a * log) Wrapped.m + type ('x,'a) result = ('x,'a * log) Wrapped.result + type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn + let elevate w = + Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero)) + let bind u f = + Wrapped.bind u (fun (a, w) -> + Wrapped.bind (f a) (fun (b, w') -> + Wrapped.unit (b, Log.plus w w'))) + let zero () = elevate (Wrapped.zero ()) + let plus u v = Wrapped.plus u v + let run u = Wrapped.run u + let run_exn u = Wrapped.run_exn u + end + include Monad.MakeT(BaseT) + let tell entries = Wrapped.unit ((), entries) + let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w)) + let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w)) + (* rest are derived in same way as before *) + let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) + let censor f u = pass (u >>= fun a -> unit (a, f)) + end end (* pre-define simple Writer *) @@ -812,7 +847,6 @@ end = struct end -(* TODO needs a T *) module Continuation_monad : sig (* expose only the implementation of type `('r,'a) result` *) type ('r,'a) m @@ -825,6 +859,16 @@ module Continuation_monad : sig (* val abort : ('a,'a) m -> ('a,'b) m *) val abort : 'a -> ('a,'b) m val run0 : ('a,'a) m -> 'a + (* ContinuationT transformer *) + module T : functor (Wrapped : Monad.S) -> sig + type ('r,'a) m + type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result + type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn + include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m + (* TODO: reset,shift,abort,run0 *) + end end = struct let id = fun i -> i module Base = struct @@ -863,6 +907,24 @@ end = struct (* let abort a = shift (fun _ -> a) *) let abort a = shift (fun _ -> unit a) let run0 (u : ('a,'a) m) = (u) id + module T(Wrapped : Monad.S) = struct + module BaseT = struct + module Wrapped = Wrapped + type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m + type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result + type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn + let elevate w = fun k -> Wrapped.bind w k + let bind u f = fun k -> u (fun a -> f a k) + let run u k = Wrapped.run (u k) + let run_exn u k = Wrapped.run_exn (u k) + let zero () = Util.undef + let plus u v = u + end + include Monad.MakeT(BaseT) + let callcc f = (fun k -> + let usek a = (fun _ -> k a) + in (f usek) k) + end end @@ -926,14 +988,14 @@ end *) -module Leaf_monad : sig +module Tree_monad : sig (* We implement the type as `'a tree option` because it has a natural`plus`, * and the rest of the library expects that `plus` and `zero` will come together. *) type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree) type ('x,'a) result = 'a tree option type ('x,'a) result_exn = 'a tree include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn - (* LeafT transformer *) + (* TreeT transformer *) module T : functor (Wrapped : Monad.S) -> sig type ('x,'a) result = ('x,'a tree option) Wrapped.result type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn @@ -1007,7 +1069,7 @@ end module L = List_monad;; module R = Reader_monad(struct type env = int -> int end);; module S = State_monad(struct type store = int end);; -module T = Leaf_monad;; +module T = Tree_monad;; module LR = L.T(R);; module LS = L.T(S);; module TL = T.T(L);; @@ -1017,7 +1079,7 @@ module C = Continuation_monad module TC = T.T(C);; -print_endline "=== test Leaf(...).distribute ==================";; +print_endline "=== test TreeT(...).distribute ==================";; let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));; @@ -1087,7 +1149,7 @@ LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts - : S.store list * S.store = ([10; 0; 0; 1; 20], 1) *) -print_endline "=== test Leaf(Continuation).distribute ==================";; +print_endline "=== test TreeT(Continuation).distribute ==================";; let id : 'z. 'z -> 'z = fun x -> x