tweak monads-lib
[lambda.git] / code / monads.ml
index 4dbd851..ae0ecc5 100644 (file)
@@ -111,7 +111,7 @@ module Monad = struct
     val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m
     val do_when :  bool -> unit m -> unit m
     val do_unless :  bool -> unit m -> unit m
-    val forever : 'a m -> 'b m
+    val forever : (unit -> 'a m) -> 'b m
     val sequence : 'a m list -> 'a list m
     val sequence_ : 'a m list -> unit m
   end
@@ -135,7 +135,9 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
-    let rec forever u = u >> forever u
+    let forever uthunk =
+        let rec loop () = uthunk () >>= fun _ -> loop ()
+        in loop ()
     let sequence ms =
       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
         Util.fold_right op ms (unit [])
@@ -247,7 +249,7 @@ module Monad = struct
     val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m
     val do_when :  bool -> ('x,unit) m -> ('x,unit) m
     val do_unless :  bool -> ('x,unit) m -> ('x,unit) m
-    val forever : ('x,'a) m -> ('x,'b) m
+    val forever : (unit -> ('x,'a) m) -> ('x,'b) m
     val sequence : ('x,'a) m list -> ('x,'a list) m
     val sequence_ : ('x,'a) m list -> ('x,unit) m
   end
@@ -263,7 +265,9 @@ module Monad = struct
     let (>=>) f g = fun a -> f a >>= g
     let do_when test u = if test then u else unit ()
     let do_unless test u = if test then unit () else u
-    let rec forever u = u >> forever u
+    let forever uthunk =
+        let rec loop () = uthunk () >>= fun _ -> loop ()
+        in loop ()
     let sequence ms =
       let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
         Util.fold_right op ms (unit [])
@@ -583,7 +587,7 @@ end) : sig
   (* declare additional operations, while still hiding implementation of type m *)
   type err = Err.err
   type 'a error = Error of err | Success of 'a
-  type 'a result = 'a
+  type 'a result = 'a error
   type 'a result_exn = 'a
   include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
   (* include Monad.PLUS with type 'a m := 'a m *)
@@ -591,37 +595,55 @@ end) : sig
   val catch : 'a m -> (err -> 'a m) -> 'a m
   (* ErrorT transformer *)
   module T : functor (Wrapped : Monad.S) -> sig
+    type 'a result = 'a error Wrapped.result
+    type 'a result_exn = 'a Wrapped.result_exn
+    include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
+    val elevate : 'a Wrapped.m -> 'a m
+    val throw : err -> 'a m
+    val catch : 'a m -> (err -> 'a m) -> 'a m
+  end
+  (* ErrorT transformer when wrapped monad has plus, zero *)
+  module TP : functor (Wrapped : Monad.P) -> sig
     type 'a result = 'a Wrapped.result
     type 'a result_exn = 'a Wrapped.result_exn
     include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn
     val elevate : 'a Wrapped.m -> 'a m
     val throw : err -> 'a m
     val catch : 'a m -> (err -> 'a m) -> 'a m
+    include Monad.PLUS with type 'a m := 'a m
   end
   module T2 : functor (Wrapped : Monad.S2) -> sig
+    type ('x,'a) result = ('x,'a error) Wrapped.result
+    type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+    include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+    val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+    val throw : err -> ('x,'a) m
+    val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
+  end
+  module TP2 : functor (Wrapped : Monad.P2) -> sig
     type ('x,'a) result = ('x,'a) Wrapped.result
     type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
     include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
     val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
     val throw : err -> ('x,'a) m
     val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m
+    include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m
   end
 end = struct
   type err = Err.err
   type 'a error = Error of err | Success of 'a
   module Base = struct
     type 'a m = 'a error
