Important: These notes are not designed to subsitute for class sessions. Importantly, they do not cover all content required for the exams or for implementing HWs. These are just an aid for re-implementing the class compiler, after you've already come to class and completed the in-class activities!

First-class functions

Most functional programming languages (really, most modern programming languages of all stripes) support some form of first-class functions. Right now, functions in our language are totally separate from values. We can define them and call them, but we can’t:

Here are some programs, written in a slightly extended version of our language, that use first-class functions:

(define (f g) (g 2))
(define (mul2 x) (+ x x))
(print (f mul2))



(define (f g) (g 2))
(print (f (lambda (x) (+ x x))))


(define (f g) (g 2))
(let ((y 3)) (print (f (lambda (x) (+ x y)))))

Today we’ll add support for the feature you see in the first program– we’ll be able to pass functions around like other values. Next time we’ll add the features we need to support the second two programs. Notice what’s different about programs two and three. In the second, we’re making an anoymous function, (lambda (x) (+ x x)). In the third, we’re making an anonymous function that references a variable y that’s let bound outside of the function. We could jump all the way to supporting all three of these programs, but this is actually pretty advanced stuff! Most undergrad compilers classes never get around to supporting language features like first-class functions. So we’ll break it down into stages.

Q: Before you scroll down, take a moment to think about this question. Will we need to change our AST in order to support this change to our language? If yes, how?

Extending the AST

Now that we have an AST for our language, we’ll need to extend it in order to support calling function values that aren’t literal function names. The change is pretty minor.

Before, the AST constructor for calls looked like this: Call of string * expr list

Now we’ll use the version that appears below:

type expr =
  | Prim0 of prim0
  | Prim1 of prim1 * expr
  | Prim2 of prim2 * expr * expr
  | Let of string * expr * expr
  | If of expr * expr * expr
  | Do of expr list
  | Num of int
  | Var of string
  | Call of expr * expr list
  | True
  | False

We’ll also want to make a change to our helper function that translates S-expression ASTs into ASTs for our language. Before the call case looked like this:

let rec expr_of_s_exp : s_exp -> expr = function
  (* some cases elided ... *)
  | Lst (Sym f :: args) ->
       Call (f, List.map expr_of_s_exp args)

Now we’ll want it to look like this:

let rec expr_of_s_exp : s_exp -> expr = function
  (* some cases elided ... *)
  | Lst (f :: args) ->
       Call (expr_of_s_exp f, List.map expr_of_s_exp args)

Notice that f is no longer a Sym—instead it’s an s_exp. So we also have to make a recursive call to expr_of_s_exp in order to get the appropriate representation of f into our AST.

Extending the interpreter

Now we’re ready to extend the interpreter. Consider our current value type:

 type value =
| Number of int
| Boolean of bool
| Pair of (value * value)

Q: What will we need to add, if anything, to our value type in order to support first-class functions?

You could consider a couple options. We could add Function of string list * expr, and then our function values could carry around all the same information that our defn type carries around. Any time we encounter a given function value in a call, we could go ahead and do the same thing we do when we’ve looked a funtion up (by name) in our list of defns. (Evaluate the arguments, make a new symbol table based on those arguments, then evaluate the body with the new symbol table.) We’d just be skipping over the part where we look it up in defns.

We could add Function of (value list -> value). We’ve had great luck in the past with building on top of OCaml’s own features when we want to implement features in the interpreter. Why not use OCaml functions to implement functions in our language?

We’re going to pick an alternative that will make the structure of the interpreter match a bit more closely with the structure of the compiler. But that doesn’t mean those other options are bad! We’re just picking one of many reasonable ways to implement this. We’ll add Function of string.

Q: How is this going to be enough information for us to implement first-class functions?

This is going to be enough information becuase we’ll just use this string to look up the appropriate definition in our list of defns! Even though we’re no longer demanding a literal function name in the first position in every Call S-expression, we haven’t introduced anonymous functions yet. So we can go ahead and evaluate the S-expression in the first position, figure out what name literal we originally gave the function in our definitions block, and that will be enough information for us to retrieve the definition when we need to use it.

Here’s how it’ll look in context:

 #+begin_src ocaml
 type value =
| Number of int
| Boolean of bool
| Pair of (value * value)
| Function of string

Let’s also make sure we can print out function values:

let rec string_of_value (v : value) : string =
  (* some cases elided ... *)
  | Function name ->
      "<function>"

