add tests

This commit is contained in:
MatthieuCoder 2023-01-14 19:06:14 +04:00
parent 3f34284614
commit db93991c41
8 changed files with 376 additions and 68 deletions

View file

@ -56,7 +56,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
async fn listen(sub: &mut Subscriber, cache: &mut Cache, features: Vec<String>) { async fn listen(sub: &mut Subscriber, cache: &mut Cache, features: Vec<String>) {
while let Some(data) = sub.next().await { while let Some(data) = sub.next().await {
let cp: CachePayload = serde_json::from_slice(&data.payload).unwrap(); let cp: CachePayload = serde_json::from_slice(&data.payload).unwrap();
let event = cp.data.data; let event = cp.data.0;
match event { match event {
// Channel events // Channel events
DispatchEvent::ChannelCreate(_) DispatchEvent::ChannelCreate(_)

View file

@ -93,9 +93,7 @@ async fn handle_event(event: Event, nats: &Client) -> anyhow::Result<()> {
debug!("handling event {}", name.unwrap()); debug!("handling event {}", name.unwrap());
let data = CachePayload { let data = CachePayload {
data: DispatchEventTagged { data: DispatchEventTagged(dispatch_event),
data: dispatch_event,
},
}; };
let value = serde_json::to_string(&data)?; let value = serde_json::to_string(&data)?;
let bytes = bytes::Bytes::from(value); let bytes = bytes::Bytes::from(value);

View file

@ -1,33 +1,46 @@
use tokio::sync::{ use tokio::sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender}, mpsc::{self, UnboundedReceiver, UnboundedSender},
oneshot::Sender,
Mutex, Mutex,
}; };
/// Queue of ratelimit requests for a bucket. /// Simple async fifo (first in fist out) queue based on unbounded channels
///
/// # Usage
/// ```
/// # use ratelimit::buckets::async_queue::AsyncQueue;
/// # tokio_test::block_on(async {
/// let queue = AsyncQueue::<i64>::default();
/// // Pushing into the queue is syncronous
/// queue.push(123);
///
/// // Popping from the queue is asyncronous
/// let value = queue.pop().await;
///
/// // Our value should be the same!
/// assert_eq!(value, Some(123));
/// # });
/// ```
#[derive(Debug)] #[derive(Debug)]
pub struct AsyncQueue { pub struct AsyncQueue<T: Send> {
/// Receiver for the ratelimit requests. rx: Mutex<UnboundedReceiver<T>>,
rx: Mutex<UnboundedReceiver<Sender<()>>>, tx: UnboundedSender<T>,
/// Sender for the ratelimit requests.
tx: UnboundedSender<Sender<()>>,
} }
impl AsyncQueue { impl<T: Send> AsyncQueue<T> {
/// Add a new ratelimit request to the queue. /// Add a new item to the queue
pub fn push(&self, tx: Sender<()>) { pub fn push(&self, tx: T) {
let _sent = self.tx.send(tx); let _sent = self.tx.send(tx);
} }
/// Receive the first incoming ratelimit request. /// Receive the first incoming ratelimit request.
pub async fn pop(&self) -> Option<Sender<()>> { pub async fn pop(&self) -> Option<T> {
let mut rx = self.rx.lock().await; let mut rx = self.rx.lock().await;
rx.recv().await rx.recv().await
} }
} }
impl Default for AsyncQueue { impl<T: Send> Default for AsyncQueue<T> {
fn default() -> Self { fn default() -> Self {
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
@ -37,3 +50,23 @@ impl Default for AsyncQueue {
} }
} }
} }
#[cfg(test)]
mod tests {
use crate::buckets::async_queue::AsyncQueue;
#[test_log::test(tokio::test)]
async fn should_queue_dequeue_fifo() {
let queue = AsyncQueue::<i64>::default();
// queue data
for i in 0..2_000_000 {
queue.push(i);
}
for i in 0..2_000_000 {
let result = queue.pop().await.unwrap();
assert_eq!(i, result);
}
}
}

View file

@ -1,19 +1,36 @@
use std::{ use std::{
hash::Hash,
ops::{Add, AddAssign, Sub},
sync::atomic::{AtomicU64, Ordering}, sync::atomic::{AtomicU64, Ordering},
time::{Duration, SystemTime, UNIX_EPOCH}, time::{Duration, SystemTime, UNIX_EPOCH},
}; };
use tracing::debug; /// Instant implementation based on an atomic number
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
///
/// let now = AtomicInstant::now();
/// let max_seconds = u64::MAX / 1_000_000_000;
/// let duration = Duration::new(max_seconds, 0);
/// println!("{:?}", now + duration);
/// ```
#[derive(Default, Debug)] #[derive(Default, Debug)]
#[cfg(not(target_feature = "atomic128"))]
pub struct AtomicInstant(AtomicU64); pub struct AtomicInstant(AtomicU64);
impl AtomicInstant { impl AtomicInstant {
#[must_use] /// Calculates the duration since the instant.
pub const fn empty() -> Self { /// # Example
Self(AtomicU64::new(0)) /// ```
} /// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::now();
/// std::thread::sleep(Duration::from_secs(1));
///
/// assert_eq!(instant.elapsed().as_secs(), 1);
/// ```
pub fn elapsed(&self) -> Duration { pub fn elapsed(&self) -> Duration {
// Truncation is expected // Truncation is expected
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]
@ -25,24 +42,228 @@ impl AtomicInstant {
- self.0.load(Ordering::Relaxed), - self.0.load(Ordering::Relaxed),
) )
} }
/// Gets the current time in millis
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// instant.set_millis(1000);
///
/// assert_eq!(instant.as_millis(), 1000);
/// ```
pub fn as_millis(&self) -> u64 { pub fn as_millis(&self) -> u64 {
self.0.load(Ordering::Relaxed) self.0.load(Ordering::Relaxed)
} }
/// Creates an instant at the current time
/// # Safety
/// Truncates if the current unix time is greater than `u64::MAX`
#[allow(clippy::cast_possible_truncation)]
#[must_use]
pub fn now() -> Self {
Self(AtomicU64::new(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_millis() as u64,
))
}
/// Sets the unix time of the instant
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// instant.set_millis(1000);
///
/// assert_eq!(instant.as_millis(), 1000);
/// ```
pub fn set_millis(&self, millis: u64) { pub fn set_millis(&self, millis: u64) {
// get address of struct
let b = self as *const _ as usize;
debug!(millis, this = ?b, "settings instant millis");
self.0.store(millis, Ordering::Relaxed); self.0.store(millis, Ordering::Relaxed);
} }
/// Determines if the current instant is at the default value
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
///
/// assert!(instant.is_empty());
/// ```
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
let millis = self.as_millis(); self.as_millis() == 0
// get address of struct }
let b = self as *const _ as usize; }
debug!(millis, this = ?b, "settings instant millis");
debug!(empty = (millis == 0), millis, this = ?b, "instant empty check"); impl Add<Duration> for AtomicInstant {
millis == 0 type Output = Self;
/// # Safety
/// This panics if the right hand side is greater than `i64::MAX`
/// You can remedy to this using the 128bits feature with changes the
/// underlying atomic.
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
///
/// // we add one second to our instant
/// instant = instant + Duration::from_secs(1);
///
/// // should be equal to a second
/// assert_eq!(instant.as_millis(), 1000);
/// ```
fn add(self, rhs: Duration) -> Self::Output {
self.0
.fetch_add(rhs.as_millis().try_into().unwrap(), Ordering::Relaxed);
self
}
}
impl AddAssign<Duration> for AtomicInstant {
/// # Safety
/// This panics if the right hand side is greater than `i64::MAX`
/// You can remedy to this using the 128bits feature with changes the
/// underlying atomic.
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
///
/// // we add one second to our instant
/// instant += Duration::from_secs(1);
///
/// // should be equal to a second
/// assert_eq!(instant.as_millis(), 1000);
/// ```
fn add_assign(&mut self, rhs: Duration) {
self.0
.fetch_add(rhs.as_millis().try_into().unwrap(), Ordering::Relaxed);
}
}
impl Hash for AtomicInstant {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.load(Ordering::Relaxed).hash(state);
}
}
impl PartialEq for AtomicInstant {
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// let mut instant2 = AtomicInstant::default();
///
/// assert_eq!(instant, instant2);
/// ```
fn eq(&self, other: &Self) -> bool {
self.0.load(Ordering::Relaxed) == other.0.load(Ordering::Relaxed)
}
}
impl Eq for AtomicInstant {}
impl PartialOrd for AtomicInstant {
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// let mut instant2 = AtomicInstant::default();
///
/// assert!(instant == instant2);
/// instant.set_millis(1000);
/// assert!(instant > instant2);
/// instant.set_millis(0);
/// instant2.set_millis(1000);
/// assert!(instant < instant2);
/// ```
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0
.load(Ordering::Relaxed)
.partial_cmp(&other.0.load(Ordering::Relaxed))
}
}
impl Ord for AtomicInstant {
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// let mut instant2 = AtomicInstant::default();
///
/// assert!(instant == instant2);
/// instant.set_millis(1000);
/// assert!(instant > instant2);
/// instant.set_millis(0);
/// instant2.set_millis(1000);
/// assert!(instant < instant2);
/// ```
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.load(Ordering::Relaxed)
.cmp(&other.0.load(Ordering::Relaxed))
}
}
impl Sub<Duration> for AtomicInstant {
type Output = Self;
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// instant.set_millis(1000);
///
/// instant = instant - Duration::from_secs(1);
///
/// assert!(instant.is_empty());
/// ```
fn sub(self, rhs: Duration) -> Self::Output {
self.0
.fetch_sub(rhs.as_millis().try_into().unwrap(), Ordering::Relaxed);
self
}
}
impl Sub<Self> for AtomicInstant {
type Output = Self;
/// # Example
/// ```
/// # use ratelimit::buckets::atomic_instant::AtomicInstant;
/// # use std::time::Duration;
/// let mut instant = AtomicInstant::default();
/// let mut instant2 = AtomicInstant::default();
/// instant.set_millis(1000);
/// instant2.set_millis(2000);
///
/// instant = instant2 - instant;
///
/// assert_eq!(instant.as_millis(), 1000);
/// ```
fn sub(self, rhs: Self) -> Self::Output {
self.0
.fetch_sub(rhs.0.load(Ordering::Relaxed), Ordering::Relaxed);
self
}
}
#[cfg(test)]
mod tests {
use super::AtomicInstant;
#[test]
fn should_detect_default() {
let instant = AtomicInstant::default();
assert!(instant.is_empty());
instant.set_millis(1000);
assert!(!instant.is_empty());
} }
} }

