monads.ml: add TP to Error
[lambda.git] / code / monads.ml
index 0dab871..7bb6894 100644 (file)
@@ -111,7 +111,7 @@ module Monad = struct
     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
     val do_when :  bool -> unit m -> unit m
     val do_unless :  bool -> unit m -> unit m
-    val forever : 'a m -> 'b m
+    val forever : (unit -> 'a m) -> 'b m
     val sequence : 'a m list -> 'a list m
     val sequence_ : 'a m list -> unit m
   end
@@ -135,7 +135,9 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
-    let rec forever u = u >> forever u
+    let forever uthunk =
+        let rec loop () = uthunk () >>= fun _ -> loop ()
+        in loop ()
     let sequence ms =
       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
         Util.fold_right op ms (unit [])
@@ -247,7 +249,7 @@ module Monad = struct
     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
-    val forever : ('x,'a) m -> ('x,'b) m
+    val forever : (unit -> ('x,'a) m) -> ('x,'b) m
     val sequence : ('x,'a) m list -> ('x,'a list) m
     val sequence_ : ('x,'a) m list -> ('x,unit) m
   end
@@ -263,7 +265,9 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
-    let rec forever u = u >> forever u
+    let forever uthunk =
+        let rec loop () = uthunk () >>= fun _ -> loop ()
+        in loop ()
     let sequence ms =
       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
         Util.fold_right op ms (unit [])
@@ -332,11 +336,11 @@ module Identity_monad : sig
 end = struct
   module Base = struct
     type 'a m = 'a
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = a
     let bind a f = f a
-    type 'a result = 'a
     let run a = a
-    type 'a result_exn = 'a
     let run_exn a = a
   end
   include Monad.Make(Base)
@@ -367,11 +371,11 @@ module Maybe_monad : sig
 end = struct
   module Base = struct
    type 'a m = 'a option
+   type 'a result = 'a option
+   type 'a result_exn = 'a
    let unit a = Some a
    let bind u f = match u with Some a -> f a | None -> None
-   type 'a result = 'a option
    let run u = u
-   type 'a result_exn = 'a
    let run_exn u = match u with
      | Some a -> a
      | None -> failwith "no value"
@@ -455,7 +459,6 @@ module List_monad : sig
     val select : 'a m -> ('a * 'a m) m
 *)
   end
-(*
   module T2 : functor (Wrapped : Monad.S2) -> sig
     type ('x,'a) result = ('x,'a list) Wrapped.result
     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
@@ -464,15 +467,14 @@ module List_monad : sig
     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
     val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m
   end
- *)
 end = struct
   module Base = struct
    type 'a m = 'a list
+   type 'a result = 'a list
+   type 'a result_exn = 'a
    let unit a = [a]
    let bind u f = Util.concat_map f u
-   type 'a result = 'a list
    let run u = u
-   type 'a result_exn = 'a
    let run_exn u = match u with
      | [] -> failwith "no values"
      | [a] -> a
@@ -500,11 +502,6 @@ end = struct
   let base_plus = plus
   module T(Wrapped : Monad.S) = struct
     module Trans = struct
-      let zero () = Wrapped.unit []
-      let plus u v =
-        Wrapped.bind u (fun us ->
-        Wrapped.bind v (fun vs ->
-        Wrapped.unit (base_plus us vs)))
       (* Wrapped.sequence ms  ===  
            let plus1 u v =
              Wrapped.bind u (fun x ->
@@ -531,6 +528,11 @@ end = struct
             | many -> failwith "multiple values"
           ) in Wrapped.run_exn w
       end)
+      let zero () = Wrapped.unit []
+      let plus u v =
+        Wrapped.bind u (fun us ->
+        Wrapped.bind v (fun vs ->
+        Wrapped.unit (base_plus us vs)))
     end
     include Trans
     include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m)
