add token from config and change the signal handler to SIGTERM

This commit is contained in:
MatthieuCoder 2023-01-02 19:53:53 +04:00
parent 867e7d7a0c
commit f152af136f
12 changed files with 128 additions and 48 deletions

3
Cargo.lock generated
View file

@ -1143,6 +1143,7 @@ name = "leash"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"pretty_env_logger",
"serde", "serde",
"shared", "shared",
"tokio", "tokio",
@ -2183,7 +2184,6 @@ dependencies = [
"hyper", "hyper",
"inner", "inner",
"log", "log",
"pretty_env_logger",
"prometheus", "prometheus",
"redis", "redis",
"serde", "serde",
@ -3006,6 +3006,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"ed25519-dalek", "ed25519-dalek",
"futures-util",
"hex", "hex",
"hyper", "hyper",
"lazy_static", "lazy_static",

View file

@ -6,18 +6,24 @@ use shared::{
nats_crate::Client, nats_crate::Client,
payloads::{CachePayload, DispatchEventTagged, Tracing}, payloads::{CachePayload, DispatchEventTagged, Tracing},
}; };
use tokio::sync::oneshot;
use std::{convert::TryFrom, pin::Pin}; use std::{convert::TryFrom, pin::Pin};
use twilight_gateway::{Event, Shard}; use twilight_gateway::{Event, Shard};
mod config; mod config;
use futures::{Future, StreamExt}; use futures::{Future, StreamExt, select};
use twilight_model::gateway::event::DispatchEvent; use twilight_model::gateway::event::DispatchEvent;
use futures::FutureExt;
struct GatewayServer {} struct GatewayServer {}
impl Component for GatewayServer { impl Component for GatewayServer {
type Config = GatewayConfig; type Config = GatewayConfig;
const SERVICE_NAME: &'static str = "gateway"; const SERVICE_NAME: &'static str = "gateway";
fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()> { fn start(
&self,
settings: Settings<Self::Config>,
stop: oneshot::Receiver<()>,
) -> AnyhowResultFuture<()> {
Box::pin(async move { Box::pin(async move {
let (shard, mut events) = Shard::builder(settings.token.to_owned(), settings.intents) let (shard, mut events) = Shard::builder(settings.token.to_owned(), settings.intents)
.shard(settings.shard, settings.shard_total)? .shard(settings.shard, settings.shard_total)?
@ -29,34 +35,48 @@ impl Component for GatewayServer {
shard.start().await?; shard.start().await?;
while let Some(event) = events.next().await { let mut stop = stop.fuse();
match event { loop {
Event::Ready(ready) => {
info!("Logged in as {}", ready.user.name);
}
_ => { select! {
let name = event.kind().name(); event = events.next().fuse() => {
if let Ok(dispatch_event) = DispatchEvent::try_from(event) { if let Some(event) = event {
let data = CachePayload { match event {
tracing: Tracing { Event::Ready(ready) => {
node_id: "".to_string(), info!("Logged in as {}", ready.user.name);
span: None, }
},
data: DispatchEventTagged { _ => {
data: dispatch_event, let name = event.kind().name();
}, if let Ok(dispatch_event) = DispatchEvent::try_from(event) {
}; let data = CachePayload {
let value = serde_json::to_string(&data)?; tracing: Tracing {
debug!("nats send: {}", value); node_id: "".to_string(),
let bytes = bytes::Bytes::from(value); span: None,
nats.publish(format!("nova.cache.dispatch.{}", name.unwrap()), bytes) },
.await?; data: DispatchEventTagged {
data: dispatch_event,
},
};
let value = serde_json::to_string(&data)?;
debug!("nats send: {}", value);
let bytes = bytes::Bytes::from(value);
nats.publish(format!("nova.cache.dispatch.{}", name.unwrap()), bytes)
.await?;
}
}
}
} else {
break
} }
} },
} _ = stop => break
};
} }
info!("stopping shard...");
shard.shutdown();
Ok(()) Ok(())
}) })
} }

View file

@ -2,7 +2,7 @@ use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use serde::Deserialize; use serde::Deserialize;
fn default_listening_address() -> SocketAddr { fn default_listening_address() -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)) SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080))
} }
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Deserialize, Clone)]

