use std::net::SocketAddr; use std::ops::ControlFlow; use axum::body::Bytes; use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade}; use axum::extract::ConnectInfo; use axum::response::IntoResponse; use axum_extra::{headers, TypedHeader}; use futures::StreamExt; use macros::ws; #[ws("/ws")] pub async fn websocket_handler( ws: WebSocketUpgrade, user_agent: Option>, ConnectInfo(addr): ConnectInfo ) -> impl IntoResponse { tracing::info!("`{:?}` at {:?} connected.", user_agent, addr); ws.on_upgrade(move |socket| handle_socket(socket, addr)) } /// Actual websocket statemachine (one will be spawned per connection) async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { // send a ping (unsupported by some browsers) just to kick things off and get a response if socket .send(Message::Ping(Bytes::from_static(&[1, 2, 3]))) .await .is_ok() { tracing::info!("Pinged {who}..."); } else { tracing::info!("Could not send ping {who}!"); // no Error here since the only thing we can do is to close the connection. // If we can not send messages, there is no way to salvage the statemachine anyway. return; } let (mut sender, mut receiver) = socket.split(); // Spawn a task that will push several messages to the client (does not matter what client does) // let mut send_task = tokio::spawn(async move { // let n_msg = 20; // for i in 0..n_msg { // // In case of any websocket error, we exit. // if sender // .send(Message::Text(format!("Server message {i} ...").into())) // .await // .is_err() // { // return i; // } // tokio::time::sleep(std::time::Duration::from_millis(300)).await; // } // tracing::info!("Sending close to {who}..."); // if let Err(e) = sender // .send(Message::Close(Some(CloseFrame { // code: axum::extract::ws::close_code::NORMAL, // reason: Utf8Bytes::from_static("Goodbye"), // }))) // .await // { // tracing::info!("Could not send Close due to {e}, probably it is ok?"); // } // n_msg // }); // This second task will receive messages from client and print them on server console let mut recv_task = tokio::spawn(async move { let mut cnt = 0; while let Some(Ok(msg)) = receiver.next().await { cnt += 1; // print message and break if instructed to do so if process_message(msg, who).is_break() { break; } } cnt }); // If any one of the tasks exit, abort the other. loop { tokio::select! { rv_b = (&mut recv_task) => { match rv_b { Ok(b) => tracing::info!("Received {b} messages"), Err(b) => tracing::info!("Error receiving messages {b:?}") } } } } // returning from the handler closes the websocket connection // tracing::info!("Websocket context {who} destroyed"); } fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { tracing::info!(">>> {who} sent str: {t:?}"); } Message::Binary(d) => { tracing::info!(">>> {} sent {} bytes: {:?}", who, d.len(), d); } Message::Close(c) => { if let Some(cf) = c { tracing::info!( ">>> {} sent close with code {} and reason `{}`", who, cf.code, cf.reason ); } else { tracing::info!(">>> {who} somehow sent close message without CloseFrame"); } return ControlFlow::Break(()); } Message::Pong(v) => { tracing::info!(">>> {who} sent pong with {v:?}"); } // You should never need to manually handle Message::Ping, as axum's websocket library // will do so for you automagically by replying with Pong and copying the v according to // spec. But if you need the contents of the pings you can see them here. Message::Ping(v) => { tracing::info!(">>> {who} sent ping with {v:?}"); } } ControlFlow::Continue(()) }