diff --git a/library/src/typed_router.rs b/library/src/typed_router.rs index c379b55..fdb4840 100644 --- a/library/src/typed_router.rs +++ b/library/src/typed_router.rs @@ -20,7 +20,6 @@ macro_rules! submit_router_method { pub fn get_router() -> Router { let mut router = Router::new(); for method in inventory::iter::<&dyn RouteMethod> { - // result.push(Arc::new(method.ge_task())); router = method.ge_router(router); } router diff --git a/macro/src/route.rs b/macro/src/route.rs index 76f5baa..46eecfd 100644 --- a/macro/src/route.rs +++ b/macro/src/route.rs @@ -22,32 +22,7 @@ struct RouteArgs { // 实现 Parse trait 以支持解析参数 impl Parse for RouteArgs { fn parse(input: ParseStream) -> Result { - // 使用Meta解析 begin - // let args = Punctuated::::parse_terminated(input)?; - // let mut path = None; - // let mut methods = Vec::new(); - // for arg in args { - // if let Meta::NameValue(nv) = arg { - // if nv.path.is_ident("method") { - // if let Expr::Array(ExprArray { elems, .. }) = nv.value { - // for elem in elems { - // if let Expr::Lit(syn::ExprLit { lit: Lit::Str(lit_str), .. }) = elem { - // methods.push(lit_str.value()); - // } - // } - // } - // } else if nv.path.is_ident("path") { - // if let Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. }) = nv.value { - // path = Some(lit_str.value()); - // } - // } - // } - // } - // let path = path.expect("Expected a path argument"); - // 使用Meta解析 end - let args = Punctuated::::parse_terminated(input)?; - let mut path = None; let mut methods = Vec::new(); @@ -57,7 +32,10 @@ impl Parse for RouteArgs { lit: Lit::Str(lit_str), .. }) => { - path = Some(lit_str.value()); + let path_str = lit_str.value(); + // 验证路由路径 + validate_route_path(&path_str)?; + path = Some(path_str); } Expr::Assign(assign) => { if let Expr::Path(path) = *assign.left { @@ -80,43 +58,24 @@ impl Parse for RouteArgs { } } - let path = path.expect("路由参数不能为空"); + let path = path.ok_or_else(|| Error::new( + Span::call_site().into(), + "路由路径参数不能为空", + ))?; Ok(RouteArgs { path, methods }) } } -struct Args { - vars: Vec, -} - -impl Parse for Args { - fn parse(input: ParseStream) -> Result { - let vars = Punctuated::::parse_terminated(input)?; - - Ok(Args { - vars: vars.into_iter().collect(), - }) - } -} - -impl Args { - pub fn get_arg(&self, index: usize) -> Result> { - match self.vars.get(index) { - Some(var) => Ok(Some(var.to_owned())), - None => { - // 第一个参数使路由url,必须存在,其他的参数根据实际需求进一步解析 - if index != 0 { - Ok(None) - } else { - Err(Error::new( - Span::call_site().into(), - "route must have one argument", - )) - } - } - } +// 添加路由参数验证 +fn validate_route_path(path: &str) -> Result<()> { + if !path.starts_with('/') { + return Err(Error::new( + Span::call_site().into(), + "路由路径必须以'/'开头", + )); } + Ok(()) } pub fn gen_route(attr: TokenStream, item: TokenStream, method: &str) -> TokenStream { @@ -150,11 +109,12 @@ pub fn gen_route(attr: TokenStream, item: TokenStream, method: &str) -> TokenStr #func } impl library::typed_router::RouteMethod for #ident { - fn ge_router(&self, router: axum::Router) -> axum::Router { + fn ge_router(&self, mut router: axum::Router) -> axum::Router { let methods = vec![#(#method_routers),*]; - methods.into_iter().fold(router, |router, (path, method_router)| { - router.route(path, method_router) - }) + for (path, method_router) in methods { + router = router.route(path, method_router); + } + router } } ::library::submit_router_method!(#ident); @@ -162,45 +122,3 @@ pub fn gen_route(attr: TokenStream, item: TokenStream, method: &str) -> TokenStr TokenStream::from(generated) } - -#[allow(dead_code)] -pub fn gen_dyn_route(attr: TokenStream, item: TokenStream, method: &str) -> TokenStream { - let args = parse_macro_input!(attr as Args); - let func = parse_macro_input!(item as ItemFn); - - let vis = func.vis.clone(); - let ident = func.sig.ident.clone(); - - let route = args.get_arg(0).unwrap().unwrap(); - - let method_option = args.get_arg(1); - match method_option { - Ok(method) => { - if let Some(Expr::Assign(methods_value)) = method { - println!("method is {:?}", methods_value); - // if let Expr::Path(left) = methods_value.left { - // left.path.segments[0].ident.eq("d") { - - // } - // } - // if let Expr::Lit(right) = methods_value.right { - - // } - } - } - Err(err) => { - println!("error is {}", err); - } - } - - let method_name: Ident = Ident::new(method, route.span()); - - let expanded = quote! { - #vis fn #ident () -> (&'static str, axum::routing::method_routing::MethodRouter) { - #func - - (#route, axum::routing::#method_name(#ident)) - } - }; - expanded.into() -}