From fd0047719919675eeb1e6279117d31d3db572ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=BF=90=E5=AE=B6?= Date: Sun, 9 Mar 2025 13:51:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0websocket?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 2 +- Cargo.lock | 135 +++++++++++++++++- Cargo.toml | 3 +- macro/src/lib.rs | 9 ++ macro/src/ws_route.rs | 76 ++++++++++ server/Cargo.toml | 7 +- server/src/controller/mod.rs | 1 + server/src/controller/websocket_controller.rs | 129 +++++++++++++++++ server/src/lib.rs | 6 +- 9 files changed, 356 insertions(+), 12 deletions(-) create mode 100644 macro/src/ws_route.rs create mode 100644 server/src/controller/websocket_controller.rs diff --git a/.vscode/launch.json b/.vscode/launch.json index 93ffb39..b76eff2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,7 +8,7 @@ "program": "${workspaceRoot}/target/debug/${workspaceFolderBasename}", "args": [], "cwd": "${workspaceRoot}", - "preLaunchTask":"rust: cargo build", + "preLaunchTask":"rust: cargo build" } ] } \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index b3d85c2..df3b174 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,6 +98,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ "axum-core", + "base64 0.22.1", "bytes", "form_urlencoded", "futures-util", @@ -117,8 +118,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -155,6 +158,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", + "headers", "http", "http-body", "http-body-util", @@ -181,6 +185,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -474,6 +484,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" + [[package]] name = "deadpool" version = "0.12.2" @@ -700,6 +716,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -773,6 +804,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -877,6 +909,30 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.5.0" @@ -962,6 +1018,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + [[package]] name = "httparse" version = "1.10.1" @@ -1274,7 +1336,7 @@ version = "9.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" dependencies = [ - "base64", + "base64 0.22.1", "js-sys", "pem", "ring", @@ -1470,6 +1532,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "miniz_oxide" version = "0.8.5" @@ -1630,7 +1702,7 @@ version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "getrandom 0.2.15", "http", @@ -1753,7 +1825,7 @@ version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3" dependencies = [ - "base64", + "base64 0.22.1", "serde", ] @@ -2105,7 +2177,7 @@ version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "encoding_rs", "futures-channel", @@ -2394,6 +2466,7 @@ dependencies = [ "chrono", "domain", "error-stack", + "futures", "futures-executor", "i18n", "lazy_static", @@ -2631,7 +2704,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bitflags", "byteorder", "bytes", @@ -2675,7 +2748,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bitflags", "byteorder", "chrono", @@ -3033,6 +3106,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.13" @@ -3104,9 +3189,18 @@ checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" dependencies = [ "bitflags", "bytes", + "futures-util", "http", "http-body", + "http-body-util", + "http-range-header", + "httpdate", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", + "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -3217,6 +3311,23 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.0", + "sha1", + "thiserror 2.0.12", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -3233,6 +3344,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -3278,6 +3395,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/Cargo.toml b/Cargo.toml index a24fa4a..d900ff1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,4 +62,5 @@ redis = "0.29.1" deadpool-redis = "0.20.0" chrono-tz = "0.10.0" inventory = "0.3.17" -oauth2 = "5.0.0" \ No newline at end of file +oauth2 = "5.0.0" +futures = "0.3" \ No newline at end of file diff --git a/macro/src/lib.rs b/macro/src/lib.rs index 071acea..4715f50 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -10,6 +10,7 @@ use syn::{parse_macro_input, DeriveInput}; // 用于解析宏输入 mod responsable; mod route; mod task; +mod ws_route; /// `Responsable`的过程宏,将结构体实现`IntoResponse` trait // #[proc_macro_derive(Responsable, attributes(status, headers))] @@ -103,3 +104,11 @@ pub fn trace(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream { task::gen_task(attr, item) } + +/// WebSocket路由 +/// +/// 参数为WebSocket的url路径 +#[proc_macro_attribute] +pub fn ws(attr: TokenStream, item: TokenStream) -> TokenStream { + ws_route::gen_ws_route(attr, item) +} diff --git a/macro/src/ws_route.rs b/macro/src/ws_route.rs new file mode 100644 index 0000000..1644975 --- /dev/null +++ b/macro/src/ws_route.rs @@ -0,0 +1,76 @@ +extern crate proc_macro2; +extern crate quote; +extern crate syn; +extern crate proc_macro; + +use parse::{Parse, ParseStream}; +use proc_macro::{Span, TokenStream}; +use punctuated::Punctuated; +use quote::quote; +use spanned::Spanned; +use syn::*; +use syn::{parse_macro_input, ItemFn}; + +struct WsRouteArgs { + path: String, +} + +impl Parse for WsRouteArgs { + fn parse(input: ParseStream) -> Result { + let args = Punctuated::::parse_terminated(input)?; + let mut path = None; + + for expr in args { + match expr { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => { + let path_str = lit_str.value(); + validate_route_path(&path_str)?; + path = Some(path_str); + } + _ => {} + } + } + + let path = path.ok_or_else(|| Error::new( + Span::call_site().into(), + "WebSocket路由路径参数不能为空", + ))?; + + Ok(WsRouteArgs { path }) + } +} + +fn validate_route_path(path: &str) -> Result<()> { + if !path.starts_with('/') { + return Err(Error::new( + Span::call_site().into(), + "WebSocket路由路径必须以'/'开头", + )); + } + Ok(()) +} + +pub fn gen_ws_route(attr: TokenStream, item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as ItemFn); + let WsRouteArgs { path } = parse_macro_input!(attr as WsRouteArgs); + let ident = func.sig.ident.clone(); + + let generated = quote! { + #[allow(non_camel_case_types)] + struct #ident; + impl #ident { + #func + } + impl library::typed_router::RouteMethod for #ident { + fn ge_router(&self, router: axum::Router) -> axum::Router { + router.route(#path, axum::routing::any(#ident::#ident)) + } + } + ::library::submit_router_method!(#ident); + }; + + TokenStream::from(generated) +} diff --git a/server/Cargo.toml b/server/Cargo.toml index bdebb34..b4ed77f 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -6,14 +6,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -axum = { workspace = true } +axum = { workspace = true, features = ["default", "ws"] } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } -tower-http = { workspace = true, features = ["trace"] } +tower-http = { workspace = true, features = ["trace", "fs"] } validator = { workspace = true } -axum-extra = { workspace = true } +axum-extra = { workspace = true, features = ["default", "typed-header"] } chrono = { workspace = true } reqwest = { workspace = true } +futures = { workspace = true } futures-executor = { workspace = true } error-stack = { workspace = true } sqlx = { workspace = true, features = ["uuid"] } diff --git a/server/src/controller/mod.rs b/server/src/controller/mod.rs index c8ab81d..aff1f85 100644 --- a/server/src/controller/mod.rs +++ b/server/src/controller/mod.rs @@ -1,3 +1,4 @@ pub mod account_controller; pub mod feedback_controller; pub mod social_wx_controller; +pub mod websocket_controller; diff --git a/server/src/controller/websocket_controller.rs b/server/src/controller/websocket_controller.rs new file mode 100644 index 0000000..2e648e2 --- /dev/null +++ b/server/src/controller/websocket_controller.rs @@ -0,0 +1,129 @@ +use std::net::SocketAddr; +use std::ops::ControlFlow; + +use axum::body::Bytes; +use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade}; +use axum::extract::ConnectInfo; +use axum::response::IntoResponse; +use axum_extra::{headers, TypedHeader}; +use futures::StreamExt; +use macros::ws; + +#[ws("/ws")] +pub async fn websocket_handler( + ws: WebSocketUpgrade, + user_agent: Option>, + ConnectInfo(addr): ConnectInfo +) -> impl IntoResponse { + tracing::info!("`{:?}` at {:?} connected.", user_agent, addr); + ws.on_upgrade(move |socket| handle_socket(socket, addr)) +} + +/// Actual websocket statemachine (one will be spawned per connection) +async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { + // send a ping (unsupported by some browsers) just to kick things off and get a response + if socket + .send(Message::Ping(Bytes::from_static(&[1, 2, 3]))) + .await + .is_ok() + { + tracing::info!("Pinged {who}..."); + } else { + tracing::info!("Could not send ping {who}!"); + // no Error here since the only thing we can do is to close the connection. + // If we can not send messages, there is no way to salvage the statemachine anyway. + return; + } + + let (mut sender, mut receiver) = socket.split(); + + // Spawn a task that will push several messages to the client (does not matter what client does) + // let mut send_task = tokio::spawn(async move { + // let n_msg = 20; + // for i in 0..n_msg { + // // In case of any websocket error, we exit. + // if sender + // .send(Message::Text(format!("Server message {i} ...").into())) + // .await + // .is_err() + // { + // return i; + // } + + // tokio::time::sleep(std::time::Duration::from_millis(300)).await; + // } + + // tracing::info!("Sending close to {who}..."); + // if let Err(e) = sender + // .send(Message::Close(Some(CloseFrame { + // code: axum::extract::ws::close_code::NORMAL, + // reason: Utf8Bytes::from_static("Goodbye"), + // }))) + // .await + // { + // tracing::info!("Could not send Close due to {e}, probably it is ok?"); + // } + // n_msg + // }); + + // This second task will receive messages from client and print them on server console + let mut recv_task = tokio::spawn(async move { + let mut cnt = 0; + while let Some(Ok(msg)) = receiver.next().await { + cnt += 1; + // print message and break if instructed to do so + if process_message(msg, who).is_break() { + break; + } + } + cnt + }); + + // If any one of the tasks exit, abort the other. + loop { + tokio::select! { + rv_b = (&mut recv_task) => { + match rv_b { + Ok(b) => tracing::info!("Received {b} messages"), + Err(b) => tracing::info!("Error receiving messages {b:?}") + } + } + } + } + + // returning from the handler closes the websocket connection + // tracing::info!("Websocket context {who} destroyed"); +} + +fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { + match msg { + Message::Text(t) => { + tracing::info!(">>> {who} sent str: {t:?}"); + } + Message::Binary(d) => { + tracing::info!(">>> {} sent {} bytes: {:?}", who, d.len(), d); + } + Message::Close(c) => { + if let Some(cf) = c { + tracing::info!( + ">>> {} sent close with code {} and reason `{}`", + who, cf.code, cf.reason + ); + } else { + tracing::info!(">>> {who} somehow sent close message without CloseFrame"); + } + return ControlFlow::Break(()); + } + + Message::Pong(v) => { + tracing::info!(">>> {who} sent pong with {v:?}"); + } + // You should never need to manually handle Message::Ping, as axum's websocket library + // will do so for you automagically by replying with Pong and copying the v according to + // spec. But if you need the contents of the pings you can see them here. + Message::Ping(v) => { + tracing::info!(">>> {who} sent ping with {v:?}"); + } + } + ControlFlow::Continue(()) +} \ No newline at end of file diff --git a/server/src/lib.rs b/server/src/lib.rs index 3f00b04..078772a 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use axum::{body::Body, extract::Request, http, routing::get, Router}; use i18n::message; use i18n::message_ids::MessageId; @@ -24,7 +26,9 @@ pub async fn serve() { // 启动任务 task::start().await; // 启动应用服务 - axum::serve(listener, init()).await.unwrap(); + axum::serve(listener, init().into_make_service_with_connect_info::()) + .await + .unwrap(); } /// 初始化router,包括router中间件和数据