Let’s go ahead and think about how our Call case should change.

Here’s where we’re starting:

let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value
    =
  match exp with
  (* some cases elided ... *)
   | Call (f, args) when is_defn defns f ->
       let defn = get_defn defns f in
       if List.length args = List.length defn.args then
	 let vals = args |> List.map (interp_exp defns env) in
	 let fenv = List.combine defn.args vals |> Symtab.of_list in
	 interp_exp defns fenv defn.body
       else raise (BadExpression exp)
   | Call _ ->
       raise (BadExpression exp)

Looks like there are a couple problems for our current purposes. First, now that f might be something other than the function’s original name, we can’t just check if it shows up in our defns list via is_defn. We’d better evaluate it first. Then we can check if it shows up in defns. And of course, once we’ve gone ahead and evaluated f, we should make sure it’s actually a function value. Once we make these changes, we end up here:

let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value
    =
  match exp with
  (* some cases elided ... *)
  | Call (f, args) -> (
       let vals = args |> List.map (interp_exp defns env) in
       let fv = interp_exp defns env f in
       match fv with
       | Function name when is_defn defns name ->
	   let defn = get_defn defns name in
	   if List.length args = List.length defn.args then
	     let fenv = List.combine defn.args vals |> Symtab.of_list in
	     interp_exp defns fenv defn.body
	   else raise (BadExpression exp)
       | _ ->
	   raise (BadExpression exp) )

Looks good! The only problem is, when will we make one of these Function values in the first place? Nothing in our interpreter actually outputs a Function at this point.

Q: Given how we’ve used function values above, when do we need to produce a function value as the output of our interp_exp function?

When we use a function by name, it’s going to look a lot like using any other variable! But we won’t find it in our symbol table. So in addition to our original Var case (see first Var case below), we’ll want to add the second Var case in the snippet below.

let rec interp_exp (defns : defn list) (env : value symtab) (exp : expr) : value
    =
  match exp with
  (* some cases elided ... *)
  | Var var when Symtab.mem var env ->
      Symtab.find var env
  | Var var when is_defn defns var ->
      Function var

This gives us a case for situations where the name isn’t in the symbol table but is the name of one of the functions in our definition block. Once we know we’re in that case, we can just make a Function value with the original name of the function.

Extending the compiler

Now we’re ready to extend the compiler! Let’s start by deciding how to represent a function value at run time. As always, we’ll need to figure out how to use a 64-bit slot to represent our value. Fortunately, we know that the addresses of the various instructions in our programs are 64 bits! Recall that we’re already adding labels immediately before the assembly for all of our functions— we use those labels when we Call or Jmp to the functions. We’ll go ahead and use the address of a function’s assembly code as our way to represent the function at run time. As with our other types, we’ll have to tag function values with a special tag that tells us what the type is.

Let’s start with the usual preliminaries. We’ll add a fn_tag and a helper function ensure_fn that emits assembly for checking if an operand is a function value.

let fn_tag = 0b110

let ensure_fn (op : operand) : directive list =
  [ Mov (Reg R8, op)
  ; And (Reg R8, Imm heap_mask)
  ; Cmp (Reg R8, Imm fn_tag)
  ; Jnz "error" ]

But what if the first instruction of a function is at an address that doesn’t end with 0b000? To make sure we have those last three bits available for tagging, we’ll have to make sure that doesn’t happen. We’ll go ahead and use the align instruction to keep those last three bits clear.

Before, compile_defn looked like this:

let compile_defn defns defn =
  let ftab =
    defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list
  in
  [Label (defn_label defn.name)]
  @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true
  @ [Ret]

We’ll just change that third to last line from [Label (defn_label defn.name)] to [Align 8; Label (defn_label defn.name)].

Now it’s:

let compile_defn defns defn =
  let ftab =
    defn.args |> List.mapi (fun i arg -> (arg, -8 * (i + 1))) |> Symtab.of_list
  in
  [Align 8; Label (defn_label defn.name)]
  @ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true
  @ [Ret]

And of course we need to make some function values in the first place!

As in the interpreter, we just want to add an additional Var case inside compile_exp:

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) (is_tail : bool) : directive list =
  match exp with
  (* some cases elided ... *)

  | Var var when Symtab.mem var tab ->
      [Mov (Reg Rax, stack_address (Symtab.find var tab))]
  | Var var when is_defn defns var ->
      [ LeaLabel (Reg Rax, defn_label var)
	; Or (Reg Rax, Imm fn_tag)]

