manip trees: more explanation
[lambda.git] / code / tree_monadize.ml
1 (*
2  * tree_monadize.ml
3  *
4  * 'a and so on are type variables in OCaml; they stand for arbitrary types
5  * What if you want a variable for a type constructor? For example, you want to
6  * generalize this pattern:
7  *      type ('a) t1 = 'a -> 'a list
8  *      type ('a) t2 = 'a -> 'a option
9  *      type ('a) t3 = 'a -> 'a reader
10  * and so on? OCaml won't let you do this:
11  *      type ('a, 'b) t = 'a -> 'a 'b
12  * to generalize on the 'b position, we instead have to use OCaml's modules,
13  * and in particular its ability to make modules parameterized on other modules
14  * (OCaml calls these parameterized modules Functors, but that name is also
15  * used in other ways in this literature, so I won't give in to it.)
16  *
17  * Here's how you'd have to define the t type from above:
18  *      module T_maker(S: sig
19  *          type 'a b
20  *      end) = struct
21  *          type 'a t = 'a -> 'a S.b
22  *      end
23  * And here's how you'd use it:
24  *      module T_list = T_maker(struct type 'a b = 'a list end);;
25  *      type 'a t1 = 'a T_list.t;;
26  *
27  * I know, it seems unnecessarily complicated.
28  *)
29
30 type 'a tree = Leaf of 'a | Node of ('a tree) * ('a tree);;
31
32 let t1 = Node
33            (Node
34              (Leaf 2, Leaf 3),
35             Node
36              (Leaf 5,
37               Node
38                 (Leaf 7, Leaf 11)));;
39
40
41 module Tree_monadizer(S: sig
42   type 'a m
43   val unit : 'a -> 'a m
44   val bind : 'a m -> ('a -> 'b m) -> 'b m
45 end) = struct
46   let rec monadize (f: 'a -> 'b S.m) (t: 'a tree) : 'b tree S.m =
47     match t with
48     | Leaf a -> S.bind (f a) (fun b -> S.unit (Leaf b))
49     | Node(l, r) ->
50         S.bind (monadize f l) (fun l' ->
51           S.bind (monadize f r) (fun r' ->
52             S.unit (Node (l', r'))))
53 end;;
54
55
56 type env = int -> int;;
57
58 type 'a reader = env -> 'a;;
59 let unit_reader a : 'a reader = fun e -> a;;
60 let bind_reader (u : 'a reader) (f : 'a -> 'b reader) : 'b reader =
61   fun e -> f (u e) e;;
62
63 module TreeReader = Tree_monadizer(struct
64   type 'a m = 'a reader
65   let unit = unit_reader
66   let bind = bind_reader
67 end);;
68
69
70 type store = int;;
71
72 type 'a state = store -> 'a * store;;
73 let unit_state a : 'a state  = fun s -> (a, s);;
74 let bind_state (u : 'a state) (f : 'a -> 'b state) : 'b state =
75   fun s -> (let (a, s') = u s in (f a) s');;
76
77 module TreeState =  Tree_monadizer(struct
78   type 'a m = 'a state
79   let unit = unit_state
80   let bind = bind_state
81 end);;
82
83
84 let unit_list a = [a];;
85 let bind_list (u: 'a list) (f : 'a -> 'b list) : 'b list =
86   List.concat(List.map f u);;
87
88 module TreeList =  Tree_monadizer(struct
89   type 'a m = 'a list
90   let unit = unit_list
91   let bind = bind_list
92 end);;
93
94
95
96 (* we need to a new module when the monad is parameterized on two types *)
97 module Tree_monadizer2(S: sig
98   type ('a,'x) m
99   val unit : 'a -> ('a,'x) m
100   val bind : ('a,'x) m -> ('a -> ('b,'x) m) -> ('b,'x) m
101 end) = struct
102   let rec monadize (f: 'a -> ('b,'x) S.m) (t: 'a tree) : ('b tree,'x) S.m =
103     (* the definition is the same, the difference is only in the types *)
104     match t with
105     | Leaf a -> S.bind (f a) (fun b -> S.unit (Leaf b))
106     | Node(l, r) ->
107         S.bind (monadize f l) (fun l' ->
108           S.bind (monadize f r) (fun r' ->
109             S.unit (Node (l', r'))))
110 end;;
111
112 type ('a,'r) cont = ('a -> 'r) -> 'r;;
113 let unit_cont a : ('a,'r) cont = fun k -> k a;;
114 let bind_cont (u: ('a,'r) cont) (f: 'a -> ('b,'r) cont) : ('b,'r) cont =
115   fun k -> u (fun a -> f a k);;
116
117 module TreeCont =  Tree_monadizer2(struct
118   type ('a,'r) m = ('a,'r) cont
119   let unit = unit_cont
120   let bind = bind_cont
121 end);;
122
123
124
125 (* 
126  * Here are all the examples from
127  * http://lambda.jimpryor.net/manipulating_trees_with_monads/
128  *)
129
130 let int_readerize : int -> int reader =
131   fun (a : int) (modifier : int -> int) -> modifier a;;
132
133 (* double each leaf *)
134 TreeReader.monadize int_readerize t1 (fun i -> i + i);;
135
136 (* square each leaf *)
137 TreeReader.monadize int_readerize t1 (fun i -> i * i);;
138
139 (* count leaves *)
140 TreeState.monadize (fun a s -> (a, s+1)) t1 0;;
141
142 (* replace leaves with list *)
143 TreeList.monadize (fun i -> [[i;i*i]]) t1;;
144
145 (* convert tree to list of leaves *)
146 TreeCont.monadize (fun a k -> a :: k a) t1 (fun t -> []);;
147
148 (* do nothing *)
149 TreeCont.monadize unit_cont t1 (fun t-> t);;
150
151 (* square each leaf using continuation *)
152 TreeCont.monadize (fun a k -> k (a*a)) t1 (fun t -> t);;
153
154 (* replace leaves with list, using continuation *)
155 TreeCont.monadize (fun a k -> k [a; a*a]) t1 (fun t -> t);;
156
157 (* count leaves, using continuation *)
158 TreeCont.monadize (fun a k -> 1 + k a) t1 (fun t -> 0);;
159
160