X-Git-Url: http://lambda.jimpryor.net/git/gitweb.cgi?a=blobdiff_plain;ds=inline;f=code%2Fmonads.ml;h=bdc3eba2adc8982c64716d7a8a9e03b36c664e68;hb=092c31a2cad42975fb4751c2eda5ac03a13c8cd5;hp=0a205f6a8c617e449f3c91d3b7d4a72ac2fefc5d;hpb=58bf3ee4a3e5ee6e343787e432602b677a596109;p=lambda.git
diff --git a/code/monads.ml b/code/monads.ml
index 0a205f6a..bdc3eba2 100644
--- a/code/monads.ml
+++ b/code/monads.ml
@@ -42,8 +42,19 @@
* have to use operations like `run` to convert the abstract monadic types
* to types whose internals you have free access to.
*
+ * Acknowledgements: This is largely based on the mtl library distributed
+ * with the Glaskow Haskell Compiler. I've also been helped in
+ * various ways by posts and direct feedback from Oleg Kiselyov and
+ * Chung-chieh Shan. The following were also useful:
+ * -
+ * - Ken Shan "Monads for natural language semantics"
+ * - http://www.grabmueller.de/martin/www/pub/Transformers.pdf
+ * - http://en.wikibooks.org/wiki/Haskell/Monad_transformers
+ *
+ * Licensing: MIT (if that's compatible with the ghc sources).
*)
+exception Undefined
(* Some library functions used below. *)
module Util = struct
@@ -60,7 +71,10 @@ module Util = struct
let rec loop n accu =
if n == 0 then accu else loop (pred n) (fill :: accu)
in loop len []
- let undefined = Obj.magic ""
+ (* Dirty hack to be a default polymorphic zero.
+ * To implement this cleanly, monads without a natural zero
+ * should always wrap themselves in an option layer (see Leaf_monad). *)
+ let undef = Obj.magic (fun () -> raise Undefined)
end
@@ -74,14 +88,15 @@ module Monad = struct
(*
* Signature extenders:
* Make :: BASE -> S
- * MakeT :: TRANS (with Wrapped : S) -> custom sig
+ * MakeT :: BASET (with Wrapped : S) -> result sig not declared
*)
(* type of base definitions *)
module type BASE = sig
- (* The only constraints we impose here on how the monadic type
- * is implemented is that it have a single type parameter 'a. *)
+ (* We make all monadic types doubly-parameterized so that they
+ * can layer nicely with Continuation, which needs the second
+ * type parameter. *)
type ('x,'a) m
type ('x,'a) result
type ('x,'a) result_exn
@@ -97,11 +112,12 @@ module Monad = struct
* Additionally, they will obey one of the following laws:
* (Catch) plus (unit a) v === unit a
* (Distrib) plus u v >>= f === plus (u >>= f) (v >>= f)
- * When no natural zero is available, use `let zero () = Util.undefined
- * The Make process automatically detects for zero >>= ..., and
+ * When no natural zero is available, use `let zero () = Util.undef`.
+ * The Make functor automatically detects for zero >>= ..., and
* plus zero _, plus _ zero; it also substitutes zero for pattern-match failures.
*)
val zero : unit -> ('x,'a) m
+ (* zero has to be thunked to ensure results are always poly enough *)
val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
end
module type S = sig
@@ -125,15 +141,19 @@ module Monad = struct
module Make(B : BASE) : S with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct
include B
let bind (u : ('x,'a) m) (f : 'a -> ('x,'b) m) : ('x,'b) m =
- if u == Util.undefined then Util.undefined
- else bind u (fun a -> try f a with Match_failure _ -> zero ())
+ if u == Util.undef then Util.undef
+ else B.bind u (fun a -> try f a with Match_failure _ -> zero ())
let plus u v =
- if u == Util.undefined then v else if v == Util.undefined then u else plus u v
+ if u == Util.undef then v else if v == Util.undef then u else B.plus u v
let run u =
- if u == Util.undefined then failwith "no zero" else run u
+ if u == Util.undef then raise Undefined else B.run u
let run_exn u =
- if u == Util.undefined then failwith "no zero" else run_exn u
+ if u == Util.undef then raise Undefined else B.run_exn u
let (>>=) = bind
+ (* expressions after >> will be evaluated before they're passed to
+ * bind, so you can't do `zero () >> assert false`
+ * this works though: `zero () >>= fun _ -> assert false`
+ *)
let (>>) u v = u >>= fun _ -> v
let lift f u = u >>= fun a -> unit (f a)
(* lift is called listM, fmap, and <$> in Haskell *)
@@ -149,9 +169,20 @@ module Monad = struct
let (>=>) f g = fun a -> f a >>= g
let do_when test u = if test then u else unit ()
let do_unless test u = if test then unit () else u
+ (* A Haskell-like version:
+ let rec forever uthunk = uthunk () >>= fun _ -> forever uthunk
+ * is not in tail position and will stack overflow. *)
let forever uthunk =
- let rec loop () = uthunk () >>= fun _ -> loop ()
+ let z = zero () in
+ let id result = result in
+ let newk = ref id in
+ let rec loop () =
+ let result = uthunk (newk := id) >>= chained
+ in !newk result
+ and chained =
+ fun _ -> newk := (fun _ -> loop ()); z (* we use z only for its polymorphism *)
in loop ()
+ (* reimplementations of the preceding using a hand-rolled State or StateT also stack overflowed *)
let sequence ms =
let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in
Util.fold_right op ms (unit [])
@@ -184,7 +215,7 @@ module Monad = struct
end
(* Signatures for MonadT *)
- module type TRANS = sig
+ module type BASET = sig
module Wrapped : S
type ('x,'a) m
type ('x,'a) result
@@ -200,7 +231,7 @@ module Monad = struct
val zero : unit -> ('x,'a) m
val plus : ('x,'a) m -> ('x,'a) m -> ('x,'a) m
end
- module MakeT(T : TRANS) = struct
+ module MakeT(T : BASET) = struct
include Make(struct
include T
let unit a = elevate (Wrapped.unit a)
@@ -228,7 +259,7 @@ end = struct
let bind a f = f a
let run a = a
let run_exn a = a
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -264,7 +295,7 @@ end = struct
end
include Monad.Make(Base)
module T(Wrapped : Monad.S) = struct
- module Trans = struct
+ module BaseT = struct
include Monad.MakeT(struct
module Wrapped = Wrapped
type ('x,'a) m = ('x,'a option) Wrapped.m
@@ -278,13 +309,13 @@ end = struct
let run_exn u =
let w = Wrapped.bind u (fun t -> match t with
| Some a -> Wrapped.unit a
- | None -> failwith "no value")
- in Wrapped.run_exn w
+ | None -> Wrapped.zero ()
+ ) in Wrapped.run_exn w
let zero () = Wrapped.unit None
let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u)
end)
end
- include Trans
+ include BaseT
end
end
@@ -323,6 +354,7 @@ end = struct
| [a] -> a
| many -> failwith "multiple values"
let zero () = []
+ (* satisfies Distrib *)
let plus = Util.append
end
include Monad.Make(Base)
@@ -341,7 +373,6 @@ end = struct
let rec select u = match u with
| [] -> zero ()
| x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs'))
- let base_plus = plus
module T(Wrapped : Monad.S) = struct
(* Wrapped.sequence ms ===
let plus1 u v =
@@ -365,15 +396,15 @@ end = struct
let run u = Wrapped.run u
let run_exn u =
let w = Wrapped.bind u (fun ts -> match ts with
- | [] -> failwith "no values"
+ | [] -> Wrapped.zero ()
| [a] -> Wrapped.unit a
- | many -> failwith "multiple values"
+ | many -> Wrapped.zero ()
) in Wrapped.run_exn w
let zero () = Wrapped.unit []
let plus u v =
Wrapped.bind u (fun us ->
Wrapped.bind v (fun vs ->
- Wrapped.unit (base_plus us vs)))
+ Wrapped.unit (Base.plus us vs)))
end)
(*
let permute : 'a m -> 'a m m
@@ -382,6 +413,13 @@ end = struct
end
end
+(*
+# LL.(run(plus (unit 1) (unit 2) >>= fun i -> plus (unit i) (unit(10*i)) ));;
+- : ('_a, int) LL.result = [[1; 10; 2; 20]]
+# LL.(run(plus (unit 1) (unit 2) >>= fun i -> elevate L.(plus (unit i) (unit(10*i)) )));;
+- : ('_a, int) LL.result = [[1; 2]; [1; 20]; [10; 2]; [10; 20]]
+*)
+
(* must be parameterized on (struct type err = ... end) *)
module Error_monad(Err : sig
@@ -424,19 +462,11 @@ end = struct
let run_exn u = match u with
| Success a -> a
| Error e -> raise (Err.Exc e)
- let zero () = Util.undefined
- let plus u v = u
- (*
- let zero () = Error Err.zero
- let plus u v = match (u, v) with
- | Success _, _ -> u
- (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _
- * otherwise, plus (Error _) v = v *)
- | Error _, _ when v = zero -> u
- (* combine errors *)
- | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2)
- | Error _, _ -> v
- *)
+ let zero () = Util.undef
+ (* satisfies Catch *)
+ let plus u v = match u with
+ | Success _ -> u
+ | Error _ -> if v == Util.undef then u else v
end
include Monad.Make(Base)
(* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *)
@@ -457,15 +487,15 @@ end = struct
let run u =
let w = Wrapped.bind u (fun t -> match t with
| Success a -> Wrapped.unit a
- | Error e -> Wrapped.zero ())
- in Wrapped.run w
+ | Error e -> Wrapped.zero ()
+ ) in Wrapped.run w
let run_exn u =
let w = Wrapped.bind u (fun t -> match t with
| Success a -> Wrapped.unit a
| Error e -> raise (Err.Exc e))
in Wrapped.run_exn w
let plus u v = Wrapped.plus u v
- let zero () = elevate (Wrapped.zero ())
+ let zero () = Wrapped.zero () (* elevate (Wrapped.zero ()) *)
end)
let throw e = Wrapped.unit (Error e)
let catch u handler = Wrapped.bind u (fun t -> match t with
@@ -514,6 +544,7 @@ module Reader_monad(Env : sig type env end) : sig
include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
val ask : ('x,env) m
val asks : (env -> 'a) -> ('x,'a) m
+ (* lookup i == `fun e -> e i` would assume env is a functional type *)
val local : (env -> env) -> ('x,'a) m -> ('x,'a) m
(* ReaderT transformer *)
module T : functor (Wrapped : Monad.S) -> sig
@@ -535,7 +566,7 @@ end = struct
let bind u f = fun e -> let a = u e in let u' = f a in u' e
let run u = fun e -> u e
let run_exn = run
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -543,20 +574,21 @@ end = struct
let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *)
let local modifier u = fun e -> u (modifier e)
module T(Wrapped : Monad.S) = struct
- module Trans = struct
+ module BaseT = struct
module Wrapped = Wrapped
type ('x,'a) m = env -> ('x,'a) Wrapped.m
type ('x,'a) result = env -> ('x,'a) Wrapped.result
type ('x,'a) result_exn = env -> ('x,'a) Wrapped.result_exn
let elevate w = fun e -> w
- let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e)
+ let bind u f = fun e -> Wrapped.bind (u e) (fun a -> f a e)
let run u = fun e -> Wrapped.run (u e)
let run_exn u = fun e -> Wrapped.run_exn (u e)
- let plus u v = fun s -> Wrapped.plus (u s) (v s)
- let zero () = elevate (Wrapped.zero ())
+ (* satisfies Distrib *)
+ let plus u v = fun e -> Wrapped.plus (u e) (v e)
+ let zero () = fun e -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
end
- include Monad.MakeT(Trans)
- let ask = fun e -> Wrapped.unit e
+ include Monad.MakeT(BaseT)
+ let ask = Wrapped.unit
let local modifier u = fun e -> u (modifier e)
let asks selector = ask >>= (fun e ->
try unit (selector e)
@@ -597,7 +629,7 @@ end = struct
let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
let run u = fun s -> (u s)
let run_exn u = fun s -> fst (u s)
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -606,7 +638,7 @@ end = struct
let put s = fun _ -> ((), s)
let puts modifier = fun s -> ((), modifier s)
module T(Wrapped : Monad.S) = struct
- module Trans = struct
+ module BaseT = struct
module Wrapped = Wrapped
type ('x,'a) m = store -> ('x,'a * store) Wrapped.m
type ('x,'a) result = store -> ('x,'a * store) Wrapped.result
@@ -619,10 +651,11 @@ end = struct
let run_exn u = fun s ->
let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a)
in Wrapped.run_exn w
+ (* satisfies Distrib *)
let plus u v = fun s -> Wrapped.plus (u s) (v s)
- let zero () = elevate (Wrapped.zero ())
+ let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
end
- include Monad.MakeT(Trans)
+ include Monad.MakeT(BaseT)
let get = fun s -> Wrapped.unit (s, s)
let gets viewer = fun s ->
try Wrapped.unit (viewer s, s)
@@ -674,7 +707,7 @@ end = struct
let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s'
let run u = fst (u empty)
let run_exn = run
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -682,7 +715,7 @@ end = struct
let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *)
let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *)
module T(Wrapped : Monad.S) = struct
- module Trans = struct
+ module BaseT = struct
module Wrapped = Wrapped
type ('x,'a) m = dict -> ('x,'a * dict) Wrapped.m
type ('x,'a) result = ('x,'a) Wrapped.result
@@ -697,10 +730,11 @@ end = struct
let run_exn u =
let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a)
in Wrapped.run_exn w
+ (* satisfies Distrib *)
let plus u v = fun s -> Wrapped.plus (u s) (v s)
- let zero () = elevate (Wrapped.zero ())
+ let zero () = fun s -> Wrapped.zero () (* elevate (Wrapped.zero ()) *)
end
- include Monad.MakeT(Trans)
+ include Monad.MakeT(BaseT)
let newref value = fun s -> Wrapped.unit (alloc value s)
let deref key = fun s -> Wrapped.unit (read key s, s)
let change key value = fun s -> Wrapped.unit ((), write key value s)
@@ -724,6 +758,17 @@ end) : sig
val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
(* val pass : ('x,'a * (log -> log)) m -> ('x,'a) m *)
val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
+ (* WriterT transformer *)
+ module T : functor (Wrapped : Monad.S) -> sig
+ type ('x,'a) result = ('x,'a * log) Wrapped.result
+ type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
+ include Monad.S with type ('x,'a) result := ('x,'a) result and type ('x,'a) result_exn := ('x,'a) result_exn
+ val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+ val tell : log -> ('x,unit) m
+ val listen : ('x,'a) m -> ('x,'a * log) m
+ val listens : (log -> 'b) -> ('x,'a) m -> ('x,'a * 'b) m
+ val censor : (log -> log) -> ('x,'a) m -> ('x,'a) m
+ end
end = struct
type log = Log.log
module Base = struct
@@ -731,10 +776,10 @@ end = struct
type ('x,'a) result = 'a * log
type ('x,'a) result_exn = 'a * log
let unit a = (a, Log.zero)
- let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w')
+ let bind (a, w) f = let (b, w') = f a in (b, Log.plus w w')
let run u = u
let run_exn = run
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -743,6 +788,31 @@ end = struct
let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *)
let pass ((a, f), w) = (a, f w) (* usually use censor helper *)
let censor f u = pass (u >>= fun a -> unit (a, f))
+ module T(Wrapped : Monad.S) = struct
+ module BaseT = struct
+ module Wrapped = Wrapped
+ type ('x,'a) m = ('x,'a * log) Wrapped.m
+ type ('x,'a) result = ('x,'a * log) Wrapped.result
+ type ('x,'a) result_exn = ('x,'a * log) Wrapped.result_exn
+ let elevate w =
+ Wrapped.bind w (fun a -> Wrapped.unit (a, Log.zero))
+ let bind u f =
+ Wrapped.bind u (fun (a, w) ->
+ Wrapped.bind (f a) (fun (b, w') ->
+ Wrapped.unit (b, Log.plus w w')))
+ let zero () = elevate (Wrapped.zero ())
+ let plus u v = Wrapped.plus u v
+ let run u = Wrapped.run u
+ let run_exn u = Wrapped.run_exn u
+ end
+ include Monad.MakeT(BaseT)
+ let tell entries = Wrapped.unit ((), entries)
+ let listen u = Wrapped.bind u (fun (a, w) -> Wrapped.unit ((a, w), w))
+ let pass u = Wrapped.bind u (fun ((a, f), w) -> Wrapped.unit (a, f w))
+ (* rest are derived in same way as before *)
+ let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w)
+ let censor f u = pass (u >>= fun a -> unit (a, f))
+ end
end
(* pre-define simple Writer *)
@@ -766,6 +836,7 @@ module Writer2 = struct
end
+(* TODO needs a T *)
module IO_monad : sig
(* declare additional operation, while still hiding implementation of type m *)
type ('x,'a) result = 'a
@@ -787,7 +858,7 @@ end = struct
{ run = (fun () -> a.run (); fres.run ()); value = fres.value }
let run a = let () = a.run () in a.value
let run_exn = run
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -812,6 +883,16 @@ module Continuation_monad : sig
(* val abort : ('a,'a) m -> ('a,'b) m *)
val abort : 'a -> ('a,'b) m
val run0 : ('a,'a) m -> 'a
+ (* ContinuationT transformer *)
+ module T : functor (Wrapped : Monad.S) -> sig
+ type ('r,'a) m
+ type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
+ type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
+ include Monad.S with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn and type ('r,'a) m := ('r,'a) m
+ val elevate : ('x,'a) Wrapped.m -> ('x,'a) m
+ val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m
+ (* TODO: reset,shift,abort,run0 *)
+ end
end = struct
let id = fun i -> i
module Base = struct
@@ -823,7 +904,7 @@ end = struct
let bind u f = (fun k -> (u) (fun a -> (f a) k))
let run u k = (u) k
let run_exn = run
- let zero () = Util.undefined
+ let zero () = Util.undef
let plus u v = u
end
include Monad.Make(Base)
@@ -850,6 +931,24 @@ end = struct
(* let abort a = shift (fun _ -> a) *)
let abort a = shift (fun _ -> unit a)
let run0 (u : ('a,'a) m) = (u) id
+ module T(Wrapped : Monad.S) = struct
+ module BaseT = struct
+ module Wrapped = Wrapped
+ type ('r,'a) m = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.m
+ type ('r,'a) result = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result
+ type ('r,'a) result_exn = ('a -> ('r,'r) Wrapped.m) -> ('r,'r) Wrapped.result_exn
+ let elevate w = fun k -> Wrapped.bind w k
+ let bind u f = fun k -> u (fun a -> f a k)
+ let run u k = Wrapped.run (u k)
+ let run_exn u k = Wrapped.run_exn (u k)
+ let zero () = Util.undef
+ let plus u v = u
+ end
+ include Monad.MakeT(BaseT)
+ let callcc f = (fun k ->
+ let usek a = (fun _ -> k a)
+ in (f usek) k)
+ end
end
@@ -947,6 +1046,7 @@ end = struct
type ('x,'a) result_exn = 'a tree
let unit a = Some (Leaf a)
let zero () = None
+ (* satisfies Distrib *)
let plus u v = match (u, v) with
| None, _ -> v
| _, None -> u
@@ -962,10 +1062,8 @@ end = struct
| Some us -> us
end
include Monad.Make(Base)
- let base_plus = plus
- let base_lift = lift
module T(Wrapped : Monad.S) = struct
- module Trans = struct
+ module BaseT = struct
include Monad.MakeT(struct
module Wrapped = Wrapped
type ('x,'a) m = ('x,'a tree option) Wrapped.m
@@ -975,19 +1073,18 @@ end = struct
let plus u v =
Wrapped.bind u (fun us ->
Wrapped.bind v (fun vs ->
- Wrapped.unit (base_plus us vs)))
+ Wrapped.unit (Base.plus us vs)))
let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a)))
let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus)
let run u = Wrapped.run u
let run_exn u =
let w = Wrapped.bind u (fun t -> match t with
- | None -> failwith "no values"
- | Some ts -> Wrapped.unit ts)
- in Wrapped.run_exn w
+ | None -> Wrapped.zero ()
+ | Some ts -> Wrapped.unit ts
+ ) in Wrapped.run_exn w
end)
end
- include Trans
- (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *)
+ include BaseT
let distribute f t = mapT (fun a -> elevate (f a)) t zero plus
end
end