(no commit message)
[lambda.git] / code / monads.ml
index 0a205f6..d872593 100644 (file)
  * making their implementations private. The interpreter won't let
  * let you freely interchange the `'a Reader_monad.m`s defined below
  * with `Reader_monad.env -> 'a`. The code in this library can see that
- * those are equivalent, but code outside the library can't. Instead, you'll 
+ * those are equivalent, but code outside the library can't. Instead, you'll
  * have to use operations like `run` to convert the abstract monadic types
  * to types whose internals you have free access to.
  *
+ * Acknowledgements: This is largely based on the mtl library distributed
+ * with the Glasgow Haskell Compiler. I've also been helped in
+ * various ways by posts and direct feedback from Oleg Kiselyov and
+ * Chung-chieh Shan. The following were also useful:
+ * - <http://pauillac.inria.fr/~xleroy/mpri/progfunc/>
+ * - Ken Shan "Monads for natural language semantics" <http://arxiv.org/abs/cs/0205026v1>
+ * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
+ * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
+ *
+ * Licensing: MIT (if that's compatible with the ghc sources this is partly
+ * derived from)
  *)
 
 
 (* Some library functions used below. *)
+
+exception Undefined
+
 module Util = struct
   let fold_right = List.fold_right
   let map = List.map
@@ -60,28 +74,31 @@ module Util = struct
     let rec loop n accu =
       if n == 0 then accu else loop (pred n) (fill :: accu)
     in loop len []
-  let undefined = Obj.magic ""
+  (* Dirty hack to be a default polymorphic zero.
+   * To implement this cleanly, monads without a natural zero
+   * should always wrap themselves in an option layer (see Tree_monad). *)
+  let undef = Obj.magic (fun () -> raise Undefined)
 end
 
-
-
 (*
  * This module contains factories that extend a base set of
  * monadic definitions with a larger family of standard derived values.
  *)
 
 module Monad = struct
+
   (*
    * Signature extenders:
    *   Make :: BASE -> S
-   *   MakeT :: TRANS (with Wrapped : S) -> custom sig
+   *   MakeT :: BASET (with Wrapped : S) -> result sig not declared
    *)
 
 
   (* type of base definitions *)
   module type BASE = sig
