Lambdas
After last class, we can run programs like this:
(define (range lo hi) (if (< lo hi) (pair lo (range (add1 lo) hi)) false)) (define (map f l) (if (not l) l (pair (f (left l)) (map f (right l))))) (define (g x) (+ x 1)) (print (map g (range 0 2)))
Very cool! We can implement map
in our language!
Q: What will be the output from running this program?
But, hey, it’s kind of annoying that we had to define this
g
function at the top level. Normally when we’re
using map
, we just make a little anonymous function right
where we’re calling map
. We’d really like our
program to look more like this:
(define (range lo hi) (if (< lo hi) (pair lo (range (add1 lo) hi)) false)) (define (map f l) (if (not l) l (pair (f (left l)) (map f (right l))))) (print (map (lambda (x) (+ x 1)) (range 0 2)))
What would it take to add support for this in our language? Possibly less effort than you’d think! Think about how similar the first and second program are. What if every time we see a program like the second one, we transform it into one that looks like the first? Once we see a lambda, or anonymous function, we can add a new definition to our list of top-level definitions. We’ll give it a special name that programmer-defined functions aren’t allowed to use. Then once our new definition is available in our usual list of top-level definitions, we can use it the same way we’ve used our definitions before— by putting the appropriate function name where we previously had the lambda itself. Let’s take a look at how this works.
Anonymous functions
Here’s the coolest thing about how we’re going to do
this: We’re not going to change compile.ml
or
interp.ml
at all! Remember that
compile.ml
and interp.ml
already have all
the functionality we need to run programs that look like the first
program above. So if we can transform the AST before the AST even
reaches the compiler or the interpreter, we don’t need to
extend the compiler or interpreter at all. We’ll just
transform every AST that uses the new feature into an equivalent AST
that only uses the old features. This is sometimes called
“syntactic sugar,” because it gives us a sweeter syntax
for writing programs that we could already write even without the
sugar!
To do this, we’re going to need to make a new type that will
look a lot like expr
, our existing type for ASTs.
We’ll call the new type expr_lam
(lam for
lambdas). expr
will stay the same as always, but
expr_lam
will include lambda uses. Our plan is to parse
programs into the new type, expr_lam
. Then we’ll
write a transformer for turning things of type
expr_lam
into things of type expr
(which
doesn’t include lambdas). (Remember, our interpreter and
compiler only know about things of type expr
! They
don’t know anything about expr_lam
. So we better
produce something of type expr
at the end.)
Let’s start by copying the expr
type, replacing
all uses of expr
in the constructors with uses of
expr_lam
and adding a constructor for
Lambda
.
type expr_lam = | Prim0 of prim0 | Prim1 of prim1 * expr_lam | Prim2 of prim2 * expr_lam * expr_lam | Let of string * expr_lam * expr_lam | If of expr_lam * expr_lam * expr_lam | Do of expr_lam list | Num of int | Var of string | Call of expr_lam * expr_lam list | True | False | Lambda of string list * expr_lam
Then let’s write our translator for translating from
expr_lam
to expr
:
let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr = function | Num x -> Num x | Var s -> Var s | True -> True | False -> False | If (test_exp, then_exp, else_exp) -> If ( expr_of_expr_lam defns test_exp , expr_of_expr_lam defns then_exp , expr_of_expr_lam defns else_exp ) | Let (var, exp, body) -> Let (var, expr_of_expr_lam defns exp, expr_of_expr_lam defns body) | Prim0 p -> Prim0 p | Prim1 (p, e) -> Prim1 (p, expr_of_expr_lam defns e) | Prim2 (p, e1, e2) -> Prim2 (p, expr_of_expr_lam defns e1, expr_of_expr_lam defns e2) | Do exps -> Do (List.map (expr_of_expr_lam defns) exps) | Call (exp, args) -> Call (expr_of_expr_lam defns exp, List.map (expr_of_expr_lam defns) args) | Lambda (args, body) -> let name = gensym "_lambda" in defns := {name; args; body= expr_of_expr_lam defns body} :: !defns ; Var name
Now most of this may look like it’s doing almost nothing, but
in fact when we see something constructed with the
expr_lam
Num
constructor, we’re
calling the expr
Num
constructor, so even
the boring cases are doing important work. But the most interesting
case is the Lambda
case. Here we see that we’re
generating a new name for the (previously) anonymous function. Then
we’re adding the definition directly to our existing list of
defns
. Finally, we replace the lambda expression with a
use of the newly created name.
And that’s it! We’ve added anonymous functions to our language, for both the interpreter and the compiler. Before our compilation process was: string -> tokens -> AST -> assembly. Now our compilation process is: string -> tokens -> AST with lambdas -> AST without lambdas -> assembly. This was a pretty simple transformation, but it gives you a sense of how AST to AST translation works. In an industrial compiler, you’ll often have many layers of AST to AST transformations.
Closures
Ok, that’s all well and good, but what about this program?
(print (let ((x 2)) ((lambda (y) (+ y x)) 3) ) )
Q: What should this produce?
Q: Why won’t it work with our current interpreter and compiler?
The problem is, say we take this lambda and hoist it to the top of
our program, as our transformation does. We can even do it manually
on paper. If we do this, how will we know what x
should
be when we run the function?
It makes sense that this is a limitation of our approach. If the context in which we’re defining the lambda doesn’t matter, then we can hoist it to the top level without issue. But if it does matter, what now?
Closures in the interpreter
Let’s start by fixing this in the interpreter.
First, we’ll add a constructor to our
expr
type:
type expr = (* some constructors elided *) | Closure of string
And we’ll start using that constructor whenever we find a lambda in our program:
let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr = function (* some cases elided) | Lambda (args, body) -> let name = gensym "_lambda" in defns := {name; args; body= expr_of_expr_lam defns body} :: !defns ; Closure name
Why closure? Well, what we’re going to do when we see a lambda in our interpreter is grab the appropriate name for retrieving the definition from our list of definitions but also save the environment that existed when the lambda was created. Remember, when we encounter the lambda in the interpreter, we’re in an environment where all the symbols we need are mapped to their values. That’s exactly the information we’ll need when we’re ready to run the function body! A closure encloses or packages up a function and an environment to use for the function. So now we’ve set it up so that our ASTs have closures in all the places where we’ll want to save a function with its environment.
Let’s move on to actually changing the interpreter.
First, we know that we want to start storing more things in our function values, so let’s change our value type:
type value = Number of int | Boolean of bool | Pair of (value * value) | Function of (string * value symtab)
Now we’re ready to change a couple cases in
interp_exp
to make this new type of function value:
let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with (* some cases elided *) | Var var when is_defn defns var -> Function (var, Symtab.empty) | Closure f -> Function (f, env)
These are the only cases where we need to make function values, so
these are the only cases we need to change (in the
Var
case) or introduce (in the
Closure
case). In the Var
case,
we’re just seeing the same thing we always used to see when
we called functions. Someone used a name that appears in our list
of definitions. So we know that’s at the top level, and we
don’t need to save the context. We’ll just use an
empty environment for that function value.
The Closure
case is the exciting case, the AST node
that we’ve introduced specifically for anonymous functions.
In this case, we’ll want to save f
, which is
the new name we gave the anonymous function. But we’ll also
want to save env
, which is the environment that was
passed as an argument to interp_exp
. Which is to say,
it’s the context in which the lambda was created. So we make
a function value with Function (f, env)
.
Now we’re ready to use our new function values!
We still want to make sure that the thing we’re calling is actually a function value. But the function value is now carrying some extra information, the environment that we’ve packaged up with it. How should we use that information?
let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with | Call (f, args) -> ( let vals = args |> List.map (interp_exp defns env) in let fv = interp_exp defns env f in match fv with | Function (name, saved_env) when is_defn defns name -> let defn = get_defn defns name in if List.length args = List.length defn.args then let fenv = List.combine defn.args vals |> Symtab.add_list saved_env in interp_exp defns fenv defn.body else raise (BadExpression exp) | _ -> raise (BadExpression exp) )
Notice that we’re adding the values from
saved_env
into the existing env
, so that
all relevant context is available while we’re running the
anonymous function.
And now our anonymous function that uses values from the outside world works!
Closures in the compiler
This is looking pretty tough for our compiler. In the interpreter, we have this nice symbol table that we can tote around. It may look like we didn’t actually change the function values that much in the interpreter, but adding a symbol table is actually a lot! That’s a whole balanced binary tree that we stuffed in there! We suddenly started carrying a lot of information with our function values. And it’s not clear how to do the same thing in the compiler. After all, the compiler’s symbol table only exists at compile time. It doesn’t exist at run time at all!
One insight that will help us out here: we don’t need the whole symbol table.
Take a look at our example from before:
(print (let ((x 2)) ((lambda (y) (+ y x)) 3) ) )
Our lambda doesn’t need everything—it needs
x
! And we can tell that just from looking at the
lambda. If we look through the lambda, x
is the only
thing that gets referenced without being defined.
Variables that an expression uses but doesn’t define are
called free variables. So from the example above,
we’d say x
is free in
(lambda (y) (+ y x))
. We’ll want to package up
not all our variables but just the free variables.
Q: What’s free in this expression?
(lambda (y) (+ y y))
Q: What’s free in this expression?
(lambda (y) (+ x y))
Q: What’s free in this expression?
(print ((adder 2) 3))
Q: What’s free in this expression? (f a b c d)
Which variables are free really depends on the context! As PL researcher [Brigitte Pientka][https://www.cs.mcgill.ca/~bpientka/] once said in answer to a question about how her system handled free variables, without missing a beat: “I don’t believe in free variables. Even birds are chained to the sky.” It’s not that our free variables are truly free! They have some definition. Those definitions are just outside the expression. It’s that they’re going to depend on the context.
So the free variables are the ones for which we need to keep context. Let’s start by figuring out which variables are free inside a given lambda. Remember, we can tell which of them are free just by looking at the lambda itself.
Here’s our helper function for identifying which variables are free in a given lambda:
let rec fv (bound : string list) (exp : expr) = match exp with | Var s when not (List.mem s bound) -> [s] | Let (v, e, body) -> fv bound e @ fv (v :: bound) body | If (te, the, ee) -> fv bound te @ fv bound the @ fv bound ee | Do es -> List.concat_map (fv bound) es | Call (exp, args) -> fv bound exp @ List.concat_map (fv bound) args | Prim1 (_, e) -> fv bound e | Prim2 (_, e1, e2) -> fv bound e1 @ fv bound e2 | _ -> []
Here’s what you should notice about fv
(fv
for free variables). The second argument is an expr
,
which represents the body of the lambda we’re analyzing. The
other argument, bound
is a list of bound variables.
When we first call fv
, these are the arguments to the
lambda. Later we may add other items. (Note that the
Let
case adds a new item to bound list for when we
traverse the body of the let.) Our goal is to identify all
variables that are referenced but not defined, so if we see a case
where we use Var s
and it’s not in
bound
, we’ll add s
to our list of
free variables. On the other hand, if we see a case where we use
Var s
and it is in bound, it’s not a
free variable, and we also don’t need to make a recursive
call to fv
, so we don’t even have a case for
that. For most of our cases, we’re just making the recursive
calls and joining together the lists of free variables we get from
each recursive call.
Now what should we do when it comes time to actually compile one
of these lambdas? That’s what we need to decide to implement
our Closure
case in compile_exp
. We
already have a way of using the name to get a function pointer for
cases where the function had a top-level definition:
let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with (* some cases elided *) | Var var when is_defn defns var -> [ LeaLabel (Reg Rax, defn_label var) ; Or (Reg Rax, Imm fn_tag) ]
We can reuse that approach for grabbing a function pointer when we
encounter a closure. But what else will we need? Well, remember
that all we have to do to figure out what variables are free in an
expression is look at the expression itself. And when we’re
compiling a lambda, we can look in our defns
list and
know exactly what expression we’ll be using. So we can
figure out the list of free variables at compile time. Which means
we can go ahead and storage a package of the function pointer and
the values of the
the fixed, finite set of free variables the function will
need. The values themselves are not fixed, but the set of free
variables is fixed.
Let’s look back at our sample program again:
(print (let ((x 2)) ((lambda (y) (+ y x)) 3) ) )
At this point, we’ve transformed it into something like this:
(define (_lambda_1 y) (+ y x)) (print (let ((x 2)) (_lambda_1 3) ) )
So for this program, we’ll put together a package of the
pointer to
_lambda_1
and the value 2, which was the value of
x
when we created the lambda. And then we just need
to make sure that this package is available when we call the
lambda.
Whiteboard: Try drawing a picture of how this should look at run
time. We’ll put this package of values, this function value,
on the heap. In the first free cell of the heap, put the function
pointer—the address of the first instruction of
_lambda_1
. (This is the same kind of function pointer
we’ve already been using.) In the second free cell of the
heap, put the value of x
, which is 2. (In our runtime
representation, 8.) This two-cell block of the heap (more cells if
we have more free variables) represents our lambda. When we call
the lambda, we need to have a pointer to this block. So
we’ll add an additional argument to the stack, when we make
a call! Remember that we’re calling
_lambda_1
on 3, so we’ll have to put the
runtime representation of 3 (12) on the first slot of the stack.
After that, we’ll add an extra argument, which is the
pointer to our lambda, and we’ll put that in the next
available slot on the stack.
Then the function itself needs to handle moving free variables
from the heap onto the appropriate places on the stack. (Where
appropriate means the places where the already-compiled body of
_lambda_1
knows to look for them, relative to the
rsp
when we enter _lambda_1
.)
So here’s what our closure case will do:
let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with (* some cases elided *) | Closure f -> let defn = get_defn defns f in let fvs = fv (List.map (fun d -> d.name) defns @ defn.args) defn.body in let fv_movs = List.mapi (fun i var -> [ Mov (Reg Rax, stack_address (Symtab.find var tab)) ; Mov (MemOffset (Reg Rdi, Imm (8 * (i + 1))), Reg Rax) ]) fvs in if List.exists (fun v -> not (Symtab.mem v tab)) fvs then raise (BadExpression exp) else [ LeaLabel (Reg Rax, defn_label f) ; Mov (MemOffset (Reg Rdi, Imm 0), Reg Rax) ] @ List.concat fv_movs @ [ Mov (Reg Rax, Reg Rdi) ; Or (Reg Rax, Imm fn_tag) ; Add (Reg Rdi, Imm (8 * (List.length fvs + 1))) ]
First, grab the definition based on the name, f
.
Next, figure out the list of free variables, which we call
fvs
. The assembly instructions in
fv_movs
are in charge of moving the values of the
free variables onto the heap. At compile time, we use the symbol
table to figure out where the value of the free variable will be
stored at run time. Then we emit the mov
instruction
that will grab the value from that slot and put it into
rax
. Then we emit the mov
instruction
that grabs the value from rax
and puts it on the
heap.
We use LeaLabel
to get the function pointer, and we
put that on the heap too. (We get the function pointer the same
way as before, in our Var
case. The difference is
that now we’re putting it on the heap.) Finally, with our
last three instructions, we make a pointer to the package
we’ve put on the heap, save that in rax
, and
then update our heap pointer.
Now, when the time comes to make a call, we better be calling one
of these function values! So let’s take a look at our
Call
cases. Remember, we have one for
is_tail
and one for not is_tail
.
let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with | Call (f, args) when not is_tail -> let stack_base = align_stack_index (stack_index + 8) in let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_base - (8 * (i + 2))) arg false @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)]) |> List.concat in compiled_args @ compile_exp defns tab (stack_base - (8 * (List.length args + 2))) f false @ ensure_fn (Reg Rax) @ [ Mov (stack_address (stack_base - (8 * (List.length args + 2))), Reg Rax) ; Sub (Reg Rax, Imm fn_tag) ; Mov (Reg Rax, MemOffset (Reg Rax, Imm 0)) ] @ [ Add (Reg Rsp, Imm stack_base) ; ComputedCall (Reg Rax) ; Sub (Reg Rsp, Imm stack_base) ] | Call (f, args) when is_tail -> let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_index - (8 * i)) arg false @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)]) |> List.concat in let moved_args = args |> List.mapi (fun i _ -> [ Mov (Reg R8, stack_address (stack_index - (8 * i))) ; Mov (stack_address ((i + 1) * -8), Reg R8) ]) |> List.concat in compiled_args @ compile_exp defns tab (stack_index - (8 * List.length args)) f false @ ensure_fn (Reg Rax) @ moved_args @ [ Mov (stack_address ((List.length args + 1) * -8), Reg Rax) ; Sub (Reg Rax, Imm fn_tag) ; Mov (Reg Rax, MemOffset (Reg Rax, Imm 0)) ] @ [ComputedJmp (Reg Rax)]
Most of this isn’t so different from what we were already doing for ~Call~s. We need to compile the arguments. We need to put them on the stack. But after that, we add one more argument, right after the standard arguments: the pointer to that lambda package that we put on the heap.
Using that pointer, we can also grab the address of the first
instruction of the function, which we use for our
ComputedJmp
. So the only new thing we’ve done
is emit assembly to add this pointer to the function package as a
final argument of our function.
Which means we have one last thing to do! The part of the compiler that emits code for each function definition has to be ready to actually use these function value packages. So we’ll need to update each compiled function definition to grab the values it needs from the heap.
Here’s our old compile_defn
:
let compile_defn defns defn = let ftab = defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list in [Align 8; Label (defn_label defn.name)] @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true @ [Ret]
Here’s our updated compile_defn
:
let compile_defn defns defn = let fvs = fv (List.map (fun d -> d.name) defns @ defn.args) defn.body in let ftab = defn.args @ fvs |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list in let fvs_to_stack = [ Mov (Reg Rax, stack_address (-8 * (List.length defn.args + 1))) ; Sub (Reg Rax, Imm fn_tag) ; Add (Reg Rax, Imm 8) ] @ List.concat (List.mapi (fun i _ -> [ Mov (Reg R8, MemOffset (Reg Rax, Imm (i * 8))) ; Mov (stack_address (-8 * (List.length defn.args + 1 + i)), Reg R8) ]) fvs) in [Align 8; Label (defn_label defn.name)] @ fvs_to_stack @ compile_exp defns ftab (-8 * (Symtab.cardinal ftab + 1)) defn.body true @ [Ret]
What’s changed? Well the old version made ftab
,
our function’s symbol table, by mapping each function
argument to an offset on the stack. Then we emitted a label, then
emitted the assembly for the function (comipled with
compile_exp
, using our new ftab
), and
finally we emitted the ret
.
Our new version calls fv
on the function body to get
the list of free variables. Then we use
both defn.args
and fv
to make our
new symbol table (let ftab = defn.args @ fvs...
).
We’re assuming that we’ll have the normal funciton
arguments on the stack, followed by the variables that are free in
the definition body. And then we emit the code that takes care of
putting those free variables in the right spots.
Remember, the Call
cases of
compile_exp
are in charge of putting the function
argument values on the stack, because it’s when we’re
executing a call that we know how to calculate the argument
values. But it’s while we’re looking at the definition
of the function that we know what variables are free, and so this
is the place where we want to move the free variables from the
heap to the stack.
Once we’ve updated the symbol table to include the free
variables, we’re ready to move them from the heap to the
stack (fvs_to_stack
). Notice that we use
(stack_address (-8 * (List.length defn.args + 1 + i))
to put them past the normal function arguments on the stack. And
notice that we find the pointer to the function value by using
that special last argument we added in our call cases (Mov (Reg Rax, stack_address (-8 * (List.length defn.args +
1)))
).
And that’s it! When we encounter a lambda, it closes over its environment to save all the context we need. We have full, general lambdas.
And with that, we’ve successfully completed a fully operational functional programming language!