diff --git a/utoipa-gen/src/security_requirement.rs b/utoipa-gen/src/security_requirement.rs index 62bfa51d..3c6a6f66 100644 --- a/utoipa-gen/src/security_requirement.rs +++ b/utoipa-gen/src/security_requirement.rs @@ -5,16 +5,16 @@ use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, token::Comma, - LitStr, Token, + Token, }; -use crate::Array; +use crate::parse_utils; #[derive(Default)] #[cfg_attr(feature = "debug", derive(Debug))] pub struct SecurityRequirementsAttrItem { pub name: Option, - pub scopes: Option>, + pub scopes: Option>, } #[derive(Default)] @@ -30,17 +30,17 @@ impl Parse for SecurityRequirementsAttr { impl Parse for SecurityRequirementsAttrItem { fn parse(input: ParseStream) -> syn::Result { - let name = input.parse::()?.value(); + let name = input.parse::()?.value(); input.parse::()?; let scopes_stream; bracketed!(scopes_stream in input); - let scopes = Punctuated::::parse_terminated(&scopes_stream)? - .iter() - .map(LitStr::value) - .collect::>(); + let scopes = + Punctuated::::parse_terminated(&scopes_stream)? + .into_iter() + .collect::>(); Ok(Self { name: Some(name), @@ -57,11 +57,14 @@ impl ToTokens for SecurityRequirementsAttr { for requirement in &self.0 { if let (Some(name), Some(scopes)) = (&requirement.name, &requirement.scopes) { - let scopes = scopes.iter().collect::>(); + let scopes_tokens = scopes.iter().map(|scope| match scope { + parse_utils::LitStrOrExpr::LitStr(lit) => quote! { #lit.to_string() }, + parse_utils::LitStrOrExpr::Expr(expr) => quote! { #expr.to_string() }, + }); let scopes_len = scopes.len(); tokens.extend(quote! { - .add::<&str, [&str; #scopes_len], &str>(#name, #scopes) + .add::<&str, [String; #scopes_len], String>(#name, [#(#scopes_tokens),*]) }); } } diff --git a/utoipa-gen/tests/path_derive.rs b/utoipa-gen/tests/path_derive.rs index 2988362f..2dcb5124 100644 --- a/utoipa-gen/tests/path_derive.rs +++ b/utoipa-gen/tests/path_derive.rs @@ -285,6 +285,60 @@ fn derive_path_with_security_requirements() { } } +#[test] +fn derive_path_with_security_requirements_display_types() { + use std::fmt::Display; + + #[derive(Debug)] + enum Scope { + Read, + Write, + } + + impl Display for Scope { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Scope::Read => write!(f, "read:items"), + Scope::Write => write!(f, "write:items"), + } + } + } + + const READ_SCOPE: &str = "read:items"; + + #[utoipa::path( + get, + path = "/items", + responses( + (status = 200, description = "success response") + ), + security( + (), + ("api_oauth" = [Scope::Read.to_string(), Scope::Write.to_string()]), + ("jwt_token" = []), + ("mixed" = [READ_SCOPE, Scope::Write.to_string()]) + ) + )] + #[allow(unused)] + fn get_items() -> String { + "".to_string() + } + let operation = test_api_fn_doc! { + get_items, + operation: get, + path: "/items" + }; + + assert_value! {operation=> + "security.[0]" = "{}", "Optional security requirement" + "security.[1].api_oauth.[0]" = r###""read:items""###, "api_oauth first scope with Display" + "security.[1].api_oauth.[1]" = r###""write:items""###, "api_oauth second scope with Display" + "security.[2].jwt_token" = "[]", "jwt_token auth scopes" + "security.[3].mixed.[0]" = r###""read:items""###, "mixed first scope literal" + "security.[3].mixed.[1]" = r###""write:items""###, "mixed second scope Display" + } +} + #[test] fn derive_path_with_extensions() { #[utoipa::path( diff --git a/utoipa-gen/tests/utoipa_gen_test.rs b/utoipa-gen/tests/utoipa_gen_test.rs index 319f1efd..d6a80a4b 100644 --- a/utoipa-gen/tests/utoipa_gen_test.rs +++ b/utoipa-gen/tests/utoipa_gen_test.rs @@ -161,3 +161,53 @@ fn derive_openapi() { build_foo!(GetFooBody, Foo, FooResources); } + +#[test] +fn derive_openapi_with_security_display_types() { + use std::fmt::Display; + + #[derive(Debug)] + enum AuthScope { + Read, + Write, + Admin, + } + + impl Display for AuthScope { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthScope::Read => write!(f, "read:all"), + AuthScope::Write => write!(f, "write:all"), + AuthScope::Admin => write!(f, "admin:all"), + } + } + } + + const CUSTOM_SCOPE: &str = "custom:scope"; + + #[derive(Default, OpenApi)] + #[openapi( + security( + (), + ("oauth2" = [AuthScope::Read.to_string(), AuthScope::Write.to_string()]), + ("api_key" = []), + ("mixed" = [CUSTOM_SCOPE, AuthScope::Admin.to_string()]) + ) + )] + struct ApiDocWithDisplay; + + let api = ApiDocWithDisplay::openapi(); + let json = api.to_json().unwrap(); + let security = serde_json::from_str::(&json).unwrap()["security"].clone(); + + assert_eq!(security[0], serde_json::json!({})); + assert_eq!( + security[1]["oauth2"], + serde_json::json!(["read:all", "write:all"]) + ); + assert_eq!(security[2]["api_key"], serde_json::json!([])); + assert_eq!( + security[3]["mixed"], + serde_json::json!(["custom:scope", "admin:all"]) + ); +}