diff --git a/src/model/dict.rs b/src/model/dict.rs index a5dbeeae2..e94db9232 100644 --- a/src/model/dict.rs +++ b/src/model/dict.rs @@ -54,11 +54,10 @@ impl Dict { } /// Mutably borrow the value the given `key` maps to. - /// - /// This inserts the key with [`None`](Value::None) as the value if not - /// present so far. - pub fn at_mut(&mut self, key: Str) -> &mut Value { - Arc::make_mut(&mut self.0).entry(key).or_default() + pub fn at_mut(&mut self, key: &str) -> StrResult<&mut Value> { + Arc::make_mut(&mut self.0) + .get_mut(key) + .ok_or_else(|| missing_key(key)) } /// Remove the value if the dictionary contains the given key. diff --git a/src/model/eval.rs b/src/model/eval.rs index 66ff2cd53..67c733cef 100644 --- a/src/model/eval.rs +++ b/src/model/eval.rs @@ -877,6 +877,18 @@ impl ast::Binary { op: fn(Value, Value) -> StrResult, ) -> SourceResult { let rhs = self.rhs().eval(vm)?; + let lhs = self.lhs(); + + // An assignment to a dictionary field is different from a normal access + // since it can create the field instead of just modifying it. + if self.op() == ast::BinOp::Assign { + if let ast::Expr::FieldAccess(access) = &lhs { + let dict = access.access_dict(vm)?; + dict.insert(access.field().take().into(), rhs); + return Ok(Value::None); + } + } + let location = self.lhs().access(vm)?; let lhs = std::mem::take(&mut *location); *location = op(lhs, rhs).at(self.span())?; @@ -898,14 +910,7 @@ impl Eval for ast::FieldAccess { .field(&field) .ok_or_else(|| format!("unknown field `{field}`")) .at(span)?, - Value::Module(module) => module - .scope() - .get(&field) - .cloned() - .ok_or_else(|| { - format!("module `{}` does not contain `{field}`", module.name()) - }) - .at(span)?, + Value::Module(module) => module.get(&field).cloned().at(span)?, v => bail!( self.target().span(), "expected dictionary or content, found {}", @@ -921,7 +926,8 @@ impl Eval for ast::FuncCall { fn eval(&self, vm: &mut Vm) -> SourceResult { let callee = self.callee(); let callee = callee.eval(vm)?.cast::().at(callee.span())?; - self.eval_with_callee(vm, callee) + let args = self.args().eval(vm)?; + Self::eval_call(vm, &callee, args, self.span()) } } @@ -929,7 +935,8 @@ impl ast::FuncCall { fn eval_in_math(&self, vm: &mut Vm) -> SourceResult { let callee = self.callee().eval(vm)?; if let Value::Func(callee) = callee { - Ok(self.eval_with_callee(vm, callee)?.display_in_math()) + let args = self.args().eval(vm)?; + Ok(Self::eval_call(vm, &callee, args, self.span())?.display_in_math()) } else { let mut body = (vm.items.math_atom)('('.into()); let mut args = self.args().eval(vm)?; @@ -944,14 +951,18 @@ impl ast::FuncCall { } } - fn eval_with_callee(&self, vm: &mut Vm, callee: Func) -> SourceResult { + fn eval_call( + vm: &mut Vm, + callee: &Func, + args: Args, + span: Span, + ) -> SourceResult { if vm.depth >= MAX_CALL_DEPTH { - bail!(self.span(), "maximum function call depth exceeded"); + bail!(span, "maximum function call depth exceeded"); } - let args = self.args().eval(vm)?; let point = || Tracepoint::Call(callee.name().map(Into::into)); - callee.call(vm, args).trace(vm.world, point, self.span()) + callee.call(vm, args).trace(vm.world, point, span) } } @@ -960,18 +971,35 @@ impl Eval for ast::MethodCall { fn eval(&self, vm: &mut Vm) -> SourceResult { let span = self.span(); - let method = self.method().take(); + let method = self.method(); let result = if methods::is_mutating(&method) { let args = self.args().eval(vm)?; let value = self.target().access(vm)?; + + if let Value::Module(module) = &value { + if let Value::Func(callee) = + module.get(&method).cloned().at(method.span())? + { + return ast::FuncCall::eval_call(vm, &callee, args, self.span()); + } + } + methods::call_mut(value, &method, args, span) } else { let value = self.target().eval(vm)?; let args = self.args().eval(vm)?; + + if let Value::Module(module) = &value { + if let Value::Func(callee) = module.get(&method).at(method.span())? { + return ast::FuncCall::eval_call(vm, callee, args, self.span()); + } + } + methods::call(vm, value, &method, args, span) }; + let method = method.take(); let point = || Tracepoint::Call(Some(method.clone())); result.trace(vm.world, point, span) } @@ -1423,16 +1451,20 @@ impl Access for ast::Parenthesized { impl Access for ast::FieldAccess { fn access<'a>(&self, vm: &'a mut Vm) -> SourceResult<&'a mut Value> { - let value = self.target().access(vm)?; - let Value::Dict(dict) = value else { - bail!( + self.access_dict(vm)?.at_mut(&self.field().take()).at(self.span()) + } +} + +impl ast::FieldAccess { + fn access_dict<'a>(&self, vm: &'a mut Vm) -> SourceResult<&'a mut Dict> { + match self.target().access(vm)? { + Value::Dict(dict) => Ok(dict), + value => bail!( self.target().span(), "expected dictionary, found {}", value.type_name(), - ); - }; - - Ok(dict.at_mut(self.field().take().into())) + ), + } } } diff --git a/src/model/methods.rs b/src/model/methods.rs index 3823df415..f80839f8e 100644 --- a/src/model/methods.rs +++ b/src/model/methods.rs @@ -153,6 +153,7 @@ pub fn call_mut( }, Value::Dict(dict) => match method { + "insert" => dict.insert(args.expect::("key")?, args.expect("value")?), "remove" => { output = dict.remove(&args.expect::("key")?).at(span)? } @@ -184,7 +185,7 @@ pub fn call_access<'a>( _ => return missing(), }, Value::Dict(dict) => match method { - "at" => dict.at_mut(args.expect("index")?), + "at" => dict.at_mut(&args.expect::("key")?).at(span)?, _ => return missing(), }, _ => return missing(), diff --git a/src/model/module.rs b/src/model/module.rs index c7cb7fafe..ba6c76fb0 100644 --- a/src/model/module.rs +++ b/src/model/module.rs @@ -2,8 +2,9 @@ use std::fmt::{self, Debug, Formatter}; use std::path::Path; use std::sync::Arc; -use super::{Content, Scope}; -use crate::util::EcoString; +use super::{Content, Scope, Value}; +use crate::diag::StrResult; +use crate::util::{format_eco, EcoString}; /// An evaluated module, ready for importing or typesetting. #[derive(Clone, Hash)] @@ -46,6 +47,13 @@ impl Module { &self.0.scope } + /// Try to access a definition in the module. + pub fn get(&self, name: &str) -> StrResult<&Value> { + self.scope().get(&name).ok_or_else(|| { + format_eco!("module `{}` does not contain `{name}`", self.name()) + }) + } + /// Extract the module's content. pub fn content(self) -> Content { match Arc::try_unwrap(self.0) { diff --git a/tests/typ/compiler/dict.typ b/tests/typ/compiler/dict.typ index f9a0b3695..5f00ef3ae 100644 --- a/tests/typ/compiler/dict.typ +++ b/tests/typ/compiler/dict.typ @@ -36,11 +36,11 @@ } --- -// Missing lvalue is automatically none-initialized. +// Missing lvalue is not automatically none-initialized. { let dict = (:) - dict.at("b") += 1 - test(dict, (b: 1)) + // Error: 3-9 dictionary does not contain key "b" + dict.b += 1 } --- diff --git a/tests/typ/compiler/import.typ b/tests/typ/compiler/import.typ index 6b2d80750..a657be503 100644 --- a/tests/typ/compiler/import.typ +++ b/tests/typ/compiler/import.typ @@ -38,7 +38,19 @@ // A module import without items. #import "module.typ" #test(module.b, 1) -#test((module.item)(1, 2), 3) +#test(module.item(1, 2), 3) +#test(module.push(2), 3) + +--- +// Edge case for module access that isn't fixed. +#import "module.typ" + +// Works because the method name isn't categorized as mutating. +#test((module,).at(0).item(1, 2), 3) + +// Doesn't work because of mutating name. +// Error: 2-11 cannot mutate a temporary value +{(module,).at(0).push()} --- // Who needs whitespace anyways? diff --git a/tests/typ/compiler/module.typ b/tests/typ/compiler/module.typ index b0a3fbf34..f06526779 100644 --- a/tests/typ/compiler/module.typ +++ b/tests/typ/compiler/module.typ @@ -7,6 +7,7 @@ #let d = 3 #let value = [hi] #let item(a, b) = a + b +#let push(a) = a + 1 #let fn = rect.with(fill: conifer, inset: 5pt) Some _includable_ text.