-    (* The only constraints we impose here on how the monadic type
-     * is implemented is that it have a single type parameter 'a. *)
+    (* We make all monadic types doubly-parameterized so that they
+     * can layer nicely with Continuation, which needs the second
+     * type parameter. *)
     type ('x,'a) m
     type ('x,'a) result
     type ('x,'a) result_exn
@@ -97,11 +114,12 @@ module Monad = struct
      * Additionally, they will obey one of the following laws:
      *     (Catch)   plus (unit a) v  ===  unit a
      *     (Distrib) plus u v >>= f   ===  plus (u >>= f) (v >>= f)
-     * When no natural zero is available, use `let zero () = Util.undefined
-     * The Make process automatically detects for zero >>= ..., and 
+     * When no natural zero is available, use `let zero () = Util.undef`.
+     * The Make functor automatically detects for zero >>= ..., and
      * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
      *)
     val zero : unit -> ('x,'a) m
+    (* zero has to be thunked to ensure results are always poly enough *)
     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
   end
   module type S = sig
@@ -125,15 +143,19 @@ module Monad = struct
   module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
     include B
     let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
-      if u == Util.undefined then Util.undefined
-      else bind u (fun a -> try f a with Match_failure _ -> zero ())
+      if u == Util.undef then Util.undef
+      else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
     let plus u v =
-      if u == Util.undefined then v else if v == Util.undefined then u else plus u v
+      if u == Util.undef then v else if v == Util.undef then u else B.plus u v
     let run u =
-      if u == Util.undefined then failwith "no zero" else run u
+      if u == Util.undef then raise Undefined else B.run u
     let run_exn u =
-      if u == Util.undefined then failwith "no zero" else run_exn u
+      if u == Util.undef then raise Undefined else B.run_exn u
     let (>>=) = bind
+    (* expressions after >> will be evaluated before they're passed to
+     * bind, so you can't do `zero () >> assert false`
+     * this works though: `zero () >>= fun _ -> assert false`
+     *)
     let (>>) u v = u >>= fun _ -> v
     let lift f u = u >>= fun a -> unit (f a)
     (* lift is called listM, fmap, and <$> in Haskell *)
@@ -149,9 +171,21 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
+    (* A Haskell-like version works:
+         let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
+     * but the recursive call is not in tail position so this can stack overflow. *)
     let forever uthunk =
-        let rec loop () = uthunk () >>= fun _ -> loop ()
-        in loop ()
+        let z = zero () in
+        let id result = result in
+        let kcell = ref id in
+        let rec loop _ =
+            let result = uthunk (kcell := id) >>= chained
+            in !kcell result
+        and chained _ =
+            kcell := loop; z (* we use z only for its polymorphism *)
+        in loop z
+    (* Reimplementations of the preceding using a hand-rolled State or StateT
+can also stack overflow. *)
     let sequence ms =
       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
         Util.fold_right op ms (unit [])
@@ -184,7 +218,7 @@ module Monad = struct
   end
 
   (* Signatures for MonadT *)
-  module type TRANS = sig
+  module type BASET = sig
     module Wrapped : S
     type ('x,'a) m
     type ('x,'a) result
@@ -200,7 +234,7 @@ module Monad = struct
     val zero : unit -> ('x,'a) m
     val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
   end
-  module MakeT(T : TRANS) = struct
+  module MakeT(T : BASET) = struct
     include Make(struct
         include T
         let unit a = elevate (Wrapped.unit a)
@@ -228,7 +262,7 @@ end = struct
     let bind a f = f a
     let run a = a
     let run_exn a = a
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -264,7 +298,7 @@ end = struct
   end
   include Monad.Make(Base)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       include Monad.MakeT(struct
         module Wrapped = Wrapped
         type ('x,'a) m = ('x,'a option) Wrapped.m
@@ -278,13 +312,13 @@ end = struct
         let run_exn u =
           let w = Wrapped.bind u (fun t -> match t with
             | Some a -> Wrapped.unit a
-            | None -> failwith "no value")
-          in Wrapped.run_exn w
+            | None -> Wrapped.zero ()
+          in Wrapped.run_exn w
         let zero () = Wrapped.unit None
         let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
       end)
     end
-    include Trans
+    include BaseT
   end
 end
 
@@ -305,10 +339,9 @@ module List_monad : sig
     (* note that second argument is an 'a list, not the more abstract 'a m *)
     (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *)
     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
-(* TODO
-    val permute : 'a m -> 'a m m
-    val select : 'a m -> ('a * 'a m) m
-*)
+    val permute : ('x,'a) m -> ('x,('x,'a) m) m
+    val select : ('x,'a) m -> ('x,('a * ('x,'a) m)) m
+    val expose : ('x,'a) m -> ('x,'a list) Wrapped.m
   end
 end = struct
   module Base = struct
@@ -323,6 +356,7 @@ end = struct
      | [a] -> a
      | many -> failwith "multiple values"
    let zero () = []
+   (* satisfies Distrib *)
    let plus = Util.append
   end
   include Monad.Make(Base)
@@ -341,9 +375,8 @@ end = struct
   let rec select u = match u with
     | [] -> zero ()
     | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
-  let base_plus = plus
   module T(Wrapped : Monad.S) = struct
-    (* Wrapped.sequence ms  ===  
+    (* Wrapped.sequence ms  ===
          let plus1 u v =
            Wrapped.bind u (fun x ->
            Wrapped.bind v (fun xs ->
@@ -365,20 +398,40 @@ end = struct
       let run u = Wrapped.run u
       let run_exn u =
         let w = Wrapped.bind u (fun ts -> match ts with
-          | [] -> failwith "no values"
+          | [] -> Wrapped.zero ()
           | [a] -> Wrapped.unit a
-          | many -> failwith "multiple values"
+          | many -> Wrapped.zero ()
         ) in Wrapped.run_exn w
       let zero () = Wrapped.unit []
       let plus u v =
         Wrapped.bind u (fun us ->
         Wrapped.bind v (fun vs ->
-        Wrapped.unit (base_plus us vs)))
+        Wrapped.unit (Base.plus us vs)))
     end)
-(*
-    let permute : 'a m -> 'a m m
-    let select : 'a m -> ('a * 'a m) m
-*)
+
+   (* insert 3 {[1;2]} ~~> {[ {[3;1;2]}; {[1;3;2]}; {[1;2;3]} ]} *)
+   let rec insert a u =
+     plus
+     (unit (Wrapped.bind u (fun us -> Wrapped.unit (a :: us))))
+     (Wrapped.bind u (fun us -> match us with
+         | [] -> zero ()
+         | x::xs -> (insert a (Wrapped.unit xs)) >>= fun v -> unit (Wrapped.bind v (fun vs -> Wrapped.unit (x :: vs)))))
+
+   (* select {[1;2;3]} ~~> {[ (1,{[2;3]}); (2,{[1;3]}), (3;{[1;2]}) ]} *)
+   let rec select u =
+     Wrapped.bind u (fun us -> match us with
+         | [] -> zero ()
+         | x::xs -> plus (unit (x, Wrapped.unit xs))
+             (select (Wrapped.unit xs) >>= fun (x', xs') -> unit (x', Wrapped.bind xs' (fun ys -> Wrapped.unit (x :: ys)))))
+
+   (* permute {[1;2;3]} ~~> {[ {[1;2;3]}; {[2;1;3]}; {[2;3;1]}; {[1;3;2]}; {[3;1;2]}; {[3;2;1]} ]} *)
+
+   let rec permute u =
+     Wrapped.bind u (fun us -> match us with
+         | [] -> unit (zero ())
+         | x::xs -> permute (Wrapped.unit xs) >>= (fun v -> insert x v))
+
+    let expose u = u
   end
 end
 
@@ -424,19 +477,11 @@ end = struct
     let run_exn u = match u with
       | Success a -> a
       | Error e -> raise (Err.Exc e)
-    let zero () = Util.undefined
-    let plus u v = u
-    (*
-    let zero () = Error Err.zero
-    let plus u v = match (u, v) with
-      | Success _, _ -> u
-      (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
-       * otherwise, plus (Error _) v = v *)
-      | Error _, _ when v = zero -> u
-      (* combine errors *)
-      | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
-      | Error _, _ -> v
-    *)
+    let zero () = Util.undef
+    (* satisfies Catch *)
+    let plus u v = match u with
+      | Success _ -> u
+      | Error _ -> if v == Util.undef then u else v
   end
   include Monad.Make(Base)
   (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
@@ -457,15 +502,15 @@ end = struct
       let run u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
-          | Error e -> Wrapped.zero ())
-        in Wrapped.run w
+          | Error e -> Wrapped.zero ()
+        in Wrapped.run w
       let run_exn u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
           | Error e -> raise (Err.Exc e))
         in Wrapped.run_exn w
       let plus u v = Wrapped.plus u v
-      let zero () = elevate (Wrapped.zero ())
+      let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end)
     let throw e = Wrapped.unit (Error e)
     let catch u handler = Wrapped.bind u (fun t -> match t with
@@ -484,26 +529,6 @@ module Failure = Error_monad(struct
   *)
 end)
 
-(*
-# EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
-- : int EL.result = [Failure.Error "bye"; Failure.Success 30]
-# LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
-- : int LE.result = Failure.Error "bye"
-# EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
-Exception: Failure "bye".
-# LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
-Exception: Failure "bye".
-
-# ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
-- : int Failure.error * S.store = (Failure.Error "bye", 1)
-# SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
-- : (int * S.store) Failure.result = Failure.Error "bye"
-# ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
-Exception: Failure "bye".
-# SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
-Exception: Failure "bye".
- *)
-
 
 (* must be parameterized on (struct type env = ... end) *)
 module Reader_monad(Env : sig type env end) : sig
@@ -514,6 +539,7 @@ module Reader_monad(Env : sig type env end) : sig
   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
   val ask : ('x,env) m
   val asks : (env -> 'a) -> ('x,'a) m
+  (* lookup i == `fun e -> e i` would assume env is a functional type *)
   val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
   (* ReaderT transformer *)
   module T : functor (Wrapped : Monad.S) -> sig
@@ -524,6 +550,7 @@ module Reader_monad(Env : sig type env end) : sig
     val ask : ('x,env) m
     val asks : (env -> 'a) -> ('x,'a) m
     val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
+    val expose : ('x,'a) m -> env -> ('x,'a) Wrapped.m
   end
 end = struct
   type env = Env.env
@@ -535,7 +562,7 @@ end = struct
     let bind u f = fun e -> let a = u e in let u' = f a in u' e
     let run u = fun e -> u e
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -543,24 +570,26 @@ end = struct
   let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
   let local modifier u = fun e -> u (modifier e)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       module Wrapped = Wrapped
       type ('x,'a) m = env -> ('x,'a) Wrapped.m
       type ('x,'a) result = env -> ('x,'a) Wrapped.result
       type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
       let elevate w = fun e -> w
-      let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
+      let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
       let run u = fun e -> Wrapped.run (u e)
       let run_exn u = fun e -> Wrapped.run_exn (u e)
-      let plus u v = fun s -> Wrapped.plus (u s) (v s)
-      let zero () = elevate (Wrapped.zero ())
+      (* satisfies Distrib *)
+      let plus u v = fun e -> Wrapped.plus (u e) (v e)
+      let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end
-    include Monad.MakeT(Trans)
-    let ask = fun e -> Wrapped.unit e
+    include Monad.MakeT(BaseT)
+    let ask = Wrapped.unit
     let local modifier u = fun e -> u (modifier e)
     let asks selector = ask >>= (fun e ->
       try unit (selector e)
       with Not_found -> fun e -> Wrapped.zero ())
+    let expose u = u
   end
 end
 
@@ -586,6 +615,8 @@ module State_monad(Store : sig type store end) : sig
     val gets : (store -> 'a) -> ('x,'a) m
     val put : store -> ('x,unit) m
     val puts : (store -> store) -> ('x,unit) m
+    (* val passthru : ('x,'a) m -> (('x,'a * store) Wrapped.result * store -> 'b) -> ('x,'b) m *)
+    val expose : ('x,'a) m -> store -> ('x,'a * store) Wrapped.m
   end
 end = struct
   type store = Store.store
@@ -597,7 +628,7 @@ end = struct
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
     let run u = fun s -> (u s)
     let run_exn u = fun s -> fst (u s)
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -606,7 +637,7 @@ end = struct
   let put s = fun _ -> ((), s)
   let puts modifier = fun s -> ((), modifier s)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       module Wrapped = Wrapped
       type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
       type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
@@ -619,19 +650,23 @@ end = struct
       let run_exn u = fun s ->
         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
+      (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
-      let zero () = elevate (Wrapped.zero ())
+      let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end
-    include Monad.MakeT(Trans)
+    include Monad.MakeT(BaseT)
     let get = fun s -> Wrapped.unit (s, s)
     let gets viewer = fun s ->
       try Wrapped.unit (viewer s, s)
       with Not_found -> Wrapped.zero ()
     let put s = fun _ -> Wrapped.unit ((), s)
     let puts modifier = fun s -> Wrapped.unit ((), modifier s)
+    (* let passthru u f = fun s -> Wrapped.unit (f (Wrapped.run (u s), s), s) *)
+    let expose u = u
   end
 end
 
+
 (* State monad with different interface (structured store) *)
 module Ref_monad(V : sig
   type value
@@ -674,7 +709,7 @@ end = struct
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
     let run u = fst (u empty)
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -682,7 +717,7 @@ end = struct
   let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
   let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       module Wrapped = Wrapped
       type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
       type ('x,'a) result = ('x,'a) Wrapped.result
@@ -697,10 +732,11 @@ end = struct
       let run_exn u =
         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
+      (* satisfies Distrib *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
-      let zero () = elevate (Wrapped.zero ())
+      let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
     end
-    include Monad.MakeT(Trans)
+    include Monad.MakeT(BaseT)
     let newref value = fun s -> Wrapped.unit (alloc value s)
     let deref key = fun s -> Wrapped.unit (read key s, s)
     let change key value = fun s -> Wrapped.unit ((), write key value s)
@@ -724,6 +760,17 @@ end) : sig
   val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
   (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
   val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
+  (* WriterT transformer *)
+  module T : functor (Wrapped : Monad.S) -> sig
+    type ('x,'a) result = ('x,'a * log) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
+    include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val tell : log -> ('x,unit) m
+    val listen : ('x,'a) m -> ('x,'a * log) m
+    val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
+    val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
+  end
 end = struct
   type log = Log.log
   module Base = struct
@@ -731,10 +778,10 @@ end = struct
     type ('x,'a) result = 'a * log
     type ('x,'a) result_exn = 'a * log
     let unit a = (a, Log.zero)
-    let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
+    let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
     let run u = u
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -743,6 +790,31 @@ end = struct
   let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
   let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
   let censor f u = pass (u >>= fun a -> unit (a, f))
+  module T(Wrapped : Monad.S) = struct
+    module BaseT = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = ('x,'a * log) Wrapped.m
+      type ('x,'a) result = ('x,'a * log) Wrapped.result
+      type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
+      let elevate w =
+        Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
+      let bind u f =
+        Wrapped.bind u (fun (a, w) ->
+        Wrapped.bind (f a) (fun (b, w') ->
+        Wrapped.unit (b, Log.plus w w')))
+      let zero () = elevate (Wrapped.zero ())
+      let plus u v = Wrapped.plus u v
+      let run u = Wrapped.run u
+      let run_exn u = Wrapped.run_exn u
+    end
+    include Monad.MakeT(BaseT)
+    let tell entries = Wrapped.unit ((), entries)
+    let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
+    let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
+    (* rest are derived in same way as before *)
+    let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
+    let censor f u = pass (u >>= fun a -> unit (a, f))
+  end
 end
 
 (* pre-define simple Writer *)
@@ -766,6 +838,7 @@ module Writer2 = struct
 end
 
 
+(* TODO needs a T *)
 module IO_monad : sig
   (* declare additional operation, while still hiding implementation of type m *)
   type ('x,'a) result = 'a
@@ -787,7 +860,7 @@ end = struct
        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
     let run a = let () = a.run () in a.value
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -812,6 +885,16 @@ module Continuation_monad : sig
   (* val abort : ('a,'a) m -> ('a,'b) m *)
   val abort : 'a -> ('a,'b) m
   val run0 : ('a,'a) m -> 'a
+  (* ContinuationT transformer *)
+  module T : functor (Wrapped : Monad.S) -> sig
+    type ('r,'a) m
+    type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
+    type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
+    include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
+    (* TODO: reset,shift,abort,run0 *)
+  end
 end = struct
   let id = fun i -> i
   module Base = struct
@@ -823,7 +906,7 @@ end = struct
     let bind u f = (fun k -> (u) (fun a -> (f a) k))
     let run u k = (u) k
     let run_exn = run
-    let zero () = Util.undefined
+    let zero () = Util.undef
     let plus u v = u
   end
   include Monad.Make(Base)
@@ -850,6 +933,24 @@ end = struct
   (* let abort a = shift (fun _ -> a) *)
   let abort a = shift (fun _ -> unit a)
   let run0 (u : ('a,'a) m) = (u) id
+  module T(Wrapped : Monad.S) = struct
+    module BaseT = struct
+      module Wrapped = Wrapped
+      type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
+      type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
+      type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
+      let elevate w = fun k -> Wrapped.bind w k
+      let bind u f = fun k -> u (fun a -> f a k)
+      let run u k = Wrapped.run (u k)
+      let run_exn u k = Wrapped.run_exn (u k)
+      let zero () = Util.undef
+      let plus u v = u
+    end
+    include Monad.MakeT(BaseT)
+    let callcc f = (fun k ->
+      let usek a = (fun _ -> k a)
+      in (f usek) k)
+  end
 end
 
 
@@ -874,53 +975,17 @@ end
  *   >>= fun x -> unit (x, 0)
  *   in run u)
  *
- *
- * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
- * let example1 () : int =
- *   Continuation_monad.(let v = reset (
- *       let u = shift (fun k -> unit (10 + 1))
- *       in u >>= fun x -> unit (100 + x)
- *     ) in let w = v >>= fun x -> unit (1000 + x)
- *     in run w)
- *
- * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
- * let example2 () =
- *   Continuation_monad.(let v = reset (
- *       let u = shift (fun k -> k (10 :: [1]))
- *       in u >>= fun x -> unit (100 :: x)
- *     ) in let w = v >>= fun x -> unit (1000 :: x)
- *     in run w)
- *
- * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
- * let example3 () =
- *   Continuation_monad.(let v = reset (
- *       let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
- *       in u >>= fun x -> unit (100 :: x)
- *     ) in let w = v >>= fun x -> unit (1000 :: x)
- *     in run w)
- *
- * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
- * (* not sure if this example can be typed without a sum-type *)
- *
- * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
- * let example5 () : int =
- *   Continuation_monad.(let v = reset (
- *       let u = shift (fun k -> k 1 >>= fun x -> k x)
- *       in u >>= fun x -> unit (10 + x)
- *     ) in let w = v >>= fun x -> unit (100 + x)
- *     in run w)
- *
  *)
 
 
-module Leaf_monad : sig
+module Tree_monad : sig
   (* We implement the type as `'a tree option` because it has a natural`plus`,
    * and the rest of the library expects that `plus` and `zero` will come together. *)
   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
   type ('x,'a) result = 'a tree option
   type ('x,'a) result_exn = 'a tree
   include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
