tweak monads-lib, migrate to T2
[lambda.git] / code / monads.ml
index 0dab871..4dbd851 100644 (file)
@@ -332,11 +332,11 @@ module Identity_monad : sig
 end = struct
   module Base = struct
     type 'a m = 'a
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = a
     let bind a f = f a
-    type 'a result = 'a
     let run a = a
-    type 'a result_exn = 'a
     let run_exn a = a
   end
   include Monad.Make(Base)
@@ -367,11 +367,11 @@ module Maybe_monad : sig
 end = struct
   module Base = struct
    type 'a m = 'a option
+   type 'a result = 'a option
+   type 'a result_exn = 'a
    let unit a = Some a
    let bind u f = match u with Some a -> f a | None -> None
-   type 'a result = 'a option
    let run u = u
-   type 'a result_exn = 'a
    let run_exn u = match u with
      | Some a -> a
      | None -> failwith "no value"
@@ -455,7 +455,6 @@ module List_monad : sig
     val select : 'a m -> ('a * 'a m) m
 *)
   end
-(*
   module T2 : functor (Wrapped : Monad.S2) -> sig
     type ('x,'a) result = ('x,'a list) Wrapped.result
     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
@@ -464,15 +463,14 @@ module List_monad : sig
     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
   end
- *)
 end = struct
   module Base = struct
    type 'a m = 'a list
+   type 'a result = 'a list
+   type 'a result_exn = 'a
    let unit a = [a]
    let bind u f = Util.concat_map f u
-   type 'a result = 'a list
    let run u = u
-   type 'a result_exn = 'a
    let run_exn u = match u with
      | [] -> failwith "no values"
      | [a] -> a
@@ -500,11 +498,6 @@ end = struct
   let base_plus = plus
   module T(Wrapped : Monad.S) = struct
     module Trans = struct
-      let zero () = Wrapped.unit []
-      let plus u v =
-        Wrapped.bind u (fun us ->
-        Wrapped.bind v (fun vs ->
-        Wrapped.unit (base_plus us vs)))
       (* Wrapped.sequence ms  ===  
            let plus1 u v =
              Wrapped.bind u (fun x ->
@@ -531,6 +524,11 @@ end = struct
             | many -> failwith "multiple values"
           ) in Wrapped.run_exn w
       end)
+      let zero () = Wrapped.unit []
+      let plus u v =
+        Wrapped.bind u (fun us ->
+        Wrapped.bind v (fun vs ->
+        Wrapped.unit (base_plus us vs)))
     end
     include Trans
     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
@@ -539,6 +537,37 @@ end = struct
     let select : 'a m -> ('a * 'a m) m
 *)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      let distribute f alist = Wrapped.sequence (Util.map f alist)
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a list) Wrapped.m
+        type ('x,'a) result = ('x,'a list) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+        (* code repetition, ugh *)
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
+        let bind u f =
+          Wrapped.bind u (fun ts ->
+          Wrapped.bind (distribute f ts) (fun tts ->
+          Wrapped.unit (Util.concat tts)))
+        let run u = Wrapped.run u
+        let run_exn u =
+          let w = Wrapped.bind u (fun ts -> match ts with
+            | [] -> failwith "no values"
+            | [a] -> Wrapped.unit a
+            | many -> failwith "multiple values"
+          ) in Wrapped.run_exn w
+      end)
+      let zero () = Wrapped.unit []
+      let plus u v =
+        Wrapped.bind u (fun us ->
+        Wrapped.bind v (fun vs ->
+        Wrapped.unit (base_plus us vs)))
+    end
+    include Trans
+    include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -569,21 +598,29 @@ end) : sig
     val throw : err -> 'a m
     val catch : 'a m -> (err -> 'a m) -> 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val throw : err -> ('x,'a) m
