tweak monads-lib, start T2
authorJim Pryor <profjim@jimpryor.net>
Sat, 11 Dec 2010 03:39:46 +0000 (22:39 -0500)
committerJim Pryor <profjim@jimpryor.net>
Sat, 11 Dec 2010 03:39:46 +0000 (22:39 -0500)
Signed-off-by: Jim Pryor <profjim@jimpryor.net>
code/monads.ml

index a334b0f..0dab871 100644 (file)
@@ -70,20 +70,34 @@ end
  *)
 
 module Monad = struct
+  (*
+   * Signature extenders:
+   *   Make :: BASE -> S
+   *     MakeCatch, MakeDistrib :: PLUSBASE -> PLUS
+   *                                           which merges into S
+   *                                           (P is merged sig)
+   *   MakeT :: TRANS (with Wrapped : S or P) -> custom sig
+   *
+   *   Make2 :: BASE2 -> S2
+   *     MakeCatch2, MakeDistrib2 :: PLUSBASE2 -> PLUS2 (P2 is merged sig)
+   *   to wrap double-typed inner monads:
+   *   MakeT2 :: TRANS2 (with Wrapped : S2 or P2) -> custom sig
+   *
+   *)
+
+
 
   (* type of base definitions *)
   module type BASE = sig
-    (*
-     * The only constraints we impose here on how the monadic type
-     * is implemented is that it have a single type parameter 'a.
-     *)
+    (* The only constraints we impose here on how the monadic type
+     * is implemented is that it have a single type parameter 'a. *)
     type 'a m
+    type 'a result
+    type 'a result_exn
     val unit : 'a -> 'a m
     val bind : 'a m -> ('a -> 'b m) -> 'b m
-    type 'a result
     val run : 'a m -> 'a result
     (* run_exn tries to provide a more ground-level result, but may fail *)
-    type 'a result_exn
     val run_exn : 'a m -> 'a result_exn
   end
   module type S = sig
@@ -118,7 +132,6 @@ module Monad = struct
     let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a')
     (* let lift f u === apply (unit f) u *)
     (* let lift2 f u v = apply (lift f u) v *)
-
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
@@ -184,15 +197,44 @@ module Monad = struct
   end
   module MakeDistrib = MakeCatch
 
+  (* Signatures for MonadT *)
+  (* sig for Wrapped that include S and PLUS *)
+  module type P = sig
+    include S
+    include PLUS with type 'a m := 'a m
+  end
+  module type TRANS = sig
+    module Wrapped : S
+    type 'a m
+    type 'a result
+    type 'a result_exn
+    val bind : 'a m -> ('a -> 'b m) -> 'b m
+    val run : 'a m -> 'a result
+    val run_exn : 'a m -> 'a result_exn
+    val elevate : 'a Wrapped.m -> 'a m
+    (* lift/elevate laws:
+     *     elevate (W.unit a) == unit a
+     *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
+     *)
+  end
+  module MakeT(T : TRANS) = struct
+    include Make(struct
+        include T
+        let unit a = elevate (Wrapped.unit a)
+    end)
+    let elevate = T.elevate
+  end
+
+
   (* We have to define BASE, S, and Make again for double-type-parameter monads. *)
   module type BASE2 = sig
     type ('x,'a) m
+    type ('x,'a) result
+    type ('x,'a) result_exn
     val unit : 'a -> ('x,'a) m
     val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
-    type ('x,'a) result
     val run : ('x,'a) m -> ('x,'a) result
-    type ('x,'a) result_exn
-    val run_exn : ('x,'a) m -> ('x,'a) result
+    val run_exn : ('x,'a) m -> ('x,'a) result_exn
   end
   module type S2 = sig
     include BASE2
@@ -210,6 +252,7 @@ module Monad = struct
     val sequence_ : ('x,'a) m list -> ('x,unit) m
   end
   module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
+    (* code repetition, ugh *)
     include B
     let (>>=) = bind
     let (>>) u v = u >>= fun _ -> v
@@ -228,31 +271,47 @@ module Monad = struct
       Util.fold_right (>>) ms (unit ())
   end
 
-  (* Signatures for MonadT *)
-  module type W = sig
-    include S
+  module type PLUSBASE2 = sig
+    include BASE2
+    val zero : unit -> ('x,'a) m
+    val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
   end
-  module type WP = sig
-    include W
-    val zero : unit -> 'a m
-    val plus : 'a m -> 'a m -> 'a m
+  module type PLUS2 = sig
+    type ('x,'a) m
+    val zero : unit -> ('x,'a) m
+    val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
+    val guard : bool -> ('x,unit) m
+    val sum : ('x,'a) m list -> ('x,'a) m
   end
-  module type TRANS = sig
-    type 'a m
-    val bind : 'a m -> ('a -> 'b m) -> 'b m
-    module Wrapped : W
-    type 'a result
-    val run : 'a m -> 'a result
-    type 'a result_exn
-    val run_exn : 'a m -> 'a result_exn
-    val elevate : 'a Wrapped.m -> 'a m
-    (* lift/elevate laws:
-     *     elevate (W.unit a) == unit a
-     *     elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a)
-     *)
+  module MakeCatch2(B : PLUSBASE2) : PLUS2 with type ('x,'a) m = ('x,'a) B.m = struct
+    type ('x,'a) m = ('x,'a) B.m
+    (* code repetition, ugh *)
+    let zero = B.zero
+    let plus = B.plus
+    let guard test = if test then B.unit () else zero ()
+    let sum ms = Util.fold_right plus ms (zero ())
   end
-  module MakeT(T : TRANS) = struct
-    include Make(struct
+  module MakeDistrib2 = MakeCatch2
+
+  (* Signatures for MonadT *)
+  (* sig for Wrapped that include S and PLUS *)
+  module type P2 = sig
+    include S2
+    include PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
+  module type TRANS2 = sig
+    module Wrapped : S2
+    type ('x,'a) m
+    type ('x,'a) result
+    type ('x,'a) result_exn
+    val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m
+    val run : ('x,'a) m -> ('x,'a) result
+    val run_exn : ('x,'a) m -> ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+  end
+  module MakeT2(T : TRANS2) = struct
+    (* code repetition, ugh *)
+    include Make2(struct
         include T
         let unit a = elevate (Wrapped.unit a)
     end)
@@ -291,13 +350,20 @@ module Maybe_monad : sig
   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
   include Monad.PLUS with type 'a m := 'a m
   (* MaybeT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = 'a option Wrapped.result
     type 'a result_exn = 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
     include Monad.PLUS with type 'a m := 'a m
     val elevate : 'a Wrapped.m -> 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a option) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+  end
 end = struct
   module Base = struct
    type 'a m = 'a option
@@ -314,18 +380,18 @@ end = struct
   end
   include Monad.Make(Base)
   include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m)
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       include Monad.MakeT(struct
         module Wrapped = Wrapped
         type 'a m = 'a option Wrapped.m
+        type 'a result = 'a option Wrapped.result
+        type 'a result_exn = 'a Wrapped.result_exn
         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
         let bind u f = Wrapped.bind u (fun t -> match t with
           | Some a -> f a
           | None -> Wrapped.unit None)
-        type 'a result = 'a option Wrapped.result
         let run u = Wrapped.run u
-        type 'a result_exn = 'a Wrapped.result_exn
         let run_exn u =
           let w = Wrapped.bind u (fun t -> match t with
             | Some a -> Wrapped.unit a
@@ -338,6 +404,31 @@ end = struct
     include Trans
     include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a option) Wrapped.m
+        type ('x,'a) result = ('x,'a option) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+        (* code repetition, ugh *)
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a))
+        let bind u f = Wrapped.bind u (fun t -> match t with
+          | Some a -> f a
+          | None -> Wrapped.unit None)
+        let run u = Wrapped.run u
+        let run_exn u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Some a -> Wrapped.unit a
+            | None -> failwith "no value")
+          in Wrapped.run_exn w
+      end)
+      let zero () = Wrapped.unit None
+      let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
+    end
+    include Trans
+    include (Monad.MakeCatch2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -350,7 +441,7 @@ module List_monad : sig
   val permute : 'a m -> 'a m m
   val select : 'a m -> ('a * 'a m) m
   (* ListT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = 'a list Wrapped.result
     type 'a result_exn = 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
@@ -364,6 +455,16 @@ module List_monad : sig
     val select : 'a m -> ('a * 'a m) m
 *)
   end
+(*
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a list) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
+  end
+ *)
 end = struct
   module Base = struct
    type 'a m = 'a list
