Function scopes (#1032)

This commit is contained in:
Pg Biel 2023-05-03 09:20:53 -03:00 committed by GitHub
parent db6a710638
commit f88ef45ee6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 383 additions and 52 deletions

View File

@ -88,6 +88,9 @@ pub fn panic(
/// Fails with an error if the condition is not fulfilled. Does not
/// produce any output in the document.
///
/// If you wish to test equality between two values, see
/// [`assert.eq`]($func/assert.eq) and [`assert.ne`]($func/assert.ne).
///
/// ## Example
/// ```typ
/// #assert(1 < 2, message: "math broke")
@ -97,6 +100,11 @@ pub fn panic(
/// Category: foundations
/// Returns:
#[func]
#[scope(
scope.define("eq", assert_eq);
scope.define("ne", assert_ne);
scope
)]
pub fn assert(
/// The condition that must be true for the assertion to pass.
condition: bool,
@ -115,6 +123,90 @@ pub fn assert(
Value::None
}
/// Ensure that two values are equal.
///
/// Fails with an error if the first value is not equal to the second. Does not
/// produce any output in the document.
///
/// ## Example
/// ```example
/// #assert.eq(10, 10)
/// ```
///
/// Display: Assert Equals
/// Category: foundations
/// Returns:
#[func]
pub fn assert_eq(
/// The first value to compare.
left: Value,
/// The second value to compare.
right: Value,
/// An optional message to display on error instead of the representations
/// of the compared values.
#[named]
#[default]
message: Option<EcoString>,
) -> Value {
if left != right {
if let Some(message) = message {
bail!(args.span, "equality assertion failed: {}", message);
} else {
bail!(
args.span,
"equality assertion failed: value {:?} was not equal to {:?}",
left,
right
);
}
}
Value::None
}
/// Ensure that two values are not equal.
///
/// Fails with an error if the first value is equal to the second. Does not
/// produce any output in the document.
///
/// ## Example
/// ```example
/// #assert.ne(3, 4)
/// ```
///
/// Display: Assert Not Equals
/// Category: foundations
/// Returns:
#[func]
pub fn assert_ne(
/// The first value to compare.
left: Value,
/// The second value to compare.
right: Value,
/// An optional message to display on error instead of the representations
/// of the compared values.
#[named]
#[default]
message: Option<EcoString>,
) -> Value {
if left == right {
if let Some(message) = message {
bail!(args.span, "inequality assertion failed: {}", message);
} else {
bail!(
args.span,
"inequality assertion failed: value {:?} was equal to {:?}",
left,
right
);
}
}
Value::None
}
/// Evaluate a string as Typst code.
///
/// This function should only be used as a last resort.

View File

@ -36,6 +36,17 @@ use super::GridLayouter;
/// + Don't forget step two
/// ```
///
/// You can also use [`enum.item`]($func/enum.item) to programmatically
/// customize the number of each item in the enumeration:
///
/// ```example
/// #enum(
/// enum.item(1)[First step],
/// enum.item(5)[Fifth step],
/// enum.item(10)[Tenth step]
/// )
/// ```
///
/// ## Syntax
/// This functions also has dedicated syntax:
///
@ -51,6 +62,10 @@ use super::GridLayouter;
/// Display: Numbered List
/// Category: layout
#[element(Layout)]
#[scope(
scope.define("item", EnumItem::func());
scope
)]
pub struct EnumElem {
/// If this is `{false}`, the items are spaced apart with
/// [enum spacing]($func/enum.spacing). If it is `{true}`, they use normal

View File

@ -15,6 +15,7 @@ struct Elem {
ident: Ident,
capable: Vec<Ident>,
fields: Vec<Field>,
scope: Option<BlockWithReturn>,
}
struct Field {
@ -28,7 +29,7 @@ struct Field {
synthesized: bool,
fold: bool,
resolve: bool,
parse: Option<FieldParser>,
parse: Option<BlockWithReturn>,
default: syn::Expr,
vis: syn::Visibility,
ident: Ident,
@ -50,21 +51,6 @@ impl Field {
}
}
struct FieldParser {
prefix: Vec<syn::Stmt>,
expr: syn::Stmt,
}
impl Parse for FieldParser {
fn parse(input: ParseStream) -> Result<Self> {
let mut stmts = syn::Block::parse_within(input)?;
let Some(expr) = stmts.pop() else {
return Err(input.error("expected at least on expression"));
};
Ok(Self { prefix: stmts, expr })
}
}
/// Preprocess the element's definition.
fn prepare(stream: TokenStream, body: &syn::ItemStruct) -> Result<Elem> {
let syn::Fields::Named(named) = &body.fields else {
@ -137,7 +123,8 @@ fn prepare(stream: TokenStream, body: &syn::ItemStruct) -> Result<Elem> {
.into_iter()
.collect();
let docs = documentation(&body.attrs);
let mut attrs = body.attrs.clone();
let docs = documentation(&attrs);
let mut lines = docs.split('\n').collect();
let category = meta_line(&mut lines, "Category")?.into();
let display = meta_line(&mut lines, "Display")?.into();
@ -152,9 +139,10 @@ fn prepare(stream: TokenStream, body: &syn::ItemStruct) -> Result<Elem> {
ident: body.ident.clone(),
capable,
fields,
scope: parse_attr(&mut attrs, "scope")?.flatten(),
};
validate_attrs(&body.attrs)?;
validate_attrs(&attrs)?;
Ok(element)
}
@ -351,6 +339,7 @@ fn create_pack_impl(element: &Elem) -> TokenStream {
.iter()
.filter(|field| !field.internal && !field.synthesized)
.map(create_param_info);
let scope = create_scope_builder(element.scope.as_ref());
quote! {
impl ::typst::model::Element for #ident {
fn pack(self) -> ::typst::model::Content {
@ -377,6 +366,7 @@ fn create_pack_impl(element: &Elem) -> TokenStream {
params: ::std::vec![#(#infos),*],
returns: ::std::vec!["content"],
category: #category,
scope: #scope,
}),
};
(&NATIVE).into()
@ -519,7 +509,7 @@ fn create_set_impl(element: &Elem) -> TokenStream {
/// Create argument parsing code for a field.
fn create_field_parser(field: &Field) -> (TokenStream, TokenStream) {
if let Some(FieldParser { prefix, expr }) = &field.parse {
if let Some(BlockWithReturn { prefix, expr }) = &field.parse {
return (quote! { #(#prefix);* }, quote! { #expr });
}

View File

@ -18,6 +18,7 @@ struct Func {
params: Vec<Param>,
returns: Vec<String>,
body: syn::Block,
scope: Option<BlockWithReturn>,
}
struct Param {
@ -72,7 +73,8 @@ fn prepare(item: &syn::ItemFn) -> Result<Func> {
validate_attrs(&attrs)?;
}
let docs = documentation(&item.attrs);
let mut attrs = item.attrs.clone();
let docs = documentation(&attrs);
let mut lines = docs.split('\n').collect();
let returns = meta_line(&mut lines, "Returns")?
.split(" or ")
@ -92,9 +94,10 @@ fn prepare(item: &syn::ItemFn) -> Result<Func> {
params,
returns,
body: (*item.block).clone(),
scope: parse_attr(&mut attrs, "scope")?.flatten(),
};
validate_attrs(&item.attrs)?;
validate_attrs(&attrs)?;
Ok(func)
}
@ -113,6 +116,7 @@ fn create(func: &Func) -> TokenStream {
} = func;
let handlers = params.iter().filter(|param| !param.external).map(create_param_parser);
let params = params.iter().map(create_param_info);
let scope = create_scope_builder(func.scope.as_ref());
quote! {
#[doc = #docs]
#vis fn #ident() -> &'static ::typst::eval::NativeFunc {
@ -129,6 +133,7 @@ fn create(func: &Func) -> TokenStream {
params: ::std::vec![#(#params),*],
returns: ::std::vec![#(#returns),*],
category: #category,
scope: #scope,
}),
};
&FUNC

View File

@ -18,6 +18,27 @@ macro_rules! bail {
};
}
/// For parsing attributes of the form:
/// #[attr(
/// statement;
/// statement;
/// returned_expression
/// )]
pub struct BlockWithReturn {
pub prefix: Vec<syn::Stmt>,
pub expr: syn::Stmt,
}
impl Parse for BlockWithReturn {
fn parse(input: ParseStream) -> Result<Self> {
let mut stmts = syn::Block::parse_within(input)?;
let Some(expr) = stmts.pop() else {
return Err(input.error("expected at least one expression"));
};
Ok(Self { prefix: stmts, expr })
}
}
/// Whether an attribute list has a specified attribute.
pub fn has_attr(attrs: &mut Vec<syn::Attribute>, target: &str) -> bool {
take_attr(attrs, target).is_some()
@ -88,3 +109,16 @@ pub fn meta_line<'a>(lines: &mut Vec<&'a str>, key: &str) -> Result<&'a str> {
None => bail!(callsite, "missing metadata key: {}", key),
}
}
/// Creates a block responsible for building a Scope.
pub fn create_scope_builder(scope_block: Option<&BlockWithReturn>) -> TokenStream {
if let Some(BlockWithReturn { prefix, expr }) = scope_block {
quote! { {
let mut scope = ::typst::eval::Scope::deduplicating();
#(#prefix);*
#expr
} }
} else {
quote! { ::typst::eval::Scope::new() }
}
}

View File

@ -5,12 +5,13 @@ use std::hash::{Hash, Hasher};
use std::sync::Arc;
use comemo::{Prehashed, Track, Tracked, TrackedMut};
use ecow::eco_format;
use once_cell::sync::Lazy;
use super::{
cast_to_value, Args, CastInfo, Eval, Flow, Route, Scope, Scopes, Tracer, Value, Vm,
};
use crate::diag::{bail, SourceResult};
use crate::diag::{bail, SourceResult, StrResult};
use crate::model::{ElemFunc, Introspector, StabilityProvider, Vt};
use crate::syntax::ast::{self, AstNode, Expr, Ident};
use crate::syntax::{SourceId, Span, SyntaxNode};
@ -144,6 +145,30 @@ impl Func {
_ => None,
}
}
/// Get a field from this function's scope, if possible.
pub fn get(&self, field: &str) -> StrResult<&Value> {
match &self.repr {
Repr::Native(func) => func.info.scope.get(field).ok_or_else(|| {
eco_format!(
"function `{}` does not contain field `{}`",
func.info.name,
field
)
}),
Repr::Elem(func) => func.info().scope.get(field).ok_or_else(|| {
eco_format!(
"function `{}` does not contain field `{}`",
func.name(),
field
)
}),
Repr::Closure(_) => {
Err(eco_format!("cannot access fields on user-defined functions"))
}
Repr::With(arc) => arc.0.get(field),
}
}
}
impl Debug for Func {
@ -225,6 +250,8 @@ pub struct FuncInfo {
pub returns: Vec<&'static str>,
/// Which category the function is part of.
pub category: &'static str,
/// The function's own scope of fields and sub-functions.
pub scope: Scope,
}
impl FuncInfo {

View File

@ -42,7 +42,7 @@ use std::mem;
use std::path::{Path, PathBuf};
use comemo::{Track, Tracked, TrackedMut};
use ecow::EcoVec;
use ecow::{EcoString, EcoVec};
use unicode_segmentation::UnicodeSegmentation;
use crate::diag::{
@ -1077,7 +1077,15 @@ impl Eval for ast::FuncCall {
if methods::is_mutating(&field) {
let args = args.eval(vm)?;
let target = target.access(vm)?;
if !matches!(target, Value::Symbol(_) | Value::Module(_)) {
// Prioritize a function's own methods (with, where) over its
// fields. This is fine as we define each field of a function,
// if it has any.
// ('methods_on' will be empty for Symbol and Module - their
// method calls always refer to their fields.)
if !matches!(target, Value::Symbol(_) | Value::Module(_) | Value::Func(_))
|| methods_on(target.type_name()).iter().any(|(m, _)| m == &field)
{
return methods::call_mut(target, &field, args, span).trace(
vm.world(),
point,
@ -1088,7 +1096,10 @@ impl Eval for ast::FuncCall {
} else {
let target = target.eval(vm)?;
let args = args.eval(vm)?;
if !matches!(target, Value::Symbol(_) | Value::Module(_)) {
if !matches!(target, Value::Symbol(_) | Value::Module(_) | Value::Func(_))
|| methods_on(target.type_name()).iter().any(|(m, _)| m == &field)
{
return methods::call(vm, target, &field, args, span).trace(
vm.world(),
point,
@ -1613,6 +1624,42 @@ impl Eval for ast::ForLoop {
}
}
/// Applies imports from `import` to the current scope.
fn apply_imports<V: Into<Value>>(
imports: Option<ast::Imports>,
vm: &mut Vm,
source_value: V,
name: impl Fn(&V) -> EcoString,
scope: impl Fn(&V) -> &Scope,
) -> SourceResult<()> {
match imports {
None => {
vm.scopes.top.define(name(&source_value), source_value);
}
Some(ast::Imports::Wildcard) => {
for (var, value) in scope(&source_value).iter() {
vm.scopes.top.define(var.clone(), value.clone());
}
}
Some(ast::Imports::Items(idents)) => {
let mut errors = vec![];
let scope = scope(&source_value);
for ident in idents {
if let Some(value) = scope.get(&ident) {
vm.define(ident, value.clone());
} else {
errors.push(error!(ident.span(), "unresolved import"));
}
}
if !errors.is_empty() {
return Err(Box::new(errors));
}
}
}
Ok(())
}
impl Eval for ast::ModuleImport {
type Output = Value;
@ -1620,30 +1667,26 @@ impl Eval for ast::ModuleImport {
fn eval(&self, vm: &mut Vm) -> SourceResult<Self::Output> {
let span = self.source().span();
let source = self.source().eval(vm)?;
let module = import(vm, source, span)?;
match self.imports() {
None => {
vm.scopes.top.define(module.name().clone(), module);
}
Some(ast::Imports::Wildcard) => {
for (var, value) in module.scope().iter() {
vm.scopes.top.define(var.clone(), value.clone());
}
}
Some(ast::Imports::Items(idents)) => {
let mut errors = vec![];
for ident in idents {
if let Some(value) = module.scope().get(&ident) {
vm.define(ident, value.clone());
} else {
errors.push(error!(ident.span(), "unresolved import"));
}
}
if !errors.is_empty() {
return Err(Box::new(errors));
}
if let Value::Func(func) = source {
if func.info().is_none() {
bail!(span, "cannot import from user-defined functions");
}
apply_imports(
self.imports(),
vm,
func,
|func| func.info().unwrap().name.into(),
|func| &func.info().unwrap().scope,
)?;
} else {
let module = import(vm, source, span, true)?;
apply_imports(
self.imports(),
vm,
module,
|module| module.name().clone(),
|module| module.scope(),
)?;
}
Ok(Value::None)
@ -1657,17 +1700,28 @@ impl Eval for ast::ModuleInclude {
fn eval(&self, vm: &mut Vm) -> SourceResult<Self::Output> {
let span = self.source().span();
let source = self.source().eval(vm)?;
let module = import(vm, source, span)?;
let module = import(vm, source, span, false)?;
Ok(module.content())
}
}
/// Process an import of a module relative to the current location.
fn import(vm: &mut Vm, source: Value, span: Span) -> SourceResult<Module> {
fn import(
vm: &mut Vm,
source: Value,
span: Span,
accept_functions: bool,
) -> SourceResult<Module> {
let path = match source {
Value::Str(path) => path,
Value::Module(module) => return Ok(module),
v => bail!(span, "expected path or module, found {}", v.type_name()),
v => {
if accept_functions {
bail!(span, "expected path, module or function, found {}", v.type_name())
} else {
bail!(span, "expected path or module, found {}", v.type_name())
}
}
};
// Load the source file.

View File

@ -127,6 +127,7 @@ impl Value {
Self::Dict(dict) => dict.at(field, None).cloned(),
Self::Content(content) => content.at(field, None),
Self::Module(module) => module.get(field).cloned(),
Self::Func(func) => func.get(field).cloned(),
v => Err(eco_format!("cannot access fields on type {}", v.type_name())),
}
}

View File

@ -388,6 +388,14 @@ fn field_access_completions(ctx: &mut CompletionContext, value: &Value) {
ctx.value_completion(Some(name.clone()), value, true, None);
}
}
Value::Func(func) => {
if let Some(info) = func.info() {
// Consider all names from the function's scope.
for (name, value) in info.scope.iter() {
ctx.value_completion(Some(name.clone()), value, true, None);
}
}
}
_ => {}
}
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.3 KiB

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -22,6 +22,30 @@
- B
- C
---
// Test fields on function scopes.
#enum.item
#assert.eq
#assert.ne
---
// Error: 9-16 function `assert` does not contain field `invalid`
#assert.invalid
---
// Error: 7-14 function `enum` does not contain field `invalid`
#enum.invalid
---
// Error: 7-14 function `enum` does not contain field `invalid`
#enum.invalid()
---
// Closures cannot have fields.
#let f(x) = x
// Error: 4-11 cannot access fields on user-defined functions
#f.invalid
---
// Error: 6-13 dictionary does not contain key "invalid" and no default value was specified
#(:).invalid

View File

@ -1,4 +1,4 @@
// Test module imports.
// Test function and module imports.
// Ref: false
---
@ -34,6 +34,20 @@
// It exists now!
#test(d, 3)
---
// Test importing from function scopes.
// Ref: true
#import enum: item
#import assert.with(true): *
#enum(
item(1)[First],
item(5)[Fifth]
)
#eq(10, 10)
#ne(5, 6)
---
// A module import without items.
#import "module.typ"
@ -59,6 +73,35 @@
// Allow the trailing comma.
#import "module.typ": a, c,
---
// Usual importing syntax also works for function scopes
#import enum
#let d = (e: enum)
#import d.e
#import d.e: item
#item(2)[a]
---
// Can't import from closures.
#let f(x) = x
// Error: 9-10 cannot import from user-defined functions
#import f: x
---
// Can't import from closures, despite modifiers.
#let f(x) = x
// Error: 9-18 cannot import from user-defined functions
#import f.with(5): x
---
// Error: 9-18 cannot import from user-defined functions
#import () => {5}: x
---
// Error: 9-10 expected path, module or function, found integer
#import 5: something
---
// Error: 9-11 failed to load file (is a directory)
#import "": name

View File

@ -40,6 +40,32 @@
// Error: 9-15 expected boolean, found string
#assert("true")
---
// Test failing assertions.
// Error: 11-19 equality assertion failed: value 10 was not equal to 11
#assert.eq(10, 11)
---
// Test failing assertions.
// Error: 11-55 equality assertion failed: 10 and 12 are not equal
#assert.eq(10, 12, message: "10 and 12 are not equal")
---
// Test failing assertions.
// Error: 11-19 inequality assertion failed: value 11 was equal to 11
#assert.ne(11, 11)
---
// Test failing assertions.
// Error: 11-57 inequality assertion failed: must be different from 11
#assert.ne(11, 11, message: "must be different from 11")
---
// Test successful assertions.
#assert(5 > 3)
#assert.eq(15, 15)
#assert.ne(10, 12)
---
// Test the `type` function.
#test(type(1), "integer")

View File

@ -35,6 +35,18 @@ Empty \
+Nope \
a + 0.
---
// Test item number overriding.
1. first
+ second
5. fifth
#enum(
enum.item(1)[First],
[Second],
enum.item(5)[Fifth]
)
---
// Alignment shouldn't affect number
#set align(horizon)