diff --git a/proxmox-api-macro/src/api/method.rs b/proxmox-api-macro/src/api/method.rs index 4577b2ca..fd961cae 100644 --- a/proxmox-api-macro/src/api/method.rs +++ b/proxmox-api-macro/src/api/method.rs @@ -109,11 +109,28 @@ impl TryFrom for ReturnSchema { } } +#[derive(Clone, Copy)] +enum MethodFlavor { + Normal, + Serializing, + Streaming, +} + +struct MethodInfo { + input_schema: Schema, + return_type: Option, + func: syn::ItemFn, + wrapper_ts: TokenStream, + default_consts: TokenStream, + flavor: MethodFlavor, + is_async: bool, +} + /// Parse `input`, `returns` and `protected` attributes out of an function annotated /// with an `#[api]` attribute and produce a `const ApiMethod` named after the function. /// /// See the top level macro documentation for a complete example. -pub fn handle_method(mut attribs: JSONObject, mut func: syn::ItemFn) -> Result { +pub fn handle_method(mut attribs: JSONObject, func: syn::ItemFn) -> Result { let input_schema: Schema = match attribs.remove("input") { Some(input) => input.into_object("input schema definition")?.try_into()?, None => Schema { @@ -124,7 +141,7 @@ pub fn handle_method(mut attribs: JSONObject, mut func: syn::ItemFn) -> Result Result = attribs + let return_type: Option = attribs .remove("returns") .map(|ret| ret.try_into()) .transpose()?; + /* FIXME: Once the deprecation period is over: + if let Some(streaming) = attribs.remove("streaming") { + error!( + streaming.span(), + "streaming attribute was renamed to 'serializing', as it did not actually stream" + ); + } + */ + + let streaming: Option = attribs + .remove("streaming") + .map(TryFrom::try_from) + .transpose()?; + let serializing: Option = attribs + .remove("serializing") + .map(TryFrom::try_from) + .transpose()?; + let deprecation_warning = if let Some(streaming) = streaming.clone() { + let deprecation_name = Ident::new( + &format!("attribute_in_{}", func.sig.ident), + streaming.span(), + ); + quote! { + mod #deprecation_name { + #[deprecated = "'streaming' attribute is being renamed to 'serializing'"] + fn streaming() {} + fn trigger_deprecation_warning() { streaming() } + } + } + } else { + TokenStream::new() + }; + let serializing = streaming + .or(serializing) + .unwrap_or(syn::LitBool::new(false, Span::call_site())); + let streaming: syn::LitBool = attribs + .remove("stream") + .map(TryFrom::try_from) + .transpose()? + .unwrap_or(syn::LitBool::new(false, Span::call_site())); + + let mut method_info = MethodInfo { + input_schema, + return_type, + wrapper_ts: TokenStream::new(), + default_consts: TokenStream::new(), + is_async: func.sig.asyncness.is_some(), + flavor: match (serializing.value(), streaming.value()) { + (false, false) => MethodFlavor::Normal, + (true, false) => MethodFlavor::Serializing, + (false, true) => MethodFlavor::Streaming, + (true, true) => { + error!(serializing => "'stream' and 'serializing' attributes are in conflict"); + MethodFlavor::Normal + } + }, + func, + }; + let access_setter = match attribs.remove("access") { Some(access) => { let access = Access::try_from(access.into_object("access rules")?)?; @@ -169,19 +245,6 @@ pub fn handle_method(mut attribs: JSONObject, mut func: syn::ItemFn) -> Result Result Result quote! { ::proxmox_router::ApiHandler::SerializingAsync(&#api_func_name) }, - (true, false) => quote! { ::proxmox_router::ApiHandler::SerializingSync(&#api_func_name) }, - (false, true) => quote! { ::proxmox_router::ApiHandler::Async(&#api_func_name) }, - (false, false) => quote! { ::proxmox_router::ApiHandler::Sync(&#api_func_name) }, + let api_handler = match (flavor, is_async) { + (MethodFlavor::Normal, true) => { + quote! { ::proxmox_router::ApiHandler::Async(&#api_func_name) } + } + (MethodFlavor::Normal, false) => { + quote! { ::proxmox_router::ApiHandler::Sync(&#api_func_name) } + } + (MethodFlavor::Serializing, true) => { + quote! { ::proxmox_router::ApiHandler::SerializingAsync(&#api_func_name) } + } + (MethodFlavor::Serializing, false) => { + quote! { ::proxmox_router::ApiHandler::SerializingSync(&#api_func_name) } + } + (MethodFlavor::Streaming, true) => { + quote! { ::proxmox_router::ApiHandler::StreamAsync(&#api_func_name) } + } + (MethodFlavor::Streaming, false) => { + quote! { ::proxmox_router::ApiHandler::StreamSync(&#api_func_name) } + } }; Ok(quote_spanned! { func.sig.span() => @@ -256,20 +336,22 @@ pub fn handle_method(mut attribs: JSONObject, mut func: syn::ItemFn) -> Result { +enum ParameterType { Value, ApiMethod, RpcEnv, - Normal(NormalParameter<'a>), + Normal(NormalParameter), } -struct NormalParameter<'a> { - ty: &'a syn::Type, - entry: &'a ObjectEntry, +struct NormalParameter { + ty: syn::Type, + entry: ObjectEntry, } fn check_input_type(input: &syn::FnArg) -> Result<(&syn::PatType, &syn::PatIdent), syn::Error> { @@ -288,16 +370,8 @@ fn check_input_type(input: &syn::FnArg) -> Result<(&syn::PatType, &syn::PatIdent Ok((pat_type, pat)) } -fn handle_function_signature( - input_schema: &mut Schema, - _return_type: &mut Option, - func: &mut syn::ItemFn, - wrapper_ts: &mut TokenStream, - default_consts: &mut TokenStream, - serializing: bool, -) -> Result { - let sig = &func.sig; - let is_async = sig.asyncness.is_some(); +fn handle_function_signature(method_info: &mut MethodInfo) -> Result { + let sig = &method_info.func.sig; let mut api_method_param = None; let mut rpc_env_param = None; @@ -315,7 +389,10 @@ fn handle_function_signature( }; // For any named type which exists on the function signature... - if let Some(entry) = input_schema.find_obj_property_by_ident_mut(&pat.ident.to_string()) { + if let Some(entry) = method_info + .input_schema + .find_obj_property_by_ident_mut(&pat.ident.to_string()) + { // try to infer the type in the schema if it is not specified explicitly: let is_option = util::infer_type(&mut entry.schema, &pat_type.ty)?; let has_default = entry.schema.find_schema_property("default").is_some(); @@ -363,75 +440,49 @@ fn handle_function_signature( // bail out with an error. let pat_ident = pat.ident.unraw(); let mut param_name: FieldName = pat_ident.clone().into(); - let param_type = - if let Some(entry) = input_schema.find_obj_property_by_ident(&pat_ident.to_string()) { - if let SchemaItem::Inferred(span) = &entry.schema.item { - bail!(*span, "failed to infer type"); - } - param_name = entry.name.clone(); - // Found an explicit parameter: extract it: - ParameterType::Normal(NormalParameter { - ty: &pat_type.ty, - entry, - }) - } else if is_api_method_type(&pat_type.ty) { - if api_method_param.is_some() { - error!(pat_type => "multiple ApiMethod parameters found"); - continue; - } - api_method_param = Some(param_list.len()); - ParameterType::ApiMethod - } else if is_rpc_env_type(&pat_type.ty) { - if rpc_env_param.is_some() { - error!(pat_type => "multiple RpcEnvironment parameters found"); - continue; - } - rpc_env_param = Some(param_list.len()); - ParameterType::RpcEnv - } else if is_value_type(&pat_type.ty) { - if value_param.is_some() { - error!(pat_type => "multiple additional Value parameters found"); - continue; - } - value_param = Some(param_list.len()); - ParameterType::Value - } else { - error!(&pat_ident => "unexpected parameter {:?}", pat_ident.to_string()); + let param_type = if let Some(entry) = method_info + .input_schema + .find_obj_property_by_ident(&pat_ident.to_string()) + { + if let SchemaItem::Inferred(span) = &entry.schema.item { + bail!(*span, "failed to infer type"); + } + param_name = entry.name.clone(); + // Found an explicit parameter: extract it: + ParameterType::Normal(NormalParameter { + ty: (*pat_type.ty).clone(), + entry: entry.clone(), + }) + } else if is_api_method_type(&pat_type.ty) { + if api_method_param.is_some() { + error!(pat_type => "multiple ApiMethod parameters found"); continue; - }; + } + api_method_param = Some(param_list.len()); + ParameterType::ApiMethod + } else if is_rpc_env_type(&pat_type.ty) { + if rpc_env_param.is_some() { + error!(pat_type => "multiple RpcEnvironment parameters found"); + continue; + } + rpc_env_param = Some(param_list.len()); + ParameterType::RpcEnv + } else if is_value_type(&pat_type.ty) { + if value_param.is_some() { + error!(pat_type => "multiple additional Value parameters found"); + continue; + } + value_param = Some(param_list.len()); + ParameterType::Value + } else { + error!(&pat_ident => "unexpected parameter {:?}", pat_ident.to_string()); + continue; + }; param_list.push((param_name, param_type)); } - /* - * Doing this is actually unreliable, since we cannot support aliased Result types, or all - * poassible combinations of paths like `result::Result<>` or `std::result::Result<>` or - * `ApiResult`. - - // Secondly, take a look at the return type, and then decide what to do: - // If our function has the correct signature we may not even need a wrapper. - if is_default_return_type(&sig.output) - && ( - param_list.len(), - value_param, - api_method_param, - rpc_env_param, - ) == (3, Some(0), Some(1), Some(2)) - { - return Ok(sig.ident.clone()); - } - */ - - create_wrapper_function( - //input_schema, - //return_type, - param_list, - func, - wrapper_ts, - default_consts, - is_async, - serializing, - ) + create_wrapper_function(method_info, param_list) } fn is_api_method_type(ty: &syn::Type) -> bool { @@ -481,24 +532,18 @@ fn is_value_type(ty: &syn::Type) -> bool { } fn create_wrapper_function( - //_input_schema: &Schema, - //_returns_schema: &Option, + method_info: &mut MethodInfo, param_list: Vec<(FieldName, ParameterType)>, - func: &syn::ItemFn, - wrapper_ts: &mut TokenStream, - default_consts: &mut TokenStream, - is_async: bool, - serializing: bool, ) -> Result { let api_func_name = Ident::new( - &format!("api_function_{}", &func.sig.ident), - func.sig.ident.span(), + &format!("api_function_{}", &method_info.func.sig.ident), + method_info.func.sig.ident.span(), ); let mut body = TokenStream::new(); let mut args = TokenStream::new(); - let func_uc = func.sig.ident.to_string().to_uppercase(); + let func_uc = method_info.func.sig.ident.to_string().to_uppercase(); for (name, param) in param_list { let span = name.span(); @@ -514,69 +559,71 @@ fn create_wrapper_function( &func_uc, name, span, - default_consts, + &mut method_info.default_consts, )?; } } } // build the wrapping function: - let func_name = &func.sig.ident; + let func_name = &method_info.func.sig.ident; - let await_keyword = if is_async { Some(quote!(.await)) } else { None }; + let await_keyword = if method_info.is_async { + Some(quote!(.await)) + } else { + None + }; - let question_mark = match func.sig.output { + let question_mark = match method_info.func.sig.output { syn::ReturnType::Default => None, _ => Some(quote!(?)), }; - let body = if serializing { - quote! { - if let ::serde_json::Value::Object(ref mut input_map) = &mut input_params { - #body - let res = #func_name(#args) #await_keyword #question_mark; - let res: ::std::boxed::Box = ::std::boxed::Box::new(res); - Ok(res) - } else { - ::anyhow::bail!("api function wrapper called with a non-object json value"); + let body = match method_info.flavor { + MethodFlavor::Normal => { + quote! { + if let ::serde_json::Value::Object(ref mut input_map) = &mut input_params { + #body + Ok(::serde_json::to_value(#func_name(#args) #await_keyword #question_mark)?) + } else { + ::anyhow::bail!("api function wrapper called with a non-object json value"); + } } } - } else { - quote! { - if let ::serde_json::Value::Object(ref mut input_map) = &mut input_params { - #body - Ok(::serde_json::to_value(#func_name(#args) #await_keyword #question_mark)?) + MethodFlavor::Serializing => { + quote! { + if let ::serde_json::Value::Object(ref mut input_map) = &mut input_params { + #body + let res = #func_name(#args) #await_keyword #question_mark; + let res: ::std::boxed::Box = ::std::boxed::Box::new(res); + Ok(res) + } else { + ::anyhow::bail!("api function wrapper called with a non-object json value"); + } + } + } + MethodFlavor::Streaming => { + let ty = if method_info.is_async { + quote! { ::proxmox_router::Stream } } else { - ::anyhow::bail!("api function wrapper called with a non-object json value"); + quote! { ::proxmox_router::SyncStream } + }; + quote! { + if let ::serde_json::Value::Object(ref mut input_map) = &mut input_params { + #body + let res = #func_name(#args) #await_keyword #question_mark; + let res = #ty::from(res); + Ok(res) + } else { + ::anyhow::bail!("api function wrapper called with a non-object json value"); + } } } }; - match (serializing, is_async) { - (true, true) => { - wrapper_ts.extend(quote! { - fn #api_func_name<'a>( - mut input_params: ::serde_json::Value, - api_method_param: &'static ::proxmox_router::ApiMethod, - rpc_env_param: &'a mut dyn ::proxmox_router::RpcEnvironment, - ) -> ::proxmox_router::SerializingApiFuture<'a> { - ::std::boxed::Box::pin(async move { #body }) - } - }); - } - (true, false) => { - wrapper_ts.extend(quote! { - fn #api_func_name( - mut input_params: ::serde_json::Value, - api_method_param: &::proxmox_router::ApiMethod, - rpc_env_param: &mut dyn ::proxmox_router::RpcEnvironment, - ) -> ::std::result::Result<::std::boxed::Box, ::anyhow::Error> { - #body - } - }); - } - (false, true) => { - wrapper_ts.extend(quote! { + match (method_info.flavor, method_info.is_async) { + (MethodFlavor::Normal, true) => { + method_info.wrapper_ts.extend(quote! { fn #api_func_name<'a>( mut input_params: ::serde_json::Value, api_method_param: &'static ::proxmox_router::ApiMethod, @@ -596,8 +643,8 @@ fn create_wrapper_function( } }); } - (false, false) => { - wrapper_ts.extend(quote! { + (MethodFlavor::Normal, false) => { + method_info.wrapper_ts.extend(quote! { fn #api_func_name( mut input_params: ::serde_json::Value, api_method_param: &::proxmox_router::ApiMethod, @@ -607,6 +654,50 @@ fn create_wrapper_function( } }); } + (MethodFlavor::Serializing, true) => { + method_info.wrapper_ts.extend(quote! { + fn #api_func_name<'a>( + mut input_params: ::serde_json::Value, + api_method_param: &'static ::proxmox_router::ApiMethod, + rpc_env_param: &'a mut dyn ::proxmox_router::RpcEnvironment, + ) -> ::proxmox_router::SerializingApiFuture<'a> { + ::std::boxed::Box::pin(async move { #body }) + } + }); + } + (MethodFlavor::Serializing, false) => { + method_info.wrapper_ts.extend(quote! { + fn #api_func_name( + mut input_params: ::serde_json::Value, + api_method_param: &::proxmox_router::ApiMethod, + rpc_env_param: &mut dyn ::proxmox_router::RpcEnvironment, + ) -> ::std::result::Result<::std::boxed::Box, ::anyhow::Error> { + #body + } + }); + } + (MethodFlavor::Streaming, true) => { + method_info.wrapper_ts.extend(quote! { + fn #api_func_name<'a>( + mut input_params: ::serde_json::Value, + api_method_param: &'static ::proxmox_router::ApiMethod, + rpc_env_param: &'a mut dyn ::proxmox_router::RpcEnvironment, + ) -> ::proxmox_router::StreamApiFuture<'a> { + ::std::boxed::Box::pin(async move { #body }) + } + }); + } + (MethodFlavor::Streaming, false) => { + method_info.wrapper_ts.extend(quote! { + fn #api_func_name( + mut input_params: ::serde_json::Value, + api_method_param: &::proxmox_router::ApiMethod, + rpc_env_param: &mut dyn ::proxmox_router::RpcEnvironment, + ) -> ::std::result::Result<::proxmox_router::SyncStream, ::anyhow::Error> { + #body + } + }); + } } Ok(api_func_name) @@ -655,7 +746,7 @@ fn extract_normal_parameter( }); } - let no_option_type = util::is_option_type(param.ty).is_none(); + let no_option_type = util::is_option_type(¶m.ty).is_none(); if let Some(def) = &default_value { let name_uc = name.as_ident().to_string().to_uppercase(); @@ -665,7 +756,7 @@ fn extract_normal_parameter( ); // strip possible Option<> from this type: - let ty = util::is_option_type(param.ty).unwrap_or(param.ty); + let ty = util::is_option_type(¶m.ty).unwrap_or(¶m.ty); default_consts.extend(quote_spanned! { span => pub const #name: #ty = #def; });