@@ -397,7 +498,7 @@ end = struct
     | [] -> zero ()
     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
   let base_plus = plus
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       let zero () = Wrapped.unit []
       let plus u v =
@@ -415,14 +516,14 @@ end = struct
       include Monad.MakeT(struct
         module Wrapped = Wrapped
         type 'a m = 'a list Wrapped.m
+        type 'a result = 'a list Wrapped.result
+        type 'a result_exn = 'a Wrapped.result_exn
         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
         let bind u f =
           Wrapped.bind u (fun ts ->
           Wrapped.bind (distribute f ts) (fun tts ->
           Wrapped.unit (Util.concat tts)))
-        type 'a result = 'a list Wrapped.result
         let run u = Wrapped.run u
-        type 'a result_exn = 'a Wrapped.result_exn
         let run_exn u =
           let w = Wrapped.bind u (fun ts -> match ts with
             | [] -> failwith "no values"
@@ -460,7 +561,7 @@ end) : sig
   val throw : err -> 'a m
   val catch : 'a m -> (err -> 'a m) -> 'a m
   (* ErrorT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = 'a Wrapped.result
     type 'a result_exn = 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
@@ -502,7 +603,7 @@ end = struct
   let catch u handler = match u with
     | Success _ -> u
     | Error e -> handler e
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = 'a Base.m Wrapped.m
@@ -555,7 +656,7 @@ module Reader_monad(Env : sig type env end) : sig
   val asks : (env -> 'a) -> 'a m
   val local : (env -> env) -> 'a m -> 'a m
   (* ReaderT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = env -> 'a Wrapped.result
     type 'a result_exn = env -> 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
@@ -565,7 +666,7 @@ module Reader_monad(Env : sig type env end) : sig
     val local : (env -> env) -> 'a m -> 'a m
   end
   (* ReaderT transformer when wrapped monad has plus, zero *)
-  module TP : functor (Wrapped : Monad.WP) -> sig
+  module TP : functor (Wrapped : Monad.P) -> sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
@@ -584,7 +685,7 @@ end = struct
   let ask = fun e -> e
   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
   let local modifier u = fun e -> u (modifier e)
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = env -> 'a Wrapped.m
@@ -600,7 +701,7 @@ end = struct
     let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
     let local modifier u = fun e -> u (modifier e)
   end
-  module TP(Wrapped : Monad.WP) = struct
+  module TP(Wrapped : Monad.P) = struct
     module TransP = struct
       include T(Wrapped)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
@@ -627,7 +728,7 @@ module State_monad(Store : sig type store end) : sig
   val put : store -> unit m
   val puts : (store -> store) -> unit m
   (* StateT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = store -> ('a * store) Wrapped.result
     type 'a result_exn = store -> 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
@@ -638,7 +739,7 @@ module State_monad(Store : sig type store end) : sig
     val puts : (store -> store) -> unit m
   end
   (* StateT transformer when wrapped monad has plus, zero *)
-  module TP : functor (Wrapped : Monad.WP) -> sig
+  module TP : functor (Wrapped : Monad.P) -> sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
@@ -658,7 +759,7 @@ end = struct
   let gets viewer = fun s -> (viewer s, s) (* may fail *)
   let put s = fun _ -> ((), s)
   let puts modifier = fun s -> ((), modifier s)
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = store -> ('a * store) Wrapped.m
@@ -679,7 +780,7 @@ end = struct
     let put s = fun _ -> Wrapped.unit ((), s)
     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
   end
-  module TP(Wrapped : Monad.WP) = struct
+  module TP(Wrapped : Monad.P) = struct
     module TransP = struct
       include T(Wrapped)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
@@ -706,7 +807,7 @@ end) : sig
   val deref : ref -> value m
   val change : ref -> value -> unit m
   (* RefT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = 'a Wrapped.result
     type 'a result_exn = 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
@@ -716,7 +817,7 @@ end) : sig
     val change : ref -> value -> unit m
   end
   (* RefT transformer when wrapped monad has plus, zero *)
-  module TP : functor (Wrapped : Monad.WP) -> sig
+  module TP : functor (Wrapped : Monad.P) -> sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
@@ -745,7 +846,7 @@ end = struct
   let newref value = fun s -> alloc value s
   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = dict -> ('a * dict) Wrapped.m
@@ -767,7 +868,7 @@ end = struct
     let deref key = fun s -> Wrapped.unit (read key s, s)
     let change key value = fun s -> Wrapped.unit ((), write key value s)
   end
-  module TP(Wrapped : Monad.WP) = struct
+  module TP(Wrapped : Monad.P) = struct
     module TransP = struct
       include T(Wrapped)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
@@ -872,9 +973,12 @@ module Continuation_monad : sig
   type 'a result = 'a m
   type 'a result_exn = 'a m
   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m
+  (* val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m *)
   (* misses that the answer types of all the cont's must be the same *)
   val callcc : (('a -> 'b m) -> 'a m) -> 'a m
+  (* val reset : ('a,'a) m -> ('r,'a) m *)
   val reset : 'a m -> 'a m
+  (* val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m *)
   (* misses that the answer types of second and third continuations must be b *)
   val shift : (('a -> 'b m) -> 'b m) -> 'a m
   (* overwrite the run declaration in S, because I can't declare 'a result =
@@ -899,23 +1003,24 @@ end = struct
     type 'a result_exn = 'a m
     let run_exn (u : 'a m) : 'a result_exn = u
     let callcc f =
-      let cont : 'r. ('a -> 'r) -> 'r = fun k ->
+      let cont : 'r. ('a -> 'r) -> 'r =
           (* Can't figure out how to make the type polymorphic enough
            * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r
            * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge
            * with Obj.magic... which tells OCaml's type checker to
            * relax, the supplied value has whatever type the context
            * needs it to have. *)
+          fun k ->
           let usek a = { cont = Obj.magic (fun _ -> k a) }
           in (f usek).cont k
       in { cont }
     let reset u = unit (u.cont id)
     let shift (f : ('a -> 'b m) -> 'b m) : 'a m =
-      let cont =
-        fun k -> (f (fun a -> unit (k a))).cont id
+      let cont = fun k ->
+        (f (fun a -> unit (k a))).cont id
       in { cont = Obj.magic cont }
-    let runk u k = u.cont k
-    let run0 u = u.cont id
+    let runk u k = (u.cont : ('a -> 'r) -> 'r) k
+    let run0 u = runk u id
   end
   include Monad.Make(Base)
   let callcc = Base.callcc
@@ -1035,7 +1140,7 @@ module Leaf_monad : sig
   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
   include Monad.PLUS with type 'a m := 'a m
   (* LeafT transformer *)
-  module T : functor (Wrapped : Monad.W) -> sig
+  module T : functor (Wrapped : Monad.S) -> sig
     type 'a result = 'a tree option Wrapped.result
     type 'a result_exn = 'a tree Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
@@ -1080,7 +1185,7 @@ end = struct
   include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m)
   let base_plus = plus
   let base_lift = lift
-  module T(Wrapped : Monad.W) = struct
+  module T(Wrapped : Monad.S) = struct
     module Trans = struct
       let zero () = Wrapped.unit None
       let plus u v =