View file

@ -6,7 +6,10 @@ use std::{
}, },
time::Duration, time::Duration,
}; };
use tokio::{sync::oneshot, task::JoinHandle}; use tokio::{
sync::oneshot::{self, Sender},
task::JoinHandle,
};
use tracing::{debug, trace}; use tracing::{debug, trace};
use twilight_http_ratelimiting::headers::Present; use twilight_http_ratelimiting::headers::Present;
@ -22,29 +25,36 @@ pub enum TimeRemaining {
/// ///
/// # Usage /// # Usage
/// ``` /// ```
/// use ratelimit::buckets::bucket::Bucket; /// # use ratelimit::buckets::bucket::Bucket;
/// use twilight_http_ratelimiting::RatelimitHeaders; /// # use twilight_http_ratelimiting::RatelimitHeaders;
/// use std::time::SystemTime; /// # use std::time::SystemTime;
/// # tokio_test::block_on(async {
/// ///
/// let bucket = Bucket::new(); /// let bucket = Bucket::new();
/// ///
/// // Feed the headers informations into the bucket to update it /// // Feed the headers informations into the bucket to update it
/// let headers = [ /// let headers = [
/// ( "x-ratelimit-bucket", "bucket id".as_bytes()), /// ( "x-ratelimit-bucket", "bucket id".as_bytes()),
/// ("x-ratelimit-limit", "100".as_bytes()), /// ("x-ratelimit-limit", "100".as_bytes()),
/// ("x-ratelimit-remaining", "0".as_bytes()), /// ("x-ratelimit-remaining", "0".as_bytes()),
/// ("x-ratelimit-reset", "".as_bytes()), /// ("x-ratelimit-reset", "99999999999999".as_bytes()),
/// ("x-ratelimit-reset-after", "10.000".as_bytes()), /// ("x-ratelimit-reset-after", "10.000".as_bytes()),
/// ]; /// ];
/// ///
/// // Parse the headers /// // Parse the headers
/// let present = if let Ok(RatelimitHeaders::Present(present)) = RatelimitHeaders::from_pairs(headers.into_iter()) { present } else { todo!() }; /// let present = if let Ok(RatelimitHeaders::Present(present))
/// = RatelimitHeaders::from_pairs(headers.into_iter()) {
/// present
/// } else { todo!() };
/// ///
/// // this should idealy the time of the request /// // this should idealy the time of the request
/// let current_time = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_millis() as u64; /// let current_time = SystemTime::now()
/// /// .duration_since(SystemTime::UNIX_EPOCH)
/// bucket.update(present, current_time).await; /// .unwrap()
/// .as_millis() as u64;
/// ///
/// bucket.update(&present, current_time);
/// # })
/// ``` /// ```
/// ///
/// # Async /// # Async
@ -63,7 +73,7 @@ pub struct Bucket {
/// List of tasks that dequeue tasks from [`Self::queue`] /// List of tasks that dequeue tasks from [`Self::queue`]
tasks: Vec<JoinHandle<()>>, tasks: Vec<JoinHandle<()>>,
/// Queue of tickets to be processed. /// Queue of tickets to be processed.
queue: AsyncQueue, queue: AsyncQueue<Sender<()>>,
} }
impl Drop for Bucket { impl Drop for Bucket {
@ -88,7 +98,7 @@ impl Bucket {
queue: AsyncQueue::default(), queue: AsyncQueue::default(),
remaining: AtomicU64::new(u64::max_value()), remaining: AtomicU64::new(u64::max_value()),
reset_after: AtomicU64::new(u64::max_value()), reset_after: AtomicU64::new(u64::max_value()),
last_update: AtomicInstant::empty(), last_update: AtomicInstant::default(),
tasks, tasks,
}); });
@ -292,7 +302,7 @@ mod tests {
("x-ratelimit-limit", b"100"), ("x-ratelimit-limit", b"100"),
("x-ratelimit-remaining", b"0"), ("x-ratelimit-remaining", b"0"),
("x-ratelimit-reset", mreset.as_bytes()), ("x-ratelimit-reset", mreset.as_bytes()),
("x-ratelimit-reset-after", b"100.000"), ("x-ratelimit-reset-after", b"10.000"),
]; ];
if let RatelimitHeaders::Present(present) = if let RatelimitHeaders::Present(present) =

