添加websocket

This commit is contained in:
李运家 2025-03-09 13:51:57 +08:00
parent 82cf1b5965
commit fd00477199
9 changed files with 356 additions and 12 deletions

2
.vscode/launch.json vendored
View File

@ -8,7 +8,7 @@
"program": "${workspaceRoot}/target/debug/${workspaceFolderBasename}",
"args": [],
"cwd": "${workspaceRoot}",
"preLaunchTask":"rust: cargo build",
"preLaunchTask":"rust: cargo build"
}
]
}

135
Cargo.lock generated
View File

@ -98,6 +98,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
dependencies = [
"axum-core",
"base64 0.22.1",
"bytes",
"form_urlencoded",
"futures-util",
@ -117,8 +118,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@ -155,6 +158,7 @@ dependencies = [
"axum-core",
"bytes",
"futures-util",
"headers",
"http",
"http-body",
"http-body-util",
@ -181,6 +185,12 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
@ -474,6 +484,12 @@ dependencies = [
"syn",
]
[[package]]
name = "data-encoding"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010"
[[package]]
name = "deadpool"
version = "0.12.2"
@ -700,6 +716,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -773,6 +804,7 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
@ -877,6 +909,30 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "headers"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
dependencies = [
"base64 0.21.7",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
dependencies = [
"http",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -962,6 +1018,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c"
[[package]]
name = "httparse"
version = "1.10.1"
@ -1274,7 +1336,7 @@ version = "9.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
dependencies = [
"base64",
"base64 0.22.1",
"js-sys",
"pem",
"ring",
@ -1470,6 +1532,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "miniz_oxide"
version = "0.8.5"
@ -1630,7 +1702,7 @@ version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d"
dependencies = [
"base64",
"base64 0.22.1",
"chrono",
"getrandom 0.2.15",
"http",
@ -1753,7 +1825,7 @@ version = "3.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3"
dependencies = [
"base64",
"base64 0.22.1",
"serde",
]
@ -2105,7 +2177,7 @@ version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64",
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-channel",
@ -2394,6 +2466,7 @@ dependencies = [
"chrono",
"domain",
"error-stack",
"futures",
"futures-executor",
"i18n",
"lazy_static",
@ -2631,7 +2704,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233"
dependencies = [
"atoi",
"base64",
"base64 0.22.1",
"bitflags",
"byteorder",
"bytes",
@ -2675,7 +2748,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613"
dependencies = [
"atoi",
"base64",
"base64 0.22.1",
"bitflags",
"byteorder",
"chrono",
@ -3033,6 +3106,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.13"
@ -3104,9 +3189,18 @@ checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
dependencies = [
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"http-range-header",
"httpdate",
"mime",
"mime_guess",
"percent-encoding",
"pin-project-lite",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
@ -3217,6 +3311,23 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand 0.9.0",
"sha1",
"thiserror 2.0.12",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.18.0"
@ -3233,6 +3344,12 @@ dependencies = [
"web-time",
]
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.18"
@ -3278,6 +3395,12 @@ dependencies = [
"serde",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf16_iter"
version = "1.0.5"

View File

@ -62,4 +62,5 @@ redis = "0.29.1"
deadpool-redis = "0.20.0"
chrono-tz = "0.10.0"
inventory = "0.3.17"
oauth2 = "5.0.0"
oauth2 = "5.0.0"
futures = "0.3"

View File

@ -10,6 +10,7 @@ use syn::{parse_macro_input, DeriveInput}; // 用于解析宏输入
mod responsable;
mod route;
mod task;
mod ws_route;
/// `Responsable`的过程宏,将结构体实现`IntoResponse` trait
// #[proc_macro_derive(Responsable, attributes(status, headers))]
@ -103,3 +104,11 @@ pub fn trace(attr: TokenStream, item: TokenStream) -> TokenStream {
pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
task::gen_task(attr, item)
}
/// WebSocket路由
///
/// 参数为WebSocket的url路径
#[proc_macro_attribute]
pub fn ws(attr: TokenStream, item: TokenStream) -> TokenStream {
ws_route::gen_ws_route(attr, item)
}

76
macro/src/ws_route.rs Normal file
View File

@ -0,0 +1,76 @@
extern crate proc_macro2;
extern crate quote;
extern crate syn;
extern crate proc_macro;
use parse::{Parse, ParseStream};
use proc_macro::{Span, TokenStream};
use punctuated::Punctuated;
use quote::quote;
use spanned::Spanned;
use syn::*;
use syn::{parse_macro_input, ItemFn};
struct WsRouteArgs {
path: String,
}
impl Parse for WsRouteArgs {
fn parse(input: ParseStream) -> Result<Self> {
let args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
let mut path = None;
for expr in args {
match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(lit_str),
..
}) => {
let path_str = lit_str.value();
validate_route_path(&path_str)?;
path = Some(path_str);
}
_ => {}
}
}
let path = path.ok_or_else(|| Error::new(
Span::call_site().into(),
"WebSocket路由路径参数不能为空",
))?;
Ok(WsRouteArgs { path })
}
}
fn validate_route_path(path: &str) -> Result<()> {
if !path.starts_with('/') {
return Err(Error::new(
Span::call_site().into(),
"WebSocket路由路径必须以'/'开头",
));
}
Ok(())
}
pub fn gen_ws_route(attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
let WsRouteArgs { path } = parse_macro_input!(attr as WsRouteArgs);
let ident = func.sig.ident.clone();
let generated = quote! {
#[allow(non_camel_case_types)]
struct #ident;
impl #ident {
#func
}
impl library::typed_router::RouteMethod for #ident {
fn ge_router(&self, router: axum::Router) -> axum::Router {
router.route(#path, axum::routing::any(#ident::#ident))
}
}
::library::submit_router_method!(#ident);
};
TokenStream::from(generated)
}

View File

@ -6,14 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = { workspace = true }
axum = { workspace = true, features = ["default", "ws"] }
tokio = { workspace = true, features = ["full"] }
tracing = { workspace = true }
tower-http = { workspace = true, features = ["trace"] }
tower-http = { workspace = true, features = ["trace", "fs"] }
validator = { workspace = true }
axum-extra = { workspace = true }
axum-extra = { workspace = true, features = ["default", "typed-header"] }
chrono = { workspace = true }
reqwest = { workspace = true }
futures = { workspace = true }
futures-executor = { workspace = true }
error-stack = { workspace = true }
sqlx = { workspace = true, features = ["uuid"] }

View File

@ -1,3 +1,4 @@
pub mod account_controller;
pub mod feedback_controller;
pub mod social_wx_controller;
pub mod websocket_controller;

View File

@ -0,0 +1,129 @@
use std::net::SocketAddr;
use std::ops::ControlFlow;
use axum::body::Bytes;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::ConnectInfo;
use axum::response::IntoResponse;
use axum_extra::{headers, TypedHeader};
use futures::StreamExt;
use macros::ws;
#[ws("/ws")]
pub async fn websocket_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>
) -> impl IntoResponse {
tracing::info!("`{:?}` at {:?} connected.", user_agent, addr);
ws.on_upgrade(move |socket| handle_socket(socket, addr))
}
/// Actual websocket statemachine (one will be spawned per connection)
async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
// send a ping (unsupported by some browsers) just to kick things off and get a response
if socket
.send(Message::Ping(Bytes::from_static(&[1, 2, 3])))
.await
.is_ok()
{
tracing::info!("Pinged {who}...");
} else {
tracing::info!("Could not send ping {who}!");
// no Error here since the only thing we can do is to close the connection.
// If we can not send messages, there is no way to salvage the statemachine anyway.
return;
}
let (mut sender, mut receiver) = socket.split();
// Spawn a task that will push several messages to the client (does not matter what client does)
// let mut send_task = tokio::spawn(async move {
// let n_msg = 20;
// for i in 0..n_msg {
// // In case of any websocket error, we exit.
// if sender
// .send(Message::Text(format!("Server message {i} ...").into()))
// .await
// .is_err()
// {
// return i;
// }
// tokio::time::sleep(std::time::Duration::from_millis(300)).await;
// }
// tracing::info!("Sending close to {who}...");
// if let Err(e) = sender
// .send(Message::Close(Some(CloseFrame {
// code: axum::extract::ws::close_code::NORMAL,
// reason: Utf8Bytes::from_static("Goodbye"),
// })))
// .await
// {
// tracing::info!("Could not send Close due to {e}, probably it is ok?");
// }
// n_msg
// });
// This second task will receive messages from client and print them on server console
let mut recv_task = tokio::spawn(async move {
let mut cnt = 0;
while let Some(Ok(msg)) = receiver.next().await {
cnt += 1;
// print message and break if instructed to do so
if process_message(msg, who).is_break() {
break;
}
}
cnt
});
// If any one of the tasks exit, abort the other.
loop {
tokio::select! {
rv_b = (&mut recv_task) => {
match rv_b {
Ok(b) => tracing::info!("Received {b} messages"),
Err(b) => tracing::info!("Error receiving messages {b:?}")
}
}
}
}
// returning from the handler closes the websocket connection
// tracing::info!("Websocket context {who} destroyed");
}
fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
match msg {
Message::Text(t) => {
tracing::info!(">>> {who} sent str: {t:?}");
}
Message::Binary(d) => {
tracing::info!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
}
Message::Close(c) => {
if let Some(cf) = c {
tracing::info!(
">>> {} sent close with code {} and reason `{}`",
who, cf.code, cf.reason
);
} else {
tracing::info!(">>> {who} somehow sent close message without CloseFrame");
}
return ControlFlow::Break(());
}
Message::Pong(v) => {
tracing::info!(">>> {who} sent pong with {v:?}");
}
// You should never need to manually handle Message::Ping, as axum's websocket library
// will do so for you automagically by replying with Pong and copying the v according to
// spec. But if you need the contents of the pings you can see them here.
Message::Ping(v) => {
tracing::info!(">>> {who} sent ping with {v:?}");
}
}
ControlFlow::Continue(())
}

View File

@ -1,3 +1,5 @@
use std::net::SocketAddr;
use axum::{body::Body, extract::Request, http, routing::get, Router};
use i18n::message;
use i18n::message_ids::MessageId;
@ -24,7 +26,9 @@ pub async fn serve() {
// 启动任务
task::start().await;
// 启动应用服务
axum::serve(listener, init()).await.unwrap();
axum::serve(listener, init().into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap();
}
/// 初始化router包括router中间件和数据