tweak monads-lib
authorJim Pryor <profjim@jimpryor.net>
Sat, 11 Dec 2010 19:41:13 +0000 (14:41 -0500)
committerJim Pryor <profjim@jimpryor.net>
Sat, 11 Dec 2010 19:41:13 +0000 (14:41 -0500)
Signed-off-by: Jim Pryor <profjim@jimpryor.net>
code/monads.ml

index 0a205f6..34ad1ce 100644 (file)
@@ -44,6 +44,7 @@
  *
  *)
 
+exception Undefined
 
 (* Some library functions used below. *)
 module Util = struct
@@ -60,7 +61,10 @@ module Util = struct
     let rec loop n accu =
       if n == 0 then accu else loop (pred n) (fill :: accu)
     in loop len []
-  let undefined = Obj.magic ""
+  (* Dirty hack to be a default polymorphic zero.
+   * To implement this cleanly, monads without a natural zero
+   * should always wrap themselves in an option layer (see Leaf_monad). *)
+  let undef = Obj.magic (fun () -> raise Undefined)
 end
 
 
@@ -74,14 +78,15 @@ module Monad = struct
   (*
    * Signature extenders:
    *   Make :: BASE -> S
-   *   MakeT :: TRANS (with Wrapped : S) -> custom sig
+   *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
    *)
 
 
   (* type of base definitions *)
   module type BASE = sig
