First-class functions
Most functional programming languages (really, most modern programming languages of all stripes) support some form of first-class functions. Right now, functions in our language are totally separate from values. We can define them and call them, but we can’t:
- Put a function in a variable
- Return a function from a function
- Pass a function to a function as an argument
Here are some programs, written in a slightly extended version of our language, that use first-class functions:
(define (f g) (g 2)) (define (mul2 x) (+ x x)) (print (f mul2)) (define (f g) (g 2)) (print (f (lambda (x) (+ x x)))) (define (f g) (g 2)) (let ((y 3)) (print (f (lambda (x) (+ x y)))))
Today we’ll add support for the feature you see in the first
program– we’ll be able to pass functions around like other
values. Next time we’ll add the features we need to support the
second two programs. Notice what’s different about programs two
and three. In the second, we’re making an anoymous function,
(lambda (x) (+ x x))
. In the third, we’re making an
anonymous function that references a variable
y
that’s let bound outside of the function. We could
jump all the way to supporting all three of these programs, but this is
actually pretty advanced stuff! Most undergrad compilers classes never
get around to supporting language features like first-class functions.
So we’ll break it down into stages.
Q: Before you scroll down, take a moment to think about this question. Will we need to change our AST in order to support this change to our language? If yes, how?
Extending the AST
Now that we have an AST for our language, we’ll need to extend it in order to support calling function values that aren’t literal function names. The change is pretty minor.
Before, the AST constructor for calls looked like this:
Call of string * expr list
Now we’ll use the version that appears below:
type expr = | Prim0 of prim0 | Prim1 of prim1 * expr | Prim2 of prim2 * expr * expr | Let of string * expr * expr | If of expr * expr * expr | Do of expr list | Num of int | Var of string | Call of expr * expr list | True | False
We’ll also want to make a change to our helper function that translates S-expression ASTs into ASTs for our language. Before the call case looked like this:
let rec expr_of_s_exp : s_exp -> expr = function (* some cases elided ... *) | Lst (Sym f :: args) -> Call (f, List.map expr_of_s_exp args)
Now we’ll want it to look like this:
let rec expr_of_s_exp : s_exp -> expr = function (* some cases elided ... *) | Lst (f :: args) -> Call (expr_of_s_exp f, List.map expr_of_s_exp args)
Notice that f
is no longer a
Sym
—instead it’s an s_exp
. So
we also have to make a recursive call to
expr_of_s_exp
in order to get the appropriate
representation of f
into our AST.
Extending the interpreter
Now we’re ready to extend the interpreter. Consider our current value type:
type value = | Number of int | Boolean of bool | Pair of (value * value)
Q: What will we need to add, if anything, to our value type in order to support first-class functions?
You could consider a couple options. We could add
Function of string list * expr
, and then our function
values could carry around all the same information that our
defn
type carries around. Any time we encounter a given
function value in a call, we could go ahead and do the same thing we
do when we’ve looked a funtion up (by name) in our list of
defns
. (Evaluate the arguments, make a new symbol table
based on those arguments, then evaluate the body with the new symbol
table.) We’d just be skipping over the part where we look it
up in defns
.
We could add Function of (value list -> value)
.
We’ve had great luck in the past with building on top of
OCaml’s own features when we want to implement features in the
interpreter. Why not use OCaml functions to implement functions in
our language?
We’re going to pick an alternative that will make the
structure of the interpreter match a bit more closely with the
structure of the compiler. But that doesn’t mean those other
options are bad! We’re just picking one of many reasonable
ways to implement this. We’ll add
Function of string
.
Q: How is this going to be enough information for us to implement first-class functions?
This is going to be enough information becuase we’ll just use
this string to look up the appropriate definition in our list of
defns
! Even though we’re no longer demanding a
literal function name in the first position in every
Call
S-expression, we haven’t introduced
anonymous functions yet. So we can go ahead and evaluate the
S-expression in the first position, figure out what name literal we
originally gave the function in our definitions block, and that will
be enough information for us to retrieve the definition when we need
to use it.
Here’s how it’ll look in context:
#+begin_src ocaml type value = | Number of int | Boolean of bool | Pair of (value * value) | Function of string
Let’s also make sure we can print out function values:
let rec string_of_value (v : value) : string = (* some cases elided ... *) | Function name -> "<function>"
Let’s go ahead and think about how our Call
case
should change.
Here’s where we’re starting:
let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with (* some cases elided ... *) | Call (f, args) when is_defn defns f -> let defn = get_defn defns f in if List.length args = List.length defn.args then let vals = args |> List.map (interp_exp defns env) in let fenv = List.combine defn.args vals |> Symtab.of_list in interp_exp defns fenv defn.body else raise (BadExpression exp) | Call _ -> raise (BadExpression exp)
Looks like there are a couple problems for our current purposes.
First, now that f
might be something other than the
function’s original name, we can’t just check if it
shows up in our defns
list via is_defn
.
We’d better evaluate it first. Then we can check if it shows
up in defns
. And of course, once we’ve gone ahead
and evaluated f
, we should make sure it’s
actually a function value. Once we make these changes, we end up
here:
let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with (* some cases elided ... *) | Call (f, args) -> ( let vals = args |> List.map (interp_exp defns env) in let fv = interp_exp defns env f in match fv with | Function name when is_defn defns name -> let defn = get_defn defns name in if List.length args = List.length defn.args then let fenv = List.combine defn.args vals |> Symtab.of_list in interp_exp defns fenv defn.body else raise (BadExpression exp) | _ -> raise (BadExpression exp) )
Looks good! The only problem is, when will we make one of these
Function
values in the first place? Nothing in our
interpreter actually outputs a Function
at this point.
Q: Given how we’ve used function values above, when do we need
to produce a function value as the output of our
interp_exp
function?
When we use a function by name, it’s going to look a lot like
using any other variable! But we won’t find it in our symbol
table. So in addition to our original Var
case (see
first Var
case below), we’ll want to add the
second Var
case in the snippet below.
let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value = match exp with (* some cases elided ... *) | Var var when Symtab.mem var env -> Symtab.find var env | Var var when is_defn defns var -> Function var
This gives us a case for situations where the name
isn’t in the symbol table but is the name of one
of the functions in our definition block. Once we know we’re
in that case, we can just make a Function
value with
the original name of the function.
Extending the compiler
Now we’re ready to extend the compiler! Let’s start by
deciding how to represent a function value at run time. As always,
we’ll need to figure out how to use a 64-bit slot to represent
our value. Fortunately, we know that the addresses of the various
instructions in our programs are 64 bits! Recall that we’re
already adding labels immediately before the assembly for all of our
functions— we use those labels when we Call
or
Jmp
to the functions. We’ll go ahead and use the
address of a function’s assembly code as our way to represent
the function at run time. As with our other types, we’ll have
to tag function values with a special tag that tells us what the
type is.
Let’s start with the usual preliminaries. We’ll add a
fn_tag
and a helper function
ensure_fn
that emits assembly for checking if an
operand is a function value.
let fn_tag = 0b110 let ensure_fn (op : operand) : directive list = [ Mov (Reg R8, op) ; And (Reg R8, Imm heap_mask) ; Cmp (Reg R8, Imm fn_tag) ; Jnz "error" ]
But what if the first instruction of a function is at an address
that doesn’t end with 0b000? To make sure we have those last
three bits available for tagging, we’ll have to make sure that
doesn’t happen. We’ll go ahead and use the
align
instruction to keep those last three bits clear.
Before, compile_defn
looked like this:
let compile_defn defns defn = let ftab = defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list in [Label (defn_label defn.name)] @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true @ [Ret]
We’ll just change that third to last line from
[Label (defn_label defn.name)]
to
[Align 8; Label (defn_label defn.name)]
.
Now it’s:
let compile_defn defns defn = let ftab = defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list in [Align 8; Label (defn_label defn.name)] @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true @ [Ret]
And of course we need to make some function values in the first place!
As in the interpreter, we just want to add an additional
Var
case inside compile_exp
:
let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with (* some cases elided ... *) | Var var when Symtab.mem var tab -> [Mov (Reg Rax, stack_address (Symtab.find var tab))] | Var var when is_defn defns var -> [ LeaLabel (Reg Rax, defn_label var) ; Or (Reg Rax, Imm fn_tag)]
Notice what we’re doing in that second Var
case.
First we use defn_label var
to get the string we use as
the label for a function named var. Next we use
LeaLabel
(which you’ve seen on your homeworks) to
get the address associated with that label into rax
.
Finally, we or it with the function tag, fn_tag
, to get
the final run-time representation of the function value.
Now all that’s left is to actually use some of these function values in some function calls!
Recall that we have two cases for Call
, one for when
the call is in tail position and one for when it’s not.
We’ll need to update both, but the updates will be similar.
First, we’ll want to remove the part of the
when
clause that checks if f
appears in
our defns
. This is the same thing we did in the
interpreter. (Remember, even though we’re still using the
fixed set of definitions in our programs’ definitions block,
we may be using them by a different name! For example, we may be
using a definition originally called f
, but we may have
let bound it to g
and called g
. So looking
for g
in our original definitions list won’t do
us any good.)
Relatedly, we won’t get to do our check for whether the number
of arguments matches at compile time anymore. Remember, at compile
time, we just don’t know which definition we’re actually
going to be calling. (Is this g
thing the function
originally called f
? The function originally called
mul2
? Something else?) If we hadn’t already done
error handling in prior class sessions, this would be a perfect time
to learn how to emit assembly that checks for errors! Since we have
already done error handling, we’ll go ahead and just remove
the checks for argument number matches. If we were going to make the
compiler properly, we’d definitely have to add the assembly
that checks for errors at run time!
Now we’re ready to make our big change. For each of the
Call
cases, we’ll add something like this:
@ compile_exp defns tab <next stack cell we can use> f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)]
This snippet first calls compile_exp
on the
expr
that represents our function. The emitted code
will put the run-time representation of the function into
rax
. Next we check that the value in
rax
is actually a function via ensure_fn
.
Finally, we strip off the fn_tag
so we can use the
value in rax
as an address. The only thing that will
vary in the is_tail
and not is_tail
cases
is what we use for <next stack cell we can use>
.
Our last change is to Call
in the
not is_tail
case and Jmp
in the
is_tail
case. Previously, we used
Call (defn_label f)
. This ran our OCaml code for
transforming from function name f
to the string label
we use for function f
. Now that we may not have the
original name for f
available in our call (remember
again, we may be calling it g
or h
!),
we’ll instead want to call whatever address we put into
rax
when we ran the assembly for evaluating
f
. So we’ll replace
Call (defn_label f)
with
ComputedCall (Reg Rax)
. Likewise, in the
is_tail
case, we’ll replace
Jmp (defn_label f)
with
ComputedJmp (Reg Rax)
.
Putting it all together, we get:
let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int) (exp : expr) (is_tail : bool) : directive list = match exp with (* some cases elided ... *) | Call (f, args) when not is_tail -> let stack_base = align_stack_index (stack_index + 8) in let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_base - (8 * (i + 2))) arg false @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)]) |> List.concat in compiled_args @ compile_exp defns tab (stack_base - (8 * (List.length args + 2))) f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)] @ [ Add (Reg Rsp, Imm stack_base) ; ComputedCall (Reg Rax) ; Sub (Reg Rsp, Imm stack_base) ] | Call (f, args) when is_tail -> let compiled_args = args |> List.mapi (fun i arg -> compile_exp defns tab (stack_index - (8 * i)) arg false @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)]) |> List.concat in let moved_args = args |> List.mapi (fun i _ -> [ Mov (Reg R8, stack_address (stack_index - (8 * i))) ; Mov (stack_address ((i + 1) * -8), Reg R8) ]) |> List.concat in compiled_args @ compile_exp defns tab (stack_index - (8 * List.length args)) f false @ ensure_fn (Reg Rax) @ [Sub (Reg Rax, Imm fn_tag)] @ moved_args @ [ComputedJmp (Reg Rax)]