-  (* LeafT transformer *)
+  (* TreeT transformer *)
   module T : functor (Wrapped : Monad.S) -> sig
     type ('x,'a) result = ('x,'a tree option) Wrapped.result
     type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
@@ -929,6 +994,7 @@ module Leaf_monad : sig
     (* note that second argument is an 'a tree?, not the more abstract 'a m *)
     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
+    val expose : ('x,'a) m -> ('x,'a tree option) Wrapped.m
   end
 end = struct
   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
@@ -947,6 +1013,7 @@ end = struct
     type ('x,'a) result_exn = 'a tree
     let unit a = Some (Leaf a)
     let zero () = None
+    (* satisfies Distrib *)
     let plus u v = match (u, v) with
       | None, _ -> v
       | _, None -> u
@@ -962,10 +1029,8 @@ end = struct
       | Some us -> us
   end
   include Monad.Make(Base)
-  let base_plus = plus
-  let base_lift = lift
   module T(Wrapped : Monad.S) = struct
-    module Trans = struct
+    module BaseT = struct
       include Monad.MakeT(struct
         module Wrapped = Wrapped
         type ('x,'a) m = ('x,'a tree option) Wrapped.m
@@ -975,226 +1040,22 @@ end = struct
         let plus u v =
           Wrapped.bind u (fun us ->
           Wrapped.bind v (fun vs ->
-          Wrapped.unit (base_plus us vs)))
+          Wrapped.unit (Base.plus us vs)))
         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
         let run u = Wrapped.run u
         let run_exn u =
             let w = Wrapped.bind u (fun t -> match t with
-              | None -> failwith "no values"
-              | Some ts -> Wrapped.unit ts)
-            in Wrapped.run_exn w
+              | None -> Wrapped.zero ()
+              | Some ts -> Wrapped.unit ts
+            in Wrapped.run_exn w
       end)
     end