-    type 'a result = 'a
+    type 'a result = 'a error
     type 'a result_exn = 'a
     let unit a = Success a
     let bind u f = match u with
       | Success a -> f a
       | Error e -> Error e (* input and output may be of different 'a types *)
-    (* TODO: should run refrain from failing? *)
-    let run u = match u with
+    let run u = u
+    let run_exn u = match u with
       | Success a -> a
       | Error e -> raise (Err.Exc e)
-    let run_exn = run
     (*
     let zero () = Error Err.zero
     let plus u v = match (u, v) with
@@ -644,23 +666,16 @@ end = struct
     module Trans = struct
       module Wrapped = Wrapped
       type 'a m = 'a error Wrapped.m
-      type 'a result = 'a Wrapped.result
+      type 'a result = 'a error Wrapped.result
       type 'a result_exn = 'a Wrapped.result_exn
       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
       let bind u f = Wrapped.bind u (fun t -> match t with
         | Success a -> f a
         | Error e -> Wrapped.unit (Error e))
-      (* TODO: should run refrain from failing? *)
-      let run u =
-        let w = Wrapped.bind u (fun t -> match t with
-          | Success a -> Wrapped.unit a
-          (* | _ -> Wrapped.fail () *)
-          | Error e -> raise (Err.Exc e))
-        in Wrapped.run w
+      let run u = Wrapped.run u
       let run_exn u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
-          (* | _ -> Wrapped.fail () *)
           | Error e -> raise (Err.Exc e))
         in Wrapped.run_exn w
     end
@@ -670,22 +685,51 @@ end = struct
       | Success _ -> Wrapped.unit t
       | Error e -> handler e)
   end
+  module TP(Wrapped : Monad.P) = struct
+    (* code repetition, ugh *)
+    module TransP = struct
+      include Monad.MakeT(struct
+        module Wrapped = Wrapped
+        type 'a m = 'a error Wrapped.m
+        type 'a result = 'a Wrapped.result
+        type 'a result_exn = 'a Wrapped.result_exn
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
+        let bind u f = Wrapped.bind u (fun t -> match t with
+          | Success a -> f a
+          | Error e -> Wrapped.unit (Error e))
+        let run u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            | Error e -> Wrapped.zero ())
+          in Wrapped.run w
+        let run_exn u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            | Error e -> raise (Err.Exc e))
+          in Wrapped.run_exn w
+      end)
+      let throw e = Wrapped.unit (Error e)
+      let catch u handler = Wrapped.bind u (fun t -> match t with
+        | Success _ -> Wrapped.unit t
+        | Error e -> handler e)
+      let plus u v = Wrapped.plus u v
+      let zero () = elevate (Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m)
+  end
   module T2(Wrapped : Monad.S2) = struct
     module Trans = struct
       module Wrapped = Wrapped
       type ('x,'a) m = ('x,'a error) Wrapped.m
-      type ('x,'a) result = ('x,'a) Wrapped.result
+      type ('x,'a) result = ('x,'a error) Wrapped.result
       type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
       (* code repetition, ugh *)
       let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
       let bind u f = Wrapped.bind u (fun t -> match t with
         | Success a -> f a
         | Error e -> Wrapped.unit (Error e))
-      let run u =
-        let w = Wrapped.bind u (fun t -> match t with
-          | Success a -> Wrapped.unit a
-          | Error e -> raise (Err.Exc e))
-        in Wrapped.run w
+      let run u = Wrapped.run u
       let run_exn u =
         let w = Wrapped.bind u (fun t -> match t with
           | Success a -> Wrapped.unit a
@@ -698,6 +742,39 @@ end = struct
       | Success _ -> Wrapped.unit t
       | Error e -> handler e)
   end
+  module TP2(Wrapped : Monad.P2) = struct
+    (* code repetition, ugh *)
+    module TransP = struct
+      include Monad.MakeT2(struct
+        module Wrapped = Wrapped
+        type ('x,'a) m = ('x,'a error) Wrapped.m
+        type ('x,'a) result = ('x,'a) Wrapped.result
+        type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn
+        let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a))
+        let bind u f = Wrapped.bind u (fun t -> match t with
+          | Success a -> f a
+          | Error e -> Wrapped.unit (Error e))
+        let run u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            | Error e -> Wrapped.zero ())
+          in Wrapped.run w
+        let run_exn u =
+          let w = Wrapped.bind u (fun t -> match t with
+            | Success a -> Wrapped.unit a
+            | Error e -> raise (Err.Exc e))
+          in Wrapped.run_exn w
+      end)
+      let throw e = Wrapped.unit (Error e)
+      let catch u handler = Wrapped.bind u (fun t -> match t with
+        | Success _ -> Wrapped.unit t
+        | Error e -> handler e)
+      let plus u v = Wrapped.plus u v
+      let zero () = elevate (Wrapped.zero ())
+    end
+    include TransP
+    include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m)
+  end
 end
 
 (* pre-define common instance of Error_monad *)
@@ -710,6 +787,27 @@ module Failure = Error_monad(struct
   *)
 end)
 