@@ -539,6 +541,37 @@ end = struct
     let select : 'a m -> ('a * 'a m) m
 *)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      let distribute f alist = Wrapped.sequence (Util.map f alist)
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a list) Wrapped.m
+        type ('x,'a) result = ('x,'a list) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+        (* code repetition, ugh *)
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a])
+        let bind u f =
+          Wrapped.bind u (fun ts ->
+          Wrapped.bind (distribute f ts) (fun tts ->
+          Wrapped.unit (Util.concat tts)))
+        let run u = Wrapped.run u
+        let run_exn u =
+          let w = Wrapped.bind u (fun ts -> match ts with
+            | [] -> failwith "no values"
+            | [a] -> Wrapped.unit a
+            | many -> failwith "multiple values"
+          ) in Wrapped.run_exn w
+      end)
+      let zero () = Wrapped.unit []
+      let plus u v =
+        Wrapped.bind u (fun us ->
+        Wrapped.bind v (fun vs ->
+        Wrapped.unit (base_plus us vs)))
+    end
+    include Trans
+    include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -569,21 +602,34 @@ end) : sig
     val throw : err -> 'a m
     val catch : 'a m -> (err -> 'a m) -> 'a m
   end
+  (* ErrorT transformer when wrapped monad has plus, zero *)
+  module TP : functor (Wrapped : Monad.P) -> sig
+    include module type of T(Wrapped)
+    include Monad.PLUS with type 'a m := 'a m
+  end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val throw : err -> ('x,'a) m
+    val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
+  end
 end = struct
   type err = Err.err
   type 'a error = Error of err | Success of 'a
   module Base = struct
     type 'a m = 'a error
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = Success a
     let bind u f = match u with
       | Success a -> f a
       | Error e -> Error e (* input and output may be of different 'a types *)
-    type 'a result = 'a
     (* TODO: should run refrain from failing? *)
     let run u = match u with
       | Success a -> a
       | Error e -> raise (Err.Exc e)
-    type 'a result_exn = 'a
     let run_exn = run
     (*
     let zero () = Error Err.zero
@@ -606,12 +652,13 @@ end = struct
   module T(Wrapped : Monad.S) = struct
     module Trans = struct
       module Wrapped = Wrapped
-      type 'a m = 'a Base.m Wrapped.m
+      type 'a m = 'a error Wrapped.m
+      type 'a result = 'a Wrapped.result
+      type 'a result_exn = 'a Wrapped.result_exn
       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
       let bind u f = Wrapped.bind u (fun t -> match t with
         | Success a -> f a
         | Error e -> Wrapped.unit (Error e))
-      type 'a result = 'a Wrapped.result
       (* TODO: should run refrain from failing? *)
       let run u =
         let w = Wrapped.bind u (fun t -> match t with
@@ -619,7 +666,6 @@ end = struct
           (* | _ -> Wrapped.fail () *)
           | Error e -> raise (Err.Exc e))
         in Wrapped.run w
-      type 'a result_exn = 'a Wrapped.result_exn
       let run_exn u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
@@ -633,6 +679,43 @@ end = struct
       | Success _ -> Wrapped.unit t
       | Error e -> handler e)
   end
+  module TP(Wrapped : Monad.P) = struct
+    module TransP = struct
+      include T(Wrapped)
+      let plus u v = Wrapped.plus u v
+      let zero () = elevate (Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
+  end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = ('x,'a error) Wrapped.m
+      type ('x,'a) result = ('x,'a) Wrapped.result
+      type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
+      let bind u f = Wrapped.bind u (fun t -> match t with
+        | Success a -> f a
+        | Error e -> Wrapped.unit (Error e))
+      let run u =
+        let w = Wrapped.bind u (fun t -> match t with
+          | Success a -> Wrapped.unit a
+          | Error e -> raise (Err.Exc e))
+        in Wrapped.run w
+      let run_exn u =
+        let w = Wrapped.bind u (fun t -> match t with
+          | Success a -> Wrapped.unit a
+          | Error e -> raise (Err.Exc e))
+        in Wrapped.run_exn w
+    end
+    include Monad.MakeT2(Trans)
+    let throw e = Wrapped.unit (Error e)
+    let catch u handler = Wrapped.bind u (fun t -> match t with
+      | Success _ -> Wrapped.unit t
+      | Error e -> handler e)
+  end
 end
 
 (* pre-define common instance of Error_monad *)
@@ -670,15 +753,28 @@ module Reader_monad(Env : sig type env end) : sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = env -> ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val ask : ('x,env) m
+    val asks : (env -> 'a) -> ('x,'a) m
+    val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
+    include module type of T2(Wrapped)
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
 end = struct
   type env = Env.env
   module Base = struct
     type 'a m = env -> 'a
+    type 'a result = env -> 'a
+    type 'a result_exn = env -> 'a
     let unit a = fun e -> a
     let bind u f = fun e -> let a = u e in let u' = f a in u' e
-    type 'a result = env -> 'a
     let run u = fun e -> u e
-    type 'a result_exn = env -> 'a
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -689,11 +785,11 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = env -> 'a Wrapped.m
+      type 'a result = env -> 'a Wrapped.result
+      type 'a result_exn = env -> 'a Wrapped.result_exn
       let elevate w = fun e -> w
       let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
-      type 'a result = env -> 'a Wrapped.result
       let run u = fun e -> Wrapped.run (u e)
-      type 'a result_exn = env -> 'a Wrapped.result_exn
       let run_exn u = fun e -> Wrapped.run_exn (u e)
     end
     include Monad.MakeT(Trans)
@@ -713,6 +809,36 @@ end = struct
     include TransP
     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = env -> ('x,'a) Wrapped.m
+      type ('x,'a) result = env -> ('x,'a) Wrapped.result
+      type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = fun e -> w
+      let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
+      let run u = fun e -> Wrapped.run (u e)
+      let run_exn u = fun e -> Wrapped.run_exn (u e)
+    end
+    include Monad.MakeT2(Trans)
+    let ask = fun e -> Wrapped.unit e
+    let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
+    let local modifier u = fun e -> u (modifier e)
+  end
+  module TP2(Wrapped : Monad.P2) = struct
+    module TransP = struct
+      (* code repetition, ugh *)
+      include T2(Wrapped)
+      let plus u v = fun s -> Wrapped.plus (u s) (v s)
+      let zero () = elevate (Wrapped.zero ())
+      let asks selector = ask >>= (fun e ->
+        try unit (selector e)
+        with Not_found -> fun e -> Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -743,15 +869,29 @@ module State_monad(Store : sig type store end) : sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
+    type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val get : ('x,store) m
+    val gets : (store -> 'a) -> ('x,'a) m
+    val put : store -> ('x,unit) m
+    val puts : (store -> store) -> ('x,unit) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
+    include module type of T2(Wrapped)
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
 end = struct
   type store = Store.store
   module Base = struct
     type 'a m = store -> 'a * store
+    type 'a result = store -> 'a * store
+    type 'a result_exn = store -> 'a
     let unit a = fun s -> (a, s)
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
-    type 'a result = store -> 'a * store
     let run u = fun s -> (u s)
-    type 'a result_exn = store -> 'a
     let run_exn u = fun s -> fst (u s)
   end
   include Monad.Make(Base)
@@ -763,13 +903,13 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = store -> ('a * store) Wrapped.m
+      type 'a result = store -> ('a * store) Wrapped.result
+      type 'a result_exn = store -> 'a Wrapped.result_exn
       let elevate w = fun s ->
         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
       let bind u f = fun s ->
         Wrapped.bind (u s) (fun (a, s') -> f a s')
-      type 'a result = store -> ('a * store) Wrapped.result
       let run u = fun s -> Wrapped.run (u s)
-      type 'a result_exn = store -> 'a Wrapped.result_exn
       let run_exn u = fun s ->
         let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
@@ -792,6 +932,40 @@ end = struct
     include TransP
     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
+      type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
+      type ('x,'a) result_exn = store -> ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = fun s ->
+        Wrapped.bind w (fun a -> Wrapped.unit (a, s))
+      let bind u f = fun s ->
+        Wrapped.bind (u s) (fun (a, s') -> f a s')
+      let run u = fun s -> Wrapped.run (u s)
+      let run_exn u = fun s ->
+        let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
+        in Wrapped.run_exn w
+    end
+    include Monad.MakeT2(Trans)
+    let get = fun s -> Wrapped.unit (s, s)
+    let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *)
+    let put s = fun _ -> Wrapped.unit ((), s)
+    let puts modifier = fun s -> Wrapped.unit ((), modifier s)
+  end
+  module TP2(Wrapped : Monad.P2) = struct
+    module TransP = struct
+      include T2(Wrapped)
+      let plus u v = fun s -> Wrapped.plus (u s) (v s)
+      let zero () = elevate (Wrapped.zero ())
+    end
+    let gets viewer = fun s ->
+      try Wrapped.unit (viewer s, s)
+      with Not_found -> Wrapped.zero ()
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 (* State monad with different interface (structured store) *)
@@ -821,6 +995,19 @@ end) : sig
     include module type of T(Wrapped)
     include Monad.PLUS with type 'a m := 'a m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val newref : value -> ('x,ref) m
+    val deref : ref -> ('x,value) m
+    val change : ref -> value -> ('x,unit) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
+    include module type of T2(Wrapped)
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+  end
 end = struct
   type ref = int
   type value = V.value
@@ -835,11 +1022,11 @@ end = struct
     { next = d.next; tree = D.add key value d.tree }
   module Base = struct
     type 'a m = dict -> 'a * dict
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = fun s -> (a, s)
     let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
-    type 'a result = 'a
     let run u = fst (u empty)
-    type 'a result_exn = 'a
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -850,15 +1037,15 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = dict -> ('a * dict) Wrapped.m
+      type 'a result = 'a Wrapped.result
+      type 'a result_exn = 'a Wrapped.result_exn
       let elevate w = fun s ->
         Wrapped.bind w (fun a -> Wrapped.unit (a, s))
       let bind u f = fun s ->
         Wrapped.bind (u s) (fun (a, s') -> f a s')
-      type 'a result = 'a Wrapped.result
       let run u =
         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run w
-      type 'a result_exn = 'a Wrapped.result_exn
       let run_exn u =
         let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
         in Wrapped.run_exn w
@@ -877,6 +1064,38 @@ end = struct
     include TransP
     include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      module Wrapped = Wrapped
+      type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
+      type ('x,'a) result = ('x,'a) Wrapped.result
+      type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+      (* code repetition, ugh *)
+      let elevate w = fun s ->
+        Wrapped.bind w (fun a -> Wrapped.unit (a, s))
+      let bind u f = fun s ->
+        Wrapped.bind (u s) (fun (a, s') -> f a s')
+      let run u =
+        let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
+        in Wrapped.run w
+      let run_exn u =
+        let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
+        in Wrapped.run_exn w
+    end
+    include Monad.MakeT2(Trans)
+    let newref value = fun s -> Wrapped.unit (alloc value s)
+    let deref key = fun s -> Wrapped.unit (read key s, s)
+    let change key value = fun s -> Wrapped.unit ((), write key value s)
+  end
+  module TP2(Wrapped : Monad.P2) = struct
+    module TransP = struct
+      include T2(Wrapped)
+      let plus u v = fun s -> Wrapped.plus (u s) (v s)
+      let zero () = elevate (Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 
@@ -900,11 +1119,11 @@ end = struct
   type log = Log.log
   module Base = struct
     type 'a m = 'a * log
+    type 'a result = 'a * log
+    type 'a result_exn = 'a * log
     let unit a = (a, Log.zero)
     let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
-    type 'a result = 'a * log
     let run u = u
-    type 'a result_exn = 'a * log
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -949,13 +1168,13 @@ module IO_monad : sig
 end = struct
   module Base = struct
     type 'a m = { run : unit -> unit; value : 'a }
+    type 'a result = 'a
+    type 'a result_exn = 'a
     let unit a = { run = (fun () -> ()); value = a }
     let bind (a : 'a m) (f: 'a -> 'b m) : 'b m =
      let fres = f a.value in
        { run = (fun () -> a.run (); fres.run ()); value = fres.value }
-    type 'a result = 'a
     let run a = let () = a.run () in a.value
-    type 'a result_exn = 'a
     let run_exn = run
   end
   include Monad.Make(Base)
@@ -967,6 +1186,7 @@ end = struct
   let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () }
 end
 
+(*
 module Continuation_monad : sig
   (* expose only the implementation of type `('r,'a) result` *)
   type 'a m
@@ -990,6 +1210,8 @@ end = struct
   module Base = struct
     (* 'r is result type of whole computation *)
     type 'a m = { cont : 'r. ('a -> 'r) -> 'r }
+    type 'a result = 'a m
+    type 'a result_exn = 'a m
     let unit a =
       let cont : 'r. ('a -> 'r) -> 'r =
         fun k -> k a
@@ -998,9 +1220,7 @@ end = struct
       let cont : 'r. ('a -> 'r) -> 'r =
         fun k -> u.cont (fun a -> (f a).cont k)
       in { cont }
-    type 'a result = 'a m
     let run (u : 'a m) : 'a result = u
-    type 'a result_exn = 'a m
     let run_exn (u : 'a m) : 'a result_exn = u
     let callcc f =
       let cont : 'r. ('a -> 'r) -> 'r =
@@ -1029,46 +1249,58 @@ end = struct
   let runk = Base.runk
   let run0 = Base.run0
 end
+ *)
 
-(*
 (* This two-type parameter version works without Obj.magic *)
-
-module Continuation_monad2 : sig
+module Continuation_monad : sig
   (* expose only the implementation of type `('r,'a) result` *)
-  type ('r,'a) result = ('a -> 'r) -> 'r
+  type ('r,'a) m
+  type ('r,'a) result = ('r,'a) m
   type ('r,'a) result_exn = ('a -> 'r) -> 'r
-  include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn
+  include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
   val reset : ('a,'a) m -> ('r,'a) m
   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
-
+  (* val abort : ('a,'a) m -> ('a,'b) m *)
+  val abort : 'a -> ('a,'b) m
+  val run0 : ('a,'a) m -> 'a
 end = struct
   let id = fun i -> i
   module Base = struct
     (* 'r is result type of whole computation *)
     type ('r,'a) m = ('a -> 'r) -> 'r
-    let unit a = fun k -> k a
-    let bind u f = fun k -> u (fun a -> (f a) k)
     type ('r,'a) result = ('a -> 'r) -> 'r
-    let run u = u
-    type ('r,'a) result_exn = ('a -> 'r) -> 'r
+    type ('r,'a) result_exn = ('r,'a) result
+    let unit a = (fun k -> k a)
+    let bind u f = (fun k -> (u) (fun a -> (f a) k))
+    let run u k = (u) k
     let run_exn = run
   end
   include Monad.Make2(Base)
-  let callcc f = fun k ->
-    let usek a = fun _ -> k a
-    in f usek k
+  let callcc f = (fun k ->
+    let usek a = (fun _ -> k a)
+    in (f usek) k)
   (*
   val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m
   val throw : ('a -> 'r) -> 'a -> ('r,'b) m
   let callcc f = fun k -> f k k
   let throw k a = fun _ -> k a
   *)
-  (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *)
-  let reset u = unit (u id)
-  let shift u = fun k -> u (fun a -> unit (k a)) id
+
+  (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
+   *
+   *  reset :: (Monad m) => ContT a m a -> ContT r m a
+   *  reset e = ContT $ \k -> runContT e return >>= k
+   *
+   *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
+   *  shift e = ContT $ \k ->
+   *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
+  let reset u = unit ((u) id)
+  let shift f = (fun k -> (f (fun a -> unit (k a))) id)
+  (* let abort a = shift (fun _ -> a) *)
+  let abort a = shift (fun _ -> unit a)
+  let run0 (u : ('a,'a) m) = (u) id
 end
- *)
 
 
 (*
@@ -1150,6 +1382,14 @@ module Leaf_monad : sig
     (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *)
     val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m
   end
+  module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a tree option) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m
+  end
 end = struct
   type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree)
   (* uses supplied plus and zero to copy t to its image under f *)
@@ -1163,6 +1403,8 @@ end = struct
                    ) in loop ts
   module Base = struct
     type 'a m = 'a tree option
+    type 'a result = 'a tree option
+    type 'a result_exn = 'a tree
     let unit a = Some (Leaf a)
     let zero () = None
     let plus u v = match (u, v) with
@@ -1170,9 +1412,7 @@ end = struct
       | _, None -> u
       | Some us, Some vs -> Some (Node (us, vs))
     let bind u f = mapT f u zero plus
-    type 'a result = 'a tree option
     let run u = u
-    type 'a result_exn = 'a tree
     let run_exn u = match u with
       | None -> failwith "no values"
       (*
@@ -1194,12 +1434,12 @@ end = struct
         Wrapped.unit (base_plus us vs)))
       include Monad.MakeT(struct
         module Wrapped = Wrapped
-        type 'a m = 'a Base.m Wrapped.m
+        type 'a m = 'a tree option Wrapped.m
+        type 'a result = 'a tree option Wrapped.result
+        type 'a result_exn = 'a tree Wrapped.result_exn
         let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
         let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
-        type 'a result = 'a tree option Wrapped.result
         let run u = Wrapped.run u
-        type 'a result_exn = 'a tree Wrapped.result_exn
         let run_exn u =
             let w = Wrapped.bind u (fun t -> match t with
               | None -> failwith "no values"
@@ -1212,6 +1452,33 @@ end = struct
     (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
     let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
   end
+  module T2(Wrapped : Monad.S2) = struct
+    module Trans = struct
+      let zero () = Wrapped.unit None
+      let plus u v =
+        Wrapped.bind u (fun us ->
+        Wrapped.bind v (fun vs ->
+        Wrapped.unit (base_plus us vs)))
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a tree option) Wrapped.m
+        type ('x,'a) result = ('x,'a tree option) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn
+        (* code repetition, ugh *)
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
+        let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
+        let run u = Wrapped.run u
+        let run_exn u =
+            let w = Wrapped.bind u (fun t -> match t with
+              | None -> failwith "no values"
+              | Some ts -> Wrapped.unit ts)
+            in Wrapped.run_exn w
+      end)
+    end
+    include Trans
+    include (Monad.MakeDistrib2(Trans) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+    let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
+  end
 end
 
 
@@ -1224,10 +1491,14 @@ module LS = L.T(S);;
 module TL = T.T(L);;
 module TR = T.T(R);;
 module TS = T.T(S);;
+module C = Continuation_monad
+module TC = T.T2(C);;
+
+
+print_endline "=== test Leaf(...).distribute ==================";;
 
 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
 
-(*
 let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;;
 TS.run ts 0;;
 (*
@@ -1294,7 +1565,7 @@ LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts
 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
 *)
 
-*)
+print_endline "=== test Leaf(Continuation).distribute ==================";;
 
 let id : 'z. 'z -> 'z = fun x -> x
 
@@ -1339,24 +1610,22 @@ let example3 () =
 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
 let example5 () : int =
   Continuation_monad.(let v = reset (
-      let u = shift (fun k -> k 1 >>= fun x -> k x)
+      let u = shift (fun k -> k 1 >>= k)
       in u >>= fun x -> unit (10 + x)
     ) in let w = v >>= fun x -> unit (100 + x)
     in run0 w)
 
-
 ;;
 
+print_endline "=== test bare Continuation ============";;
+
 (1011, 1111, 1111, 121);;
 (example1(), example2(), example3(), example5());;
 ((111,0), (0,0));;
 (example ~+10, example ~-10);;
 
-module C = Continuation_monad
-module TC = T.T(C)
-
 let testc df ic =
-    C.runk TC.(run_exn (distribute df t1)) ic;;
+    C.run_exn TC.(run (distribute df t1)) ic;;
 
 
 (*
@@ -1399,3 +1668,22 @@ TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
 
 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
 
+print_endline "=== pa_monad's Continuation Tests ============";;
+
+(1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
+(2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
+(3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
+(4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
+(5, 27 = C.(run0 (
+              let c = reset(abort 5 >>= fun y -> unit (y+6))
+              in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
+
+(7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
+
+(8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
+
+(12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
+
+
+(0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
+