From ef6e90fdda821a1d071c93dc50587d1a0fd207b9 Mon Sep 17 00:00:00 2001 From: Jim Pryor Date: Sat, 11 Dec 2010 14:41:13 -0500 Subject: [PATCH 1/1] tweak monads-lib Signed-off-by: Jim Pryor --- code/monads.ml | 105 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/code/monads.ml b/code/monads.ml index 0a205f6a..34ad1cef 100644 --- a/code/monads.ml +++ b/code/monads.ml @@ -44,6 +44,7 @@ * *) +exception Undefined (* Some library functions used below. *) module Util = struct @@ -60,7 +61,10 @@ module Util = struct let rec loop n accu = if n == 0 then accu else loop (pred n) (fill :: accu) in loop len [] - let undefined = Obj.magic "" + (* Dirty hack to be a default polymorphic zero. + * To implement this cleanly, monads without a natural zero + * should always wrap themselves in an option layer (see Leaf_monad). *) + let undef = Obj.magic (fun () -> raise Undefined) end @@ -74,14 +78,15 @@ module Monad = struct (* * Signature extenders: * Make :: BASE -> S - * MakeT :: TRANS (with Wrapped : S) -> custom sig + * MakeT :: BASET (with Wrapped : S) -> result sig not declared *) (* type of base definitions *) module type BASE = sig - (* The only constraints we impose here on how the monadic type - * is implemented is that it have a single type parameter 'a. *) + (* We make all monadic types doubly-parameterized so that they + * can layer nicely with Continuation, which needs the second + * type parameter. *) type ('x,'a) m type ('x,'a) result type ('x,'a) result_exn @@ -97,11 +102,12 @@ module Monad = struct * Additionally, they will obey one of the following laws: * (Catch) plus (unit a) v === unit a * (Distrib) plus u v >>= f === plus (u >>= f) (v >>= f) - * When no natural zero is available, use `let zero () = Util.undefined - * The Make process automatically detects for zero >>= ..., and + * When no natural zero is available, use `let zero () = Util.undef`. + * The Make functor automatically detects for zero >>= ..., and * plus zero _, plus _ zero; it also substitutes zero for pattern-match failures. *) val zero : unit -> ('x,'a) m + (* zero has to be thunked to ensure results are always poly enough *) val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m end module type S = sig @@ -125,14 +131,14 @@ module Monad = struct module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct include B let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m = - if u == Util.undefined then Util.undefined - else bind u (fun a -> try f a with Match_failure _ -> zero ()) + if u == Util.undef then Util.undef + else B.bind u (fun a -> try f a with Match_failure _ -> zero ()) let plus u v = - if u == Util.undefined then v else if v == Util.undefined then u else plus u v + if u == Util.undef then v else if v == Util.undef then u else B.plus u v let run u = - if u == Util.undefined then failwith "no zero" else run u + if u == Util.undef then raise Undefined else B.run u let run_exn u = - if u == Util.undefined then failwith "no zero" else run_exn u + if u == Util.undef then raise Undefined else B.run_exn u let (>>=) = bind let (>>) u v = u >>= fun _ -> v let lift f u = u >>= fun a -> unit (f a) @@ -149,6 +155,7 @@ module Monad = struct let (>=>) f g = fun a -> f a >>= g let do_when test u = if test then u else unit () let do_unless test u = if test then unit () else u + (* not in tail position, will Stack overflow *) let forever uthunk = let rec loop () = uthunk () >>= fun _ -> loop () in loop () @@ -184,7 +191,7 @@ module Monad = struct end (* Signatures for MonadT *) - module type TRANS = sig + module type BASET = sig module Wrapped : S type ('x,'a) m type ('x,'a) result @@ -200,7 +207,7 @@ module Monad = struct val zero : unit -> ('x,'a) m val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m end - module MakeT(T : TRANS) = struct + module MakeT(T : BASET) = struct include Make(struct include T let unit a = elevate (Wrapped.unit a) @@ -228,7 +235,7 @@ end = struct let bind a f = f a let run a = a let run_exn a = a - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -264,7 +271,7 @@ end = struct end include Monad.Make(Base) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct include Monad.MakeT(struct module Wrapped = Wrapped type ('x,'a) m = ('x,'a option) Wrapped.m @@ -278,13 +285,13 @@ end = struct let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Some a -> Wrapped.unit a - | None -> failwith "no value") - in Wrapped.run_exn w + | None -> Wrapped.zero () + ) in Wrapped.run_exn w let zero () = Wrapped.unit None let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u) end) end - include Trans + include BaseT end end @@ -323,6 +330,7 @@ end = struct | [a] -> a | many -> failwith "multiple values" let zero () = [] + (* satisfies Distrib *) let plus = Util.append end include Monad.Make(Base) @@ -341,7 +349,6 @@ end = struct let rec select u = match u with | [] -> zero () | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs')) - let base_plus = plus module T(Wrapped : Monad.S) = struct (* Wrapped.sequence ms === let plus1 u v = @@ -365,15 +372,15 @@ end = struct let run u = Wrapped.run u let run_exn u = let w = Wrapped.bind u (fun ts -> match ts with - | [] -> failwith "no values" + | [] -> Wrapped.zero () | [a] -> Wrapped.unit a - | many -> failwith "multiple values" + | many -> Wrapped.zero () ) in Wrapped.run_exn w let zero () = Wrapped.unit [] let plus u v = Wrapped.bind u (fun us -> Wrapped.bind v (fun vs -> - Wrapped.unit (base_plus us vs))) + Wrapped.unit (Base.plus us vs))) end) (* let permute : 'a m -> 'a m m @@ -424,7 +431,7 @@ end = struct let run_exn u = match u with | Success a -> a | Error e -> raise (Err.Exc e) - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u (* let zero () = Error Err.zero @@ -457,8 +464,8 @@ end = struct let run u = let w = Wrapped.bind u (fun t -> match t with | Success a -> Wrapped.unit a - | Error e -> Wrapped.zero ()) - in Wrapped.run w + | Error e -> Wrapped.zero () + ) in Wrapped.run w let run_exn u = let w = Wrapped.bind u (fun t -> match t with | Success a -> Wrapped.unit a @@ -514,6 +521,7 @@ module Reader_monad(Env : sig type env end) : sig include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn val ask : ('x,env) m val asks : (env -> 'a) -> ('x,'a) m + (* lookup i == `fun e -> e i` would assume env is a functional type *) val local : (env -> env) -> ('x,'a) m -> ('x,'a) m (* ReaderT transformer *) module T : functor (Wrapped : Monad.S) -> sig @@ -535,7 +543,7 @@ end = struct let bind u f = fun e -> let a = u e in let u' = f a in u' e let run u = fun e -> u e let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -543,7 +551,7 @@ end = struct let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *) let local modifier u = fun e -> u (modifier e) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct module Wrapped = Wrapped type ('x,'a) m = env -> ('x,'a) Wrapped.m type ('x,'a) result = env -> ('x,'a) Wrapped.result @@ -552,10 +560,11 @@ end = struct let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e) let run u = fun e -> Wrapped.run (u e) let run_exn u = fun e -> Wrapped.run_exn (u e) + (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) let zero () = elevate (Wrapped.zero ()) end - include Monad.MakeT(Trans) + include Monad.MakeT(BaseT) let ask = fun e -> Wrapped.unit e let local modifier u = fun e -> u (modifier e) let asks selector = ask >>= (fun e -> @@ -597,7 +606,7 @@ end = struct let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s' let run u = fun s -> (u s) let run_exn u = fun s -> fst (u s) - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -606,7 +615,7 @@ end = struct let put s = fun _ -> ((), s) let puts modifier = fun s -> ((), modifier s) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct module Wrapped = Wrapped type ('x,'a) m = store -> ('x,'a * store) Wrapped.m type ('x,'a) result = store -> ('x,'a * store) Wrapped.result @@ -619,10 +628,11 @@ end = struct let run_exn u = fun s -> let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a) in Wrapped.run_exn w + (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) let zero () = elevate (Wrapped.zero ()) end - include Monad.MakeT(Trans) + include Monad.MakeT(BaseT) let get = fun s -> Wrapped.unit (s, s) let gets viewer = fun s -> try Wrapped.unit (viewer s, s) @@ -674,7 +684,7 @@ end = struct let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s' let run u = fst (u empty) let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -682,7 +692,7 @@ end = struct let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *) let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *) module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct module Wrapped = Wrapped type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m type ('x,'a) result = ('x,'a) Wrapped.result @@ -697,17 +707,18 @@ end = struct let run_exn u = let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a) in Wrapped.run_exn w + (* satisfies Distrib *) let plus u v = fun s -> Wrapped.plus (u s) (v s) let zero () = elevate (Wrapped.zero ()) end - include Monad.MakeT(Trans) + include Monad.MakeT(BaseT) let newref value = fun s -> Wrapped.unit (alloc value s) let deref key = fun s -> Wrapped.unit (read key s, s) let change key value = fun s -> Wrapped.unit ((), write key value s) end end - +(* TODO needs a T *) (* must be parameterized on (struct type log = ... end) *) module Writer_monad(Log : sig type log @@ -734,7 +745,7 @@ end = struct let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w') let run u = u let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -766,6 +777,7 @@ module Writer2 = struct end +(* TODO needs a T *) module IO_monad : sig (* declare additional operation, while still hiding implementation of type m *) type ('x,'a) result = 'a @@ -787,7 +799,7 @@ end = struct { run = (fun () -> a.run (); fres.run ()); value = fres.value } let run a = let () = a.run () in a.value let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -800,6 +812,7 @@ end = struct end +(* TODO needs a T *) module Continuation_monad : sig (* expose only the implementation of type `('r,'a) result` *) type ('r,'a) m @@ -823,7 +836,7 @@ end = struct let bind u f = (fun k -> (u) (fun a -> (f a) k)) let run u k = (u) k let run_exn = run - let zero () = Util.undefined + let zero () = Util.undef let plus u v = u end include Monad.Make(Base) @@ -947,6 +960,7 @@ end = struct type ('x,'a) result_exn = 'a tree let unit a = Some (Leaf a) let zero () = None + (* satisfies Distrib *) let plus u v = match (u, v) with | None, _ -> v | _, None -> u @@ -962,10 +976,8 @@ end = struct | Some us -> us end include Monad.Make(Base) - let base_plus = plus - let base_lift = lift module T(Wrapped : Monad.S) = struct - module Trans = struct + module BaseT = struct include Monad.MakeT(struct module Wrapped = Wrapped type ('x,'a) m = ('x,'a tree option) Wrapped.m @@ -975,19 +987,18 @@ end = struct let plus u v = Wrapped.bind u (fun us -> Wrapped.bind v (fun vs -> - Wrapped.unit (base_plus us vs))) + Wrapped.unit (Base.plus us vs))) let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a))) let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus) let run u = Wrapped.run u let run_exn u = let w = Wrapped.bind u (fun t -> match t with - | None -> failwith "no values" - | Some ts -> Wrapped.unit ts) - in Wrapped.run_exn w + | None -> Wrapped.zero () + | Some ts -> Wrapped.unit ts + ) in Wrapped.run_exn w end) end - include Trans - (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *) + include BaseT let distribute f t = mapT (fun a -> elevate (f a)) t zero plus end end -- 2.11.0