monads.ml: make Error.TP,TP2 drop fail to Wrapped.zero
[lambda.git] / code / monads.ml
index 4db605c..0b6af2e 100644 (file)
@@ -604,7 +604,12 @@ end) : sig
   end
   (* ErrorT transformer when wrapped monad has plus, zero *)
   module TP : functor (Wrapped : Monad.P) -> sig
-    include module type of T(Wrapped)
+    type 'a result = 'a Wrapped.result
+    type 'a result_exn = 'a Wrapped.result_exn
+    include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
+    val elevate : 'a Wrapped.m -> 'a m
+    val throw : err -> 'a m
+    val catch : 'a m -> (err -> 'a m) -> 'a m
     include Monad.PLUS with type 'a m := 'a m
   end
   module T2 : functor (Wrapped : Monad.S2) -> sig
@@ -616,7 +621,12 @@ end) : sig
     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
   end
   module TP2 : functor (Wrapped : Monad.P2) -> sig
-    include module type of T2(Wrapped)
+    type ('x,'a) result = ('x,'a) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val throw : err -> ('x,'a) m
+    val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
     include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
   end
 end = struct
@@ -677,8 +687,33 @@ end = struct
       | Error e -> handler e)
   end
   module TP(Wrapped : Monad.P) = struct
+    (* code repetition, ugh *)
     module TransP = struct
-      include T(Wrapped)
+      include Monad.MakeT(struct
+        module Wrapped = Wrapped
+        type 'a m = 'a error Wrapped.m
+        type 'a result = 'a Wrapped.result
+        type 'a result_exn = 'a Wrapped.result_exn
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
+        let bind u f = Wrapped.bind u (fun t -> match t with
+          | Success a -> f a
+          | Error e -> Wrapped.unit (Error e))
+        let run u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            | Error e -> Wrapped.zero ())
+          in Wrapped.run w
+        let run_exn u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            (* | _ -> Wrapped.fail () *)
+            | Error e -> raise (Err.Exc e))
+          in Wrapped.run_exn w
+      end)
+      let throw e = Wrapped.unit (Error e)
+      let catch u handler = Wrapped.bind u (fun t -> match t with
+        | Success _ -> Wrapped.unit t
+        | Error e -> handler e)
       let plus u v = Wrapped.plus u v
       let zero () = elevate (Wrapped.zero ())
     end
@@ -710,9 +745,33 @@ end = struct
       | Error e -> handler e)
   end
   module TP2(Wrapped : Monad.P2) = struct
+    (* code repetition, ugh *)
     module TransP = struct
-      include T2(Wrapped)
-      (* code repetition, ugh *)
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a error) Wrapped.m
+        type ('x,'a) result = ('x,'a) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
+        let bind u f = Wrapped.bind u (fun t -> match t with
+          | Success a -> f a
+          | Error e -> Wrapped.unit (Error e))
+        let run u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            | Error e -> Wrapped.zero ())
+          in Wrapped.run w
+        let run_exn u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            (* | _ -> Wrapped.fail () *)
+            | Error e -> raise (Err.Exc e))
+          in Wrapped.run_exn w
+      end)
+      let throw e = Wrapped.unit (Error e)
+      let catch u handler = Wrapped.bind u (fun t -> match t with
+        | Success _ -> Wrapped.unit t
+        | Error e -> handler e)
       let plus u v = Wrapped.plus u v
       let zero () = elevate (Wrapped.zero ())
     end