-    include Trans
-    (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
+    include BaseT
     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
+    let expose u = u
   end
-end
-
-
-module L = List_monad;;
-module R = Reader_monad(struct type env = int -> int end);;
-module S = State_monad(struct type store = int end);;
-module T = Leaf_monad;;
-module LR = L.T(R);;
-module LS = L.T(S);;
-module TL = T.T(L);;
-module TR = T.T(R);;
-module TS = T.T(S);;
-module C = Continuation_monad
-module TC = T.T(C);;
-
-
-print_endline "=== test Leaf(...).distribute ==================";;
-
-let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
-
-let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
-TS.run ts 0;;
-(*
-- : int T.tree option * S.store =
-(Some
-  (T.Node
-    (T.Node (T.Leaf 2, T.Leaf 3),
-     T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))),
- 5)
-*)
-
-let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;;
-TS.run_exn ts2 0;;
-(*
-- : (int * S.store) T.tree option * S.store =
-(Some
-  (T.Node
-    (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)),
-     T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))),
- 5)
-*)
-
-let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;;
-TR.run_exn tr (fun i -> i+i);;
-(*
-- : int T.tree option =
-Some
- (T.Node
-   (T.Node (T.Leaf 4, T.Leaf 6),
-    T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22))))
-*)
-
-let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;;
-TL.run_exn tl;;
-(*
-- : (int * int) TL.result =
-[Some
-  (T.Node
-    (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)),
-     T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))]
-*)
-
-let l2 = [1;2;3;4;5];;
-let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));;
-
-LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
-(* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *)
-
-TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);;
-(*
-int T.tree option =
-Some
- (T.Node
-   (T.Node (T.Leaf 10, T.Leaf 11),
-    T.Node
-     (T.Node
-       (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)),
-        T.Node (T.Leaf 40, T.Leaf 41)),
-      T.Node (T.Leaf 50, T.Leaf 51))))
- *)
-
-LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;;
-(*
-- : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
-*)
-
-print_endline "=== test Leaf(Continuation).distribute ==================";;
-
-let id : 'z. 'z -> 'z = fun x -> x
-
-let example n : (int * int) =
-  Continuation_monad.(let u = callcc (fun k ->
-      (if n < 0 then k 0 else unit [n + 100])
-      (* all of the following is skipped by k 0; the end type int is k's input type *)
-      >>= fun [x] -> unit (x + 1)
-  )
-  (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *)
-  >>= fun x -> unit (x, 0)
-  in run0 u)
-
-
-(* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *)
-let example1 () : int =
-  Continuation_monad.(let v = reset (
-      let u = shift (fun k -> unit (10 + 1))
-      in u >>= fun x -> unit (100 + x)
-    ) in let w = v >>= fun x -> unit (1000 + x)
-    in run0 w)
-
-(* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *)
-let example2 () =
-  Continuation_monad.(let v = reset (
-      let u = shift (fun k -> k (10 :: [1]))
-      in u >>= fun x -> unit (100 :: x)
-    ) in let w = v >>= fun x -> unit (1000 :: x)
-    in run0 w)
-
-(* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *)
-let example3 () =
-  Continuation_monad.(let v = reset (
-      let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x))
-      in u >>= fun x -> unit (100 :: x)
-    ) in let w = v >>= fun x -> unit (1000 :: x)
-    in run0 w)
-
-(* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *)
-(* not sure if this example can be typed without a sum-type *)
-
-(* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
-let example5 () : int =
-  Continuation_monad.(let v = reset (
-      let u = shift (fun k -> k 1 >>= k)
-      in u >>= fun x -> unit (10 + x)
-    ) in let w = v >>= fun x -> unit (100 + x)
-    in run0 w)
-
-;;
-
-print_endline "=== test bare Continuation ============";;
-
-(1011, 1111, 1111, 121);;
-(example1(), example2(), example3(), example5());;
-((111,0), (0,0));;
-(example ~+10, example ~-10);;
-
-let testc df ic =
-    C.run_exn TC.(run (distribute df t1)) ic;;
-
-
-(*
-(* do nothing *)
-let initial_continuation = fun t -> t in
-TreeCont.monadize t1 Continuation_monad.unit initial_continuation;;
-*)
-testc (C.unit) id;;
-
-(*
-(* count leaves, using continuation *)
-let initial_continuation = fun t -> 0 in
-TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;;
-*)
-
-testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);;
-
-(*
-(* convert tree to list of leaves *)
-let initial_continuation = fun t -> [] in
-TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;;
-*)
-
-testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));;
-
-(*
-(* square each leaf using continuation *)
-let initial_continuation = fun t -> t in
-TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;;
-*)
-
-testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);;
-
-
-(*
-(* replace leaves with list, using continuation *)
-let initial_continuation = fun t -> t in
-TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
-*)
-
-testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
-
-print_endline "=== pa_monad's Continuation Tests ============";;
-
-(1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
-(2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
-(3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
-(4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
-(5, 27 = C.(run0 (
-              let c = reset(abort 5 >>= fun y -> unit (y+6))
-              in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
-
-(7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
-
-(8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
-
-(12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
 
+end;;
 
-(0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;