+    val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
+  end
 end = struct
   type err = Err.err
   type 'a error = Error of err | Success of 'a
   module Base = struct
     type 'a m = 'a error
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = Success a
     let bind u f = match u with
       | Success a -> f a
       | Error e -> Error e (* input and output may be of different 'a types *)
-    type 'a result = 'a
     (* TODO: should run refrain from failing? *)
     let run u = match u with
       | Success a -> a
       | Error e -> raise (Err.Exc e)
-    type 'a result_exn = 'a
     let run_exn = run
     (*
     let zero () = Error Err.zero
@@ -606,12 +643,13 @@ end = struct
   module T(Wrapped : Monad.S) = struct
     module Trans = struct
       module Wrapped = Wrapped
-      type 'a m = 'a Base.m Wrapped.m
+      type 'a m = 'a error Wrapped.m
+      type 'a result = 'a Wrapped.result
+      type 'a result_exn = 'a Wrapped.result_exn
       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
       let bind u f = Wrapped.bind u (fun t -> match t with
         | Success a -> f a
         | Error e -> Wrapped.unit (Error e))
-      type 'a result = 'a Wrapped.result
       (* TODO: should run refrain from failing? *)
       let run u =
         let w = Wrapped.bind u (fun t -> match t with
@@ -619,7 +657,6 @@ end = struct
           (* | _ -> Wrapped.fail () *)
           | Error e -> raise (Err.Exc e))
         in Wrapped.run w
-      type 'a result_exn = 'a Wrapped.result_exn
       let run_exn u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
@@ -633,6 +670,34 @@ end = struct
       | Success _ -> Wrapped.unit t
       | Error e -> handler e)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = ('x,'a error) Wrapped.m
+      type ('x,'a) result = ('x,'a) Wrapped.result
+      type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
+      let bind u f = Wrapped.bind u (fun t -> match t with
+        | Success a -> f a
+        | Error e -> Wrapped.unit (Error e))
+      let run u =
+        let w = Wrapped.bind u (fun t -> match t with
+          | Success a -> Wrapped.unit a
+          | Error e -> raise (Err.Exc e))
+        in Wrapped.run w
+      let run_exn u =
+        let w = Wrapped.bind u (fun t -> match t with
+          | Success a -> Wrapped.unit a
+          | Error e -> raise (Err.Exc e))
+        in Wrapped.run_exn w
+    end
+    include Monad.MakeT2(Trans)
+    let throw e = Wrapped.unit (Error e)
+    let catch u handler = Wrapped.bind u (fun t -> match t with
+      | Success _ -> Wrapped.unit t
+      | Error e -> handler e)
+  end
 end
 
 (* pre-define common instance of Error_monad *)
@@ -670,15 +735,28 @@ module Reader_monad(Env : sig type env end) : sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = env -> ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val ask : ('x,env) m
+    val asks : (env -> 'a) -> ('x,'a) m
+    val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
+    include module type of T2(Wrapped)
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
 end = struct
   type env = Env.env
   module Base = struct
     type 'a m = env -> 'a
+    type 'a result = env -> 'a
+    type 'a result_exn = env -> 'a
     let unit a = fun e -> a
     let bind u f = fun e -> let a = u e in let u' = f a in u' e
-    type 'a result = env -> 'a
     let run u = fun e -> u e
-    type 'a result_exn = env -> 'a
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -689,11 +767,11 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = env -> 'a Wrapped.m
+      type 'a result = env -> 'a Wrapped.result
+      type 'a result_exn = env -> 'a Wrapped.result_exn
       let elevate w = fun e -> w
       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
-      type 'a result = env -> 'a Wrapped.result
       let run u = fun e -> Wrapped.run (u e)
-      type 'a result_exn = env -> 'a Wrapped.result_exn
       let run_exn u = fun e -> Wrapped.run_exn (u e)
     end
     include Monad.MakeT(Trans)
@@ -713,6 +791,36 @@ end = struct
     include TransP
     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = env -> ('x,'a) Wrapped.m
+      type ('x,'a) result = env -> ('x,'a) Wrapped.result
+      type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = fun e -> w
+      let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
+      let run u = fun e -> Wrapped.run (u e)
+      let run_exn u = fun e -> Wrapped.run_exn (u e)
+    end
+    include Monad.MakeT2(Trans)
+    let ask = fun e -> Wrapped.unit e
+    let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
+    let local modifier u = fun e -> u (modifier e)
+  end
+  module TP2(Wrapped : Monad.P2) = struct
+    module TransP = struct
+      (* code repetition, ugh *)
+      include T2(Wrapped)
+      let plus u v = fun s -> Wrapped.plus (u s) (v s)
+      let zero () = elevate (Wrapped.zero ())
+      let asks selector = ask >>= (fun e ->
+        try unit (selector e)
+        with Not_found -> fun e -> Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -743,15 +851,29 @@ module State_monad(Store : sig type store end) : sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
+    type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val get : ('x,store) m
+    val gets : (store -> 'a) -> ('x,'a) m
+    val put : store -> ('x,unit) m
+    val puts : (store -> store) -> ('x,unit) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
+    include module type of T2(Wrapped)
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
 end = struct
   type store = Store.store
   module Base = struct
     type 'a m = store -> 'a * store
+    type 'a result = store -> 'a * store
+    type 'a result_exn = store -> 'a
     let unit a = fun s -> (a, s)
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
-    type 'a result = store -> 'a * store
     let run u = fun s -> (u s)
-    type 'a result_exn = store -> 'a
     let run_exn u = fun s -> fst (u s)
   end
   include Monad.Make(Base)
@@ -763,13 +885,13 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = store -> ('a * store) Wrapped.m
+      type 'a result = store -> ('a * store) Wrapped.result
+      type 'a result_exn = store -> 'a Wrapped.result_exn
       let elevate w = fun s ->
         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
       let bind u f = fun s ->
         Wrapped.bind (u s) (fun (a, s') -> f a s')
-      type 'a result = store -> ('a * store) Wrapped.result
       let run u = fun s -> Wrapped.run (u s)
-      type 'a result_exn = store -> 'a Wrapped.result_exn
       let run_exn u = fun s ->
         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
@@ -792,6 +914,40 @@ end = struct
     include TransP
     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
+      type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
+      type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = fun s ->
+        Wrapped.bind w (fun a -> Wrapped.unit (a, s))
+      let bind u f = fun s ->
+        Wrapped.bind (u s) (fun (a, s') -> f a s')
+      let run u = fun s -> Wrapped.run (u s)
+      let run_exn u = fun s ->
+        let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
+        in Wrapped.run_exn w
+    end
+    include Monad.MakeT2(Trans)
+    let get = fun s -> Wrapped.unit (s, s)
+    let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
+    let put s = fun _ -> Wrapped.unit ((), s)
+    let puts modifier = fun s -> Wrapped.unit ((), modifier s)
+  end
+  module TP2(Wrapped : Monad.P2) = struct
+    module TransP = struct
+      include T2(Wrapped)
+      let plus u v = fun s -> Wrapped.plus (u s) (v s)
+      let zero () = elevate (Wrapped.zero ())
+    end
+    let gets viewer = fun s ->
+      try Wrapped.unit (viewer s, s)
+      with Not_found -> Wrapped.zero ()
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 (* State monad with different interface (structured store) *)
@@ -821,6 +977,19 @@ end) : sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val newref : value -> ('x,ref) m
+    val deref : ref -> ('x,value) m
+    val change : ref -> value -> ('x,unit) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
+    include module type of T2(Wrapped)
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
 end = struct
   type ref = int
   type value = V.value
@@ -835,11 +1004,11 @@ end = struct
     { next = d.next; tree = D.add key value d.tree }
   module Base = struct
     type 'a m = dict -> 'a * dict
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = fun s -> (a, s)
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
-    type 'a result = 'a
     let run u = fst (u empty)
-    type 'a result_exn = 'a
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -850,15 +1019,15 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = dict -> ('a * dict) Wrapped.m
+      type 'a result = 'a Wrapped.result
+      type 'a result_exn = 'a Wrapped.result_exn
       let elevate w = fun s ->
         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
       let bind u f = fun s ->
         Wrapped.bind (u s) (fun (a, s') -> f a s')
-      type 'a result = 'a Wrapped.result
       let run u =
         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run w
-      type 'a result_exn = 'a Wrapped.result_exn
       let run_exn u =
         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
@@ -877,6 +1046,38 @@ end = struct
     include TransP
     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
+      type ('x,'a) result = ('x,'a) Wrapped.result
+      type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = fun s ->
+        Wrapped.bind w (fun a -> Wrapped.unit (a, s))
+      let bind u f = fun s ->
+        Wrapped.bind (u s) (fun (a, s') -> f a s')
+      let run u =
+        let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
+        in Wrapped.run w
+      let run_exn u =
+        let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
+        in Wrapped.run_exn w
+    end
+    include Monad.MakeT2(Trans)
+    let newref value = fun s -> Wrapped.unit (alloc value s)
+    let deref key = fun s -> Wrapped.unit (read key s, s)
+    let change key value = fun s -> Wrapped.unit ((), write key value s)
+  end
+  module TP2(Wrapped : Monad.P2) = struct
+    module TransP = struct
+      include T2(Wrapped)
+      let plus u v = fun s -> Wrapped.plus (u s) (v s)
+      let zero () = elevate (Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -900,11 +1101,11 @@ end = struct
   type log = Log.log
   module Base = struct
     type 'a m = 'a * log
+    type 'a result = 'a * log
+    type 'a result_exn = 'a * log
     let unit a = (a, Log.zero)
     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
-    type 'a result = 'a * log
     let run u = u
-    type 'a result_exn = 'a * log
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -949,13 +1150,13 @@ module IO_monad : sig
 end = struct
   module Base = struct
     type 'a m = { run : unit -> unit; value : 'a }
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = { run = (fun () -> ()); value = a }
     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
      let fres = f a.value in
        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
-    type 'a result = 'a
     let run a = let () = a.run () in a.value
-    type 'a result_exn = 'a
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -967,6 +1168,7 @@ end = struct
   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
 end
 
+(*
 module Continuation_monad : sig
   (* expose only the implementation of type `('r,'a) result` *)
   type 'a m
@@ -990,6 +1192,8 @@ end = struct
   module Base = struct
     (* 'r is result type of whole computation *)
     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
+    type 'a result = 'a m
+    type 'a result_exn = 'a m
     let unit a =
       let cont : 'r. ('a -> 'r) -> 'r =
         fun k -> k a
@@ -998,9 +1202,7 @@ end = struct
       let cont : 'r. ('a -> 'r) -> 'r =
         fun k -> u.cont (fun a -> (f a).cont k)
       in { cont }
-    type 'a result = 'a m
     let run (u : 'a m) : 'a result = u
-    type 'a result_exn = 'a m
     let run_exn (u : 'a m) : 'a result_exn = u
     let callcc f =
       let cont : 'r. ('a -> 'r) -> 'r =
@@ -1029,35 +1231,36 @@ end = struct
   let runk = Base.runk
   let run0 = Base.run0
 end
+ *)
 
-(*
 (* This two-type parameter version works without Obj.magic *)
-
-module Continuation_monad2 : sig
+module Continuation_monad : sig
   (* expose only the implementation of type `('r,'a) result` *)
-  type ('r,'a) result = ('a -> 'r) -> 'r
+  type ('r,'a) m
+  type ('r,'a) result = ('r,'a) m
   type ('r,'a) result_exn = ('a -> 'r) -> 'r
-  include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn
+  include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
   val reset : ('a,'a) m -> ('r,'a) m
   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
-
+  val abort : ('a,'a) m -> ('a,'b) m
+  val run0 : ('a,'a) m -> 'a
 end = struct
   let id = fun i -> i
   module Base = struct
     (* 'r is result type of whole computation *)
     type ('r,'a) m = ('a -> 'r) -> 'r
-    let unit a = fun k -> k a
-    let bind u f = fun k -> u (fun a -> (f a) k)
     type ('r,'a) result = ('a -> 'r) -> 'r
-    let run u = u
-    type ('r,'a) result_exn = ('a -> 'r) -> 'r
+    type ('r,'a) result_exn = ('r,'a) result
+    let unit a = (fun k -> k a)
+    let bind u f = (fun k -> (u) (fun a -> (f a) k))
+    let run u k = (u) k
     let run_exn = run
   end
   include Monad.Make2(Base)
-  let callcc f = fun k ->
-    let usek a = fun _ -> k a
-    in f usek k
+  let callcc f = (fun k ->
+    let usek a = (fun _ -> k a)
+    in (f usek) k)
   (*
   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
@@ -1065,10 +1268,11 @@ end = struct
   let throw k a = fun _ -> k a
   *)
   (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *)
-  let reset u = unit (u id)
-  let shift u = fun k -> u (fun a -> unit (k a)) id
+  let reset u = unit ((u) id)
+  let shift f = (fun k -> (f (fun a -> unit (k a))) id)
+  let abort a = shift (fun _ -> a)
+  let run0 (u : ('a,'a) m) = (u) id
 end
- *)
 
 
 (*
@@ -1150,6 +1354,14 @@ module Leaf_monad : sig
     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a tree option) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
+  end
 end = struct
   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
   (* uses supplied plus and zero to copy t to its image under f *)
@@ -1163,6 +1375,8 @@ end = struct
                    ) in loop ts
   module Base = struct
     type 'a m = 'a tree option
+    type 'a result = 'a tree option
+    type 'a result_exn = 'a tree
     let unit a = Some (Leaf a)
     let zero () = None
     let plus u v = match (u, v) with
@@ -1170,9 +1384,7 @@ end = struct
       | _, None -> u
       | Some us, Some vs -> Some (Node (us, vs))
     let bind u f = mapT f u zero plus
-    type 'a result = 'a tree option
     let run u = u
-    type 'a result_exn = 'a tree
     let run_exn u = match u with
       | None -> failwith "no values"
       (*
@@ -1194,12 +1406,12 @@ end = struct
         Wrapped.unit (base_plus us vs)))
       include Monad.MakeT(struct
         module Wrapped = Wrapped
-        type 'a m = 'a Base.m Wrapped.m
+        type 'a m = 'a tree option Wrapped.m
+        type 'a result = 'a tree option Wrapped.result
+        type 'a result_exn = 'a tree Wrapped.result_exn
         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
-        type 'a result = 'a tree option Wrapped.result
         let run u = Wrapped.run u
-        type 'a result_exn = 'a tree Wrapped.result_exn
         let run_exn u =
             let w = Wrapped.bind u (fun t -> match t with
               | None -> failwith "no values"
@@ -1212,6 +1424,33 @@ end = struct
     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      let zero () = Wrapped.unit None
+      let plus u v =
+        Wrapped.bind u (fun us ->
+        Wrapped.bind v (fun vs ->
+        Wrapped.unit (base_plus us vs)))
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a tree option) Wrapped.m
+        type ('x,'a) result = ('x,'a tree option) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
+        (* code repetition, ugh *)
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
+        let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
+        let run u = Wrapped.run u
+        let run_exn u =
+            let w = Wrapped.bind u (fun t -> match t with
+              | None -> failwith "no values"
+              | Some ts -> Wrapped.unit ts)
+            in Wrapped.run_exn w
+      end)
+    end
+    include Trans
+    include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+    let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
+  end
 end
 
 
@@ -1224,10 +1463,14 @@ module LS = L.T(S);;
 module TL = T.T(L);;
 module TR = T.T(R);;
 module TS = T.T(S);;
+module C = Continuation_monad
+module TC = T.T2(C);;
+
+
+print_endline "================================================";;
 
 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
 
-(*
 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
 TS.run ts 0;;
 (*
@@ -1294,7 +1537,6 @@ LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts
 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
 *)
 
-*)
 
 let id : 'z. 'z -> 'z = fun x -> x
 
@@ -1352,11 +1594,8 @@ let example5 () : int =
 ((111,0), (0,0));;
 (example ~+10, example ~-10);;
 
-module C = Continuation_monad
-module TC = T.T(C)
-
 let testc df ic =
-    C.runk TC.(run_exn (distribute df t1)) ic;;
+    C.run_exn TC.(run (distribute df t1)) ic;;
 
 
 (*