Notice what we’re doing in that second Var case. First we use defn_label var to get the string we use as the label for a function named var. Next we use LeaLabel (which you’ve seen on your homeworks) to get the address associated with that label into rax. Finally, we or it with the function tag, fn_tag, to get the final run-time representation of the function value.

Now all that’s left is to actually use some of these function values in some function calls!

Recall that we have two cases for Call, one for when the call is in tail position and one for when it’s not. We’ll need to update both, but the updates will be similar.

First, we’ll want to remove the part of the when clause that checks if f appears in our defns. This is the same thing we did in the interpreter. (Remember, even though we’re still using the fixed set of definitions in our programs’ definitions block, we may be using them by a different name! For example, we may be using a definition originally called f, but we may have let bound it to g and called g. So looking for g in our original definitions list won’t do us any good.)

Relatedly, we won’t get to do our check for whether the number of arguments matches at compile time anymore. Remember, at compile time, we just don’t know which definition we’re actually going to be calling. (Is this g thing the function originally called f? The function originally called mul2? Something else?) If we hadn’t already done error handling in prior class sessions, this would be a perfect time to learn how to emit assembly that checks for errors! Since we have already done error handling, we’ll go ahead and just remove the checks for argument number matches. If we were going to make the compiler properly, we’d definitely have to add the assembly that checks for errors at run time!

Now we’re ready to make our big change. For each of the Call cases, we’ll add something like this:

@ compile_exp defns tab <next stack cell we can use> f false
@ ensure_fn (Reg Rax)
@ [Sub (Reg Rax, Imm fn_tag)]

This snippet first calls compile_exp on the expr that represents our function. The emitted code will put the run-time representation of the function into rax. Next we check that the value in rax is actually a function via ensure_fn. Finally, we strip off the fn_tag so we can use the value in rax as an address. The only thing that will vary in the is_tail and not is_tail cases is what we use for <next stack cell we can use>.

Our last change is to Call in the not is_tail case and Jmp in the is_tail case. Previously, we used Call (defn_label f). This ran our OCaml code for transforming from function name f to the string label we use for function f. Now that we may not have the original name for f available in our call (remember again, we may be calling it g or h!), we’ll instead want to call whatever address we put into rax when we ran the assembly for evaluating f. So we’ll replace Call (defn_label f) with ComputedCall (Reg Rax). Likewise, in the is_tail case, we’ll replace Jmp (defn_label f) with ComputedJmp (Reg Rax).

Putting it all together, we get:

let rec compile_exp (defns : defn list) (tab : int symtab) (stack_index : int)
    (exp : expr) (is_tail : bool) : directive list =
  match exp with
  (* some cases elided ... *)

  | Call (f, args) when not is_tail ->
      let stack_base = align_stack_index (stack_index + 8) in
      let compiled_args =
	args
	|> List.mapi (fun i arg ->
	       compile_exp defns tab (stack_base - (8 * (i + 2))) arg false
	       @ [Mov (stack_address (stack_base - (8 * (i + 2))), Reg Rax)])
	|> List.concat
      in
      compiled_args
      @ compile_exp defns tab
	  (stack_base - (8 * (List.length args + 2)))
	  f false
      @ ensure_fn (Reg Rax)
      @ [Sub (Reg Rax, Imm fn_tag)]
      @ [ Add (Reg Rsp, Imm stack_base)
	; ComputedCall (Reg Rax)
	; Sub (Reg Rsp, Imm stack_base) ]

  | Call (f, args) when is_tail ->
      let compiled_args =
	args
	|> List.mapi (fun i arg ->
	       compile_exp defns tab (stack_index - (8 * i)) arg false
	       @ [Mov (stack_address (stack_index - (8 * i)), Reg Rax)])
	|> List.concat
      in
      let moved_args =
	args
	|> List.mapi (fun i _ ->
	       [ Mov (Reg R8, stack_address (stack_index - (8 * i)))
	       ; Mov (stack_address ((i + 1) * -8), Reg R8) ])
	|> List.concat
      in
      compiled_args
      @ compile_exp defns tab (stack_index - (8 * List.length args)) f false
      @ ensure_fn (Reg Rax)
      @ [Sub (Reg Rax, Imm fn_tag)]
      @ moved_args @ [ComputedJmp (Reg Rax)]