From: Jim Pryor Date: Sat, 11 Dec 2010 03:39:46 +0000 (-0500) Subject: tweak monads-lib, start T2 X-Git-Url: http://lambda.jimpryor.net/git/gitweb.cgi?p=lambda.git;a=commitdiff_plain;h=d3ea4212bbca7a47f8aae537023852fd9a214389 tweak monads-lib, start T2 Signed-off-by: Jim Pryor --- diff --git a/code/monads.ml b/code/monads.ml index a334b0f0..0dab8712 100644 --- a/code/monads.ml +++ b/code/monads.ml @@ -70,20 +70,34 @@ end *) module Monad = struct + (* + * Signature extenders: + * Make :: BASE -> S + * MakeCatch, MakeDistrib :: PLUSBASE -> PLUS + * which merges into S + * (P is merged sig) + * MakeT :: TRANS (with Wrapped : S or P) -> custom sig + * + * Make2 :: BASE2 -> S2 + * MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig) + * to wrap double-typed inner monads: + * MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig + * + *) + + (* type of base definitions *) module type BASE = sig - (* - * The only constraints we impose here on how the monadic type - * is implemented is that it have a single type parameter 'a. - *) + (* The only constraints we impose here on how the monadic type + * is implemented is that it have a single type parameter 'a. *) type 'a m + type 'a result + type 'a result_exn val unit : 'a -> 'a m val bind : 'a m -> ('a -> 'b m) -> 'b m - type 'a result val run : 'a m -> 'a result (* run_exn tries to provide a more ground-level result, but may fail *) - type 'a result_exn val run_exn : 'a m -> 'a result_exn end module type S = sig @@ -118,7 +132,6 @@ module Monad = struct let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a') (* let lift f u === apply (unit f) u *) (* let lift2 f u v = apply (lift f u) v *) - let (>=>) f g = fun a -> f a >>= g let do_when test u = if test then u else unit () let do_unless test u = if test then unit () else u @@ -184,15 +197,44 @@ module Monad = struct end module MakeDistrib = MakeCatch + (* Signatures for MonadT *) + (* sig for Wrapped that include S and PLUS *) + module type P = sig + include S + include PLUS with type 'a m := 'a m + end + module type TRANS = sig + module Wrapped : S + type 'a m + type 'a result + type 'a result_exn + val bind : 'a m -> ('a -> 'b m) -> 'b m + val run : 'a m -> 'a result + val run_exn : 'a m -> 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + (* lift/elevate laws: + * elevate (W.unit a) == unit a + * elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a) + *) + end + module MakeT(T : TRANS) = struct + include Make(struct + include T + let unit a = elevate (Wrapped.unit a) + end) + let elevate = T.elevate + end + + (* We have to define BASE, S, and Make again for double-type-parameter monads. *) module type BASE2 = sig type ('x,'a) m + type ('x,'a) result + type ('x,'a) result_exn val unit : 'a -> ('x,'a) m val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m - type ('x,'a) result val run : ('x,'a) m -> ('x,'a) result - type ('x,'a) result_exn - val run_exn : ('x,'a) m -> ('x,'a) result + val run_exn : ('x,'a) m -> ('x,'a) result_exn end module type S2 = sig include BASE2 @@ -210,6 +252,7 @@ module Monad = struct val sequence_ : ('x,'a) m list -> ('x,unit) m end module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct + (* code repetition, ugh *) include B let (>>=) = bind let (>>) u v = u >>= fun _ -> v @@ -228,31 +271,47 @@ module Monad = struct Util.fold_right (>>) ms (unit ()) end - (* Signatures for MonadT *) - module type W = sig - include S + module type PLUSBASE2 = sig + include BASE2 + val zero : unit -> ('x,'a) m + val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m end - module type WP = sig - include W - val zero : unit -> 'a m - val plus : 'a m -> 'a m -> 'a m + module type PLUS2 = sig + type ('x,'a) m + val zero : unit -> ('x,'a) m + val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m + val guard : bool -> ('x,unit) m + val sum : ('x,'a) m list -> ('x,'a) m end - module type TRANS = sig - type 'a m - val bind : 'a m -> ('a -> 'b m) -> 'b m - module Wrapped : W - type 'a result - val run : 'a m -> 'a result - type 'a result_exn - val run_exn : 'a m -> 'a result_exn - val elevate : 'a Wrapped.m -> 'a m - (* lift/elevate laws: - * elevate (W.unit a) == unit a - * elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a) - *) + module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct + type ('x,'a) m = ('x,'a) B.m + (* code repetition, ugh *) + let zero = B.zero + let plus = B.plus + let guard test = if test then B.unit () else zero () + let sum ms = Util.fold_right plus ms (zero ()) end - module MakeT(T : TRANS) = struct - include Make(struct + module MakeDistrib2 = MakeCatch2 + + (* Signatures for MonadT *) + (* sig for Wrapped that include S and PLUS *) + module type P2 = sig + include S2 + include PLUS2 with type ('x,'a) m := ('x,'a) m + end + module type TRANS2 = sig + module Wrapped : S2 + type ('x,'a) m + type ('x,'a) result + type ('x,'a) result_exn + val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m + val run : ('x,'a) m -> ('x,'a) result + val run_exn : ('x,'a) m -> ('x,'a) result_exn + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + end + module MakeT2(T : TRANS2) = struct + (* code repetition, ugh *) + include Make2(struct include T let unit a = elevate (Wrapped.unit a) end) @@ -291,13 +350,20 @@ module Maybe_monad : sig include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn include Monad.PLUS with type 'a m := 'a m (* MaybeT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = 'a option Wrapped.result type 'a result_exn = 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn include Monad.PLUS with type 'a m := 'a m val elevate : 'a Wrapped.m -> 'a m end + module T2 : functor (Wrapped : Monad.S2) -> sig + type ('x,'a) result = ('x,'a option) Wrapped.result + type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn + include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn + include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + end end = struct module Base = struct type 'a m = 'a option @@ -314,18 +380,18 @@ end = struct end include Monad.Make(Base) include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct include Monad.MakeT(struct module Wrapped = Wrapped type 'a m = 'a option Wrapped.m + type 'a result = 'a option Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a)) let bind u f = Wrapped.bind u (fun t -> match t with | Some a -> f a | None -> Wrapped.unit None) - type 'a result = 'a option Wrapped.result let run u = Wrapped.run u - type 'a result_exn = 'a Wrapped.result_exn let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Some a -> Wrapped.unit a @@ -338,6 +404,31 @@ end = struct include Trans include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m) end + module T2(Wrapped : Monad.S2) = struct + module Trans = struct + include Monad.MakeT2(struct + module Wrapped = Wrapped + type ('x,'a) m = ('x,'a option) Wrapped.m + type ('x,'a) result = ('x,'a option) Wrapped.result + type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn + (* code repetition, ugh *) + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a)) + let bind u f = Wrapped.bind u (fun t -> match t with + | Some a -> f a + | None -> Wrapped.unit None) + let run u = Wrapped.run u + let run_exn u = + let w = Wrapped.bind u (fun t -> match t with + | Some a -> Wrapped.unit a + | None -> failwith "no value") + in Wrapped.run_exn w + end) + let zero () = Wrapped.unit None + let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u) + end + include Trans + include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m) + end end @@ -350,7 +441,7 @@ module List_monad : sig val permute : 'a m -> 'a m m val select : 'a m -> ('a * 'a m) m (* ListT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = 'a list Wrapped.result type 'a result_exn = 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn @@ -364,6 +455,16 @@ module List_monad : sig val select : 'a m -> ('a * 'a m) m *) end +(* + module T2 : functor (Wrapped : Monad.S2) -> sig + type ('x,'a) result = ('x,'a list) Wrapped.result + type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn + include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn + include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m + end + *) end = struct module Base = struct type 'a m = 'a list @@ -397,7 +498,7 @@ end = struct | [] -> zero () | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs')) let base_plus = plus - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct let zero () = Wrapped.unit [] let plus u v = @@ -415,14 +516,14 @@ end = struct include Monad.MakeT(struct module Wrapped = Wrapped type 'a m = 'a list Wrapped.m + type 'a result = 'a list Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a]) let bind u f = Wrapped.bind u (fun ts -> Wrapped.bind (distribute f ts) (fun tts -> Wrapped.unit (Util.concat tts))) - type 'a result = 'a list Wrapped.result let run u = Wrapped.run u - type 'a result_exn = 'a Wrapped.result_exn let run_exn u = let w = Wrapped.bind u (fun ts -> match ts with | [] -> failwith "no values" @@ -460,7 +561,7 @@ end) : sig val throw : err -> 'a m val catch : 'a m -> (err -> 'a m) -> 'a m (* ErrorT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = 'a Wrapped.result type 'a result_exn = 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn @@ -502,7 +603,7 @@ end = struct let catch u handler = match u with | Success _ -> u | Error e -> handler e - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct module Wrapped = Wrapped type 'a m = 'a Base.m Wrapped.m @@ -555,7 +656,7 @@ module Reader_monad(Env : sig type env end) : sig val asks : (env -> 'a) -> 'a m val local : (env -> env) -> 'a m -> 'a m (* ReaderT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = env -> 'a Wrapped.result type 'a result_exn = env -> 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn @@ -565,7 +666,7 @@ module Reader_monad(Env : sig type env end) : sig val local : (env -> env) -> 'a m -> 'a m end (* ReaderT transformer when wrapped monad has plus, zero *) - module TP : functor (Wrapped : Monad.WP) -> sig + module TP : functor (Wrapped : Monad.P) -> sig include module type of T(Wrapped) include Monad.PLUS with type 'a m := 'a m end @@ -584,7 +685,7 @@ end = struct let ask = fun e -> e let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *) let local modifier u = fun e -> u (modifier e) - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct module Wrapped = Wrapped type 'a m = env -> 'a Wrapped.m @@ -600,7 +701,7 @@ end = struct let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *) let local modifier u = fun e -> u (modifier e) end - module TP(Wrapped : Monad.WP) = struct + module TP(Wrapped : Monad.P) = struct module TransP = struct include T(Wrapped) let plus u v = fun s -> Wrapped.plus (u s) (v s) @@ -627,7 +728,7 @@ module State_monad(Store : sig type store end) : sig val put : store -> unit m val puts : (store -> store) -> unit m (* StateT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = store -> ('a * store) Wrapped.result type 'a result_exn = store -> 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn @@ -638,7 +739,7 @@ module State_monad(Store : sig type store end) : sig val puts : (store -> store) -> unit m end (* StateT transformer when wrapped monad has plus, zero *) - module TP : functor (Wrapped : Monad.WP) -> sig + module TP : functor (Wrapped : Monad.P) -> sig include module type of T(Wrapped) include Monad.PLUS with type 'a m := 'a m end @@ -658,7 +759,7 @@ end = struct let gets viewer = fun s -> (viewer s, s) (* may fail *) let put s = fun _ -> ((), s) let puts modifier = fun s -> ((), modifier s) - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct module Wrapped = Wrapped type 'a m = store -> ('a * store) Wrapped.m @@ -679,7 +780,7 @@ end = struct let put s = fun _ -> Wrapped.unit ((), s) let puts modifier = fun s -> Wrapped.unit ((), modifier s) end - module TP(Wrapped : Monad.WP) = struct + module TP(Wrapped : Monad.P) = struct module TransP = struct include T(Wrapped) let plus u v = fun s -> Wrapped.plus (u s) (v s) @@ -706,7 +807,7 @@ end) : sig val deref : ref -> value m val change : ref -> value -> unit m (* RefT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = 'a Wrapped.result type 'a result_exn = 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn @@ -716,7 +817,7 @@ end) : sig val change : ref -> value -> unit m end (* RefT transformer when wrapped monad has plus, zero *) - module TP : functor (Wrapped : Monad.WP) -> sig + module TP : functor (Wrapped : Monad.P) -> sig include module type of T(Wrapped) include Monad.PLUS with type 'a m := 'a m end @@ -745,7 +846,7 @@ end = struct let newref value = fun s -> alloc value s let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *) let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *) - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct module Wrapped = Wrapped type 'a m = dict -> ('a * dict) Wrapped.m @@ -767,7 +868,7 @@ end = struct let deref key = fun s -> Wrapped.unit (read key s, s) let change key value = fun s -> Wrapped.unit ((), write key value s) end - module TP(Wrapped : Monad.WP) = struct + module TP(Wrapped : Monad.P) = struct module TransP = struct include T(Wrapped) let plus u v = fun s -> Wrapped.plus (u s) (v s) @@ -872,9 +973,12 @@ module Continuation_monad : sig type 'a result = 'a m type 'a result_exn = 'a m include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m + (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *) (* misses that the answer types of all the cont's must be the same *) val callcc : (('a -> 'b m) -> 'a m) -> 'a m + (* val reset : ('a,'a) m -> ('r,'a) m *) val reset : 'a m -> 'a m + (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *) (* misses that the answer types of second and third continuations must be b *) val shift : (('a -> 'b m) -> 'b m) -> 'a m (* overwrite the run declaration in S, because I can't declare 'a result = @@ -899,23 +1003,24 @@ end = struct type 'a result_exn = 'a m let run_exn (u : 'a m) : 'a result_exn = u let callcc f = - let cont : 'r. ('a -> 'r) -> 'r = fun k -> + let cont : 'r. ('a -> 'r) -> 'r = (* Can't figure out how to make the type polymorphic enough * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge * with Obj.magic... which tells OCaml's type checker to * relax, the supplied value has whatever type the context * needs it to have. *) + fun k -> let usek a = { cont = Obj.magic (fun _ -> k a) } in (f usek).cont k in { cont } let reset u = unit (u.cont id) let shift (f : ('a -> 'b m) -> 'b m) : 'a m = - let cont = - fun k -> (f (fun a -> unit (k a))).cont id + let cont = fun k -> + (f (fun a -> unit (k a))).cont id in { cont = Obj.magic cont } - let runk u k = u.cont k - let run0 u = u.cont id + let runk u k = (u.cont : ('a -> 'r) -> 'r) k + let run0 u = runk u id end include Monad.Make(Base) let callcc = Base.callcc @@ -1035,7 +1140,7 @@ module Leaf_monad : sig include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn include Monad.PLUS with type 'a m := 'a m (* LeafT transformer *) - module T : functor (Wrapped : Monad.W) -> sig + module T : functor (Wrapped : Monad.S) -> sig type 'a result = 'a tree option Wrapped.result type 'a result_exn = 'a tree Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn @@ -1080,7 +1185,7 @@ end = struct include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m) let base_plus = plus let base_lift = lift - module T(Wrapped : Monad.W) = struct + module T(Wrapped : Monad.S) = struct module Trans = struct let zero () = Wrapped.unit None let plus u v =