From: Jim Pryor Date: Sat, 11 Dec 2010 00:00:14 +0000 (-0500) Subject: push monads library X-Git-Url: http://lambda.jimpryor.net/git/gitweb.cgi?p=lambda.git;a=commitdiff_plain;h=341631c0ca1850cc0e2bdaa459fc1bdc771a3175 push monads library Signed-off-by: Jim Pryor --- diff --git a/code/monads.ml b/code/monads.ml new file mode 100644 index 00000000..a334b0f0 --- /dev/null +++ b/code/monads.ml @@ -0,0 +1,1296 @@ +(* + * monads.ml + * + * Relies on features introduced in OCaml 3.12 + * + * This library uses parameterized modules, see tree_monadize.ml for + * more examples and explanation. + * + * Some comparisons with the Haskell monadic libraries, which we mostly follow: + * In Haskell, the Reader 'a monadic type would be defined something like this: + * newtype Reader a = Reader { runReader :: env -> a } + * (For simplicity, I'm suppressing the fact that Reader is also parameterized + * on the type of env.) + * This creates a type wrapper around `env -> a`, so that Haskell will + * distinguish between values that have been specifically designated as + * being of type `Reader a`, and common-garden values of type `env -> a`. + * To lift an aribtrary expression E of type `env -> a` into an `Reader a`, + * you do this: + * Reader { runReader = E } + * or use any of the following equivalent shorthands: + * Reader (E) + * Reader $ E + * To drop an expression R of type `Reader a` back into an `env -> a`, you do + * one of these: + * runReader (R) + * runReader $ R + * The `newtype` in the type declaration ensures that Haskell does this all + * efficiently: though it regards E and R as type-distinct, their underlying + * machine implementation is identical and doesn't need to be transformed when + * lifting/dropping from one type to the other. + * + * Now, you _could_ also declare monads as record types in OCaml, too, _but_ + * doing so would introduce an extra level of machine representation, and + * lifting/dropping from the one type to the other wouldn't be free like it is + * in Haskell. + * + * This library encapsulates the monadic types in another way: by + * making their implementations private. The interpreter won't let + * let you freely interchange the `'a Reader_monad.m`s defined below + * with `Reader_monad.env -> 'a`. The code in this library can see that + * those are equivalent, but code outside the library can't. Instead, you'll + * have to use operations like `run` to convert the abstract monadic types + * to types whose internals you have free access to. + * + *) + + +(* Some library functions used below. *) +module Util = struct + let fold_right = List.fold_right + let map = List.map + let append = List.append + let reverse = List.rev + let concat = List.concat + let concat_map f lst = List.concat (List.map f lst) + (* let zip = List.combine *) + let unzip = List.split + let zip_with = List.map2 + let replicate len fill = + let rec loop n accu = + if n == 0 then accu else loop (pred n) (fill :: accu) + in loop len [] +end + + + +(* + * This module contains factories that extend a base set of + * monadic definitions with a larger family of standard derived values. + *) + +module Monad = struct + + (* type of base definitions *) + module type BASE = sig + (* + * The only constraints we impose here on how the monadic type + * is implemented is that it have a single type parameter 'a. + *) + type 'a m + val unit : 'a -> 'a m + val bind : 'a m -> ('a -> 'b m) -> 'b m + type 'a result + val run : 'a m -> 'a result + (* run_exn tries to provide a more ground-level result, but may fail *) + type 'a result_exn + val run_exn : 'a m -> 'a result_exn + end + module type S = sig + include BASE + val (>>=) : 'a m -> ('a -> 'b m) -> 'b m + val (>>) : 'a m -> 'b m -> 'b m + val join : ('a m) m -> 'a m + val apply : ('a -> 'b) m -> 'a m -> 'b m + val lift : ('a -> 'b) -> 'a m -> 'b m + val lift2 : ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m + val (>=>) : ('a -> 'b m) -> ('b -> 'c m) -> 'a -> 'c m + val do_when : bool -> unit m -> unit m + val do_unless : bool -> unit m -> unit m + val forever : 'a m -> 'b m + val sequence : 'a m list -> 'a list m + val sequence_ : 'a m list -> unit m + end + + (* Standard, single-type-parameter monads. *) + module Make(B : BASE) : S with type 'a m = 'a B.m and type 'a result = 'a B.result and type 'a result_exn = 'a B.result_exn = struct + include B + let (>>=) = bind + let (>>) u v = u >>= fun _ -> v + let lift f u = u >>= fun a -> unit (f a) + (* lift is called listM, fmap, and <$> in Haskell *) + let join uu = uu >>= fun u -> u + (* u >>= f === join (lift f u) *) + let apply u v = u >>= fun f -> v >>= fun a -> unit (f a) + (* [f] <*> [x1,x2] = [f x1,f x2] *) + (* let apply u v = u >>= fun f -> lift f v *) + (* let apply = lift2 id *) + let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a') + (* let lift f u === apply (unit f) u *) + (* let lift2 f u v = apply (lift f u) v *) + + let (>=>) f g = fun a -> f a >>= g + let do_when test u = if test then u else unit () + let do_unless test u = if test then unit () else u + let rec forever u = u >> forever u + let sequence ms = + let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in + Util.fold_right op ms (unit []) + let sequence_ ms = + Util.fold_right (>>) ms (unit ()) + + (* Haskell defines these other operations combining lists and monads. + * We don't, but notice that M.mapM == ListT(M).distribute + * There's also a parallel TreeT(M).distribute *) + (* + let mapM f alist = sequence (Util.map f alist) + let mapM_ f alist = sequence_ (Util.map f alist) + let rec filterM f lst = match lst with + | [] -> unit [] + | x::xs -> f x >>= fun flag -> filterM f xs >>= fun ys -> unit (if flag then x :: ys else ys) + let forM alist f = mapM f alist + let forM_ alist f = mapM_ f alist + let map_and_unzipM f xs = sequence (Util.map f xs) >>= fun x -> unit (Util.unzip x) + let zip_withM f xs ys = sequence (Util.zip_with f xs ys) + let zip_withM_ f xs ys = sequence_ (Util.zip_with f xs ys) + let rec foldM f z lst = match lst with + | [] -> unit z + | x::xs -> f z x >>= fun z' -> foldM f z' xs + let foldM_ f z xs = foldM f z xs >> unit () + let replicateM n x = sequence (Util.replicate n x) + let replicateM_ n x = sequence_ (Util.replicate n x) + *) + end + + (* Single-type-parameter monads that also define `plus` and `zero` + * operations. These obey the following laws: + * zero >>= f === zero + * plus zero u === u + * plus u zero === u + * Additionally, these monads will obey one of the following laws: + * (Catch) plus (unit a) v === unit a + * (Distrib) plus u v >>= f === plus (u >>= f) (v >>= f) + *) + module type PLUSBASE = sig + include BASE + val zero : unit -> 'a m + val plus : 'a m -> 'a m -> 'a m + end + module type PLUS = sig + type 'a m + val zero : unit -> 'a m + val plus : 'a m -> 'a m -> 'a m + val guard : bool -> unit m + val sum : 'a m list -> 'a m + end + (* MakeCatch and MakeDistrib have the same implementation; we just declare + * them twice to document which laws the client code is promising to honor. *) + module MakeCatch(B : PLUSBASE) : PLUS with type 'a m = 'a B.m = struct + type 'a m = 'a B.m + let zero = B.zero + let plus = B.plus + let guard test = if test then B.unit () else zero () + let sum ms = Util.fold_right plus ms (zero ()) + end + module MakeDistrib = MakeCatch + + (* We have to define BASE, S, and Make again for double-type-parameter monads. *) + module type BASE2 = sig + type ('x,'a) m + val unit : 'a -> ('x,'a) m + val bind : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m + type ('x,'a) result + val run : ('x,'a) m -> ('x,'a) result + type ('x,'a) result_exn + val run_exn : ('x,'a) m -> ('x,'a) result + end + module type S2 = sig + include BASE2 + val (>>=) : ('x,'a) m -> ('a -> ('x,'b) m) -> ('x,'b) m + val (>>) : ('x,'a) m -> ('x,'b) m -> ('x,'b) m + val join : ('x,('x,'a) m) m -> ('x,'a) m + val apply : ('x,'a -> 'b) m -> ('x,'a) m -> ('x,'b) m + val lift : ('a -> 'b) -> ('x,'a) m -> ('x,'b) m + val lift2 : ('a -> 'b -> 'c) -> ('x,'a) m -> ('x,'b) m -> ('x,'c) m + val (>=>) : ('a -> ('x,'b) m) -> ('b -> ('x,'c) m) -> 'a -> ('x,'c) m + val do_when : bool -> ('x,unit) m -> ('x,unit) m + val do_unless : bool -> ('x,unit) m -> ('x,unit) m + val forever : ('x,'a) m -> ('x,'b) m + val sequence : ('x,'a) m list -> ('x,'a list) m + val sequence_ : ('x,'a) m list -> ('x,unit) m + end + module Make2(B : BASE2) : S2 with type ('x,'a) m = ('x,'a) B.m and type ('x,'a) result = ('x,'a) B.result and type ('x,'a) result_exn = ('x,'a) B.result_exn = struct + include B + let (>>=) = bind + let (>>) u v = u >>= fun _ -> v + let lift f u = u >>= fun a -> unit (f a) + let join uu = uu >>= fun u -> u + let apply u v = u >>= fun f -> v >>= fun a -> unit (f a) + let lift2 f u v = u >>= fun a -> v >>= fun a' -> unit (f a a') + let (>=>) f g = fun a -> f a >>= g + let do_when test u = if test then u else unit () + let do_unless test u = if test then unit () else u + let rec forever u = u >> forever u + let sequence ms = + let op u v = u >>= fun x -> v >>= fun xs -> unit (x :: xs) in + Util.fold_right op ms (unit []) + let sequence_ ms = + Util.fold_right (>>) ms (unit ()) + end + + (* Signatures for MonadT *) + module type W = sig + include S + end + module type WP = sig + include W + val zero : unit -> 'a m + val plus : 'a m -> 'a m -> 'a m + end + module type TRANS = sig + type 'a m + val bind : 'a m -> ('a -> 'b m) -> 'b m + module Wrapped : W + type 'a result + val run : 'a m -> 'a result + type 'a result_exn + val run_exn : 'a m -> 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + (* lift/elevate laws: + * elevate (W.unit a) == unit a + * elevate (W.bind w f) == elevate w >>= fun a -> elevate (f a) + *) + end + module MakeT(T : TRANS) = struct + include Make(struct + include T + let unit a = elevate (Wrapped.unit a) + end) + let elevate = T.elevate + end + +end + + + + + +module Identity_monad : sig + (* expose only the implementation of type `'a result` *) + type 'a result = 'a + type 'a result_exn = 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn +end = struct + module Base = struct + type 'a m = 'a + let unit a = a + let bind a f = f a + type 'a result = 'a + let run a = a + type 'a result_exn = 'a + let run_exn a = a + end + include Monad.Make(Base) +end + + +module Maybe_monad : sig + (* expose only the implementation of type `'a result` *) + type 'a result = 'a option + type 'a result_exn = 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + include Monad.PLUS with type 'a m := 'a m + (* MaybeT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = 'a option Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + include Monad.PLUS with type 'a m := 'a m + val elevate : 'a Wrapped.m -> 'a m + end +end = struct + module Base = struct + type 'a m = 'a option + let unit a = Some a + let bind u f = match u with Some a -> f a | None -> None + type 'a result = 'a option + let run u = u + type 'a result_exn = 'a + let run_exn u = match u with + | Some a -> a + | None -> failwith "no value" + let zero () = None + let plus u v = match u with None -> v | _ -> u + end + include Monad.Make(Base) + include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) + module T(Wrapped : Monad.W) = struct + module Trans = struct + include Monad.MakeT(struct + module Wrapped = Wrapped + type 'a m = 'a option Wrapped.m + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some a)) + let bind u f = Wrapped.bind u (fun t -> match t with + | Some a -> f a + | None -> Wrapped.unit None) + type 'a result = 'a option Wrapped.result + let run u = Wrapped.run u + type 'a result_exn = 'a Wrapped.result_exn + let run_exn u = + let w = Wrapped.bind u (fun t -> match t with + | Some a -> Wrapped.unit a + | None -> failwith "no value") + in Wrapped.run_exn w + end) + let zero () = Wrapped.unit None + let plus u v = Wrapped.bind u (fun t -> match t with | None -> v | _ -> u) + end + include Trans + include (Monad.MakeCatch(Trans) : Monad.PLUS with type 'a m := 'a m) + end +end + + +module List_monad : sig + (* declare additional operation, while still hiding implementation of type m *) + type 'a result = 'a list + type 'a result_exn = 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + include Monad.PLUS with type 'a m := 'a m + val permute : 'a m -> 'a m m + val select : 'a m -> ('a * 'a m) m + (* ListT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = 'a list Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + include Monad.PLUS with type 'a m := 'a m + val elevate : 'a Wrapped.m -> 'a m + (* note that second argument is an 'a list, not the more abstract 'a m *) + (* type is ('a -> 'b W) -> 'a list -> 'b list W == 'b listT(W) *) + val distribute : ('a -> 'b Wrapped.m) -> 'a list -> 'b m +(* TODO + val permute : 'a m -> 'a m m + val select : 'a m -> ('a * 'a m) m +*) + end +end = struct + module Base = struct + type 'a m = 'a list + let unit a = [a] + let bind u f = Util.concat_map f u + type 'a result = 'a list + let run u = u + type 'a result_exn = 'a + let run_exn u = match u with + | [] -> failwith "no values" + | [a] -> a + | many -> failwith "multiple values" + let zero () = [] + let plus = Util.append + end + include Monad.Make(Base) + include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m) + (* let either u v = plus u v *) + (* insert 3 [1;2] ~~> [[3;1;2]; [1;3;2]; [1;2;3]] *) + let rec insert a u = + plus (unit (a :: u)) (match u with + | [] -> zero () + | x :: xs -> (insert a xs) >>= fun v -> unit (x :: v) + ) + (* permute [1;2;3] ~~> [1;2;3]; [2;1;3]; [2;3;1]; [1;3;2]; [3;1;2]; [3;2;1] *) + let rec permute u = match u with + | [] -> unit [] + | x :: xs -> (permute xs) >>= (fun v -> insert x v) + (* select [1;2;3] ~~> [(1,[2;3]); (2,[1;3]), (3;[1;2])] *) + let rec select u = match u with + | [] -> zero () + | x::xs -> plus (unit (x, xs)) (select xs >>= fun (x', xs') -> unit (x', x :: xs')) + let base_plus = plus + module T(Wrapped : Monad.W) = struct + module Trans = struct + let zero () = Wrapped.unit [] + let plus u v = + Wrapped.bind u (fun us -> + Wrapped.bind v (fun vs -> + Wrapped.unit (base_plus us vs))) + (* Wrapped.sequence ms === + let plus1 u v = + Wrapped.bind u (fun x -> + Wrapped.bind v (fun xs -> + Wrapped.unit (x :: xs))) + in Util.fold_right plus1 ms (Wrapped.unit []) *) + (* distribute === Wrapped.mapM; copies alist to its image under f *) + let distribute f alist = Wrapped.sequence (Util.map f alist) + include Monad.MakeT(struct + module Wrapped = Wrapped + type 'a m = 'a list Wrapped.m + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit [a]) + let bind u f = + Wrapped.bind u (fun ts -> + Wrapped.bind (distribute f ts) (fun tts -> + Wrapped.unit (Util.concat tts))) + type 'a result = 'a list Wrapped.result + let run u = Wrapped.run u + type 'a result_exn = 'a Wrapped.result_exn + let run_exn u = + let w = Wrapped.bind u (fun ts -> match ts with + | [] -> failwith "no values" + | [a] -> Wrapped.unit a + | many -> failwith "multiple values" + ) in Wrapped.run_exn w + end) + end + include Trans + include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m) +(* + let permute : 'a m -> 'a m m + let select : 'a m -> ('a * 'a m) m +*) + end +end + + +(* must be parameterized on (struct type err = ... end) *) +module Error_monad(Err : sig + type err + exception Exc of err + (* + val zero : unit -> err + val plus : err -> err -> err + *) +end) : sig + (* declare additional operations, while still hiding implementation of type m *) + type err = Err.err + type 'a error = Error of err | Success of 'a + type 'a result = 'a + type 'a result_exn = 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + (* include Monad.PLUS with type 'a m := 'a m *) + val throw : err -> 'a m + val catch : 'a m -> (err -> 'a m) -> 'a m + (* ErrorT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = 'a Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + val throw : err -> 'a m + val catch : 'a m -> (err -> 'a m) -> 'a m + end +end = struct + type err = Err.err + type 'a error = Error of err | Success of 'a + module Base = struct + type 'a m = 'a error + let unit a = Success a + let bind u f = match u with + | Success a -> f a + | Error e -> Error e (* input and output may be of different 'a types *) + type 'a result = 'a + (* TODO: should run refrain from failing? *) + let run u = match u with + | Success a -> a + | Error e -> raise (Err.Exc e) + type 'a result_exn = 'a + let run_exn = run + (* + let zero () = Error Err.zero + let plus u v = match (u, v) with + | Success _, _ -> u + (* to satisfy (Catch) laws, plus u zero = u, even if u = Error _ + * otherwise, plus (Error _) v = v *) + | Error _, _ when v = zero -> u + (* combine errors *) + | Error e1, Error e2 when u <> zero -> Error (Err.plus e1 e2) + | Error _, _ -> v + *) + end + include Monad.Make(Base) + (* include (Monad.MakeCatch(Base) : Monad.PLUS with type 'a m := 'a m) *) + let throw e = Error e + let catch u handler = match u with + | Success _ -> u + | Error e -> handler e + module T(Wrapped : Monad.W) = struct + module Trans = struct + module Wrapped = Wrapped + type 'a m = 'a Base.m Wrapped.m + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Success a)) + let bind u f = Wrapped.bind u (fun t -> match t with + | Success a -> f a + | Error e -> Wrapped.unit (Error e)) + type 'a result = 'a Wrapped.result + (* TODO: should run refrain from failing? *) + let run u = + let w = Wrapped.bind u (fun t -> match t with + | Success a -> Wrapped.unit a + (* | _ -> Wrapped.fail () *) + | Error e -> raise (Err.Exc e)) + in Wrapped.run w + type 'a result_exn = 'a Wrapped.result_exn + let run_exn u = + let w = Wrapped.bind u (fun t -> match t with + | Success a -> Wrapped.unit a + (* | _ -> Wrapped.fail () *) + | Error e -> raise (Err.Exc e)) + in Wrapped.run_exn w + end + include Monad.MakeT(Trans) + let throw e = Wrapped.unit (Error e) + let catch u handler = Wrapped.bind u (fun t -> match t with + | Success _ -> Wrapped.unit t + | Error e -> handler e) + end +end + +(* pre-define common instance of Error_monad *) +module Failure = Error_monad(struct + type err = string + exception Exc = Failure + (* + let zero = "" + let plus s1 s2 = s1 ^ "\n" ^ s2 + *) +end) + +(* must be parameterized on (struct type env = ... end) *) +module Reader_monad(Env : sig type env end) : sig + (* declare additional operations, while still hiding implementation of type m *) + type env = Env.env + type 'a result = env -> 'a + type 'a result_exn = env -> 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val ask : env m + val asks : (env -> 'a) -> 'a m + val local : (env -> env) -> 'a m -> 'a m + (* ReaderT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = env -> 'a Wrapped.result + type 'a result_exn = env -> 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + val ask : env m + val asks : (env -> 'a) -> 'a m + val local : (env -> env) -> 'a m -> 'a m + end + (* ReaderT transformer when wrapped monad has plus, zero *) + module TP : functor (Wrapped : Monad.WP) -> sig + include module type of T(Wrapped) + include Monad.PLUS with type 'a m := 'a m + end +end = struct + type env = Env.env + module Base = struct + type 'a m = env -> 'a + let unit a = fun e -> a + let bind u f = fun e -> let a = u e in let u' = f a in u' e + type 'a result = env -> 'a + let run u = fun e -> u e + type 'a result_exn = env -> 'a + let run_exn = run + end + include Monad.Make(Base) + let ask = fun e -> e + let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *) + let local modifier u = fun e -> u (modifier e) + module T(Wrapped : Monad.W) = struct + module Trans = struct + module Wrapped = Wrapped + type 'a m = env -> 'a Wrapped.m + let elevate w = fun e -> w + let bind u f = fun e -> Wrapped.bind (u e) (fun v -> f v e) + type 'a result = env -> 'a Wrapped.result + let run u = fun e -> Wrapped.run (u e) + type 'a result_exn = env -> 'a Wrapped.result_exn + let run_exn u = fun e -> Wrapped.run_exn (u e) + end + include Monad.MakeT(Trans) + let ask = fun e -> Wrapped.unit e + let asks selector = ask >>= (fun e -> unit (selector e)) (* may fail *) + let local modifier u = fun e -> u (modifier e) + end + module TP(Wrapped : Monad.WP) = struct + module TransP = struct + include T(Wrapped) + let plus u v = fun s -> Wrapped.plus (u s) (v s) + let zero () = elevate (Wrapped.zero ()) + let asks selector = ask >>= (fun e -> + try unit (selector e) + with Not_found -> fun e -> Wrapped.zero ()) + end + include TransP + include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m) + end +end + + +(* must be parameterized on (struct type store = ... end) *) +module State_monad(Store : sig type store end) : sig + (* declare additional operations, while still hiding implementation of type m *) + type store = Store.store + type 'a result = store -> 'a * store + type 'a result_exn = store -> 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val get : store m + val gets : (store -> 'a) -> 'a m + val put : store -> unit m + val puts : (store -> store) -> unit m + (* StateT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = store -> ('a * store) Wrapped.result + type 'a result_exn = store -> 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + val get : store m + val gets : (store -> 'a) -> 'a m + val put : store -> unit m + val puts : (store -> store) -> unit m + end + (* StateT transformer when wrapped monad has plus, zero *) + module TP : functor (Wrapped : Monad.WP) -> sig + include module type of T(Wrapped) + include Monad.PLUS with type 'a m := 'a m + end +end = struct + type store = Store.store + module Base = struct + type 'a m = store -> 'a * store + let unit a = fun s -> (a, s) + let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s' + type 'a result = store -> 'a * store + let run u = fun s -> (u s) + type 'a result_exn = store -> 'a + let run_exn u = fun s -> fst (u s) + end + include Monad.Make(Base) + let get = fun s -> (s, s) + let gets viewer = fun s -> (viewer s, s) (* may fail *) + let put s = fun _ -> ((), s) + let puts modifier = fun s -> ((), modifier s) + module T(Wrapped : Monad.W) = struct + module Trans = struct + module Wrapped = Wrapped + type 'a m = store -> ('a * store) Wrapped.m + let elevate w = fun s -> + Wrapped.bind w (fun a -> Wrapped.unit (a, s)) + let bind u f = fun s -> + Wrapped.bind (u s) (fun (a, s') -> f a s') + type 'a result = store -> ('a * store) Wrapped.result + let run u = fun s -> Wrapped.run (u s) + type 'a result_exn = store -> 'a Wrapped.result_exn + let run_exn u = fun s -> + let w = Wrapped.bind (u s) (fun (a,s) -> Wrapped.unit a) + in Wrapped.run_exn w + end + include Monad.MakeT(Trans) + let get = fun s -> Wrapped.unit (s, s) + let gets viewer = fun s -> Wrapped.unit (viewer s, s) (* may fail *) + let put s = fun _ -> Wrapped.unit ((), s) + let puts modifier = fun s -> Wrapped.unit ((), modifier s) + end + module TP(Wrapped : Monad.WP) = struct + module TransP = struct + include T(Wrapped) + let plus u v = fun s -> Wrapped.plus (u s) (v s) + let zero () = elevate (Wrapped.zero ()) + end + let gets viewer = fun s -> + try Wrapped.unit (viewer s, s) + with Not_found -> Wrapped.zero () + include TransP + include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m) + end +end + +(* State monad with different interface (structured store) *) +module Ref_monad(V : sig + type value +end) : sig + type ref + type value = V.value + type 'a result = 'a + type 'a result_exn = 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val newref : value -> ref m + val deref : ref -> value m + val change : ref -> value -> unit m + (* RefT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = 'a Wrapped.result + type 'a result_exn = 'a Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val elevate : 'a Wrapped.m -> 'a m + val newref : value -> ref m + val deref : ref -> value m + val change : ref -> value -> unit m + end + (* RefT transformer when wrapped monad has plus, zero *) + module TP : functor (Wrapped : Monad.WP) -> sig + include module type of T(Wrapped) + include Monad.PLUS with type 'a m := 'a m + end +end = struct + type ref = int + type value = V.value + module D = Map.Make(struct type t = ref let compare = compare end) + type dict = { next: ref; tree : value D.t } + let empty = { next = 0; tree = D.empty } + let alloc (value : value) (d : dict) = + (d.next, { next = succ d.next; tree = D.add d.next value d.tree }) + let read (key : ref) (d : dict) = + D.find key d.tree + let write (key : ref) (value : value) (d : dict) = + { next = d.next; tree = D.add key value d.tree } + module Base = struct + type 'a m = dict -> 'a * dict + let unit a = fun s -> (a, s) + let bind u f = fun s -> let (a, s') = u s in let u' = f a in u' s' + type 'a result = 'a + let run u = fst (u empty) + type 'a result_exn = 'a + let run_exn = run + end + include Monad.Make(Base) + let newref value = fun s -> alloc value s + let deref key = fun s -> (read key s, s) (* shouldn't fail because key will have an abstract type, and we never garbage collect *) + let change key value = fun s -> ((), write key value s) (* shouldn't allocate because key will have an abstract type *) + module T(Wrapped : Monad.W) = struct + module Trans = struct + module Wrapped = Wrapped + type 'a m = dict -> ('a * dict) Wrapped.m + let elevate w = fun s -> + Wrapped.bind w (fun a -> Wrapped.unit (a, s)) + let bind u f = fun s -> + Wrapped.bind (u s) (fun (a, s') -> f a s') + type 'a result = 'a Wrapped.result + let run u = + let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a) + in Wrapped.run w + type 'a result_exn = 'a Wrapped.result_exn + let run_exn u = + let w = Wrapped.bind (u empty) (fun (a,s) -> Wrapped.unit a) + in Wrapped.run_exn w + end + include Monad.MakeT(Trans) + let newref value = fun s -> Wrapped.unit (alloc value s) + let deref key = fun s -> Wrapped.unit (read key s, s) + let change key value = fun s -> Wrapped.unit ((), write key value s) + end + module TP(Wrapped : Monad.WP) = struct + module TransP = struct + include T(Wrapped) + let plus u v = fun s -> Wrapped.plus (u s) (v s) + let zero () = elevate (Wrapped.zero ()) + end + include TransP + include (Monad.MakeDistrib(TransP) : Monad.PLUS with type 'a m := 'a m) + end +end + + +(* must be parameterized on (struct type log = ... end) *) +module Writer_monad(Log : sig + type log + val zero : log + val plus : log -> log -> log +end) : sig + (* declare additional operations, while still hiding implementation of type m *) + type log = Log.log + type 'a result = 'a * log + type 'a result_exn = 'a * log + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val tell : log -> unit m + val listen : 'a m -> ('a * log) m + val listens : (log -> 'b) -> 'a m -> ('a * 'b) m + (* val pass : ('a * (log -> log)) m -> 'a m *) + val censor : (log -> log) -> 'a m -> 'a m +end = struct + type log = Log.log + module Base = struct + type 'a m = 'a * log + let unit a = (a, Log.zero) + let bind (a, w) f = let (a', w') = f a in (a', Log.plus w w') + type 'a result = 'a * log + let run u = u + type 'a result_exn = 'a * log + let run_exn = run + end + include Monad.Make(Base) + let tell entries = ((), entries) (* add entries to log *) + let listen (a, w) = ((a, w), w) + let listens selector u = listen u >>= fun (a, w) -> unit (a, selector w) (* filter listen through selector *) + let pass ((a, f), w) = (a, f w) (* usually use censor helper *) + let censor f u = pass (u >>= fun a -> unit (a, f)) +end + +(* pre-define simple Writer *) +module Writer1 = Writer_monad(struct + type log = string + let zero = "" + let plus s1 s2 = s1 ^ "\n" ^ s2 +end) + +(* slightly more efficient Writer *) +module Writer2 = struct + include Writer_monad(struct + type log = string list + let zero = [] + let plus w w' = Util.append w' w + end) + let tell_string s = tell [s] + let tell entries = tell (Util.reverse entries) + let run u = let (a, w) = run u in (a, Util.reverse w) + let run_exn = run +end + + +module IO_monad : sig + (* declare additional operation, while still hiding implementation of type m *) + type 'a result = 'a + type 'a result_exn = 'a + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + val printf : ('a, unit, string, unit m) format4 -> 'a + val print_string : string -> unit m + val print_int : int -> unit m + val print_hex : int -> unit m + val print_bool : bool -> unit m +end = struct + module Base = struct + type 'a m = { run : unit -> unit; value : 'a } + let unit a = { run = (fun () -> ()); value = a } + let bind (a : 'a m) (f: 'a -> 'b m) : 'b m = + let fres = f a.value in + { run = (fun () -> a.run (); fres.run ()); value = fres.value } + type 'a result = 'a + let run a = let () = a.run () in a.value + type 'a result_exn = 'a + let run_exn = run + end + include Monad.Make(Base) + let printf fmt = + Printf.ksprintf (fun s -> { Base.run = (fun () -> Pervasives.print_string s); value = () }) fmt + let print_string s = { Base.run = (fun () -> Printf.printf "%s\n" s); value = () } + let print_int i = { Base.run = (fun () -> Printf.printf "%d\n" i); value = () } + let print_hex i = { Base.run = (fun () -> Printf.printf "0x%x\n" i); value = () } + let print_bool b = { Base.run = (fun () -> Printf.printf "%B\n" b); value = () } +end + +module Continuation_monad : sig + (* expose only the implementation of type `('r,'a) result` *) + type 'a m + type 'a result = 'a m + type 'a result_exn = 'a m + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn and type 'a m := 'a m + (* misses that the answer types of all the cont's must be the same *) + val callcc : (('a -> 'b m) -> 'a m) -> 'a m + val reset : 'a m -> 'a m + (* misses that the answer types of second and third continuations must be b *) + val shift : (('a -> 'b m) -> 'b m) -> 'a m + (* overwrite the run declaration in S, because I can't declare 'a result = + * this polymorphic type (complains that 'r is unbound *) + val runk : 'a m -> ('a -> 'r) -> 'r + val run0 : 'a m -> 'a +end = struct + let id = fun i -> i + module Base = struct + (* 'r is result type of whole computation *) + type 'a m = { cont : 'r. ('a -> 'r) -> 'r } + let unit a = + let cont : 'r. ('a -> 'r) -> 'r = + fun k -> k a + in { cont } + let bind u f = + let cont : 'r. ('a -> 'r) -> 'r = + fun k -> u.cont (fun a -> (f a).cont k) + in { cont } + type 'a result = 'a m + let run (u : 'a m) : 'a result = u + type 'a result_exn = 'a m + let run_exn (u : 'a m) : 'a result_exn = u + let callcc f = + let cont : 'r. ('a -> 'r) -> 'r = fun k -> + (* Can't figure out how to make the type polymorphic enough + * to satisfy the OCaml type-checker (it's ('a -> 'r) -> 'r + * instead of 'r. ('a -> 'r) -> 'r); so we have to fudge + * with Obj.magic... which tells OCaml's type checker to + * relax, the supplied value has whatever type the context + * needs it to have. *) + let usek a = { cont = Obj.magic (fun _ -> k a) } + in (f usek).cont k + in { cont } + let reset u = unit (u.cont id) + let shift (f : ('a -> 'b m) -> 'b m) : 'a m = + let cont = + fun k -> (f (fun a -> unit (k a))).cont id + in { cont = Obj.magic cont } + let runk u k = u.cont k + let run0 u = u.cont id + end + include Monad.Make(Base) + let callcc = Base.callcc + let reset = Base.reset + let shift = Base.shift + let runk = Base.runk + let run0 = Base.run0 +end + +(* +(* This two-type parameter version works without Obj.magic *) + +module Continuation_monad2 : sig + (* expose only the implementation of type `('r,'a) result` *) + type ('r,'a) result = ('a -> 'r) -> 'r + type ('r,'a) result_exn = ('a -> 'r) -> 'r + include Monad.S2 with type ('r,'a) result := ('r,'a) result and type ('r,'a) result_exn := ('r,'a) result_exn + val callcc : (('a -> ('r,'b) m) -> ('r,'a) m) -> ('r,'a) m + val reset : ('a,'a) m -> ('r,'a) m + val shift : (('a -> ('q,'r) m) -> ('r,'r) m) -> ('r,'a) m + +end = struct + let id = fun i -> i + module Base = struct + (* 'r is result type of whole computation *) + type ('r,'a) m = ('a -> 'r) -> 'r + let unit a = fun k -> k a + let bind u f = fun k -> u (fun a -> (f a) k) + type ('r,'a) result = ('a -> 'r) -> 'r + let run u = u + type ('r,'a) result_exn = ('a -> 'r) -> 'r + let run_exn = run + end + include Monad.Make2(Base) + let callcc f = fun k -> + let usek a = fun _ -> k a + in f usek k + (* + val callcc : (('a -> 'r) -> ('r,'a) m) -> ('r,'a) m + val throw : ('a -> 'r) -> 'a -> ('r,'b) m + let callcc f = fun k -> f k k + let throw k a = fun _ -> k a + *) + (* from http://www.haskell.org/haskellwiki/MonadCont_done_right *) + let reset u = unit (u id) + let shift u = fun k -> u (fun a -> unit (k a)) id +end + *) + + +(* + * Scheme: + * (define (example n) + * (let ([u (let/cc k ; type int -> int pair + * (let ([v (if (< n 0) (k 0) (list (+ n 100)))]) + * (+ 1 (car v))))]) ; int + * (cons u 0))) ; int pair + * ; (example 10) ~~> '(111 . 0) + * ; (example -10) ~~> '(0 . 0) + * + * OCaml monads: + * let example n : (int * int) = + * Continuation_monad.(let u = callcc (fun k -> + * (if n < 0 then k 0 else unit [n + 100]) + * (* all of the following is skipped by k 0; the end type int is k's input type *) + * >>= fun [x] -> unit (x + 1) + * ) + * (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *) + * >>= fun x -> unit (x, 0) + * in run u) + * + * + * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *) + * let example1 () : int = + * Continuation_monad.(let v = reset ( + * let u = shift (fun k -> unit (10 + 1)) + * in u >>= fun x -> unit (100 + x) + * ) in let w = v >>= fun x -> unit (1000 + x) + * in run w) + * + * (* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *) + * let example2 () = + * Continuation_monad.(let v = reset ( + * let u = shift (fun k -> k (10 :: [1])) + * in u >>= fun x -> unit (100 :: x) + * ) in let w = v >>= fun x -> unit (1000 :: x) + * in run w) + * + * (* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *) + * let example3 () = + * Continuation_monad.(let v = reset ( + * let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x)) + * in u >>= fun x -> unit (100 :: x) + * ) in let w = v >>= fun x -> unit (1000 :: x) + * in run w) + * + * (* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *) + * (* not sure if this example can be typed without a sum-type *) + * + * (* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *) + * let example5 () : int = + * Continuation_monad.(let v = reset ( + * let u = shift (fun k -> k 1 >>= fun x -> k x) + * in u >>= fun x -> unit (10 + x) + * ) in let w = v >>= fun x -> unit (100 + x) + * in run w) + * + *) + + +module Leaf_monad : sig + (* We implement the type as `'a tree option` because it has a natural`plus`, + * and the rest of the library expects that `plus` and `zero` will come together. *) + type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree) + type 'a result = 'a tree option + type 'a result_exn = 'a tree + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + include Monad.PLUS with type 'a m := 'a m + (* LeafT transformer *) + module T : functor (Wrapped : Monad.W) -> sig + type 'a result = 'a tree option Wrapped.result + type 'a result_exn = 'a tree Wrapped.result_exn + include Monad.S with type 'a result := 'a result and type 'a result_exn := 'a result_exn + include Monad.PLUS with type 'a m := 'a m + val elevate : 'a Wrapped.m -> 'a m + (* note that second argument is an 'a tree?, not the more abstract 'a m *) + (* type is ('a -> 'b W) -> 'a tree? -> 'b tree? W == 'b treeT(W) *) + val distribute : ('a -> 'b Wrapped.m) -> 'a tree option -> 'b m + end +end = struct + type 'a tree = Leaf of 'a | Node of ('a tree * 'a tree) + (* uses supplied plus and zero to copy t to its image under f *) + let mapT (f : 'a -> 'b) (t : 'a tree option) (zero : unit -> 'b) (plus : 'b -> 'b -> 'b) : 'b = match t with + | None -> zero () + | Some ts -> let rec loop ts = (match ts with + | Leaf a -> f a + | Node (l, r) -> + (* recursive application of f may delete a branch *) + plus (loop l) (loop r) + ) in loop ts + module Base = struct + type 'a m = 'a tree option + let unit a = Some (Leaf a) + let zero () = None + let plus u v = match (u, v) with + | None, _ -> v + | _, None -> u + | Some us, Some vs -> Some (Node (us, vs)) + let bind u f = mapT f u zero plus + type 'a result = 'a tree option + let run u = u + type 'a result_exn = 'a tree + let run_exn u = match u with + | None -> failwith "no values" + (* + | Some (Leaf a) -> a + | many -> failwith "multiple values" + *) + | Some us -> us + end + include Monad.Make(Base) + include (Monad.MakeDistrib(Base) : Monad.PLUS with type 'a m := 'a m) + let base_plus = plus + let base_lift = lift + module T(Wrapped : Monad.W) = struct + module Trans = struct + let zero () = Wrapped.unit None + let plus u v = + Wrapped.bind u (fun us -> + Wrapped.bind v (fun vs -> + Wrapped.unit (base_plus us vs))) + include Monad.MakeT(struct + module Wrapped = Wrapped + type 'a m = 'a Base.m Wrapped.m + let elevate w = Wrapped.bind w (fun a -> Wrapped.unit (Some (Leaf a))) + let bind u f = Wrapped.bind u (fun t -> mapT f t zero plus) + type 'a result = 'a tree option Wrapped.result + let run u = Wrapped.run u + type 'a result_exn = 'a tree Wrapped.result_exn + let run_exn u = + let w = Wrapped.bind u (fun t -> match t with + | None -> failwith "no values" + | Some ts -> Wrapped.unit ts) + in Wrapped.run_exn w + end) + end + include Trans + include (Monad.MakeDistrib(Trans) : Monad.PLUS with type 'a m := 'a m) + (* let distribute f t = mapT (fun a -> a) (base_lift (fun a -> elevate (f a)) t) zero plus *) + let distribute f t = mapT (fun a -> elevate (f a)) t zero plus + end +end + + +module L = List_monad;; +module R = Reader_monad(struct type env = int -> int end);; +module S = State_monad(struct type store = int end);; +module T = Leaf_monad;; +module LR = L.T(R);; +module LS = L.T(S);; +module TL = T.T(L);; +module TR = T.T(R);; +module TS = T.T(S);; + +let t1 = Some (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11))));; + +(* +let ts = TS.distribute (fun i -> S.(puts succ >> unit i)) t1;; +TS.run ts 0;; +(* +- : int T.tree option * S.store = +(Some + (T.Node + (T.Node (T.Leaf 2, T.Leaf 3), + T.Node (T.Leaf 5, T.Node (T.Leaf 7, T.Leaf 11)))), + 5) +*) + +let ts2 = TS.distribute (fun i -> S.(puts succ >> get >>= fun n -> unit (i,n))) t1;; +TS.run_exn ts2 0;; +(* +- : (int * S.store) T.tree option * S.store = +(Some + (T.Node + (T.Node (T.Leaf (2, 1), T.Leaf (3, 2)), + T.Node (T.Leaf (5, 3), T.Node (T.Leaf (7, 4), T.Leaf (11, 5))))), + 5) +*) + +let tr = TR.distribute (fun i -> R.asks (fun e -> e i)) t1;; +TR.run_exn tr (fun i -> i+i);; +(* +- : int T.tree option = +Some + (T.Node + (T.Node (T.Leaf 4, T.Leaf 6), + T.Node (T.Leaf 10, T.Node (T.Leaf 14, T.Leaf 22)))) +*) + +let tl = TL.distribute (fun i -> L.(unit (i,i+1))) t1;; +TL.run_exn tl;; +(* +- : (int * int) TL.result = +[Some + (T.Node + (T.Node (T.Leaf (2, 3), T.Leaf (3, 4)), + T.Node (T.Leaf (5, 6), T.Node (T.Leaf (7, 8), T.Leaf (11, 12)))))] +*) + +let l2 = [1;2;3;4;5];; +let t2 = Some (T.Node (T.Leaf 1, (T.Node (T.Node (T.Node (T.Leaf 2, T.Leaf 3), T.Leaf 4), T.Leaf 5))));; + +LR.(run (distribute (fun i -> R.(asks (fun e -> e i))) l2 >>= fun j -> LR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);; +(* int list = [10; 11; 20; 21; 30; 31; 40; 41; 50; 51] *) + +TR.(run_exn (distribute (fun i -> R.(asks (fun e -> e i))) t2 >>= fun j -> TR.(plus (unit j) (unit (succ j))))) (fun i -> i*10);; +(* +int T.tree option = +Some + (T.Node + (T.Node (T.Leaf 10, T.Leaf 11), + T.Node + (T.Node + (T.Node (T.Node (T.Leaf 20, T.Leaf 21), T.Node (T.Leaf 30, T.Leaf 31)), + T.Node (T.Leaf 40, T.Leaf 41)), + T.Node (T.Leaf 50, T.Leaf 51)))) + *) + +LS.run (LS.distribute (fun i -> if i = -1 then S.get else if i < 0 then S.(puts succ >> unit 0) else S.unit i) [10;-1;-2;-1;20]) 0;; +(* +- : S.store list * S.store = ([10; 0; 0; 1; 20], 1) +*) + +*) + +let id : 'z. 'z -> 'z = fun x -> x + +let example n : (int * int) = + Continuation_monad.(let u = callcc (fun k -> + (if n < 0 then k 0 else unit [n + 100]) + (* all of the following is skipped by k 0; the end type int is k's input type *) + >>= fun [x] -> unit (x + 1) + ) + (* k 0 starts again here, outside the callcc (...); the end type int * int is k's output type *) + >>= fun x -> unit (x, 0) + in run0 u) + + +(* (+ 1000 (prompt (+ 100 (shift k (+ 10 1))))) ~~> 1011 *) +let example1 () : int = + Continuation_monad.(let v = reset ( + let u = shift (fun k -> unit (10 + 1)) + in u >>= fun x -> unit (100 + x) + ) in let w = v >>= fun x -> unit (1000 + x) + in run0 w) + +(* (+ 1000 (prompt (+ 100 (shift k (k (+ 10 1)))))) ~~> 1111 *) +let example2 () = + Continuation_monad.(let v = reset ( + let u = shift (fun k -> k (10 :: [1])) + in u >>= fun x -> unit (100 :: x) + ) in let w = v >>= fun x -> unit (1000 :: x) + in run0 w) + +(* (+ 1000 (prompt (+ 100 (shift k (+ 10 (k 1)))))) ~~> 1111 but added differently *) +let example3 () = + Continuation_monad.(let v = reset ( + let u = shift (fun k -> k [1] >>= fun x -> unit (10 :: x)) + in u >>= fun x -> unit (100 :: x) + ) in let w = v >>= fun x -> unit (1000 :: x) + in run0 w) + +(* (+ 100 ((prompt (+ 10 (shift k k))) 1)) ~~> 111 *) +(* not sure if this example can be typed without a sum-type *) + +(* (+ 100 (prompt (+ 10 (shift k (k (k 1)))))) ~~> 121 *) +let example5 () : int = + Continuation_monad.(let v = reset ( + let u = shift (fun k -> k 1 >>= fun x -> k x) + in u >>= fun x -> unit (10 + x) + ) in let w = v >>= fun x -> unit (100 + x) + in run0 w) + + +;; + +(1011, 1111, 1111, 121);; +(example1(), example2(), example3(), example5());; +((111,0), (0,0));; +(example ~+10, example ~-10);; + +module C = Continuation_monad +module TC = T.T(C) + +let testc df ic = + C.runk TC.(run_exn (distribute df t1)) ic;; + + +(* +(* do nothing *) +let initial_continuation = fun t -> t in +TreeCont.monadize t1 Continuation_monad.unit initial_continuation;; +*) +testc (C.unit) id;; + +(* +(* count leaves, using continuation *) +let initial_continuation = fun t -> 0 in +TreeCont.monadize t1 (fun a k -> 1 + k a) initial_continuation;; +*) + +testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (1 + v))) (fun t -> 0);; + +(* +(* convert tree to list of leaves *) +let initial_continuation = fun t -> [] in +TreeCont.monadize t1 (fun a k -> a :: k a) initial_continuation;; +*) + +testc C.(fun a -> shift (fun k -> k a >>= fun v -> unit (a::v))) (fun t -> ([] : int list));; + +(* +(* square each leaf using continuation *) +let initial_continuation = fun t -> t in +TreeCont.monadize t1 (fun a k -> k (a*a)) initial_continuation;; +*) + +testc C.(fun a -> shift (fun k -> k (a*a))) (fun t -> t);; + + +(* +(* replace leaves with list, using continuation *) +let initial_continuation = fun t -> t in +TreeCont.monadize t1 (fun a k -> k [a; a*a]) initial_continuation;; +*) + +testc C.(fun a -> shift (fun k -> k (a,a+1))) (fun t -> t);; +