X-Git-Url: http://lambda.jimpryor.net/git/gitweb.cgi?p=lambda.git;a=blobdiff_plain;f=code%2Fmonads.ml;h=ae0ecc51690c82d0bb5316ec142c4cd8786d1518;hp=4dbd851c431ce9d4a512ea32380af3a13f9fcb77;hb=1d3dba7d49400782f8d5529e4a3c850f4ab7f16c;hpb=0c24fa2c006d9a9c2224e6106042264d631e4a29 diff --git a/code/monads.ml b/code/monads.ml index 4dbd851c..ae0ecc51 100644 --- a/code/monads.ml +++ b/code/monads.ml @@ -111,7 +111,7 @@ module Monad = struct val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m val do_when : bool -> unit m -> unit m val do_unless : bool -> unit m -> unit m - val forever : 'a m -> 'b m + val forever : (unit -> 'a m) -> 'b m val sequence : 'a m list -> 'a list m val sequence_ : 'a m list -> unit m end @@ -135,7 +135,9 @@ module Monad = struct let (>=>) f g = fun a -> f a >>= g let do_when test u = if test then u else unit () let do_unless test u = if test then unit () else u - let rec forever u = u >> forever u + let forever uthunk = + let rec loop () = uthunk () >>= fun _ -> loop () + in loop () let sequence ms = let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in Util.fold_right op ms (unit []) @@ -247,7 +249,7 @@ module Monad = struct val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m val do_when : bool -> ('x,unit) m -> ('x,unit) m val do_unless : bool -> ('x,unit) m -> ('x,unit) m - val forever : ('x,'a) m -> ('x,'b) m + val forever : (unit -> ('x,'a) m) -> ('x,'b) m val sequence : ('x,'a) m list -> ('x,'a list) m val sequence_ : ('x,'a) m list -> ('x,unit) m end @@ -263,7 +265,9 @@ module Monad = struct let (>=>) f g = fun a -> f a >>= g let do_when test u = if test then u else unit () let do_unless test u = if test then unit () else u - let rec forever u = u >> forever u + let forever uthunk = + let rec loop () = uthunk () >>= fun _ -> loop () + in loop () let sequence ms = let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in Util.fold_right op ms (unit []) @@ -583,7 +587,7 @@ end) : sig (* declare additional operations, while still hiding implementation of type m *) type err = Err.err type 'a error = Error of err | Success of 'a - type 'a result = 'a + type 'a result = 'a error type 'a result_exn = 'a include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn (* include Monad.PLUS with type 'a m := 'a m *) @@ -591,37 +595,55 @@ end) : sig val catch : 'a m -> (err -> 'a m) -> 'a m (* ErrorT transformer *) module T : functor (Wrapped : Monad.S) -> sig + type 'a result = 'a error Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + val throw : err -> 'a m + val catch : 'a m -> (err -> 'a m) -> 'a m + end + (* ErrorT transformer when wrapped monad has plus, zero *) + module TP : functor (Wrapped : Monad.P) -> sig type 'a result = 'a Wrapped.result type 'a result_exn = 'a Wrapped.result_exn include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn val elevate : 'a Wrapped.m -> 'a m val throw : err -> 'a m val catch : 'a m -> (err -> 'a m) -> 'a m + include Monad.PLUS with type 'a m := 'a m end module T2 : functor (Wrapped : Monad.S2) -> sig + type ('x,'a) result = ('x,'a error) Wrapped.result + type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn + include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + val throw : err -> ('x,'a) m + val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m + end + module TP2 : functor (Wrapped : Monad.P2) -> sig type ('x,'a) result = ('x,'a) Wrapped.result type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn include Monad.S2 with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn val elevate : ('x,'a) Wrapped.m -> ('x,'a) m val throw : err -> ('x,'a) m val catch : ('x,'a) m -> (err -> ('x,'a) m) -> ('x,'a) m + include Monad.PLUS2 with type ('x,'a) m := ('x,'a) m end end = struct type err = Err.err type 'a error = Error of err | Success of 'a module Base = struct type 'a m = 'a error - type 'a result = 'a + type 'a result = 'a error type 'a result_exn = 'a let unit a = Success a let bind u f = match u with | Success a -> f a | Error e -> Error e (* input and output may be of different 'a types *) - (* TODO: should run refrain from failing? *) - let run u = match u with + let run u = u + let run_exn u = match u with | Success a -> a | Error e -> raise (Err.Exc e) - let run_exn = run (* let zero () = Error Err.zero let plus u v = match (u, v) with @@ -644,23 +666,16 @@ end = struct module Trans = struct module Wrapped = Wrapped type 'a m = 'a error Wrapped.m - type 'a result = 'a Wrapped.result + type 'a result = 'a error Wrapped.result type 'a result_exn = 'a Wrapped.result_exn let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a)) let bind u f = Wrapped.bind u (fun t -> match t with | Success a -> f a | Error e -> Wrapped.unit (Error e)) - (* TODO: should run refrain from failing? *) - let run u = - let w = Wrapped.bind u (fun t -> match t with - | Success a -> Wrapped.unit a - (* | _ -> Wrapped.fail () *) - | Error e -> raise (Err.Exc e)) - in Wrapped.run w + let run u = Wrapped.run u let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Success a -> Wrapped.unit a - (* | _ -> Wrapped.fail () *) | Error e -> raise (Err.Exc e)) in Wrapped.run_exn w end @@ -670,22 +685,51 @@ end = struct | Success _ -> Wrapped.unit t | Error e -> handler e) end + module TP(Wrapped : Monad.P) = struct + (* code repetition, ugh *) + module TransP = struct + include Monad.MakeT(struct + module Wrapped = Wrapped + type 'a m = 'a error Wrapped.m + type 'a result = 'a Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a)) + let bind u f = Wrapped.bind u (fun t -> match t with + | Success a -> f a + | Error e -> Wrapped.unit (Error e)) + let run u = + let w = Wrapped.bind u (fun t -> match t with + | Success a -> Wrapped.unit a + | Error e -> Wrapped.zero ()) + in Wrapped.run w + let run_exn u = + let w = Wrapped.bind u (fun t -> match t with + | Success a -> Wrapped.unit a + | Error e -> raise (Err.Exc e)) + in Wrapped.run_exn w + end) + let throw e = Wrapped.unit (Error e) + let catch u handler = Wrapped.bind u (fun t -> match t with + | Success _ -> Wrapped.unit t + | Error e -> handler e) + let plus u v = Wrapped.plus u v + let zero () = elevate (Wrapped.zero ()) + end + include TransP + include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m) + end module T2(Wrapped : Monad.S2) = struct module Trans = struct module Wrapped = Wrapped type ('x,'a) m = ('x,'a error) Wrapped.m - type ('x,'a) result = ('x,'a) Wrapped.result + type ('x,'a) result = ('x,'a error) Wrapped.result type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn (* code repetition, ugh *) let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a)) let bind u f = Wrapped.bind u (fun t -> match t with | Success a -> f a | Error e -> Wrapped.unit (Error e)) - let run u = - let w = Wrapped.bind u (fun t -> match t with - | Success a -> Wrapped.unit a - | Error e -> raise (Err.Exc e)) - in Wrapped.run w + let run u = Wrapped.run u let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Success a -> Wrapped.unit a @@ -698,6 +742,39 @@ end = struct | Success _ -> Wrapped.unit t | Error e -> handler e) end + module TP2(Wrapped : Monad.P2) = struct + (* code repetition, ugh *) + module TransP = struct + include Monad.MakeT2(struct + module Wrapped = Wrapped + type ('x,'a) m = ('x,'a error) Wrapped.m + type ('x,'a) result = ('x,'a) Wrapped.result + type ('x,'a) result_exn = ('x,'a) Wrapped.result_exn + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a)) + let bind u f = Wrapped.bind u (fun t -> match t with + | Success a -> f a + | Error e -> Wrapped.unit (Error e)) + let run u = + let w = Wrapped.bind u (fun t -> match t with + | Success a -> Wrapped.unit a + | Error e -> Wrapped.zero ()) + in Wrapped.run w + let run_exn u = + let w = Wrapped.bind u (fun t -> match t with + | Success a -> Wrapped.unit a + | Error e -> raise (Err.Exc e)) + in Wrapped.run_exn w + end) + let throw e = Wrapped.unit (Error e) + let catch u handler = Wrapped.bind u (fun t -> match t with + | Success _ -> Wrapped.unit t + | Error e -> handler e) + let plus u v = Wrapped.plus u v + let zero () = elevate (Wrapped.zero ()) + end + include TransP + include (Monad.MakeDistrib2(TransP) : Monad.PLUS2 with type ('x,'a) m := ('x,'a) m) + end end (* pre-define common instance of Error_monad *) @@ -710,6 +787,27 @@ module Failure = Error_monad(struct *) end) +(* +# EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));; +- : int EL.result = [Failure.Error "bye"; Failure.Success 30] +# LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));; +- : int LE.result = Failure.Error "bye" +# EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));; +Exception: Failure "bye". +# LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));; +Exception: Failure "bye". + +# ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;; +- : int Failure.error * S.store = (Failure.Error "bye", 1) +# SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;; +- : (int * S.store) Failure.result = Failure.Error "bye" +# ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;; +Exception: Failure "bye". +# SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;; +Exception: Failure "bye". + *) + + (* must be parameterized on (struct type env = ... end) *) module Reader_monad(Env : sig type env end) : sig (* declare additional operations, while still hiding implementation of type m *) @@ -810,8 +908,8 @@ end = struct end module TP2(Wrapped : Monad.P2) = struct module TransP = struct - (* code repetition, ugh *) include T2(Wrapped) + (* code repetition, ugh *) let plus u v = fun s -> Wrapped.plus (u s) (v s) let zero () = elevate (Wrapped.zero ()) let asks selector = ask >>= (fun e -> @@ -939,6 +1037,7 @@ end = struct module TP2(Wrapped : Monad.P2) = struct module TransP = struct include T2(Wrapped) + (* code repetition, ugh *) let plus u v = fun s -> Wrapped.plus (u s) (v s) let zero () = elevate (Wrapped.zero ()) end @@ -1072,6 +1171,7 @@ end = struct module TP2(Wrapped : Monad.P2) = struct module TransP = struct include T2(Wrapped) + (* code repetition, ugh *) let plus u v = fun s -> Wrapped.plus (u s) (v s) let zero () = elevate (Wrapped.zero ()) end @@ -1243,7 +1343,8 @@ module Continuation_monad : sig val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m val reset : ('a,'a) m -> ('r,'a) m val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m - val abort : ('a,'a) m -> ('a,'b) m + (* val abort : ('a,'a) m -> ('a,'b) m *) + val abort : 'a -> ('a,'b) m val run0 : ('a,'a) m -> 'a end = struct let id = fun i -> i @@ -1267,10 +1368,19 @@ end = struct let callcc f = fun k -> f k k let throw k a = fun _ -> k a *) - (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *) + + (* from http://www.haskell.org/haskellwiki/MonadCont_done_right + * + * reset :: (Monad m) => ContT a m a -> ContT r m a + * reset e = ContT $ \k -> runContT e return >>= k + * + * shift :: (Monad m) => ((a -> ContT r m b) -> ContT b m b) -> ContT b m a + * shift e = ContT $ \k -> + * runContT (e $ \v -> ContT $ \c -> k v >>= c) return *) let reset u = unit ((u) id) let shift f = (fun k -> (f (fun a -> unit (k a))) id) - let abort a = shift (fun _ -> a) + (* let abort a = shift (fun _ -> a) *) + let abort a = shift (fun _ -> unit a) let run0 (u : ('a,'a) m) = (u) id end @@ -1467,7 +1577,7 @@ module C = Continuation_monad module TC = T.T2(C);; -print_endline "================================================";; +print_endline "=== test Leaf(...).distribute ==================";; let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));; @@ -1537,6 +1647,7 @@ LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts - : S.store list * S.store = ([10; 0; 0; 1; 20], 1) *) +print_endline "=== test Leaf(Continuation).distribute ==================";; let id : 'z. 'z -> 'z = fun x -> x @@ -1581,14 +1692,15 @@ let example3 () = (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *) let example5 () : int = Continuation_monad.(let v = reset ( - let u = shift (fun k -> k 1 >>= fun x -> k x) + let u = shift (fun k -> k 1 >>= k) in u >>= fun x -> unit (10 + x) ) in let w = v >>= fun x -> unit (100 + x) in run0 w) - ;; +print_endline "=== test bare Continuation ============";; + (1011, 1111, 1111, 121);; (example1(), example2(), example3(), example5());; ((111,0), (0,0));; @@ -1638,3 +1750,22 @@ TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;; testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);; +print_endline "=== pa_monad's Continuation Tests ============";; + +(1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );; +(2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );; +(3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );; +(4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );; +(5, 27 = C.(run0 ( + let c = reset(abort 5 >>= fun y -> unit (y+6)) + in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );; + +(7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );; + +(8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );; + +(12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t) ) >>= fun xv -> shift (fun _ -> unit xv)))) );; + + +(0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );; +