Expand monad_transformers re elevate, layering
[lambda.git] / code / monads.ml
index 34ad1ce..b678a92 100644 (file)
  * have to use operations like `run` to convert the abstract monadic types
  * to types whose internals you have free access to.
  *
+ * Acknowledgements: This is largely based on the mtl library distributed
+ * with the Glasgow Haskell Compiler. I've also been helped in
+ * various ways by posts and direct feedback from Oleg Kiselyov and
+ * Chung-chieh Shan. The following were also useful:
+ * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
+ * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
+ * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
+ * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
+ *
+ * Licensing: MIT (if that's compatible with the ghc sources this is partly
+ * derived from)
  *)
 
 exception Undefined
@@ -63,7 +74,7 @@ module Util = struct
     in loop len []
   (* Dirty hack to be a default polymorphic zero.
    * To implement this cleanly, monads without a natural zero
-   * should always wrap themselves in an option layer (see Leaf_monad). *)
+   * should always wrap themselves in an option layer (see Tree_monad). *)
   let undef = Obj.magic (fun () -> raise Undefined)
 end
 
@@ -140,6 +151,10 @@ module Monad = struct
     let run_exn u =
       if u == Util.undef then raise Undefined else B.run_exn u
     let (>>=) = bind
+    (* expressions after >> will be evaluated before they're passed to
+     * bind, so you can't do `zero () >> assert false`
+     * this works though: `zero () >>= fun _ -> assert false`
+     *)
     let (>>) u v = u >>= fun _ -> v
     let lift f u = u >>= fun a -> unit (f a)
     (* lift is called listM, fmap, and <$> in Haskell *)
@@ -155,10 +170,21 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
-    (* not in tail position, will Stack overflow *)
+    (* A Haskell-like version works:
+         let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
+     * but the recursive call is not in tail position so this can stack overflow. *)
     let forever uthunk =
-        let rec loop () = uthunk () >>= fun _ -> loop ()
-        in loop ()
+        let z = zero () in
+        let id result = result in
+        let kcell = ref id in
+        let rec loop _ =
+            let result = uthunk (kcell := id) >>= chained
+            in !kcell result
+        and chained _ =
+            kcell := loop; z (* we use z only for its polymorphism *)
+        in loop z
+    (* Reimplementations of the preceding using a hand-rolled State or StateT
+can also stack overflow. *)
     let sequence ms =
       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
         Util.fold_right op ms (unit [])
@@ -432,18 +458,10 @@ end = struct
       | Success a -> a
       | Error e -> raise (Err.Exc e)
     let zero () = Util.undef
-    let plus u v = u
-    (*
-    let zero () = Error Err.zero
-    let plus u v = match (u, v) with
-      | Success _, _ -> u
-      (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
-       * otherwise, plus (Error _) v = v *)
-      | Error _, _ when v = zero -> u
-      (* combine errors *)
-      | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
-      | Error _, _ -> v
-    *)
+    (* satisfies Catch *)
+    let plus u v = match u with
+      | Success _ -> u
+      | Error _ -> if v == Util.undef then u else v
   end
   include Monad.Make(Base)
   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
@@ -472,7 +490,7 @@ end = struct
           | Error e -> raise (Err.Exc e))
         in Wrapped.run_exn w
       let plus u v = Wrapped.plus u v
-      let zero () = elevate (Wrapped.zero ())
+      let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end)
     let throw e = Wrapped.unit (Error e)
     let catch u handler = Wrapped.bind u (fun t -> match t with
@@ -491,26 +509,6 @@ module Failure = Error_monad(struct
   *)
 end)
 
-(*
-# EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
-- : int EL.result = [Failure.Error "bye"; Failure.Success 30]
-# LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
-- : int LE.result = Failure.Error "bye"
-# EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
-Exception: Failure "bye".
-# LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
-Exception: Failure "bye".
-
-# ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
-- : int Failure.error * S.store = (Failure.Error "bye", 1)
-# SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
-- : (int * S.store) Failure.result = Failure.Error "bye"
-# ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
-Exception: Failure "bye".
-# SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
-Exception: Failure "bye".
- *)
-
 
 (* must be parameterized on (struct type env = ... end) *)
 module Reader_monad(Env : sig type env end) : sig
@@ -557,15 +555,15 @@ end = struct
       type ('x,'a) result = env -> ('x,'a) Wrapped.result
       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
       let elevate w = fun e -> w
-      let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
+      let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
       let run u = fun e -> Wrapped.run (u e)
       let run_exn u = fun e -> Wrapped.run_exn (u e)
       (* satisfies Distrib *)
-      let plus u v = fun s -> Wrapped.plus (u s) (v s)
-      let zero () = elevate (Wrapped.zero ())
+      let plus u v = fun e -> Wrapped.plus (u e) (v e)
+      let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end
     include Monad.MakeT(BaseT)
