diff --git a/Cargo.lock b/Cargo.lock index bf2407d..959b6e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1066,6 +1066,7 @@ dependencies = [ "http-body", "http-body-util", "jsonwebtoken", + "lazy_static", "moka", "once_cell", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 4ad0a2a..1378534 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,4 +39,5 @@ futures-util = "0.3" reqwest = "0.12" futures-executor = "0.3" error-stack = "0.4" -jsonwebtoken = "9.3.0" \ No newline at end of file +jsonwebtoken = "9.3.0" +lazy_static = "1.4.0" \ No newline at end of file diff --git a/library/Cargo.toml b/library/Cargo.toml index c2e0d90..9e02e7b 100644 --- a/library/Cargo.toml +++ b/library/Cargo.toml @@ -26,4 +26,5 @@ tokio = { workspace = true, features = ["rt-multi-thread", "macros" ] } futures-util = { workspace = true } jsonwebtoken = { workspace = true } reqwest = { workspace = true, features = ["blocking", "json"] } -validator = { workspace = true } \ No newline at end of file +validator = { workspace = true } +lazy_static = { workspace = true } \ No newline at end of file diff --git a/library/src/social/google.rs b/library/src/social/google.rs index 42b9453..f9edf1d 100644 --- a/library/src/social/google.rs +++ b/library/src/social/google.rs @@ -1,7 +1,9 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use chrono::Utc; +use futures_util::lock::Mutex; use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use lazy_static::lazy_static; use reqwest::Client; use serde::Deserialize; use serde_json::Value; @@ -10,6 +12,15 @@ use crate::resp::response::ResErr; use super::SocialResult; +lazy_static! { + pub static ref GOOGLE_SOCIAL: GoogleSocial = GoogleSocial::default(); +} + +#[derive(Default)] +pub struct GoogleSocial { + google_public_keys: Arc>, +} + // 假设GOOGLE_PUBLIC_CERT_URL是Google提供的公钥URL const GOOGLE_PUBLIC_CERT_URL: &str = "https://www.googleapis.com/oauth2/v3/certs"; @@ -93,7 +104,7 @@ impl From for GoogleJwtProfile { } } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default)] struct GooglePublicKey { e: String, n: String, @@ -103,69 +114,78 @@ struct GooglePublicKey { kid: String, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default)] struct GooglePublicKeys { keys: Vec, + refresh_at: i64, } -// 异步获取并解析Google公钥 -async fn fetch_and_parse_google_public_keys() -> SocialResult> { - let response = Client::new().get(GOOGLE_PUBLIC_CERT_URL).send().await?; - let google_keys: GooglePublicKeys = response.json().await?; - tracing::info!("Google公钥获取成功, {:?}", google_keys); - - let mut key_map = HashMap::new(); - // 解析公钥 - for key in google_keys.keys.iter() { - if key.kty == "RSA" { - key_map.insert(key.kid.to_owned(), key.to_owned()); +impl GoogleSocial { + // 异步获取并解析Google公钥 + async fn fetch_and_parse_google_public_keys(&self) -> SocialResult> + { + let mut public_keys = self.google_public_keys.lock().await; + if public_keys.keys.is_empty() || public_keys.refresh_at < Utc::now().timestamp() { + let response = Client::new().get(GOOGLE_PUBLIC_CERT_URL).send().await?; + let mut google_keys: GooglePublicKeys = response.json().await?; + tracing::info!("Google公钥获取成功, {:?}", google_keys); + google_keys.refresh_at = Utc::now().timestamp() + 3600; + *public_keys = google_keys; } + + let mut key_map = HashMap::new(); + // 解析公钥 + for key in public_keys.keys.iter() { + if key.kty == "RSA" { + key_map.insert(key.kid.to_owned(), key.to_owned()); + } + } + + tracing::info!("Google公钥解析成功, {:?}", key_map); + Ok(key_map) } - tracing::info!("Google公钥解析成功, {:?}", key_map); - Ok(key_map) -} - -// 验证ID Token并考虑kid匹配 -pub async fn verify_id_token(id_token: &str) -> SocialResult { - // 获取并解析公钥 - let public_keys = fetch_and_parse_google_public_keys().await?; - - // 解码Token头部以获取kid - let token_header = jsonwebtoken::decode_header(id_token).unwrap(); - let kid = token_header.kid; - - // 检查是否找到了有效的kid - if kid.is_none() { - return Err(Box::new(ResErr::social("校验Token失败,未找到有效的kid"))); - } - let kid = kid.unwrap(); - - // 根据kid找到正确的公钥 - let key = public_keys - .get(&kid) - .ok_or_else(|| Box::new(ResErr::social("校验Token失败,未找到正确的公钥")))?; - - tracing::info!("public key : {:?}", key); - - // 验证Token - let mut validation: Validation = Validation::new(Algorithm::RS256); - validation.set_issuer(&["https://accounts.google.com", "accounts.google.com"]); // 设置预期的发行者 - validation.validate_aud = false; - - let decoded = decode::( - id_token, - &DecodingKey::from_rsa_components(key.n.as_str(), key.e.as_str()).unwrap(), - &validation, - )?; - - let claims: Value = decoded.claims; - let google_jwt_profile = GoogleJwtProfile::from(claims); - - // 校验有效期 - if google_jwt_profile.exp < Utc::now().timestamp() { - return Err(Box::new(ResErr::social("校验Token失败,token有效期无效"))); - } - - Ok(google_jwt_profile) + // 验证ID Token并考虑kid匹配 + pub async fn verify_id_token(&self, id_token: &str) -> SocialResult { + // 获取并解析公钥 + let public_keys = self.fetch_and_parse_google_public_keys().await?; + + // 解码Token头部以获取kid + let token_header = jsonwebtoken::decode_header(id_token).unwrap(); + let kid = token_header.kid; + + // 检查是否找到了有效的kid + if kid.is_none() { + return Err(Box::new(ResErr::social("校验Token失败,未找到有效的kid"))); + } + let kid = kid.unwrap(); + + // 根据kid找到正确的公钥 + let key = public_keys + .get(&kid) + .ok_or_else(|| Box::new(ResErr::social("校验Token失败,未找到正确的公钥")))?; + + tracing::info!("public key : {:?}", key); + + // 验证Token + let mut validation: Validation = Validation::new(Algorithm::RS256); + validation.set_issuer(&["https://accounts.google.com", "accounts.google.com"]); // 设置预期的发行者 + validation.validate_aud = false; + + let decoded = decode::( + id_token, + &DecodingKey::from_rsa_components(key.n.as_str(), key.e.as_str()).unwrap(), + &validation, + )?; + + let claims: Value = decoded.claims; + let google_jwt_profile = GoogleJwtProfile::from(claims); + + // 校验有效期 + if google_jwt_profile.exp < Utc::now().timestamp() { + return Err(Box::new(ResErr::social("校验Token失败,token有效期无效"))); + } + + Ok(google_jwt_profile) + } } diff --git a/service/src/account.rs b/service/src/account.rs index 0dbde9a..f52f1ba 100644 --- a/service/src/account.rs +++ b/service/src/account.rs @@ -4,10 +4,10 @@ use domain::entities::account::Account; use library::{db, token}; use library::resp::response::ResErr::ErrPerm; use library::resp::response::{ResErr, ResOK, ResResult}; -use library::social::google; +use library::social::google::GOOGLE_SOCIAL; pub async fn authticate_google(req: AuthenticateGooleAccountReq) -> ResResult> { - let verify_result = google::verify_id_token(&req.id_token.unwrap()).await.map_err(|err| { + let verify_result = GOOGLE_SOCIAL.verify_id_token(&req.id_token.unwrap()).await.map_err(|err| { tracing::error!(error = ?err, "校验Google Token失败"); ErrPerm(None) })?;