View file

@ -3,6 +3,7 @@ use std::{
convert::TryFrom, convert::TryFrom,
hash::{Hash, Hasher}, hash::{Hash, Hasher},
str::FromStr, str::FromStr,
time::Instant,
}; };
use anyhow::bail; use anyhow::bail;
@ -38,7 +39,7 @@ fn normalize_path(request_path: &str) -> (&str, &str) {
pub async fn handle_request( pub async fn handle_request(
client: Client<HttpsConnector<HttpConnector>, Body>, client: Client<HttpsConnector<HttpConnector>, Body>,
ratelimiter: RemoteRatelimiter, ratelimiter: RemoteRatelimiter,
token: String, token: &str,
mut request: Request<Body>, mut request: Request<Body>,
) -> Result<Response<Body>, anyhow::Error> { ) -> Result<Response<Body>, anyhow::Error> {
let (hash, uri_string) = { let (hash, uri_string) = {
@ -57,7 +58,7 @@ pub async fn handle_request(
let request_path = request.uri().path(); let request_path = request.uri().path();
let (api_path, trimmed_path) = normalize_path(&request_path); let (api_path, trimmed_path) = normalize_path(&request_path);
let mut uri_string = format!("http://192.168.0.27:8000{}{}", api_path, trimmed_path); let mut uri_string = format!("https://discord.com{}{}", api_path, trimmed_path);
if let Some(query) = request.uri().query() { if let Some(query) = request.uri().query() {
uri_string.push('?'); uri_string.push('?');
uri_string.push_str(query); uri_string.push_str(query);
@ -79,6 +80,7 @@ pub async fn handle_request(
(hash.finish().to_string(), uri_string) (hash.finish().to_string(), uri_string)
}; };
let start_ticket_request = Instant::now();
let header_sender = match ratelimiter.ticket(hash).await { let header_sender = match ratelimiter.ticket(hash).await {
Ok(sender) => sender, Ok(sender) => sender,
Err(e) => { Err(e) => {
@ -86,6 +88,7 @@ pub async fn handle_request(
bail!("failed to reteive ticket"); bail!("failed to reteive ticket");
} }
}; };
let time_took_ticket = Instant::now() - start_ticket_request;
request.headers_mut().insert( request.headers_mut().insert(
AUTHORIZATION, AUTHORIZATION,
@ -106,9 +109,7 @@ pub async fn handle_request(
request.headers_mut().remove(AUTHORIZATION); request.headers_mut().remove(AUTHORIZATION);
request.headers_mut().append( request.headers_mut().append(
AUTHORIZATION, AUTHORIZATION,
HeaderValue::from_static( HeaderValue::from_str(&format!("Bot {}", token))?,
"Bot ODA3MTg4MzM1NzE3Mzg0MjEy.G3sXFM.8gY2sVYDAq2WuPWwDskAAEFLfTg8htooxME-LE",
),
); );
let uri = match Uri::from_str(&uri_string) { let uri = match Uri::from_str(&uri_string) {
@ -119,14 +120,26 @@ pub async fn handle_request(
} }
}; };
*request.uri_mut() = uri; *request.uri_mut() = uri;
let resp = match client.request(request).await {
let start_upstream_req = Instant::now();
let mut resp = match client.request(request).await {
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
error!("Error when requesting the Discord API: {:?}", e); error!("Error when requesting the Discord API: {:?}", e);
bail!("failed to request the discord api"); bail!("failed to request the discord api");
} }
}; };
let upstream_time_took = Instant::now() - start_upstream_req;
resp.headers_mut().append(
"X-TicketRequest-Ms",
HeaderValue::from_str(&time_took_ticket.as_millis().to_string()).unwrap(),
);
resp.headers_mut().append(
"X-Upstream-Ms",
HeaderValue::from_str(&upstream_time_took.as_millis().to_string()).unwrap(),
);
let ratelimit_headers = resp let ratelimit_headers = resp
.headers() .headers()
.into_iter() .into_iter()

View file

@ -9,7 +9,8 @@ use hyper::{
use hyper_tls::HttpsConnector; use hyper_tls::HttpsConnector;
use leash::{ignite, AnyhowResultFuture, Component}; use leash::{ignite, AnyhowResultFuture, Component};
use shared::config::Settings; use shared::config::Settings;
use std::convert::Infallible; use std::{convert::Infallible, sync::Arc};
use tokio::sync::oneshot;
mod config; mod config;
mod handler; mod handler;
@ -20,21 +21,29 @@ impl Component for ReverseProxyServer {
type Config = ReverseProxyConfig; type Config = ReverseProxyConfig;
const SERVICE_NAME: &'static str = "rest"; const SERVICE_NAME: &'static str = "rest";
fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()> { fn start(
&self,
settings: Settings<Self::Config>,
stop: oneshot::Receiver<()>,
) -> AnyhowResultFuture<()> {
Box::pin(async move { Box::pin(async move {
// Client to the remote ratelimiters // Client to the remote ratelimiters
let ratelimiter = ratelimit_client::RemoteRatelimiter::new(); let ratelimiter = ratelimit_client::RemoteRatelimiter::new();
let client = Client::builder().build(HttpsConnector::new()); let client = Client::builder().build(HttpsConnector::new());
let token = Arc::new(settings.discord.token.clone());
let service_fn = make_service_fn(move |_: &AddrStream| { let service_fn = make_service_fn(move |_: &AddrStream| {
let client = client.clone(); let client = client.clone();
let ratelimiter = ratelimiter.clone(); let ratelimiter = ratelimiter.clone();
let token = token.clone();
async move { async move {
Ok::<_, Infallible>(service_fn(move |request: Request<Body>| { Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
let client = client.clone(); let client = client.clone();
let ratelimiter = ratelimiter.clone(); let ratelimiter = ratelimiter.clone();
let token = token.clone();
async move { async move {
handle_request(client, ratelimiter, "token".to_string(), request).await let token = token.as_str();
handle_request(client, ratelimiter, token, request).await
} }
})) }))
} }
@ -42,7 +51,11 @@ impl Component for ReverseProxyServer {
let server = Server::bind(&settings.config.server.listening_adress).serve(service_fn); let server = Server::bind(&settings.config.server.listening_adress).serve(service_fn);
server.await?; server
.with_graceful_shutdown(async {
stop.await.expect("should not fail");
})
.await?;
Ok(()) Ok(())
}) })

View file

@ -64,7 +64,7 @@ impl RemoteRatelimiter {
obj_clone.get_ratelimiters().await.unwrap(); obj_clone.get_ratelimiters().await.unwrap();
tokio::select! { tokio::select! {
() = &mut sleep => { () = &mut sleep => {
println!("timer elapsed"); debug!("timer elapsed");
}, },
_ = tx.recv() => {} _ = tx.recv() => {}
} }

View file

@ -17,6 +17,7 @@ lazy_static = "1.4.0"
ed25519-dalek = "1" ed25519-dalek = "1"
twilight-model = { version = "0.14" } twilight-model = { version = "0.14" }
anyhow = "1.0.68" anyhow = "1.0.68"
futures-util = "0.3.25"
[[bin]] [[bin]]
name = "webhook" name = "webhook"

View file

@ -4,7 +4,7 @@ use ed25519_dalek::PublicKey;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
fn default_listening_address() -> SocketAddr { fn default_listening_address() -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)) SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080))
} }
#[derive(Debug, Deserialize, Clone, Copy)] #[derive(Debug, Deserialize, Clone, Copy)]

View file

@ -9,6 +9,7 @@ use crate::{
use hyper::Server; use hyper::Server;
use leash::{ignite, AnyhowResultFuture, Component}; use leash::{ignite, AnyhowResultFuture, Component};
use shared::{config::Settings, log::info, nats_crate::Client}; use shared::{config::Settings, log::info, nats_crate::Client};
use tokio::sync::oneshot;
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct WebhookServer {} struct WebhookServer {}
@ -17,7 +18,11 @@ impl Component for WebhookServer {
type Config = WebhookConfig; type Config = WebhookConfig;
const SERVICE_NAME: &'static str = "webhook"; const SERVICE_NAME: &'static str = "webhook";
fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()> { fn start(
&self,
settings: Settings<Self::Config>,
stop: oneshot::Receiver<()>,
) -> AnyhowResultFuture<()> {
Box::pin(async move { Box::pin(async move {
info!("Starting server on {}", settings.server.listening_adress); info!("Starting server on {}", settings.server.listening_adress);
@ -33,7 +38,9 @@ impl Component for WebhookServer {
let server = Server::bind(&bind).serve(make_service); let server = Server::bind(&bind).serve(make_service);
server.await?; server.with_graceful_shutdown(async {
stop.await.expect("should not fail");
}).await?;
Ok(()) Ok(())
}) })

View file

@ -9,5 +9,5 @@ edition = "2021"
shared = { path = "../shared" } shared = { path = "../shared" }
anyhow = "1.0.68" anyhow = "1.0.68"
tokio = { version = "1.23.0", features = ["full"] } tokio = { version = "1.23.0", features = ["full"] }
pretty_env_logger = "0.4"
serde = "1.0.152" serde = "1.0.152"

View file

@ -1,19 +1,29 @@
use anyhow::Result; use anyhow::Result;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use shared::config::Settings; use shared::{
config::Settings,
log::{error, info},
};
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use tokio::{signal::{unix::SignalKind}, sync::oneshot};
pub type AnyhowResultFuture<T> = Pin<Box<dyn Future<Output = Result<T>>>>; pub type AnyhowResultFuture<T> = Pin<Box<dyn Future<Output = Result<T>>>>;
pub trait Component: Send + Sync + 'static + Sized { pub trait Component: Send + Sync + 'static + Sized {
type Config: Default + Clone + DeserializeOwned; type Config: Default + Clone + DeserializeOwned;
const SERVICE_NAME: &'static str; const SERVICE_NAME: &'static str;
fn start(&self, settings: Settings<Self::Config>) -> AnyhowResultFuture<()>; fn start(
&self,
settings: Settings<Self::Config>,
stop: oneshot::Receiver<()>,
) -> AnyhowResultFuture<()>;
fn new() -> Self; fn new() -> Self;
fn _internal_start(self) -> AnyhowResultFuture<()> { fn _internal_start(self) -> AnyhowResultFuture<()> {
Box::pin(async move { Box::pin(async move {
pretty_env_logger::init();
let settings = Settings::<Self::Config>::new(Self::SERVICE_NAME); let settings = Settings::<Self::Config>::new(Self::SERVICE_NAME);
let (stop, stop_channel) = oneshot::channel();
// Start the grpc healthcheck // Start the grpc healthcheck
tokio::spawn(async move {}); tokio::spawn(async move {});
@ -21,7 +31,21 @@ pub trait Component: Send + Sync + 'static + Sized {
// Start the prometheus monitoring job // Start the prometheus monitoring job
tokio::spawn(async move {}); tokio::spawn(async move {});
self.start(settings?).await tokio::spawn(async move {
match tokio::signal::unix::signal(SignalKind::terminate()).unwrap().recv().await {
Some(()) => {
info!("Stopping program.");
stop.send(()).unwrap();
}
None => {
error!("Unable to listen for shutdown signal");
// we also shut down in case of error
}
}
});
self.start(settings?, stop_channel).await
}) })
} }
} }
@ -41,6 +65,7 @@ macro_rules! ignite {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use serde::Deserialize; use serde::Deserialize;
use tokio::sync::oneshot;
use crate::Component; use crate::Component;
@ -57,6 +82,7 @@ mod test {
fn start( fn start(
&self, &self,
_settings: shared::config::Settings<Self::Config>, _settings: shared::config::Settings<Self::Config>,
_stop: oneshot::Receiver<()>,
) -> crate::AnyhowResultFuture<()> { ) -> crate::AnyhowResultFuture<()> {
Box::pin(async move { Ok(()) }) Box::pin(async move { Ok(()) })
} }
@ -65,6 +91,6 @@ mod test {
Self {} Self {}
} }
} }
ignite!(TestComponent); ignite!(TestComponent);
} }

View file

@ -4,7 +4,6 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
pretty_env_logger = "0.4"
log = { version = "0.4", features = ["std"] } log = { version = "0.4", features = ["std"] }
serde = { version = "1.0.8", features = ["derive"] } serde = { version = "1.0.8", features = ["derive"] }
serde_repr = "0.1" serde_repr = "0.1"