Apply type-directed coercions to arguments in calls of user functions (#6179)

Signed-off-by: Nick Cameron <nrc@ncameron.org>
This commit is contained in:
Nick Cameron
2025-04-07 18:02:46 +12:00
committed by GitHub
parent e6ae89ebf9
commit bc22d888ee
2 changed files with 102 additions and 55 deletions

View File

@ -2006,32 +2006,102 @@ fn assign_args_to_params(
Ok(())
}
fn assign_args_to_params_kw(
fn type_check_params_kw(
fn_name: Option<&str>,
function_expression: NodeRef<'_, FunctionExpression>,
mut args: crate::std::args::KwArgs,
args: &mut crate::std::args::KwArgs,
exec_state: &mut ExecState,
) -> Result<(), KclError> {
for (label, arg) in &args.labeled {
for (label, arg) in &mut args.labeled {
match function_expression.params.iter().find(|p| &p.identifier.name == label) {
Some(p) => {
if !p.labeled {
exec_state.err(CompilationError::err(
arg.source_range,
format!(
"This function expects an unlabeled first parameter (`{label}`), but it is labelled in the call"
"{} expects an unlabeled first parameter (`{label}`), but it is labelled in the call",
fn_name
.map(|n| format!("The function `{}`", n))
.unwrap_or_else(|| "This function".to_owned()),
),
));
}
if let Some(ty) = &p.type_ {
arg.value = arg
.value
.coerce(
&RuntimeType::from_parsed(ty.inner.clone(), exec_state, arg.source_range).unwrap(),
exec_state,
)
.map_err(|e| {
let mut message = format!(
"{label} requires a value with type `{}`, but found {}",
ty.inner,
arg.value.human_friendly_type(),
);
if let Some(ty) = e.explicit_coercion {
// TODO if we have access to the AST for the argument we could choose which example to suggest.
message = format!("{message}\n\nYou may need to add information about the type of the argument, for example:\n using a numeric suffix: `42{ty}`\n or using type ascription: `foo(): number({ty})`");
}
KclError::Semantic(KclErrorDetails {
message,
source_ranges: vec![arg.source_range],
})
})?;
}
}
None => {
exec_state.err(CompilationError::err(
arg.source_range,
format!("`{label}` is not an argument of this function"),
format!(
"`{label}` is not an argument of {}",
fn_name
.map(|n| format!("`{}`", n))
.unwrap_or_else(|| "this function".to_owned()),
),
));
}
}
}
if let Some(arg) = &mut args.unlabeled {
if let Some(p) = function_expression.params.iter().find(|p| !p.labeled) {
if let Some(ty) = &p.type_ {
arg.value = arg
.value
.coerce(
&RuntimeType::from_parsed(ty.inner.clone(), exec_state, arg.source_range).unwrap(),
exec_state,
)
.map_err(|_| {
KclError::Semantic(KclErrorDetails {
message: format!(
"The input argument of {} requires a value with type `{}`, but found {}",
fn_name
.map(|n| format!("`{}`", n))
.unwrap_or_else(|| "this function".to_owned()),
ty.inner,
arg.value.human_friendly_type()
),
source_ranges: vec![arg.source_range],
})
})?;
}
}
}
Ok(())
}
fn assign_args_to_params_kw(
fn_name: Option<&str>,
function_expression: NodeRef<'_, FunctionExpression>,
mut args: crate::std::args::KwArgs,
exec_state: &mut ExecState,
) -> Result<(), KclError> {
type_check_params_kw(fn_name, function_expression, &mut args, exec_state)?;
// Add the arguments to the memory. A new call frame should have already
// been created.
let source_ranges = vec![function_expression.into()];
@ -2118,6 +2188,7 @@ async fn call_user_defined_function(
}
async fn call_user_defined_function_kw(
fn_name: Option<&str>,
args: crate::std::args::KwArgs,
memory: EnvironmentRef,
function_expression: NodeRef<'_, FunctionExpression>,
@ -2128,7 +2199,7 @@ async fn call_user_defined_function_kw(
// variables shadow variables in the parent scope. The new environment's
// parent should be the environment of the closure.
exec_state.mut_stack().push_new_env_for_call(memory);
if let Err(e) = assign_args_to_params_kw(function_expression, args, exec_state) {
if let Err(e) = assign_args_to_params_kw(fn_name, function_expression, args, exec_state) {
exec_state.mut_stack().pop_env();
return Err(e);
}
@ -2206,52 +2277,7 @@ impl FunctionSource {
));
}
for (label, arg) in &mut args.kw_args.labeled {
match ast.params.iter().find(|p| &p.identifier.name == label) {
Some(p) => {
if !p.labeled {
exec_state.err(CompilationError::err(
arg.source_range,
format!(
"The function `{}` expects an unlabeled first parameter (`{label}`), but it is labelled in the call",
props.name
),
));
}
if let Some(ty) = &p.type_ {
arg.value = arg
.value
.coerce(
&RuntimeType::from_parsed(ty.inner.clone(), exec_state, arg.source_range)
.unwrap(),
exec_state,
)
.map_err(|e| {
let mut message = format!(
"{label} requires a value with type `{}`, but found {}",
ty.inner,
arg.value.human_friendly_type(),
);
if let Some(ty) = e.explicit_coercion {
// TODO if we have access to the AST for the argument we could choose which example to suggest.
message = format!("{message}\n\nYou may need to add information about the type of the argument, for example:\n using a numeric suffix: `42{ty}`\n or using type ascription: `foo(): number({ty})`");
}
KclError::Semantic(KclErrorDetails {
message,
source_ranges: vec![callsite],
})
})?;
}
}
None => {
exec_state.err(CompilationError::err(
arg.source_range,
format!("`{label}` is not an argument of `{}`", props.name),
));
}
}
}
type_check_params_kw(Some(&props.name), ast, &mut args.kw_args, exec_state)?;
if let Some(arg) = &mut args.kw_args.unlabeled {
if let Some(p) = ast.params.iter().find(|p| !p.labeled) {
@ -2331,7 +2357,7 @@ impl FunctionSource {
.collect();
exec_state.global.operations.push(Operation::GroupBegin {
group: Group::FunctionCall {
name: fn_name,
name: fn_name.clone(),
function_source_range: ast.as_source_range(),
unlabeled_arg: args
.kw_args
@ -2344,7 +2370,7 @@ impl FunctionSource {
});
}
call_user_defined_function_kw(args.kw_args, *memory, ast, exec_state, ctx).await
call_user_defined_function_kw(fn_name.as_deref(), args.kw_args, *memory, ast, exec_state, ctx).await
}
FunctionSource::None => unreachable!(),
}
@ -2570,4 +2596,25 @@ a = foo()
let result = parse_execute(program).await;
assert!(result.unwrap_err().to_string().contains("return"));
}
#[tokio::test(flavor = "multi_thread")]
async fn user_coercion() {
let program = r#"fn foo(x: Axis2d) {
return 0
}
foo(x = { direction = [0, 0], origin = [0, 0]})
"#;
parse_execute(program).await.unwrap();
let program = r#"fn foo(x: Axis3d) {
return 0
}
foo(x = { direction = [0, 0], origin = [0, 0]})
"#;
parse_execute(program).await.unwrap_err();
}
}