View file

@ -82,11 +82,9 @@ impl WebhookService {
// this should hopefully not fail ? // this should hopefully not fail ?
let data = CachePayload { let data = CachePayload {
data: DispatchEventTagged { data: DispatchEventTagged(DispatchEvent::InteractionCreate(Box::new(
data: DispatchEvent::InteractionCreate(Box::new(InteractionCreate( InteractionCreate(interaction),
interaction, ))),
))),
},
}; };
let payload = serde_json::to_string(&data).unwrap(); let payload = serde_json::to_string(&data).unwrap();

View file

@ -128,5 +128,7 @@ mod test {
} }
} }
ignite!(TestComponent); ignite!(TestComponent);
} }

View file

@ -1,4 +1,5 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use serde::de::DeserializeSeed; use serde::de::DeserializeSeed;
use serde::Deserializer; use serde::Deserializer;
@ -7,10 +8,20 @@ use serde_json::Value;
use tracing::trace_span; use tracing::trace_span;
use twilight_model::gateway::event::{DispatchEvent, DispatchEventWithTypeDeserializer}; use twilight_model::gateway::event::{DispatchEvent, DispatchEventWithTypeDeserializer};
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq)]
#[repr(transparent)] #[repr(transparent)]
pub struct DispatchEventTagged { pub struct DispatchEventTagged(pub DispatchEvent);
pub data: DispatchEvent,
impl Deref for DispatchEventTagged {
type Target = DispatchEvent;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DispatchEventTagged {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -31,9 +42,7 @@ impl<'de> Deserialize<'de> for DispatchEventTagged {
let tagged = DispatchEventTaggedSerialized::deserialize(deserializer)?; let tagged = DispatchEventTaggedSerialized::deserialize(deserializer)?;
let deserializer_seed = DispatchEventWithTypeDeserializer::new(&tagged.kind); let deserializer_seed = DispatchEventWithTypeDeserializer::new(&tagged.kind);
let dispatch_event = deserializer_seed.deserialize(tagged.data).unwrap(); let dispatch_event = deserializer_seed.deserialize(tagged.data).unwrap();
Ok(Self { Ok(Self(dispatch_event))
data: dispatch_event,
})
} }
} }
@ -43,9 +52,9 @@ impl Serialize for DispatchEventTagged {
S: serde::Serializer, S: serde::Serializer,
{ {
let _s = trace_span!("serializing DispatchEventTagged"); let _s = trace_span!("serializing DispatchEventTagged");
let kind = self.data.kind().name().unwrap(); let kind = self.0.kind().name().unwrap();
DispatchEventTaggedSerialized { DispatchEventTaggedSerialized {
data: serde_json::to_value(&self.data).unwrap(), data: serde_json::to_value(&self.0).unwrap(),
kind: kind.to_string(), kind: kind.to_string(),
} }
.serialize(serializer) .serialize(serializer)
@ -58,3 +67,40 @@ pub struct CachePayload {
#[serde(flatten)] #[serde(flatten)]
pub data: DispatchEventTagged, pub data: DispatchEventTagged,
} }
#[cfg(test)]
mod tests {
use serde_json::json;
use twilight_model::gateway::event::DispatchEvent;
use super::DispatchEventTagged;
#[test]
fn serialize_event_tagged() {
let dispatch_event = DispatchEvent::GiftCodeUpdate;
let value = serde_json::to_value(&dispatch_event);
assert!(value.is_ok());
let value = value.unwrap();
let kind = value.get("t").and_then(serde_json::Value::as_str);
assert_eq!(kind, Some("GIFT_CODE_UPDATE"));
}
#[test]
fn deserialize_event_tagged() {
let json = json!({
"t": "GIFT_CODE_UPDATE",
"d": {}
});
let dispatch_event = serde_json::from_value::<DispatchEventTagged>(json);
assert!(dispatch_event.is_ok());
let dispatch_event_tagged = dispatch_event.unwrap();
assert_eq!(
DispatchEventTagged(DispatchEvent::GiftCodeUpdate),
dispatch_event_tagged
);
}
}