+(*
+# EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
+- : int EL.result = [Failure.Error "bye"; Failure.Success 30]
+# LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
+- : int LE.result = Failure.Error "bye"
+# EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));;
+Exception: Failure "bye".
+# LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));;
+Exception: Failure "bye".
+
+# ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
+- : int Failure.error * S.store = (Failure.Error "bye", 1)
+# SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
+- : (int * S.store) Failure.result = Failure.Error "bye"
+# ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;;
+Exception: Failure "bye".
+# SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;;
+Exception: Failure "bye".
+ *)
+
+
 (* must be parameterized on (struct type env = ... end) *)
 module Reader_monad(Env : sig type env end) : sig
   (* declare additional operations, while still hiding implementation of type m *)
@@ -810,8 +908,8 @@ end = struct
   end
   module TP2(Wrapped : Monad.P2) = struct
     module TransP = struct
-      (* code repetition, ugh *)
       include T2(Wrapped)
+      (* code repetition, ugh *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
       let zero () = elevate (Wrapped.zero ())
       let asks selector = ask >>= (fun e ->
@@ -939,6 +1037,7 @@ end = struct
   module TP2(Wrapped : Monad.P2) = struct
     module TransP = struct
       include T2(Wrapped)
+      (* code repetition, ugh *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
       let zero () = elevate (Wrapped.zero ())
     end
@@ -1072,6 +1171,7 @@ end = struct
   module TP2(Wrapped : Monad.P2) = struct
     module TransP = struct
       include T2(Wrapped)
+      (* code repetition, ugh *)
       let plus u v = fun s -> Wrapped.plus (u s) (v s)
       let zero () = elevate (Wrapped.zero ())
     end
@@ -1243,7 +1343,8 @@ module Continuation_monad : sig
   val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
   val reset : ('a,'a) m -> ('r,'a) m
   val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m
-  val abort : ('a,'a) m -> ('a,'b) m
+  (* val abort : ('a,'a) m -> ('a,'b) m *)
+  val abort : 'a -> ('a,'b) m
   val run0 : ('a,'a) m -> 'a
 end = struct
   let id = fun i -> i
@@ -1267,10 +1368,19 @@ end = struct
   let callcc f = fun k -> f k k
   let throw k a = fun _ -> k a
   *)
-  (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *)
+
+  (* from http://www.haskell.org/haskellwiki/MonadCont_done_right
+   *
+   *  reset :: (Monad m) => ContT a m a -> ContT r m a
+   *  reset e = ContT $ \k -> runContT e return >>= k
+   *
+   *  shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a
+   *  shift e = ContT $ \k ->
+   *              runContT (e $ \v -> ContT $ \c -> k v >>= c) return *)
   let reset u = unit ((u) id)
   let shift f = (fun k -> (f (fun a -> unit (k a))) id)
-  let abort a = shift (fun _ -> a)
+  (* let abort a = shift (fun _ -> a) *)
+  let abort a = shift (fun _ -> unit a)
   let run0 (u : ('a,'a) m) = (u) id
 end
 
@@ -1467,7 +1577,7 @@ module C = Continuation_monad
 module TC = T.T2(C);;
 
 
-print_endline "================================================";;
+print_endline "=== test Leaf(...).distribute ==================";;
 
 let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));;
 
@@ -1537,6 +1647,7 @@ LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts
 - : S.store list * S.store = ([10; 0; 0; 1; 20], 1)
 *)
 
+print_endline "=== test Leaf(Continuation).distribute ==================";;
 
 let id : 'z. 'z -> 'z = fun x -> x
 
@@ -1581,14 +1692,15 @@ let example3 () =
 (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *)
 let example5 () : int =
   Continuation_monad.(let v = reset (
-      let u = shift (fun k -> k 1 >>= fun x -> k x)
+      let u = shift (fun k -> k 1 >>= k)
       in u >>= fun x -> unit (10 + x)
     ) in let w = v >>= fun x -> unit (100 + x)
     in run0 w)
 
-
 ;;
 
+print_endline "=== test bare Continuation ============";;
+
 (1011, 1111, 1111, 121);;
 (example1(), example2(), example3(), example5());;
 ((111,0), (0,0));;
@@ -1638,3 +1750,22 @@ TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;;
 
 testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);;
 
+print_endline "=== pa_monad's Continuation Tests ============";;
+
+(1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );;
+(2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );;
+(3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );;
+(4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );;
+(5, 27 = C.(run0 (
+              let c = reset(abort 5 >>= fun y -> unit (y+6))
+              in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );;
+
+(7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
+
+(8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );;
+
+(12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t)  ) >>= fun xv -> shift (fun _ -> unit xv)))) );;
+
+
+(0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;
+