-    let ask = fun e -> Wrapped.unit e
+    let ask = Wrapped.unit
     let local modifier u = fun e -> u (modifier e)
     let asks selector = ask >>= (fun e ->
       try unit (selector e)
@@ -630,7 +628,7 @@ end = struct
         in Wrapped.run_exn w
       (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
-      let zero () = elevate (Wrapped.zero ())
+      let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end
     include Monad.MakeT(BaseT)
     let get = fun s -> Wrapped.unit (s, s)
@@ -642,6 +640,7 @@ end = struct
   end
 end
 
+
 (* State monad with different interface (structured store) *)
 module Ref_monad(V : sig
   type value
@@ -709,7 +708,7 @@ end = struct
         in Wrapped.run_exn w
       (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
-      let zero () = elevate (Wrapped.zero ())
+      let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end
     include Monad.MakeT(BaseT)
     let newref value = fun s -> Wrapped.unit (alloc value s)
@@ -718,7 +717,7 @@ end = struct
   end
 end
 
-(* TODO needs a T *)
+
 (* must be parameterized on (struct type log = ... end) *)
 module Writer_monad(Log : sig
   type log
@@ -735,6 +734,17 @@ end) : sig
   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
+  (* WriterT transformer *)
+  module T : functor (Wrapped : Monad.S) -> sig
+    type ('x,'a) result = ('x,'a * log) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
+    include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val tell : log -> ('x,unit) m
+    val listen : ('x,'a) m -> ('x,'a * log) m
+    val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
+    val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
+  end
 end = struct
   type log = Log.log
   module Base = struct
@@ -742,7 +752,7 @@ end = struct
     type ('x,'a) result = 'a * log
     type ('x,'a) result_exn = 'a * log
     let unit a = (a, Log.zero)
-    let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
+    let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
     let run u = u
     let run_exn = run
     let zero () = Util.undef
@@ -754,6 +764,31 @@ end = struct
   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
   let censor f u = pass (u >>= fun a -> unit (a, f))
+  module T(Wrapped : Monad.S) = struct
+    module BaseT = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = ('x,'a * log) Wrapped.m
+      type ('x,'a) result = ('x,'a * log) Wrapped.result
+      type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
+      let elevate w =
+        Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
+      let bind u f =
+        Wrapped.bind u (fun (a, w) ->
+        Wrapped.bind (f a) (fun (b, w') ->
+        Wrapped.unit (b, Log.plus w w')))
+      let zero () = elevate (Wrapped.zero ())
+      let plus u v = Wrapped.plus u v
+      let run u = Wrapped.run u
+      let run_exn u = Wrapped.run_exn u
+    end
+    include Monad.MakeT(BaseT)
+    let tell entries = Wrapped.unit ((), entries)
+    let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
+    let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
+    (* rest are derived in same way as before *)
+    let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
+    let censor f u = pass (u >>= fun a -> unit (a, f))
+  end
 end
 
 (* pre-define simple Writer *)
@@ -812,7 +847,6 @@ end = struct
 end
 
 
-(* TODO needs a T *)
 module Continuation_monad : sig
   (* expose only the implementation of type `('r,'a) result` *)
   type ('r,'a) m
@@ -825,6 +859,16 @@ module Continuation_monad : sig
   (* val abort : ('a,'a) m -> ('a,'b) m *)
   val abort : 'a -> ('a,'b) m
   val run0 : ('a,'a) m -> 'a
+  (* ContinuationT transformer *)
+  module T : functor (Wrapped : Monad.S) -> sig
+    type ('r,'a) m
+    type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
+    type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
+    include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
+    (* TODO: reset,shift,abort,run0 *)
+  end
 end = struct
   let id = fun i -> i
   module Base = struct
@@ -863,6 +907,24 @@ end = struct
   (* let abort a = shift (fun _ -> a) *)
   let abort a = shift (fun _ -> unit a)
   let run0 (u : ('a,'a) m) = (u) id
+  module T(Wrapped : Monad.S) = struct
+    module BaseT = struct
+      module Wrapped = Wrapped
+      type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
+      type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
+      type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
+      let elevate w = fun k -> Wrapped.bind w k
+      let bind u f = fun k -> u (fun a -> f a k)
+      let run u k = Wrapped.run (u k)
+      let run_exn u k = Wrapped.run_exn (u k)
+      let zero () = Util.undef
+      let plus u v = u
+    end
+    include Monad.MakeT(BaseT)
+    let callcc f = (fun k ->
+      let usek a = (fun _ -> k a)
+      in (f usek) k)
+  end
 end
 
 
@@ -926,14 +988,14 @@ end
  *)
 
 
-module Leaf_monad : sig
+module Tree_monad : sig
   (* We implement the type as `'a tree option` because it has a natural`plus`,
    * and the rest of the library expects that `plus` and `zero` will come together. *)
   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
   type ('x,'a) result = 'a tree option
   type ('x,'a) result_exn = 'a tree
   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
-  (* LeafT transformer *)
+  (* TreeT transformer *)
   module T : functor (Wrapped : Monad.S) -> sig
     type ('x,'a) result = ('x,'a tree option) Wrapped.result
     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
@@ -1007,7 +1069,7 @@ end
 module L = List_monad;;
 module R = Reader_monad(struct type env = int -> int end);;
 module S = State_monad(struct type store = int end);;
-module T = Leaf_monad;;
+module T = Tree_monad;;
 module LR = L.T(R);;
 module LS = L.T(S);;
 module TL = T.T(L);;
@@ -1017,7 +1079,7 @@ module C = Continuation_monad
 module TC = T.T(C);;
 
 
-print_endline "=== test Leaf(...).distribute ==================";;
+print_endline "=== test TreeT(...).distribute ==================";;
 
 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
 
@@ -1087,7 +1149,7 @@ LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts
 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
 *)
 
-print_endline "=== test Leaf(Continuation).distribute ==================";;
+print_endline "=== test TreeT(Continuation).distribute ==================";;
 
 let id : 'z. 'z -> 'z = fun x -> x