diff --git a/Cargo.lock b/Cargo.lock index c76107e..8181e8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1143,6 +1143,7 @@ name = "leash" version = "0.1.0" dependencies = [ "anyhow", + "pretty_env_logger", "serde", "shared", "tokio", @@ -2183,7 +2184,6 @@ dependencies = [ "hyper", "inner", "log", - "pretty_env_logger", "prometheus", "redis", "serde", @@ -3006,6 +3006,7 @@ version = "0.1.0" dependencies = [ "anyhow", "ed25519-dalek", + "futures-util", "hex", "hyper", "lazy_static", diff --git a/exes/gateway/src/main.rs b/exes/gateway/src/main.rs index 7957b08..f2a4f93 100644 --- a/exes/gateway/src/main.rs +++ b/exes/gateway/src/main.rs @@ -6,18 +6,24 @@ use shared::{ nats_crate::Client, payloads::{CachePayload, DispatchEventTagged, Tracing}, }; +use tokio::sync::oneshot; use std::{convert::TryFrom, pin::Pin}; use twilight_gateway::{Event, Shard}; mod config; -use futures::{Future, StreamExt}; +use futures::{Future, StreamExt, select}; use twilight_model::gateway::event::DispatchEvent; +use futures::FutureExt; struct GatewayServer {} impl Component for GatewayServer { type Config = GatewayConfig; const SERVICE_NAME: &'static str = "gateway"; - fn start(&self, settings: Settings) -> AnyhowResultFuture<()> { + fn start( + &self, + settings: Settings, + stop: oneshot::Receiver<()>, + ) -> AnyhowResultFuture<()> { Box::pin(async move { let (shard, mut events) = Shard::builder(settings.token.to_owned(), settings.intents) .shard(settings.shard, settings.shard_total)? @@ -29,34 +35,48 @@ impl Component for GatewayServer { shard.start().await?; - while let Some(event) = events.next().await { - match event { - Event::Ready(ready) => { - info!("Logged in as {}", ready.user.name); - } + let mut stop = stop.fuse(); + loop { - _ => { - let name = event.kind().name(); - if let Ok(dispatch_event) = DispatchEvent::try_from(event) { - let data = CachePayload { - tracing: Tracing { - node_id: "".to_string(), - span: None, - }, - data: DispatchEventTagged { - data: dispatch_event, - }, - }; - let value = serde_json::to_string(&data)?; - debug!("nats send: {}", value); - let bytes = bytes::Bytes::from(value); - nats.publish(format!("nova.cache.dispatch.{}", name.unwrap()), bytes) - .await?; + select! { + event = events.next().fuse() => { + if let Some(event) = event { + match event { + Event::Ready(ready) => { + info!("Logged in as {}", ready.user.name); + } + + _ => { + let name = event.kind().name(); + if let Ok(dispatch_event) = DispatchEvent::try_from(event) { + let data = CachePayload { + tracing: Tracing { + node_id: "".to_string(), + span: None, + }, + data: DispatchEventTagged { + data: dispatch_event, + }, + }; + let value = serde_json::to_string(&data)?; + debug!("nats send: {}", value); + let bytes = bytes::Bytes::from(value); + nats.publish(format!("nova.cache.dispatch.{}", name.unwrap()), bytes) + .await?; + } + } + } + } else { + break } - } - } + }, + _ = stop => break + }; } + info!("stopping shard..."); + shard.shutdown(); + Ok(()) }) } diff --git a/exes/rest/src/config.rs b/exes/rest/src/config.rs index 9261de2..5c2698b 100644 --- a/exes/rest/src/config.rs +++ b/exes/rest/src/config.rs @@ -2,7 +2,7 @@ use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use serde::Deserialize; fn default_listening_address() -> SocketAddr { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)) + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080)) } #[derive(Debug, Deserialize, Clone)] diff --git a/exes/rest/src/handler.rs b/exes/rest/src/handler.rs index 8b0dd52..ea81ade 100644 --- a/exes/rest/src/handler.rs +++ b/exes/rest/src/handler.rs @@ -3,6 +3,7 @@ use std::{ convert::TryFrom, hash::{Hash, Hasher}, str::FromStr, + time::Instant, }; use anyhow::bail; @@ -38,7 +39,7 @@ fn normalize_path(request_path: &str) -> (&str, &str) { pub async fn handle_request( client: Client, Body>, ratelimiter: RemoteRatelimiter, - token: String, + token: &str, mut request: Request, ) -> Result, anyhow::Error> { let (hash, uri_string) = { @@ -57,7 +58,7 @@ pub async fn handle_request( let request_path = request.uri().path(); let (api_path, trimmed_path) = normalize_path(&request_path); - let mut uri_string = format!("http://192.168.0.27:8000{}{}", api_path, trimmed_path); + let mut uri_string = format!("https://discord.com{}{}", api_path, trimmed_path); if let Some(query) = request.uri().query() { uri_string.push('?'); uri_string.push_str(query); @@ -79,6 +80,7 @@ pub async fn handle_request( (hash.finish().to_string(), uri_string) }; + let start_ticket_request = Instant::now(); let header_sender = match ratelimiter.ticket(hash).await { Ok(sender) => sender, Err(e) => { @@ -86,6 +88,7 @@ pub async fn handle_request( bail!("failed to reteive ticket"); } }; + let time_took_ticket = Instant::now() - start_ticket_request; request.headers_mut().insert( AUTHORIZATION, @@ -106,9 +109,7 @@ pub async fn handle_request( request.headers_mut().remove(AUTHORIZATION); request.headers_mut().append( AUTHORIZATION, - HeaderValue::from_static( - "Bot ODA3MTg4MzM1NzE3Mzg0MjEy.G3sXFM.8gY2sVYDAq2WuPWwDskAAEFLfTg8htooxME-LE", - ), + HeaderValue::from_str(&format!("Bot {}", token))?, ); let uri = match Uri::from_str(&uri_string) { @@ -119,14 +120,26 @@ pub async fn handle_request( } }; *request.uri_mut() = uri; - let resp = match client.request(request).await { + + let start_upstream_req = Instant::now(); + let mut resp = match client.request(request).await { Ok(response) => response, Err(e) => { error!("Error when requesting the Discord API: {:?}", e); bail!("failed to request the discord api"); } }; + let upstream_time_took = Instant::now() - start_upstream_req; + resp.headers_mut().append( + "X-TicketRequest-Ms", + HeaderValue::from_str(&time_took_ticket.as_millis().to_string()).unwrap(), + ); + resp.headers_mut().append( + "X-Upstream-Ms", + HeaderValue::from_str(&upstream_time_took.as_millis().to_string()).unwrap(), + ); + let ratelimit_headers = resp .headers() .into_iter() diff --git a/exes/rest/src/main.rs b/exes/rest/src/main.rs index 8d014ab..07d835c 100644 --- a/exes/rest/src/main.rs +++ b/exes/rest/src/main.rs @@ -9,7 +9,8 @@ use hyper::{ use hyper_tls::HttpsConnector; use leash::{ignite, AnyhowResultFuture, Component}; use shared::config::Settings; -use std::convert::Infallible; +use std::{convert::Infallible, sync::Arc}; +use tokio::sync::oneshot; mod config; mod handler; @@ -20,21 +21,29 @@ impl Component for ReverseProxyServer { type Config = ReverseProxyConfig; const SERVICE_NAME: &'static str = "rest"; - fn start(&self, settings: Settings) -> AnyhowResultFuture<()> { + fn start( + &self, + settings: Settings, + stop: oneshot::Receiver<()>, + ) -> AnyhowResultFuture<()> { Box::pin(async move { // Client to the remote ratelimiters let ratelimiter = ratelimit_client::RemoteRatelimiter::new(); let client = Client::builder().build(HttpsConnector::new()); + let token = Arc::new(settings.discord.token.clone()); let service_fn = make_service_fn(move |_: &AddrStream| { let client = client.clone(); let ratelimiter = ratelimiter.clone(); + let token = token.clone(); async move { Ok::<_, Infallible>(service_fn(move |request: Request| { let client = client.clone(); let ratelimiter = ratelimiter.clone(); + let token = token.clone(); async move { - handle_request(client, ratelimiter, "token".to_string(), request).await + let token = token.as_str(); + handle_request(client, ratelimiter, token, request).await } })) } @@ -42,7 +51,11 @@ impl Component for ReverseProxyServer { let server = Server::bind(&settings.config.server.listening_adress).serve(service_fn); - server.await?; + server + .with_graceful_shutdown(async { + stop.await.expect("should not fail"); + }) + .await?; Ok(()) }) diff --git a/exes/rest/src/ratelimit_client/mod.rs b/exes/rest/src/ratelimit_client/mod.rs index 8263d15..87737dd 100644 --- a/exes/rest/src/ratelimit_client/mod.rs +++ b/exes/rest/src/ratelimit_client/mod.rs @@ -64,7 +64,7 @@ impl RemoteRatelimiter { obj_clone.get_ratelimiters().await.unwrap(); tokio::select! { () = &mut sleep => { - println!("timer elapsed"); + debug!("timer elapsed"); }, _ = tx.recv() => {} } diff --git a/exes/webhook/Cargo.toml b/exes/webhook/Cargo.toml index 12a6608..589b5bd 100644 --- a/exes/webhook/Cargo.toml +++ b/exes/webhook/Cargo.toml @@ -17,6 +17,7 @@ lazy_static = "1.4.0" ed25519-dalek = "1" twilight-model = { version = "0.14" } anyhow = "1.0.68" +futures-util = "0.3.25" [[bin]] name = "webhook" diff --git a/exes/webhook/src/config.rs b/exes/webhook/src/config.rs index 68f6a5f..e98de13 100644 --- a/exes/webhook/src/config.rs +++ b/exes/webhook/src/config.rs @@ -4,7 +4,7 @@ use ed25519_dalek::PublicKey; use serde::{Deserialize, Deserializer}; fn default_listening_address() -> SocketAddr { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)) + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080)) } #[derive(Debug, Deserialize, Clone, Copy)] diff --git a/exes/webhook/src/main.rs b/exes/webhook/src/main.rs index efd4147..0215e51 100644 --- a/exes/webhook/src/main.rs +++ b/exes/webhook/src/main.rs @@ -9,6 +9,7 @@ use crate::{ use hyper::Server; use leash::{ignite, AnyhowResultFuture, Component}; use shared::{config::Settings, log::info, nats_crate::Client}; +use tokio::sync::oneshot; #[derive(Clone, Copy)] struct WebhookServer {} @@ -17,7 +18,11 @@ impl Component for WebhookServer { type Config = WebhookConfig; const SERVICE_NAME: &'static str = "webhook"; - fn start(&self, settings: Settings) -> AnyhowResultFuture<()> { + fn start( + &self, + settings: Settings, + stop: oneshot::Receiver<()>, + ) -> AnyhowResultFuture<()> { Box::pin(async move { info!("Starting server on {}", settings.server.listening_adress); @@ -33,7 +38,9 @@ impl Component for WebhookServer { let server = Server::bind(&bind).serve(make_service); - server.await?; + server.with_graceful_shutdown(async { + stop.await.expect("should not fail"); + }).await?; Ok(()) }) diff --git a/libs/leash/Cargo.toml b/libs/leash/Cargo.toml index 5cd54a5..32f385c 100644 --- a/libs/leash/Cargo.toml +++ b/libs/leash/Cargo.toml @@ -9,5 +9,5 @@ edition = "2021" shared = { path = "../shared" } anyhow = "1.0.68" tokio = { version = "1.23.0", features = ["full"] } - +pretty_env_logger = "0.4" serde = "1.0.152" \ No newline at end of file diff --git a/libs/leash/src/lib.rs b/libs/leash/src/lib.rs index 360db12..1de7687 100644 --- a/libs/leash/src/lib.rs +++ b/libs/leash/src/lib.rs @@ -1,19 +1,29 @@ use anyhow::Result; use serde::de::DeserializeOwned; -use shared::config::Settings; +use shared::{ + config::Settings, + log::{error, info}, +}; use std::{future::Future, pin::Pin}; +use tokio::{signal::{unix::SignalKind}, sync::oneshot}; pub type AnyhowResultFuture = Pin>>>; pub trait Component: Send + Sync + 'static + Sized { type Config: Default + Clone + DeserializeOwned; const SERVICE_NAME: &'static str; - fn start(&self, settings: Settings) -> AnyhowResultFuture<()>; + fn start( + &self, + settings: Settings, + stop: oneshot::Receiver<()>, + ) -> AnyhowResultFuture<()>; fn new() -> Self; fn _internal_start(self) -> AnyhowResultFuture<()> { Box::pin(async move { + pretty_env_logger::init(); let settings = Settings::::new(Self::SERVICE_NAME); + let (stop, stop_channel) = oneshot::channel(); // Start the grpc healthcheck tokio::spawn(async move {}); @@ -21,7 +31,21 @@ pub trait Component: Send + Sync + 'static + Sized { // Start the prometheus monitoring job tokio::spawn(async move {}); - self.start(settings?).await + tokio::spawn(async move { + match tokio::signal::unix::signal(SignalKind::terminate()).unwrap().recv().await { + Some(()) => { + info!("Stopping program."); + + stop.send(()).unwrap(); + } + None => { + error!("Unable to listen for shutdown signal"); + // we also shut down in case of error + } + } + }); + + self.start(settings?, stop_channel).await }) } } @@ -41,6 +65,7 @@ macro_rules! ignite { #[cfg(test)] mod test { use serde::Deserialize; + use tokio::sync::oneshot; use crate::Component; @@ -57,6 +82,7 @@ mod test { fn start( &self, _settings: shared::config::Settings, + _stop: oneshot::Receiver<()>, ) -> crate::AnyhowResultFuture<()> { Box::pin(async move { Ok(()) }) } @@ -65,6 +91,6 @@ mod test { Self {} } } - + ignite!(TestComponent); } diff --git a/libs/shared/Cargo.toml b/libs/shared/Cargo.toml index ab19ce8..ce08fbc 100644 --- a/libs/shared/Cargo.toml +++ b/libs/shared/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -pretty_env_logger = "0.4" log = { version = "0.4", features = ["std"] } serde = { version = "1.0.8", features = ["derive"] } serde_repr = "0.1"