use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display}; use std::rc::Rc; use std::sync::LazyLock; use assert_matches::assert_matches; use common_macros::hash_map; use thiserror::Error; use crate::ast::*; use crate::obj::prelude::*; use crate::obj::BUILTINS; use crate::token::TokenKind; use crate::vm::*; pub type Result = std::result::Result>; //////////////////////////////////////////////////////////////////////////////// // LineNumber visitor //////////////////////////////////////////////////////////////////////////////// #[derive(Default)] struct LineNumber { lock_start: bool, start: usize, end: usize, } impl LineNumber { fn update_start(&mut self, start: usize) { if !self.lock_start { self.start = start; self.lock_start = true; } } fn update_end(&mut self, end: usize) { self.end = end; } } impl StmtVisitor for LineNumber { fn visit_expr_stmt(&mut self, stmt: &ExprStmt) -> Result<()> { stmt.expr.accept(self).unwrap(); Ok(()) } fn visit_assign_stmt(&mut self, stmt: &AssignStmt) -> Result<()> { self.update_start(stmt.lhs.line); stmt.rhs.accept(self).unwrap(); Ok(()) } fn visit_set_stmt(&mut self, stmt: &SetStmt) -> Result<()> { stmt.expr.accept(self).unwrap(); stmt.rhs.accept(self).unwrap(); Ok(()) } fn visit_block_stmt(&mut self, stmt: &BlockStmt) -> Result<()> { self.update_start(stmt.lbrace.line); self.update_end(stmt.rbrace.line); Ok(()) } fn visit_return_stmt(&mut self, stmt: &ReturnStmt) -> Result<()> { self.update_start(stmt.return_kw.line); self.update_end(stmt.return_kw.line); if let Some(expr) = stmt.expr.as_ref() { expr.accept(self).unwrap(); } Ok(()) } fn visit_if_stmt(&mut self, stmt: &IfStmt) -> Result<()> { self.update_start(stmt.if_kw.line); stmt.condition.accept(self).unwrap(); stmt.then_branch.accept(self).unwrap(); for stmt in &stmt.else_branch { stmt.accept(self).unwrap(); } Ok(()) } } impl ExprVisitor for LineNumber { fn visit_binary_expr(&mut self, expr: &BinaryExpr) -> Result<()> { expr.lhs.accept(self).unwrap(); expr.rhs.accept(self).unwrap(); Ok(()) } fn visit_unary_expr(&mut self, expr: &UnaryExpr) -> Result<()> { self.update_start(expr.op.line); expr.expr.accept(self).unwrap(); Ok(()) } fn visit_call_expr(&mut self, expr: &CallExpr) -> Result<()> { expr.expr.accept(self).unwrap(); self.update_end(expr.rparen.line); Ok(()) } fn visit_get_expr(&mut self, expr: &GetExpr) -> Result<()> { expr.expr.accept(self).unwrap(); self.update_end(expr.name.line); Ok(()) } fn visit_primary_expr(&mut self, expr: &PrimaryExpr) -> Result<()> { self.update_start(expr.token.line); self.update_end(expr.token.line); Ok(()) } fn visit_function_expr(&mut self, expr: &FunctionExpr) -> Result<()> { self.update_start(expr.lparen.line); self.update_end(expr.rbrace.line); Ok(()) } } fn expr_line_number(expr: &dyn Expr) -> LineRange { let mut line_number = LineNumber::default(); expr.accept(&mut line_number).unwrap(); (line_number.start, line_number.end) } fn stmt_line_number(stmt: &dyn Stmt) -> LineRange { let mut line_number = LineNumber::default(); stmt.accept(&mut line_number).unwrap(); (line_number.start, line_number.end) } //////////////////////////////////////////////////////////////////////////////// // LocalAssignCollector and LocalNameCollector //////////////////////////////////////////////////////////////////////////////// // TODO - reduce copy/paste stuff here? #[derive(Default)] struct LocalAssignCollector { names: HashSet, } impl LocalAssignCollector { fn collect(body: &Vec) -> HashSet { let mut collector = Self::default(); for stmt in body { stmt.accept(&mut collector).unwrap(); } collector.names } } impl StmtVisitor for LocalAssignCollector { fn visit_expr_stmt(&mut self, stmt: &ExprStmt) -> Result<()> { stmt.expr.accept(self)?; Ok(()) } fn visit_assign_stmt(&mut self, stmt: &AssignStmt) -> Result<()> { self.names.insert(stmt.lhs.text.to_string()); Ok(()) } fn visit_set_stmt(&mut self, stmt: &SetStmt) -> Result<()> { stmt.expr.accept(self)?; stmt.rhs.accept(self)?; Ok(()) } fn visit_block_stmt(&mut self, stmt: &BlockStmt) -> Result<()> { // we visit the block statement because even though it goes below the current "local" // scope, we're ultimately trying to get a list of ALL local names that are assigned to in // this scope. // TODO FIXME BUG this does create some weirdness, for example take this: // outer_function = () { // some_value = 1234 // inner_function = () { // { // # this is a local value because we're assigning to it // some_value = 5678 // } // # our local named "some_value" has gone out of scope, so hypothetically we // # should be using the "some_value" that was defined in the scope above us. // # however, since we're collecting local assignments in all blocks, this should // # error out as "unknown local 'some_value'" // println(some_value) // } // return inner_function // } // // Ideally, we would be checking nonlocals with every new scope layer, and every new block. // This is a pretty tough bug to solve with how things are set up right now. not sure how // we'll go about solving this one. for stmt in &stmt.stmts { stmt.accept(self)?; } Ok(()) } fn visit_return_stmt(&mut self, stmt: &ReturnStmt) -> Result<()> { if let Some(expr) = stmt.expr.as_ref() { expr.accept(self)?; } Ok(()) } fn visit_if_stmt(&mut self, stmt: &IfStmt) -> Result<()> { stmt.condition.accept(self)?; stmt.then_branch.accept(self)?; for stmt in &stmt.else_branch { stmt.accept(self)?; } Ok(()) } } impl ExprVisitor for LocalAssignCollector { fn visit_binary_expr(&mut self, expr: &BinaryExpr) -> Result<()> { expr.lhs.accept(self)?; expr.rhs.accept(self)?; Ok(()) } fn visit_unary_expr(&mut self, expr: &UnaryExpr) -> Result<()> { expr.expr.accept(self)?; Ok(()) } fn visit_call_expr(&mut self, expr: &CallExpr) -> Result<()> { expr.expr.accept(self)?; Ok(()) } fn visit_get_expr(&mut self, expr: &GetExpr) -> Result<()> { expr.expr.accept(self)?; Ok(()) } fn visit_primary_expr(&mut self, _expr: &PrimaryExpr) -> Result<()> { Ok(()) } fn visit_function_expr(&mut self, _expr: &FunctionExpr) -> Result<()> { // don't visit function expr, we're only collecting local assigns Ok(()) } } #[derive(Default)] struct LocalNameCollector { names: HashSet, } impl LocalNameCollector { fn collect(body: &Vec) -> HashSet { let mut collector = Self::default(); for stmt in body { stmt.accept(&mut collector).unwrap(); } collector.names } } impl StmtVisitor for LocalNameCollector { fn visit_expr_stmt(&mut self, stmt: &ExprStmt) -> Result<()> { stmt.expr.accept(self)?; Ok(()) } fn visit_assign_stmt(&mut self, stmt: &AssignStmt) -> Result<()> { stmt.rhs.accept(self)?; Ok(()) } fn visit_set_stmt(&mut self, stmt: &SetStmt) -> Result<()> { stmt.expr.accept(self)?; stmt.rhs.accept(self)?; Ok(()) } fn visit_block_stmt(&mut self, stmt: &BlockStmt) -> Result<()> { for stmt in &stmt.stmts { stmt.accept(self)?; } Ok(()) } fn visit_return_stmt(&mut self, stmt: &ReturnStmt) -> Result<()> { if let Some(expr) = stmt.expr.as_ref() { expr.accept(self)?; } Ok(()) } fn visit_if_stmt(&mut self, stmt: &IfStmt) -> Result<()> { stmt.condition.accept(self)?; stmt.then_branch.accept(self)?; for stmt in &stmt.else_branch { stmt.accept(self)?; } Ok(()) } } impl ExprVisitor for LocalNameCollector { fn visit_binary_expr(&mut self, expr: &BinaryExpr) -> Result<()> { expr.lhs.accept(self)?; expr.rhs.accept(self)?; Ok(()) } fn visit_unary_expr(&mut self, expr: &UnaryExpr) -> Result<()> { expr.expr.accept(self)?; Ok(()) } fn visit_call_expr(&mut self, expr: &CallExpr) -> Result<()> { expr.expr.accept(self)?; Ok(()) } fn visit_get_expr(&mut self, expr: &GetExpr) -> Result<()> { expr.expr.accept(self)?; Ok(()) } fn visit_primary_expr(&mut self, expr: &PrimaryExpr) -> Result<()> { if expr.token.kind == TokenKind::Name { self.names.insert(expr.token.text.to_string()); } Ok(()) } fn visit_function_expr(&mut self, _expr: &FunctionExpr) -> Result<()> { // don't visit function expr, we're only collecting local assigns Ok(()) } } //////////////////////////////////////////////////////////////////////////////// // Misc //////////////////////////////////////////////////////////////////////////////// fn unescape(s: &str) -> String { s.chars() .skip(1) .take(s.len() - 2) // first and last chars are guaranteed to be 1 byte long .collect::() .replace("\\n", "\n") .replace("\\r", "\r") .replace("\\t", "\t") .replace("\\\"", "\"") .replace("\\\'", "\'") .replace("\\\\", "\\") } //////////////////////////////////////////////////////////////////////////////// // Scope //////////////////////////////////////////////////////////////////////////////// #[derive(Debug, PartialEq)] enum ScopeKind { Local, Function, //Class, } #[derive(Debug)] struct Scope { kind: ScopeKind, scope: Vec, } impl Scope { pub fn new(kind: ScopeKind) -> Self { Self { kind, scope: Default::default(), } } } //////////////////////////////////////////////////////////////////////////////// // CompileError //////////////////////////////////////////////////////////////////////////////// #[derive(Error, Debug)] pub struct CompileError { pub line: Option, pub message: String, } impl Display for CompileError { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { if let Some(line) = &self.line { write!(fmt, "line {:?}: {}", line, self.message) } else { write!(fmt, "{}", self.message) } } } //////////////////////////////////////////////////////////////////////////////// // Compiler //////////////////////////////////////////////////////////////////////////////// #[derive(Debug)] pub struct Compiler { chunks: Vec, scopes: Vec, constants: Vec, globals: Vec, } impl Compiler { pub fn new() -> Self { Compiler { chunks: Default::default(), scopes: Default::default(), constants: Default::default(), globals: BUILTINS .with_borrow(|builtins| builtins.keys().map(ToString::to_string).collect()), } } fn chunk(&self) -> &Chunk { self.chunks.last().expect("no chunk") } fn chunk_mut(&mut self) -> &mut Chunk { self.chunks.last_mut().expect("no chunk") } fn scope(&self) -> &Scope { self.scopes.last().expect("no scope") } fn scope_mut(&mut self) -> &mut Scope { self.scopes.last_mut().expect("no scope") } fn is_global_scope(&self) -> bool { self.scopes.is_empty() } /// Compiles a body of code. /// /// This returns a tuple of `Chunk`, the constants table, and the list of globals. pub fn compile(mut self, body: &Vec) -> Result<(Chunk, Vec, Vec)> { self.chunks.push(Chunk::default()); for stmt in body { self.compile_stmt(stmt)?; } // add halt instruction with last line, if any let mut last_line = (0, 0); if let Some(last) = body.last() { last_line = stmt_line_number(last.as_ref()); } self.emit(last_line, Op::Halt); let chunk = self.chunks.pop().expect("no chunk"); Ok((chunk, self.constants, self.globals)) } fn compile_stmt(&mut self, stmt: &StmtP) -> Result<()> { stmt.accept(self) } fn compile_expr(&mut self, expr: &ExprP) -> Result<()> { expr.accept(self) } fn insert_constant(&mut self, constant: ObjP) -> Result { // simple interning - try to find a constant that is exactly equal to this one and just // return its value instead for (index, interned) in self.constants.iter().enumerate() { if constant.borrow().equals(&*interned.borrow()) { return Ok(index as ConstantId); } } let index = self.constants.len(); if index > (ConstantId::MAX as usize) { return Err(CompileError { line: None, message: format!("too many constants (maximum {})", ConstantId::MAX), } .into()); } // convert this to a pointer, upcast, and then re-GC self.constants.push(constant); Ok(index as ConstantId) } fn get_global(&self, name: &str) -> Option { self.globals .iter() .position(|global| global == &name) .map(|id| id as GlobalId) } fn insert_global(&mut self, name: &str) -> Result { if let Some(id) = self.get_global(name) { return Ok(id); } let index = self.globals.len(); if index > (GlobalId::MAX as usize) { return Err(CompileError { line: None, message: format!("too many globals (maximum {})", GlobalId::MAX), } .into()); } self.globals.push(name.to_string()); Ok(index as GlobalId) } /// Get a nonlocal binding to a variable. /// /// This will return how many stack frames up we should look for this nonlocal, the `Local` /// that defines this binding. fn get_nonlocal(&self, name: &str) -> Option<(FrameDepth, &Local)> { let mut is_local = true; let mut depth = 0; for scope in self.scopes.iter().rev() { if scope.kind == ScopeKind::Function { // no longer inside the local scope if is_local { is_local = false; continue; } // increase stack frame search depth += 1; } // skip local variables if is_local { continue; } // outside of the local scope, check if we hvae defined the sought-after name for local in &scope.scope { if local.name == name { return Some((depth, local)); } } } None } fn get_local(&self, name: &str) -> Option<&Local> { for scope in self.scopes.iter().rev() { for local in &scope.scope { if local.name == name { return Some(local); } } if scope.kind == ScopeKind::Function { break; } } None } fn insert_local(&mut self, name: String) -> Result<&Local> { let index = self.chunk().locals.len(); if index > (LocalIndex::MAX as usize) { return Err(CompileError { line: None, message: format!("too many locals (maximum: {})", LocalIndex::MAX), } .into()); } let mut local = Local { slot: 0, index: index as LocalIndex, name, }; // get the last allocated slot for scope in self.scopes.iter().rev() { if scope.scope.len() == 0 { if scope.kind == ScopeKind::Function { // don't go above the current function's scope (which was just determined to be // empty) break; } continue; } // get the last allocated slot and increment by one let last = &scope.scope.last().unwrap(); if last.slot == LocalSlot::MAX { return Err(CompileError { line: None, message: format!( "too many stack slots used by locals(maximum: {})", LocalSlot::MAX ), } .into()); } local.slot = last.slot + 1; break; } self.scope_mut().scope.push(local.clone()); self.chunk_mut().locals.push(local); Ok(self.scope().scope.last().unwrap()) } fn begin_scope(&mut self, kind: ScopeKind) { self.scopes.push(Scope::new(kind)); } fn end_scope(&mut self, line: LineRange) { let scope = self.scopes.pop().expect("no scope"); for _local in scope.scope { self.emit(line, Op::Pop); } } fn emit(&mut self, line: LineRange, op: Op) { let chunk = self.chunk_mut(); chunk.code.push(op); chunk.lines.push(line); } } impl StmtVisitor for Compiler { fn visit_expr_stmt(&mut self, stmt: &ExprStmt) -> Result<()> { self.compile_expr(&stmt.expr)?; self.emit(stmt_line_number(stmt), Op::Pop); Ok(()) } fn visit_assign_stmt(&mut self, stmt: &AssignStmt) -> Result<()> { let name = &stmt.lhs.text; if self.is_global_scope() { let global = self.insert_global(name)?; self.compile_expr(&stmt.rhs)?; self.emit(stmt_line_number(stmt), Op::SetGlobal(global)); } else { let mut declare = false; let local = if let Some(local) = self.get_local(name) { local } else { declare = true; self.insert_local(name.to_string())? } .clone(); // gotta clone so we can borrow self as mutable for compile_expr self.compile_expr(&stmt.rhs)?; if !declare { self.emit(stmt_line_number(stmt), Op::SetLocal(local.index)); } } // If the last value that was assigned to is a function, set its name here // TODO - maybe this would be smarter to set up in the AST. I'm 99% sure that the last // object created, if it were a function object, will be what we're assigning it to, but I // want to be 100% sure instead of 99%. let obj = self.constants.last().unwrap().as_ref(); if let Some(fun) = obj.borrow_mut().as_any_mut().downcast_mut::() { fun.set_name(Rc::new(name.to_string())); } Ok(()) } fn visit_set_stmt(&mut self, stmt: &SetStmt) -> Result<()> { self.compile_expr(&stmt.expr)?; let name = self.insert_constant(Str::create(&stmt.name.text))?; self.compile_expr(&stmt.rhs)?; self.emit(stmt_line_number(stmt), Op::SetAttr(name)); Ok(()) } fn visit_block_stmt(&mut self, stmt: &BlockStmt) -> Result<()> { self.begin_scope(ScopeKind::Local); for s in &stmt.stmts { self.compile_stmt(s)?; } self.end_scope((stmt.rbrace.line, stmt.rbrace.line)); Ok(()) } fn visit_return_stmt(&mut self, stmt: &ReturnStmt) -> Result<()> { if let Some(expr) = &stmt.expr { self.compile_expr(expr)?; } else { let nil = self.insert_constant(Nil::create())?; self.emit(stmt_line_number(stmt), Op::PushConstant(nil)); } self.emit(stmt_line_number(stmt), Op::Return); Ok(()) } fn visit_if_stmt(&mut self, stmt: &IfStmt) -> Result<()> { // condition self.compile_expr(&stmt.condition)?; // call obj.to_bool() let bool_attr = self.insert_constant(Str::create("to_bool"))?; self.emit(expr_line_number(&*stmt.condition), Op::GetAttr(bool_attr)); self.emit(expr_line_number(&*stmt.condition), Op::Call(0)); let condition_patch_index = self.chunk().code.len(); self.emit(expr_line_number(&*stmt.condition), Op::JumpFalse(0)); // then branch // pop the condition on top of the stack (no jump taken) self.emit(expr_line_number(&*stmt.condition), Op::Pop); // not using compile_stmt because then_branch isn't a pointer, it's an honest-to-goodness // value stmt.then_branch.accept(self)?; let exit_patch_index = self.chunk().code.len(); self.emit(stmt_line_number(&stmt.then_branch), Op::Jump(0)); // else branch // patch the condition index - this is where the JUMP_FALSE will jump to assert_matches!(self.chunk().code[condition_patch_index], Op::JumpFalse(_)); let offset = self.chunk().code.len() - condition_patch_index; assert!( offset <= (JumpOpArg::MAX as usize), "jump offset too large between lines {:?} - this is a compiler limitation, sorry", stmt_line_number(&stmt.then_branch) ); self.chunk_mut().code[condition_patch_index] = Op::JumpFalse(offset as JumpOpArg); // pop the condition on top of the stack (jump taken) self.emit(expr_line_number(&*stmt.condition), Op::Pop); for s in &stmt.else_branch { self.compile_stmt(s)?; } // patch the "then" branch exit jump address - this is where Op::Jump will jump to. // TODO : see if we can eliminate duplicates by checking the last two instructions assert_matches!(self.chunk().code[exit_patch_index], Op::Jump(_)); let offset = self.chunk().code.len() - condition_patch_index; assert!( offset <= (JumpOpArg::MAX as usize), "jump offset too large between lines {:?} - this is a compiler limitation, sorry", stmt_line_number(&stmt.then_branch) ); self.chunk_mut().code[exit_patch_index] = Op::Jump(offset as JumpOpArg); Ok(()) } } impl ExprVisitor for Compiler { fn visit_binary_expr(&mut self, expr: &BinaryExpr) -> Result<()> { static OP_NAMES: LazyLock> = LazyLock::new(|| { hash_map! { TokenKind::Plus => "__add__", TokenKind::Minus => "__sub__", TokenKind::Star => "__mul__", TokenKind::Slash => "__div__", TokenKind::And => "__and__", TokenKind::Or => "__or__", TokenKind::BangEq => "__ne__", TokenKind::EqEq => "__eq__", TokenKind::Greater => "__gt__", TokenKind::GreaterEq => "__ge__", TokenKind::Less => "__lt__", TokenKind::LessEq => "__le__", } }); self.compile_expr(&expr.lhs)?; // short-circuit setup let mut exit_patch_index = 0; if let TokenKind::And | TokenKind::Or = expr.op.kind { let constant_id = self.insert_constant(Str::create("to_bool"))?; self.emit(expr_line_number(&*expr.lhs), Op::GetAttr(constant_id)); self.emit(expr_line_number(&*expr.lhs), Op::Call(0)); exit_patch_index = self.chunk().code.len(); if expr.op.kind == TokenKind::And { self.emit((expr.op.line, expr.op.line), Op::JumpFalse(0)); } else { self.emit((expr.op.line, expr.op.line), Op::JumpTrue(0)); } } let name = OP_NAMES .get(&expr.op.kind) .expect("invalid binary operator"); let constant_id = self.insert_constant(Str::create(name))?; self.emit(expr_line_number(expr), Op::GetAttr(constant_id)); self.compile_expr(&expr.rhs)?; // convert RHS to a bool if we're doing AND or OR if let TokenKind::And | TokenKind::Or = expr.op.kind { let constant_id = self.insert_constant(Str::create("to_bool"))?; self.emit(expr_line_number(&*expr.rhs), Op::GetAttr(constant_id)); self.emit(expr_line_number(&*expr.rhs), Op::Call(0)); } // call operator function self.emit(expr_line_number(expr), Op::Call(1)); // patch exit if we're doing a short circuit if exit_patch_index != 0 { assert_matches!( self.chunk().code[exit_patch_index], Op::JumpTrue(_) | Op::JumpFalse(_) ); let offset = self.chunk().code.len() - exit_patch_index; // don't worry about doing a check on if offset is small enough for JumpOpArg, if you // have 4 billion instructions between jumps that is probably your own fault let new_op = match self.chunk().code[exit_patch_index] { Op::JumpTrue(_) => Op::JumpTrue(offset as JumpOpArg), Op::JumpFalse(_) => Op::JumpFalse(offset as JumpOpArg), _ => unreachable!(), }; self.chunk_mut().code[exit_patch_index] = new_op; } Ok(()) } fn visit_unary_expr(&mut self, expr: &UnaryExpr) -> Result<()> { static OP_NAMES: LazyLock> = LazyLock::new(|| { hash_map! { TokenKind::Plus => "__pos__", TokenKind::Minus => "__neg__", TokenKind::Bang => "__not__", } }); self.compile_expr(&expr.expr)?; let name = OP_NAMES.get(&expr.op.kind).expect("invalid unary operator"); let constant_id = self.insert_constant(Str::create(name))?; self.emit(expr_line_number(expr), Op::GetAttr(constant_id)); self.emit(expr_line_number(expr), Op::Call(0)); Ok(()) } fn visit_call_expr(&mut self, expr: &CallExpr) -> Result<()> { self.compile_expr(&expr.expr)?; for arg in &expr.args { self.compile_expr(arg)?; } if expr.args.len() > (Argc::MAX as usize) { return Err(CompileError { line: Some(expr_line_number(expr)), message: format!("too many function arguments (maximum: {})", Argc::MAX), } .into()); } self.emit(expr_line_number(expr), Op::Call(expr.args.len() as Argc)); Ok(()) } fn visit_get_expr(&mut self, expr: &GetExpr) -> Result<()> { self.compile_expr(&expr.expr)?; let constant_id = self.insert_constant(Str::create(&expr.name.text))?; self.emit(expr_line_number(expr), Op::GetAttr(constant_id)); Ok(()) } fn visit_primary_expr(&mut self, expr: &PrimaryExpr) -> Result<()> { match expr.token.kind { TokenKind::Name => { let name = &expr.token.text; // check if there's a local with this name, otherwise check globals if let Some(local) = self.get_local(name) { self.emit(expr_line_number(expr), Op::GetLocal(local.index)); } else { let global = self.get_global(name).ok_or_else(|| CompileError { line: Some(expr_line_number(expr)), message: if self.is_global_scope() { format!("unknown global {}", name) } else { format!("unknown local {}", name) }, })?; self.emit(expr_line_number(expr), Op::GetGlobal(global)); } } TokenKind::Number => { let obj = if expr.token.text.contains('.') { Float::create(expr.token.text.parse().unwrap()) } else if expr.token.text.starts_with("0x") || expr.token.text.starts_with("0X") { Int::create(i64::from_str_radix(&expr.token.text[2..], 16).unwrap()) } else if expr.token.text.starts_with("0b") || expr.token.text.starts_with("0B") { Int::create(i64::from_str_radix(&expr.token.text[2..], 2).unwrap()) } else { Int::create(expr.token.text.parse().unwrap()) }; let constant_id = self.insert_constant(obj)?; self.emit(expr_line_number(expr), Op::PushConstant(constant_id)); } TokenKind::String => { let constant_id = self.insert_constant(Str::create(unescape(&expr.token.text)))?; self.emit(expr_line_number(expr), Op::PushConstant(constant_id)); } TokenKind::True | TokenKind::False => { let constant_id = self.insert_constant(Bool::create(expr.token.kind == TokenKind::True))?; self.emit(expr_line_number(expr), Op::PushConstant(constant_id)); } TokenKind::Nil => { let constant_id = self.insert_constant(Nil::create())?; self.emit(expr_line_number(expr), Op::PushConstant(constant_id)); } _ => unreachable!(), } Ok(()) } fn visit_function_expr(&mut self, expr: &FunctionExpr) -> Result<()> { let end_line = (expr.rbrace.line, expr.rbrace.line); self.begin_scope(ScopeKind::Function); self.chunks.push(Chunk::default()); let mut locals: HashSet = Default::default(); for (param, _ty) in &expr.params { // register all params as locals locals.insert(param.text.to_string()); // also insert them as locals in the scope self.insert_local(param.text.to_string())?; } // closures: figure out all other locals that are assigned to in the function for local in LocalAssignCollector::collect(&expr.body) { locals.insert(local); } // figure out all nonlocals being used, and then re-register them as locals // when a user function is called, all values of the nonlocal are pushed to the top of the // stack on top of the function parameters. let all_names = LocalNameCollector::collect(&expr.body); // these are the nonlocals that we're copying/re-registering as locals let mut captures: HashMap = Default::default(); let mut nonlocals: HashMap = Default::default(); for name in &all_names { // already registered as a local if locals.contains(name) { continue; } // already captured if captures.contains_key(name) { continue; } if let Some((depth, nonlocal)) = self.get_nonlocal(name) { let nonlocal = nonlocal.clone(); nonlocals.insert(name.to_string(), (depth, nonlocal)); captures.insert( name.to_string(), self.insert_local(name.to_string())?.clone(), ); } } // compile body for stmt in &expr.body { self.compile_stmt(stmt)?; } // always end with a "return nil" let nil = self.insert_constant(Nil::create())?; self.emit(end_line, Op::PushConstant(nil)); self.emit(end_line, Op::Return); self.end_scope(end_line); // create the function let chunk = self.chunks.pop().unwrap(); let fun = UserFunction::create(chunk, expr.params.len() as Argc); // register the function as a constant let fun_constant = self.insert_constant(fun)?; self.emit(expr_line_number(expr), Op::PushConstant(fun_constant)); // close over the captured values for (depth, local) in nonlocals.values() { self.emit( expr_line_number(expr), Op::CloseOver { depth: *depth, slot: local.slot, }, ); } Ok(()) } }