-    (* The only constraints we impose here on how the monadic type
-     * is implemented is that it have a single type parameter 'a. *)
+    (* We make all monadic types doubly-parameterized so that they
+     * can layer nicely with Continuation, which needs the second
+     * type parameter. *)
     type ('x,'a) m
     type ('x,'a) result
     type ('x,'a) result_exn
@@ -97,11 +102,12 @@ module Monad = struct
      * Additionally, they will obey one of the following laws:
      *     (Catch)   plus (unit a) v  ===  unit a
      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
-     * When no natural zero is available, use `let zero () = Util.undefined
-     * The Make process automatically detects for zero >>= ..., and 
+     * When no natural zero is available, use `let zero () = Util.undef`.
+     * The Make functor automatically detects for zero >>= ..., and 
      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
      *)
     val zero : unit -> ('x,'a) m
+    (* zero has to be thunked to ensure results are always poly enough *)
     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
   end
   module type S = sig
@@ -125,14 +131,14 @@ module Monad = struct
   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
     include B
     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
-      if u == Util.undefined then Util.undefined
-      else bind u (fun a -> try f a with Match_failure _ -> zero ())
+      if u == Util.undef then Util.undef
+      else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
     let plus u v =
-      if u == Util.undefined then v else if v == Util.undefined then u else plus u v
+      if u == Util.undef then v else if v == Util.undef then u else B.plus u v
     let run u =
-      if u == Util.undefined then failwith "no zero" else run u
+      if u == Util.undef then raise Undefined else B.run u
     let run_exn u =
-      if u == Util.undefined then failwith "no zero" else run_exn u
+      if u == Util.undef then raise Undefined else B.run_exn u
     let (>>=) = bind
     let (>>) u v = u >>= fun _ -> v
     let lift f u = u >>= fun a -> unit (f a)
@@ -149,6 +155,7 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
+    (* not in tail position, will Stack overflow *)
     let forever uthunk =
         let rec loop () = uthunk () >>= fun _ -> loop ()
         in loop ()
@@ -184,7 +191,7 @@ module Monad = struct
   end
 
   (* Signatures for MonadT *)
-  module type TRANS = sig
+  module type BASET = sig
     module Wrapped : S
     type ('x,'a) m
     type ('x,'a) result
@@ -200,7 +207,7 @@ module Monad = struct
     val zero : unit -> ('x,'a) m
     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
   end
-  module MakeT(T : TRANS) = struct
+  module MakeT(T : BASET) = struct
     include Make(struct
         include T
         let unit a = elevate (Wrapped.unit a)
@@ -228,7 +235,7 @@ end = struct
     let bind a f = f a
     let run a = a
     let run_exn a = a
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -264,7 +271,7 @@ end = struct
   end
   include Monad.Make(Base)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       include Monad.MakeT(struct
         module Wrapped = Wrapped
         type ('x,'a) m = ('x,'a option) Wrapped.m
@@ -278,13 +285,13 @@ end = struct
         let run_exn u =
           let w = Wrapped.bind u (fun t -> match t with
             | Some a -> Wrapped.unit a
-            | None -> failwith "no value")
-          in Wrapped.run_exn w
+            | None -> Wrapped.zero ()
+          in Wrapped.run_exn w
         let zero () = Wrapped.unit None
         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
       end)
     end
-    include Trans
+    include BaseT
   end
 end
 
@@ -323,6 +330,7 @@ end = struct
      | [a] -> a
      | many -> failwith "multiple values"
    let zero () = []
+   (* satisfies Distrib *)
    let plus = Util.append
   end
   include Monad.Make(Base)
@@ -341,7 +349,6 @@ end = struct
   let rec select u = match u with
     | [] -> zero ()
     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
-  let base_plus = plus
   module T(Wrapped : Monad.S) = struct
     (* Wrapped.sequence ms  ===  
          let plus1 u v =
@@ -365,15 +372,15 @@ end = struct
       let run u = Wrapped.run u
       let run_exn u =
         let w = Wrapped.bind u (fun ts -> match ts with
-          | [] -> failwith "no values"
+          | [] -> Wrapped.zero ()
           | [a] -> Wrapped.unit a
-          | many -> failwith "multiple values"
+          | many -> Wrapped.zero ()
         ) in Wrapped.run_exn w
       let zero () = Wrapped.unit []
       let plus u v =
         Wrapped.bind u (fun us ->
         Wrapped.bind v (fun vs ->
-        Wrapped.unit (base_plus us vs)))
+        Wrapped.unit (Base.plus us vs)))
     end)
 (*
     let permute : 'a m -> 'a m m
@@ -424,7 +431,7 @@ end = struct
     let run_exn u = match u with
       | Success a -> a
       | Error e -> raise (Err.Exc e)
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
     (*
     let zero () = Error Err.zero
@@ -457,8 +464,8 @@ end = struct
       let run u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
-          | Error e -> Wrapped.zero ())
-        in Wrapped.run w
+          | Error e -> Wrapped.zero ()
+        in Wrapped.run w
       let run_exn u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
@@ -514,6 +521,7 @@ module Reader_monad(Env : sig type env end) : sig
   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
   val ask : ('x,env) m
   val asks : (env -> 'a) -> ('x,'a) m
+  (* lookup i == `fun e -> e i` would assume env is a functional type *)
   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
   (* ReaderT transformer *)
   module T : functor (Wrapped : Monad.S) -> sig
@@ -535,7 +543,7 @@ end = struct
     let bind u f = fun e -> let a = u e in let u' = f a in u' e
     let run u = fun e -> u e
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -543,7 +551,7 @@ end = struct
   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
   let local modifier u = fun e -> u (modifier e)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       module Wrapped = Wrapped
       type ('x,'a) m = env -> ('x,'a) Wrapped.m
       type ('x,'a) result = env -> ('x,'a) Wrapped.result
@@ -552,10 +560,11 @@ end = struct
       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
       let run u = fun e -> Wrapped.run (u e)
       let run_exn u = fun e -> Wrapped.run_exn (u e)
+      (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
       let zero () = elevate (Wrapped.zero ())
     end
-    include Monad.MakeT(Trans)
+    include Monad.MakeT(BaseT)
     let ask = fun e -> Wrapped.unit e
     let local modifier u = fun e -> u (modifier e)
     let asks selector = ask >>= (fun e ->
@@ -597,7 +606,7 @@ end = struct
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
     let run u = fun s -> (u s)
     let run_exn u = fun s -> fst (u s)
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -606,7 +615,7 @@ end = struct
   let put s = fun _ -> ((), s)
   let puts modifier = fun s -> ((), modifier s)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       module Wrapped = Wrapped
       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
@@ -619,10 +628,11 @@ end = struct
       let run_exn u = fun s ->
         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
+      (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
       let zero () = elevate (Wrapped.zero ())
     end
-    include Monad.MakeT(Trans)
+    include Monad.MakeT(BaseT)
     let get = fun s -> Wrapped.unit (s, s)
     let gets viewer = fun s ->
       try Wrapped.unit (viewer s, s)
@@ -674,7 +684,7 @@ end = struct
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
     let run u = fst (u empty)
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -682,7 +692,7 @@ end = struct
   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       module Wrapped = Wrapped
       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
       type ('x,'a) result = ('x,'a) Wrapped.result
@@ -697,17 +707,18 @@ end = struct
       let run_exn u =
         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
+      (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
       let zero () = elevate (Wrapped.zero ())
     end
-    include Monad.MakeT(Trans)
+    include Monad.MakeT(BaseT)
     let newref value = fun s -> Wrapped.unit (alloc value s)
     let deref key = fun s -> Wrapped.unit (read key s, s)
     let change key value = fun s -> Wrapped.unit ((), write key value s)
   end
 end
 
-
+(* TODO needs a T *)
 (* must be parameterized on (struct type log = ... end) *)
 module Writer_monad(Log : sig
   type log
@@ -734,7 +745,7 @@ end = struct
     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
     let run u = u
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -766,6 +777,7 @@ module Writer2 = struct
 end
 
 
+(* TODO needs a T *)
 module IO_monad : sig
   (* declare additional operation, while still hiding implementation of type m *)
   type ('x,'a) result = 'a
@@ -787,7 +799,7 @@ end = struct
        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
     let run a = let () = a.run () in a.value
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -800,6 +812,7 @@ end = struct
 end
 
 
+(* TODO needs a T *)
 module Continuation_monad : sig
   (* expose only the implementation of type `('r,'a) result` *)
   type ('r,'a) m
@@ -823,7 +836,7 @@ end = struct
     let bind u f = (fun k -> (u) (fun a -> (f a) k))
     let run u k = (u) k
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -947,6 +960,7 @@ end = struct
     type ('x,'a) result_exn = 'a tree
     let unit a = Some (Leaf a)
     let zero () = None
+    (* satisfies Distrib *)
     let plus u v = match (u, v) with
       | None, _ -> v
       | _, None -> u
@@ -962,10 +976,8 @@ end = struct
       | Some us -> us
   end
   include Monad.Make(Base)
-  let base_plus = plus
-  let base_lift = lift
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       include Monad.MakeT(struct
         module Wrapped = Wrapped
         type ('x,'a) m = ('x,'a tree option) Wrapped.m
@@ -975,19 +987,18 @@ end = struct
         let plus u v =
           Wrapped.bind u (fun us ->
           Wrapped.bind v (fun vs ->
-          Wrapped.unit (base_plus us vs)))
+          Wrapped.unit (Base.plus us vs)))
         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
         let run u = Wrapped.run u
         let run_exn u =
             let w = Wrapped.bind u (fun t -> match t with
-              | None -> failwith "no values"
-              | Some ts -> Wrapped.unit ts)
-            in Wrapped.run_exn w
+              | None -> Wrapped.zero ()
+              | Some ts -> Wrapped.unit ts
+            in Wrapped.run_exn w
       end)
     end
-    include Trans
-    (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
+    include BaseT
     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
   end
 end