(* * tree_monadize.ml * * 'a and so on are type variables in OCaml; they stand for arbitrary types * What if you want a variable for a type constructor? For example, you want to * generalize this pattern: * type ('a) t1 = 'a -> 'a list * type ('a) t2 = 'a -> 'a option * type ('a) t3 = 'a -> 'a reader * and so on? OCaml won't let you do this: * type ('a, 'b) t = 'a -> 'a 'b * to generalize on the 'b position, we instead have to use OCaml's modules, * and in particular its ability to make modules parameterized on other modules * (OCaml calls these parameterized modules Functors, but that name is also * used in other ways in this literature, so I won't give in to it.) * * Here's how you'd have to define the t type from above: * module T_maker(S: sig * type 'a b * end) = struct * type 'a t = 'a -> 'a S.b * end * And here's how you'd use it: * module T_list = T_maker(struct type 'a b = 'a list end);; * type 'a t1 = 'a T_list.t;; * * I know, it seems unnecessarily complicated. *) type 'a tree = Leaf of 'a | Node of ('a tree) * ('a tree);; let t1 = Node (Node (Leaf 2, Leaf 3), Node (Leaf 5, Node (Leaf 7, Leaf 11)));; module Tree_monadizer(S: sig type 'a m val unit : 'a -> 'a m val bind : 'a m -> ('a -> 'b m) -> 'b m end) = struct let rec monadize (f: 'a -> 'b S.m) (t: 'a tree) : 'b tree S.m = match t with | Leaf a -> S.bind (f a) (fun b -> S.unit (Leaf b)) | Node(l, r) -> S.bind (monadize f l) (fun l' -> S.bind (monadize f r) (fun r' -> S.unit (Node (l', r')))) end;; type env = int -> int;; type 'a reader = env -> 'a;; let unit_reader a : 'a reader = fun e -> a;; let bind_reader (u : 'a reader) (f : 'a -> 'b reader) : 'b reader = fun e -> f (u e) e;; module TreeReader = Tree_monadizer(struct type 'a m = 'a reader let unit = unit_reader let bind = bind_reader end);; type store = int;; type 'a state = store -> 'a * store;; let unit_state a : 'a state = fun s -> (a, s);; let bind_state (u : 'a state) (f : 'a -> 'b state) : 'b state = fun s -> (let (a, s') = u s in (f a) s');; module TreeState = Tree_monadizer(struct type 'a m = 'a state let unit = unit_state let bind = bind_state end);; let unit_list a = [a];; let bind_list (u: 'a list) (f : 'a -> 'b list) : 'b list = List.concat(List.map f u);; module TreeList = Tree_monadizer(struct type 'a m = 'a list let unit = unit_list let bind = bind_list end);; (* we need to a new module when the monad is parameterized on two types *) module Tree_monadizer2(S: sig type ('a,'x) m val unit : 'a -> ('a,'x) m val bind : ('a,'x) m -> ('a -> ('b,'x) m) -> ('b,'x) m end) = struct let rec monadize (f: 'a -> ('b,'x) S.m) (t: 'a tree) : ('b tree,'x) S.m = (* the definition is the same, the difference is only in the types *) match t with | Leaf a -> S.bind (f a) (fun b -> S.unit (Leaf b)) | Node(l, r) -> S.bind (monadize f l) (fun l' -> S.bind (monadize f r) (fun r' -> S.unit (Node (l', r')))) end;; type ('a,'r) cont = ('a -> 'r) -> 'r;; let unit_cont a : ('a,'r) cont = fun k -> k a;; let bind_cont (u: ('a,'r) cont) (f: 'a -> ('b,'r) cont) : ('b,'r) cont = fun k -> u (fun a -> f a k);; module TreeCont = Tree_monadizer2(struct type ('a,'r) m = ('a,'r) cont let unit = unit_cont let bind = bind_cont end);; (* * Here are all the examples from * http://lambda.jimpryor.net/manipulating_trees_with_monads/ *) let int_readerize : int -> int reader = fun (a : int) (modifier : int -> int) -> modifier a;; (* double each leaf *) TreeReader.monadize int_readerize t1 (fun i -> i + i);; (* square each leaf *) TreeReader.monadize int_readerize t1 (fun i -> i * i);; (* count leaves *) TreeState.monadize (fun a s -> (a, s+1)) t1 0;; (* replace leaves with list *) TreeList.monadize (fun i -> [[i;i*i]]) t1;; (* convert tree to list of leaves *) TreeCont.monadize (fun a k -> a :: k a) t1 (fun t -> []);; (* do nothing *) TreeCont.monadize unit_cont t1 (fun t-> t);; (* square each leaf using continuation *) TreeCont.monadize (fun a k -> k (a*a)) t1 (fun t -> t);; (* replace leaves with list, using continuation *) TreeCont.monadize (fun a k -> k [a; a*a]) t1 (fun t -> t);; (* count leaves, using continuation *) TreeCont.monadize (fun a k -> 1 + k a) t1 (fun t -> 0);;