X-Git-Url: http://lambda.jimpryor.net/git/gitweb.cgi?p=lambda.git;a=blobdiff_plain;f=code%2Fmonads.ml;h=d8725937adb384bb104ce7f1781636c03f2424e3;hp=0a205f6a8c617e449f3c91d3b7d4a72ac2fefc5d;hb=109034aa514a67fcaed0607b4deb4b339f67ab76;hpb=58bf3ee4a3e5ee6e343787e432602b677a596109 diff --git a/code/monads.ml b/code/monads.ml index 0a205f6a..d8725937 100644 --- a/code/monads.ml +++ b/code/monads.ml @@ -38,14 +38,28 @@ * making their implementations private. The interpreter won't let * let you freely interchange the `'a Reader_monad.m`s defined below * with `Reader_monad.env -> 'a`. The code in this library can see that - * those are equivalent, but code outside the library can't. Instead, you'll + * those are equivalent, but code outside the library can't. Instead, you'll * have to use operations like `run` to convert the abstract monadic types * to types whose internals you have free access to. * + * Acknowledgements: This is largely based on the mtl library distributed + * with the Glasgow Haskell Compiler. I've also been helped in + * various ways by posts and direct feedback from Oleg Kiselyov and + * Chung-chieh Shan. The following were also useful: + * - + * - Ken Shan "Monads for natural language semantics" + * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf + * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers + * + * Licensing: MIT (if that's compatible with the ghc sources this is partly + * derived from) *) (* Some library functions used below. *) + +exception Undefined + module Util = struct let fold_right = List.fold_right let map = List.map @@ -60,28 +74,31 @@ module Util = struct let rec loop n accu = if n == 0 then accu else loop (pred n) (fill :: accu) in loop len [] - let undefined = Obj.magic "" + (* Dirty hack to be a default polymorphic zero. + * To implement this cleanly, monads without a natural zero + * should always wrap themselves in an option layer (see Tree_monad). *) + let undef = Obj.magic (fun () -> raise Undefined) end - - (* * This module contains factories that extend a base set of * monadic definitions with a larger family of standard derived values. *) module Monad = struct + (* * Signature extenders: * Make :: BASE -> S - * MakeT :: TRANS (with Wrapped : S) -> custom sig + * MakeT :: BASET (with Wrapped : S) -> result sig not declared *) (* type of base definitions *) module type BASE = sig - (* The only constraints we impose here on how the monadic type - * is implemented is that it have a single type parameter 'a. *) + (* We make all monadic types doubly-parameterized so that they + * can layer nicely with Continuation, which needs the second + * type parameter. *) type ('x,'a) m type ('x,'a) result type ('x,'a) result_exn @@ -97,11 +114,12 @@ module Monad = struct * Additionally, they will obey one of the following laws: * (Catch) plus (unit a) v === unit a * (Distrib) plus u v >>= f === plus (u >>= f) (v >>= f) - * When no natural zero is available, use `let zero () = Util.undefined - * The Make process automatically detects for zero >>= ..., and + * When no natural zero is available, use `let zero () = Util.undef`. + * The Make functor automatically detects for zero >>= ..., and * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures. *) val zero : unit -> ('x,'a) m + (* zero has to be thunked to ensure results are always poly enough *) val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m end module type S = sig @@ -125,15 +143,19 @@ module Monad = struct module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct include B let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m = - if u == Util.undefined then Util.undefined - else bind u (fun a -> try f a with Match_failure _ -> zero ()) + if u == Util.undef then Util.undef + else B.bind u (fun a -> try f a with Match_failure _ -> zero ()) let plus u v = - if u == Util.undefined then v else if v == Util.undefined then u else plus u v + if u == Util.undef then v else if v == Util.undef then u else B.plus u v let run u = - if u == Util.undefined then failwith "no zero" else run u + if u == Util.undef then raise Undefined else B.run u let run_exn u = - if u == Util.undefined then failwith "no zero" else run_exn u + if u == Util.undef then raise Undefined else B.run_exn u let (>>=) = bind + (* expressions after >> will be evaluated before they're passed to + * bind, so you can't do `zero () >> assert false` + * this works though: `zero () >>= fun _ -> assert false` + *) let (>>) u v = u >>= fun _ -> v let lift f u = u >>= fun a -> unit (f a) (* lift is called listM, fmap, and <$> in Haskell *) @@ -149,9 +171,21 @@ module Monad = struct let (>=>) f g = fun a -> f a >>= g let do_when test u = if test then u else unit () let do_unless test u = if test then unit () else u + (* A Haskell-like version works: + let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk + * but the recursive call is not in tail position so this can stack overflow. *) let forever uthunk = - let rec loop () = uthunk () >>= fun _ -> loop () - in loop () + let z = zero () in + let id result = result in + let kcell = ref id in + let rec loop _ = + let result = uthunk (kcell := id) >>= chained + in !kcell result + and chained _ = + kcell := loop; z (* we use z only for its polymorphism *) + in loop z + (* Reimplementations of the preceding using a hand-rolled State or StateT +can also stack overflow. *) let sequence ms = let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in Util.fold_right op ms (unit []) @@ -184,7 +218,7 @@ module Monad = struct end (* Signatures for MonadT *) - module type TRANS = sig + module type BASET = sig module Wrapped : S type ('x,'a) m type ('x,'a) result @@ -200,7 +234,7 @@ module Monad = struct val zero : unit -> ('x,'a) m val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m end - module MakeT(T : TRANS) = struct + module MakeT(T : BASET) = struct include Make(struct include T let unit a = elevate (Wrapped.unit a) @@ -228,7 +262,7 @@ end = struct let bind a f = f a let run a = a let run_exn a = a - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -264,7 +298,7 @@ end = struct end include Monad.Make(Base) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct include Monad.MakeT(struct module Wrapped = Wrapped type ('x,'a) m = ('x,'a option) Wrapped.m @@ -278,13 +312,13 @@ end = struct let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Some a -> Wrapped.unit a - | None -> failwith "no value") - in Wrapped.run_exn w + | None -> Wrapped.zero () + ) in Wrapped.run_exn w let zero () = Wrapped.unit None let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u) end) end - include Trans + include BaseT end end @@ -305,10 +339,9 @@ module List_monad : sig (* note that second argument is an 'a list, not the more abstract 'a m *) (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *) val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a list -> ('x,'b) m -(* TODO - val permute : 'a m -> 'a m m - val select : 'a m -> ('a * 'a m) m -*) + val permute : ('x,'a) m -> ('x,('x,'a) m) m + val select : ('x,'a) m -> ('x,('a * ('x,'a) m)) m + val expose : ('x,'a) m -> ('x,'a list) Wrapped.m end end = struct module Base = struct @@ -323,6 +356,7 @@ end = struct | [a] -> a | many -> failwith "multiple values" let zero () = [] + (* satisfies Distrib *) let plus = Util.append end include Monad.Make(Base) @@ -341,9 +375,8 @@ end = struct let rec select u = match u with | [] -> zero () | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs')) - let base_plus = plus module T(Wrapped : Monad.S) = struct - (* Wrapped.sequence ms === + (* Wrapped.sequence ms === let plus1 u v = Wrapped.bind u (fun x -> Wrapped.bind v (fun xs -> @@ -365,20 +398,40 @@ end = struct let run u = Wrapped.run u let run_exn u = let w = Wrapped.bind u (fun ts -> match ts with - | [] -> failwith "no values" + | [] -> Wrapped.zero () | [a] -> Wrapped.unit a - | many -> failwith "multiple values" + | many -> Wrapped.zero () ) in Wrapped.run_exn w let zero () = Wrapped.unit [] let plus u v = Wrapped.bind u (fun us -> Wrapped.bind v (fun vs -> - Wrapped.unit (base_plus us vs))) + Wrapped.unit (Base.plus us vs))) end) -(* - let permute : 'a m -> 'a m m - let select : 'a m -> ('a * 'a m) m -*) + + (* insert 3 {[1;2]} ~~> {[ {[3;1;2]}; {[1;3;2]}; {[1;2;3]} ]} *) + let rec insert a u = + plus + (unit (Wrapped.bind u (fun us -> Wrapped.unit (a :: us)))) + (Wrapped.bind u (fun us -> match us with + | [] -> zero () + | x::xs -> (insert a (Wrapped.unit xs)) >>= fun v -> unit (Wrapped.bind v (fun vs -> Wrapped.unit (x :: vs))))) + + (* select {[1;2;3]} ~~> {[ (1,{[2;3]}); (2,{[1;3]}), (3;{[1;2]}) ]} *) + let rec select u = + Wrapped.bind u (fun us -> match us with + | [] -> zero () + | x::xs -> plus (unit (x, Wrapped.unit xs)) + (select (Wrapped.unit xs) >>= fun (x', xs') -> unit (x', Wrapped.bind xs' (fun ys -> Wrapped.unit (x :: ys))))) + + (* permute {[1;2;3]} ~~> {[ {[1;2;3]}; {[2;1;3]}; {[2;3;1]}; {[1;3;2]}; {[3;1;2]}; {[3;2;1]} ]} *) + + let rec permute u = + Wrapped.bind u (fun us -> match us with + | [] -> unit (zero ()) + | x::xs -> permute (Wrapped.unit xs) >>= (fun v -> insert x v)) + + let expose u = u end end @@ -424,19 +477,11 @@ end = struct let run_exn u = match u with | Success a -> a | Error e -> raise (Err.Exc e) - let zero () = Util.undefined - let plus u v = u - (* - let zero () = Error Err.zero - let plus u v = match (u, v) with - | Success _, _ -> u - (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _ - * otherwise, plus (Error _) v = v *) - | Error _, _ when v = zero -> u - (* combine errors *) - | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2) - | Error _, _ -> v - *) + let zero () = Util.undef + (* satisfies Catch *) + let plus u v = match u with + | Success _ -> u + | Error _ -> if v == Util.undef then u else v end include Monad.Make(Base) (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *) @@ -457,15 +502,15 @@ end = struct let run u = let w = Wrapped.bind u (fun t -> match t with | Success a -> Wrapped.unit a - | Error e -> Wrapped.zero ()) - in Wrapped.run w + | Error e -> Wrapped.zero () + ) in Wrapped.run w let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Success a -> Wrapped.unit a | Error e -> raise (Err.Exc e)) in Wrapped.run_exn w let plus u v = Wrapped.plus u v - let zero () = elevate (Wrapped.zero ()) + let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *) end) let throw e = Wrapped.unit (Error e) let catch u handler = Wrapped.bind u (fun t -> match t with @@ -484,26 +529,6 @@ module Failure = Error_monad(struct *) end) -(* -# EL.(run( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));; -- : int EL.result = [Failure.Error "bye"; Failure.Success 30] -# LE.(run( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));; -- : int LE.result = Failure.Error "bye" -# EL.(run_exn( plus (throw "bye") (unit 20) >>= fun i -> unit(i+10)));; -Exception: Failure "bye". -# LE.(run_exn( plus (elevate (Failure.throw "bye")) (unit 20) >>= fun i -> unit(i+10)));; -Exception: Failure "bye". - -# ES.(run( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;; -- : int Failure.error * S.store = (Failure.Error "bye", 1) -# SE.(run( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;; -- : (int * S.store) Failure.result = Failure.Error "bye" -# ES.(run_exn( elevate (S.puts succ) >> throw "bye" >> elevate S.get >>= fun i -> unit(i+10) )) 0;; -Exception: Failure "bye". -# SE.(run_exn( puts succ >> elevate (Failure.throw "bye") >> get >>= fun i -> unit(i+10) )) 0;; -Exception: Failure "bye". - *) - (* must be parameterized on (struct type env = ... end) *) module Reader_monad(Env : sig type env end) : sig @@ -514,6 +539,7 @@ module Reader_monad(Env : sig type env end) : sig include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn val ask : ('x,env) m val asks : (env -> 'a) -> ('x,'a) m + (* lookup i == `fun e -> e i` would assume env is a functional type *) val local : (env -> env) -> ('x,'a) m -> ('x,'a) m (* ReaderT transformer *) module T : functor (Wrapped : Monad.S) -> sig @@ -524,6 +550,7 @@ module Reader_monad(Env : sig type env end) : sig val ask : ('x,env) m val asks : (env -> 'a) -> ('x,'a) m val local : (env -> env) -> ('x,'a) m -> ('x,'a) m + val expose : ('x,'a) m -> env -> ('x,'a) Wrapped.m end end = struct type env = Env.env @@ -535,7 +562,7 @@ end = struct let bind u f = fun e -> let a = u e in let u' = f a in u' e let run u = fun e -> u e let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -543,24 +570,26 @@ end = struct let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *) let local modifier u = fun e -> u (modifier e) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct module Wrapped = Wrapped type ('x,'a) m = env -> ('x,'a) Wrapped.m type ('x,'a) result = env -> ('x,'a) Wrapped.result type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn let elevate w = fun e -> w - let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e) + let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e) let run u = fun e -> Wrapped.run (u e) let run_exn u = fun e -> Wrapped.run_exn (u e) - let plus u v = fun s -> Wrapped.plus (u s) (v s) - let zero () = elevate (Wrapped.zero ()) + (* satisfies Distrib *) + let plus u v = fun e -> Wrapped.plus (u e) (v e) + let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *) end - include Monad.MakeT(Trans) - let ask = fun e -> Wrapped.unit e + include Monad.MakeT(BaseT) + let ask = Wrapped.unit let local modifier u = fun e -> u (modifier e) let asks selector = ask >>= (fun e -> try unit (selector e) with Not_found -> fun e -> Wrapped.zero ()) + let expose u = u end end @@ -586,6 +615,8 @@ module State_monad(Store : sig type store end) : sig val gets : (store -> 'a) -> ('x,'a) m val put : store -> ('x,unit) m val puts : (store -> store) -> ('x,unit) m + (* val passthru : ('x,'a) m -> (('x,'a * store) Wrapped.result * store -> 'b) -> ('x,'b) m *) + val expose : ('x,'a) m -> store -> ('x,'a * store) Wrapped.m end end = struct type store = Store.store @@ -597,7 +628,7 @@ end = struct let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s' let run u = fun s -> (u s) let run_exn u = fun s -> fst (u s) - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -606,7 +637,7 @@ end = struct let put s = fun _ -> ((), s) let puts modifier = fun s -> ((), modifier s) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct module Wrapped = Wrapped type ('x,'a) m = store -> ('x,'a * store) Wrapped.m type ('x,'a) result = store -> ('x,'a * store) Wrapped.result @@ -619,19 +650,23 @@ end = struct let run_exn u = fun s -> let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a) in Wrapped.run_exn w + (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) - let zero () = elevate (Wrapped.zero ()) + let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *) end - include Monad.MakeT(Trans) + include Monad.MakeT(BaseT) let get = fun s -> Wrapped.unit (s, s) let gets viewer = fun s -> try Wrapped.unit (viewer s, s) with Not_found -> Wrapped.zero () let put s = fun _ -> Wrapped.unit ((), s) let puts modifier = fun s -> Wrapped.unit ((), modifier s) + (* let passthru u f = fun s -> Wrapped.unit (f (Wrapped.run (u s), s), s) *) + let expose u = u end end + (* State monad with different interface (structured store) *) module Ref_monad(V : sig type value @@ -674,7 +709,7 @@ end = struct let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s' let run u = fst (u empty) let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -682,7 +717,7 @@ end = struct let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *) let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct module Wrapped = Wrapped type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m type ('x,'a) result = ('x,'a) Wrapped.result @@ -697,10 +732,11 @@ end = struct let run_exn u = let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a) in Wrapped.run_exn w + (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) - let zero () = elevate (Wrapped.zero ()) + let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *) end - include Monad.MakeT(Trans) + include Monad.MakeT(BaseT) let newref value = fun s -> Wrapped.unit (alloc value s) let deref key = fun s -> Wrapped.unit (read key s, s) let change key value = fun s -> Wrapped.unit ((), write key value s) @@ -724,6 +760,17 @@ end) : sig val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m (* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *) val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m + (* WriterT transformer *) + module T : functor (Wrapped : Monad.S) -> sig + type ('x,'a) result = ('x,'a * log) Wrapped.result + type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn + include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + val tell : log -> ('x,unit) m + val listen : ('x,'a) m -> ('x,'a * log) m + val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m + val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m + end end = struct type log = Log.log module Base = struct @@ -731,10 +778,10 @@ end = struct type ('x,'a) result = 'a * log type ('x,'a) result_exn = 'a * log let unit a = (a, Log.zero) - let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w') + let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w') let run u = u let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -743,6 +790,31 @@ end = struct let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *) let pass ((a, f), w) = (a, f w) (* usually use censor helper *) let censor f u = pass (u >>= fun a -> unit (a, f)) + module T(Wrapped : Monad.S) = struct + module BaseT = struct + module Wrapped = Wrapped + type ('x,'a) m = ('x,'a * log) Wrapped.m + type ('x,'a) result = ('x,'a * log) Wrapped.result + type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn + let elevate w = + Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero)) + let bind u f = + Wrapped.bind u (fun (a, w) -> + Wrapped.bind (f a) (fun (b, w') -> + Wrapped.unit (b, Log.plus w w'))) + let zero () = elevate (Wrapped.zero ()) + let plus u v = Wrapped.plus u v + let run u = Wrapped.run u + let run_exn u = Wrapped.run_exn u + end + include Monad.MakeT(BaseT) + let tell entries = Wrapped.unit ((), entries) + let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w)) + let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w)) + (* rest are derived in same way as before *) + let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) + let censor f u = pass (u >>= fun a -> unit (a, f)) + end end (* pre-define simple Writer *) @@ -766,6 +838,7 @@ module Writer2 = struct end +(* TODO needs a T *) module IO_monad : sig (* declare additional operation, while still hiding implementation of type m *) type ('x,'a) result = 'a @@ -787,7 +860,7 @@ end = struct { run = (fun () -> a.run (); fres.run ()); value = fres.value } let run a = let () = a.run () in a.value let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -812,6 +885,16 @@ module Continuation_monad : sig (* val abort : ('a,'a) m -> ('a,'b) m *) val abort : 'a -> ('a,'b) m val run0 : ('a,'a) m -> 'a + (* ContinuationT transformer *) + module T : functor (Wrapped : Monad.S) -> sig + type ('r,'a) m + type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result + type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn + include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m + val elevate : ('x,'a) Wrapped.m -> ('x,'a) m + val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m + (* TODO: reset,shift,abort,run0 *) + end end = struct let id = fun i -> i module Base = struct @@ -823,7 +906,7 @@ end = struct let bind u f = (fun k -> (u) (fun a -> (f a) k)) let run u k = (u) k let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -850,6 +933,24 @@ end = struct (* let abort a = shift (fun _ -> a) *) let abort a = shift (fun _ -> unit a) let run0 (u : ('a,'a) m) = (u) id + module T(Wrapped : Monad.S) = struct + module BaseT = struct + module Wrapped = Wrapped + type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m + type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result + type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn + let elevate w = fun k -> Wrapped.bind w k + let bind u f = fun k -> u (fun a -> f a k) + let run u k = Wrapped.run (u k) + let run_exn u k = Wrapped.run_exn (u k) + let zero () = Util.undef + let plus u v = u + end + include Monad.MakeT(BaseT) + let callcc f = (fun k -> + let usek a = (fun _ -> k a) + in (f usek) k) + end end @@ -874,53 +975,17 @@ end * >>= fun x -> unit (x, 0) * in run u) * - * - * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *) - * let example1 () : int = - * Continuation_monad.(let v = reset ( - * let u = shift (fun k -> unit (10 + 1)) - * in u >>= fun x -> unit (100 + x) - * ) in let w = v >>= fun x -> unit (1000 + x) - * in run w) - * - * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *) - * let example2 () = - * Continuation_monad.(let v = reset ( - * let u = shift (fun k -> k (10 :: [1])) - * in u >>= fun x -> unit (100 :: x) - * ) in let w = v >>= fun x -> unit (1000 :: x) - * in run w) - * - * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *) - * let example3 () = - * Continuation_monad.(let v = reset ( - * let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x)) - * in u >>= fun x -> unit (100 :: x) - * ) in let w = v >>= fun x -> unit (1000 :: x) - * in run w) - * - * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *) - * (* not sure if this example can be typed without a sum-type *) - * - * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *) - * let example5 () : int = - * Continuation_monad.(let v = reset ( - * let u = shift (fun k -> k 1 >>= fun x -> k x) - * in u >>= fun x -> unit (10 + x) - * ) in let w = v >>= fun x -> unit (100 + x) - * in run w) - * *) -module Leaf_monad : sig +module Tree_monad : sig (* We implement the type as `'a tree option` because it has a natural`plus`, * and the rest of the library expects that `plus` and `zero` will come together. *) type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree) type ('x,'a) result = 'a tree option type ('x,'a) result_exn = 'a tree include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn - (* LeafT transformer *) + (* TreeT transformer *) module T : functor (Wrapped : Monad.S) -> sig type ('x,'a) result = ('x,'a tree option) Wrapped.result type ('x,'a) result_exn = ('x,'a tree) Wrapped.result_exn @@ -929,6 +994,7 @@ module Leaf_monad : sig (* note that second argument is an 'a tree?, not the more abstract 'a m *) (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *) val distribute : ('a -> ('x,'b) Wrapped.m) -> 'a tree option -> ('x,'b) m + val expose : ('x,'a) m -> ('x,'a tree option) Wrapped.m end end = struct type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree) @@ -947,6 +1013,7 @@ end = struct type ('x,'a) result_exn = 'a tree let unit a = Some (Leaf a) let zero () = None + (* satisfies Distrib *) let plus u v = match (u, v) with | None, _ -> v | _, None -> u @@ -962,10 +1029,8 @@ end = struct | Some us -> us end include Monad.Make(Base) - let base_plus = plus - let base_lift = lift module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct include Monad.MakeT(struct module Wrapped = Wrapped type ('x,'a) m = ('x,'a tree option) Wrapped.m @@ -975,226 +1040,22 @@ end = struct let plus u v = Wrapped.bind u (fun us -> Wrapped.bind v (fun vs -> - Wrapped.unit (base_plus us vs))) + Wrapped.unit (Base.plus us vs))) let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a))) let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus) let run u = Wrapped.run u let run_exn u = let w = Wrapped.bind u (fun t -> match t with - | None -> failwith "no values" - | Some ts -> Wrapped.unit ts) - in Wrapped.run_exn w + | None -> Wrapped.zero () + | Some ts -> Wrapped.unit ts + ) in Wrapped.run_exn w end) end - include Trans - (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *) + include BaseT let distribute f t = mapT (fun a -> elevate (f a)) t zero plus + let expose u = u end -end - - -module L = List_monad;; -module R = Reader_monad(struct type env = int -> int end);; -module S = State_monad(struct type store = int end);; -module T = Leaf_monad;; -module LR = L.T(R);; -module LS = L.T(S);; -module TL = T.T(L);; -module TR = T.T(R);; -module TS = T.T(S);; -module C = Continuation_monad -module TC = T.T(C);; - - -print_endline "=== test Leaf(...).distribute ==================";; - -let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));; - -let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;; -TS.run ts 0;; -(* -- : int T.tree option * S.store = -(Some - (T.Node - (T.Node (T.Leaf 2, T.Leaf 3), - T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))), - 5) -*) - -let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;; -TS.run_exn ts2 0;; -(* -- : (int * S.store) T.tree option * S.store = -(Some - (T.Node - (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)), - T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))), - 5) -*) - -let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;; -TR.run_exn tr (fun i -> i+i);; -(* -- : int T.tree option = -Some - (T.Node - (T.Node (T.Leaf 4, T.Leaf 6), - T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22)))) -*) - -let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;; -TL.run_exn tl;; -(* -- : (int * int) TL.result = -[Some - (T.Node - (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)), - T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))] -*) - -let l2 = [1;2;3;4;5];; -let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));; - -LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);; -(* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *) - -TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);; -(* -int T.tree option = -Some - (T.Node - (T.Node (T.Leaf 10, T.Leaf 11), - T.Node - (T.Node - (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)), - T.Node (T.Leaf 40, T.Leaf 41)), - T.Node (T.Leaf 50, T.Leaf 51)))) - *) - -LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;; -(* -- : S.store list * S.store = ([10; 0; 0; 1; 20], 1) -*) - -print_endline "=== test Leaf(Continuation).distribute ==================";; - -let id : 'z. 'z -> 'z = fun x -> x - -let example n : (int * int) = - Continuation_monad.(let u = callcc (fun k -> - (if n < 0 then k 0 else unit [n + 100]) - (* all of the following is skipped by k 0; the end type int is k's input type *) - >>= fun [x] -> unit (x + 1) - ) - (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *) - >>= fun x -> unit (x, 0) - in run0 u) - - -(* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *) -let example1 () : int = - Continuation_monad.(let v = reset ( - let u = shift (fun k -> unit (10 + 1)) - in u >>= fun x -> unit (100 + x) - ) in let w = v >>= fun x -> unit (1000 + x) - in run0 w) - -(* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *) -let example2 () = - Continuation_monad.(let v = reset ( - let u = shift (fun k -> k (10 :: [1])) - in u >>= fun x -> unit (100 :: x) - ) in let w = v >>= fun x -> unit (1000 :: x) - in run0 w) - -(* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *) -let example3 () = - Continuation_monad.(let v = reset ( - let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x)) - in u >>= fun x -> unit (100 :: x) - ) in let w = v >>= fun x -> unit (1000 :: x) - in run0 w) - -(* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *) -(* not sure if this example can be typed without a sum-type *) - -(* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *) -let example5 () : int = - Continuation_monad.(let v = reset ( - let u = shift (fun k -> k 1 >>= k) - in u >>= fun x -> unit (10 + x) - ) in let w = v >>= fun x -> unit (100 + x) - in run0 w) - -;; - -print_endline "=== test bare Continuation ============";; - -(1011, 1111, 1111, 121);; -(example1(), example2(), example3(), example5());; -((111,0), (0,0));; -(example ~+10, example ~-10);; - -let testc df ic = - C.run_exn TC.(run (distribute df t1)) ic;; - - -(* -(* do nothing *) -let initial_continuation = fun t -> t in -TreeCont.monadize t1 Continuation_monad.unit initial_continuation;; -*) -testc (C.unit) id;; - -(* -(* count leaves, using continuation *) -let initial_continuation = fun t -> 0 in -TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;; -*) - -testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);; - -(* -(* convert tree to list of leaves *) -let initial_continuation = fun t -> [] in -TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;; -*) - -testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));; - -(* -(* square each leaf using continuation *) -let initial_continuation = fun t -> t in -TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;; -*) - -testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);; - - -(* -(* replace leaves with list, using continuation *) -let initial_continuation = fun t -> t in -TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;; -*) - -testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);; - -print_endline "=== pa_monad's Continuation Tests ============";; - -(1, 5 = C.(run0 (unit 1 >>= fun x -> unit (x+4))) );; -(2, 9 = C.(run0 (reset (unit 5 >>= fun x -> unit (x+4)))) );; -(3, 9 = C.(run0 (reset (abort 5 >>= fun y -> unit (y+6)) >>= fun x -> unit (x+4))) );; -(4, 9 = C.(run0 (reset (reset (abort 5 >>= fun y -> unit (y+6))) >>= fun x -> unit (x+4))) );; -(5, 27 = C.(run0 ( - let c = reset(abort 5 >>= fun y -> unit (y+6)) - in reset(c >>= fun v1 -> abort 7 >>= fun v2 -> unit (v2+10) ) >>= fun x -> unit (x+20))) );; - -(7, 117 = C.(run0 (reset (shift (fun sk -> sk 3 >>= sk >>= fun v3 -> unit (v3+100) ) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );; - -(8, 115 = C.(run0 (reset (shift (fun sk -> sk 3 >>= fun v3 -> unit (v3+100)) >>= fun v1 -> unit (v1+2)) >>= fun x -> unit (x+10))) );; - -(12, ["a"] = C.(run0 (reset (shift (fun f -> f [] >>= fun t -> unit ("a"::t) ) >>= fun xv -> shift (fun _ -> unit xv)))) );; +end;; -(0, 15 = C.(run0 (let f k = k 10 >>= fun v-> unit (v+100) in reset (callcc f >>= fun v -> unit (v+5)))) );;