diff --git a/Cargo.lock b/Cargo.lock index 8b42f1aa14b3a8558fcdce46e4ba1b3f3ca85523..4f8124f4696ac81c0893a0ff2181452de5b81007 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,6 +171,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sync_wrapper 1.0.2", + "tokio", "tower 0.5.2", "tower-layer", "tower-service", @@ -3628,6 +3629,7 @@ dependencies = [ "futures-util", "pin-project-lite", "sync_wrapper 1.0.2", + "tokio", "tower-layer", "tower-service", ] diff --git a/Cargo.toml b/Cargo.toml index a8c5918ca296eb6475af8cd891f20f773a955d44..8d13e6b6e2d0be669a521c1eb0d2355fdf37ef38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ axum = { version = "0.8", default-features = false, features = [ "http2", "json", "matched-path", + "tokio", ], optional = true } axum-extra = { version = "0.10", features = ["typed-header"] } axum-server = { version = "0.7", features = ["tls-rustls"] } diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index e922b1576f2166cd5a41e64e2564d86594a25ea1..721e4ef242bc6d47859fe8d90a623e27c4297a68 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -3,7 +3,13 @@ use std::time::Duration; -use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma}; +use crate::{ + service::{ + media::{size, FileMeta}, + rate_limiting::Target, + }, + services, utils, Error, Result, Ruma, +}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use ruma::{ api::{ @@ -54,6 +60,8 @@ pub async fn get_media_config_auth_route( pub async fn create_content_route( body: Ruma, ) -> Result { + let sender_user = body.sender_user.expect("user is authenticated"); + let create_content::v3::Request { filename, content_type, @@ -61,6 +69,13 @@ pub async fn create_content_route( .. } = body.body; + let target = Target::from_client_request(body.appservice_info, &sender_user); + + services() + .rate_limiting + .check_media_upload(target, size(&file)?) + .await?; + let media_id = utils::random_string(MXC_LENGTH); services() @@ -71,7 +86,7 @@ pub async fn create_content_route( filename.as_deref(), content_type.as_deref(), &file, - body.sender_user.as_deref(), + Some(&sender_user), ) .await?; @@ -84,7 +99,13 @@ pub async fn create_content_route( pub async fn get_remote_content( server_name: &ServerName, media_id: String, + target: Target, ) -> Result { + services() + .rate_limiting + .check_media_pre_fetch(&target) + .await?; + let content_response = match services() .sending .send_federation_request( @@ -153,6 +174,11 @@ pub async fn get_remote_content( ) .await?; + services() + .rate_limiting + .update_media_post_fetch(target, size(&content_response.file)?) + .await; + Ok(content_response) } @@ -171,11 +197,21 @@ pub async fn get_content_route( } = get_content( &body.server_name, body.media_id.clone(), - body.allow_remote, - false, + body.sender_ip_address.map(Target::Ip), ) .await?; + if let Some(target) = Target::from_client_request_optional_auth( + body.appservice_info, + &body.sender_user, + body.sender_ip_address, + ) { + services() + .rate_limiting + .update_media_post_fetch(target, size(&file)?) + .await; + } + Ok(media::get_content::v3::Response { file, content_type, @@ -190,14 +226,24 @@ pub async fn get_content_route( pub async fn get_content_auth_route( body: Ruma, ) -> Result { - get_content(&body.server_name, body.media_id.clone(), true, true).await + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + + let sender_user = sender_user.as_ref().expect("user is authenticated"); + + let target = Target::from_client_request(appservice_info, sender_user); + + get_content(&body.server_name, body.media_id.clone(), Some(target)).await } pub async fn get_content( server_name: &ServerName, media_id: String, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -207,7 +253,7 @@ pub async fn get_content( file, })) = services() .media - .get(server_name, &media_id, authenticated) + .get(server_name, &media_id, target.clone()) .await { Ok(get_content::v1::Response { @@ -215,16 +261,25 @@ pub async fn get_content( content_type, content_disposition: Some(content_disposition), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let remote_content_response = get_remote_content(server_name, media_id.clone()).await?; - - Ok(get_content::v1::Response { - content_disposition: remote_content_response.content_disposition, - content_type: remote_content_response.content_type, - file: remote_content_response.file, - }) } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + + if let Some(target) = target { + if server_name != services().globals.server_name() && target.is_authenticated() { + let remote_content_response = + get_remote_content(server_name, media_id.clone(), target).await?; + + Ok(get_content::v1::Response { + content_disposition: remote_content_response.content_disposition, + content_type: remote_content_response.content_type, + file: remote_content_response.file, + }) + } else { + error + } + } else { + error + } } } @@ -244,8 +299,7 @@ pub async fn get_content_as_filename_route( &body.server_name, body.media_id.clone(), body.filename.clone(), - body.allow_remote, - false, + body.sender_ip_address.map(Target::Ip), ) .await?; @@ -263,12 +317,22 @@ pub async fn get_content_as_filename_route( pub async fn get_content_as_filename_auth_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + + let sender_user = sender_user.as_ref().expect("user is authenticated"); + + let target = Target::from_client_request(appservice_info, sender_user); + get_content_as_filename( &body.server_name, body.media_id.clone(), body.filename.clone(), - true, - true, + Some(target), ) .await } @@ -277,8 +341,7 @@ async fn get_content_as_filename( server_name: &ServerName, media_id: String, filename: String, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -286,7 +349,7 @@ async fn get_content_as_filename( file, content_type, .. })) = services() .media - .get(server_name, &media_id, authenticated) + .get(server_name, &media_id, target.clone()) .await { Ok(get_content_as_filename::v1::Response { @@ -297,19 +360,28 @@ async fn get_content_as_filename( .with_filename(Some(filename.clone())), ), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let remote_content_response = get_remote_content(server_name, media_id.clone()).await?; - - Ok(get_content_as_filename::v1::Response { - content_disposition: Some( - ContentDisposition::new(ContentDispositionType::Inline) - .with_filename(Some(filename.clone())), - ), - content_type: remote_content_response.content_type, - file: remote_content_response.file, - }) } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + + if let Some(target) = target { + if server_name != services().globals.server_name() && target.is_authenticated() { + let remote_content_response = + get_remote_content(server_name, media_id.clone(), target).await?; + + Ok(get_content_as_filename::v1::Response { + content_disposition: Some( + ContentDisposition::new(ContentDispositionType::Inline) + .with_filename(Some(filename.clone())), + ), + content_type: remote_content_response.content_type, + file: remote_content_response.file, + }) + } else { + error + } + } else { + error + } } } @@ -321,6 +393,17 @@ async fn get_content_as_filename( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + sender_ip_address, + appservice_info, + .. + } = body; + + let target = + Target::from_client_request_optional_auth(appservice_info, &sender_user, sender_ip_address); + let get_content_thumbnail::v1::Response { file, content_type, @@ -332,8 +415,7 @@ pub async fn get_content_thumbnail_route( body.width, body.method.clone(), body.animated, - body.allow_remote, - false, + target, ) .await?; @@ -351,6 +433,15 @@ pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_auth_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + let sender_user = sender_user.as_ref().expect("user is authenticated"); + let target = Target::from_client_request(appservice_info, sender_user); + get_content_thumbnail( &body.server_name, body.media_id.clone(), @@ -358,8 +449,7 @@ pub async fn get_content_thumbnail_auth_route( body.width, body.method.clone(), body.animated, - true, - true, + Some(target), ) .await } @@ -372,8 +462,7 @@ async fn get_content_thumbnail( width: UInt, method: Option, animated: Option, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -392,7 +481,7 @@ async fn get_content_thumbnail( height .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, - authenticated, + target.clone(), ) .await? { @@ -401,99 +490,117 @@ async fn get_content_thumbnail( content_type, content_disposition: Some(content_disposition), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let thumbnail_response = match services() - .sending - .send_federation_request( - server_name, - federation_media::get_content_thumbnail::v1::Request { - height, - width, - method: method.clone(), - media_id: media_id.clone(), - timeout_ms: Duration::from_secs(20), - animated, - }, - ) - .await - { - Ok(federation_media::get_content_thumbnail::v1::Response { - metadata: _, - content: FileOrLocation::File(content), - }) => get_content_thumbnail::v1::Response { - file: content.file, - content_type: content.content_type, - content_disposition: content.content_disposition, - }, + } else { + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); - Ok(federation_media::get_content_thumbnail::v1::Response { - metadata: _, - content: FileOrLocation::Location(url), - }) => { - let get_content::v1::Response { - file, - content_type, - content_disposition, - } = get_location_content(url).await?; - - get_content_thumbnail::v1::Response { - file, - content_type, - content_disposition, - } - } - Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => { - let media::get_content_thumbnail::v3::Response { - file, - content_type, - content_disposition, - .. - } = services() + if let Some(target) = target { + if server_name != services().globals.server_name() { + services() + .rate_limiting + .check_media_pre_fetch(&target) + .await?; + + let thumbnail_response = match services() .sending .send_federation_request( server_name, - media::get_content_thumbnail::v3::Request { + federation_media::get_content_thumbnail::v1::Request { height, width, method: method.clone(), - server_name: server_name.to_owned(), media_id: media_id.clone(), timeout_ms: Duration::from_secs(20), - allow_redirect: false, animated, - allow_remote: false, }, ) + .await + { + Ok(federation_media::get_content_thumbnail::v1::Response { + metadata: _, + content: FileOrLocation::File(content), + }) => get_content_thumbnail::v1::Response { + file: content.file, + content_type: content.content_type, + content_disposition: content.content_disposition, + }, + + Ok(federation_media::get_content_thumbnail::v1::Response { + metadata: _, + content: FileOrLocation::Location(url), + }) => { + let get_content::v1::Response { + file, + content_type, + content_disposition, + } = get_location_content(url).await?; + + get_content_thumbnail::v1::Response { + file, + content_type, + content_disposition, + } + } + Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => { + let media::get_content_thumbnail::v3::Response { + file, + content_type, + content_disposition, + .. + } = services() + .sending + .send_federation_request( + server_name, + media::get_content_thumbnail::v3::Request { + height, + width, + method: method.clone(), + server_name: server_name.to_owned(), + media_id: media_id.clone(), + timeout_ms: Duration::from_secs(20), + allow_redirect: false, + animated, + allow_remote: false, + }, + ) + .await?; + + get_content_thumbnail::v1::Response { + file, + content_type, + content_disposition, + } + } + Err(e) => return Err(e), + }; + + services() + .rate_limiting + .update_media_post_fetch(target, size(&thumbnail_response.file)?) + .await; + + services() + .media + .upload_thumbnail( + server_name, + &media_id, + thumbnail_response + .content_disposition + .as_ref() + .and_then(|cd| cd.filename.as_deref()), + thumbnail_response.content_type.as_deref(), + width.try_into().expect("all UInts are valid u32s"), + height.try_into().expect("all UInts are valid u32s"), + &thumbnail_response.file, + ) .await?; - get_content_thumbnail::v1::Response { - file, - content_type, - content_disposition, - } + Ok(thumbnail_response) + } else { + error } - Err(e) => return Err(e), - }; - - services() - .media - .upload_thumbnail( - server_name, - &media_id, - thumbnail_response - .content_disposition - .as_ref() - .and_then(|cd| cd.filename.as_deref()), - thumbnail_response.content_type.as_deref(), - width.try_into().expect("all UInts are valid u32s"), - height.try_into().expect("all UInts are valid u32s"), - &thumbnail_response.file, - ) - .await?; - - Ok(thumbnail_response) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + } else { + error + } } } diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 044565437d7978ce8789b9289a182d672ad64a09..7e1722b15c846b19359bb63eeaa7b9e75a9f3086 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,8 +1,14 @@ -use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + error::Error as _, + iter::FromIterator, + net::{IpAddr, SocketAddr}, + str::{self, FromStr}, +}; use axum::{ body::Body, - extract::{FromRequest, Path}, + extract::{ConnectInfo, FromRequest, Path}, response::{IntoResponse, Response}, RequestPartsExt, }; @@ -24,11 +30,16 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{ + config::IpAddrDetection, + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, Error, Result, +}; enum Token { Appservice(Box), User((OwnedUserId, OwnedDeviceId)), + AuthRateLimited(Error), Invalid, None, } @@ -99,8 +110,33 @@ where None => query_params.access_token.as_deref(), }; + let sender_ip_address: Option = + match &services().globals.config.ip_address_detection { + IpAddrDetection::None => None, + IpAddrDetection::Socket => { + let addr: ConnectInfo = parts.extract().await?; + Some(addr.ip()) + } + IpAddrDetection::Header(name) => parts + .headers + .get(name) + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()), + }; + let token = if let Some(token) = token { - if let Some(reg_info) = services().appservice.find_from_token(token).await { + let mut rate_limited = None; + + if let Some(ip_addr) = sender_ip_address { + if let Err(instant) = services().rate_limiting.pre_auth_check(ip_addr).await { + rate_limited = Some(instant); + } + } + + if let Some(instant) = rate_limited { + Token::AuthRateLimited(instant) + } else if let Some(reg_info) = services().appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info.clone())) } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { Token::User((user_id, device_id)) @@ -115,6 +151,23 @@ where let (sender_user, sender_device, sender_servername, appservice_info) = match (metadata.authentication, token) { + ( + AuthScheme::AccessToken + | AuthScheme::AppserviceToken + | AuthScheme::AccessTokenOptional + | AuthScheme::AppserviceTokenOptional, + Token::AuthRateLimited(instant), + ) => { + services() + .rate_limiting + .update_post_auth_failure( + sender_ip_address + .expect("Token variant could only be set if sender ip was Some"), + ) + .await; + + return Err(instant); + } (_, Token::Invalid) => { // OpenID endpoint uses a query param with the same name, drop this once query params for user auth are removed from the spec if query_params.access_token.is_some() { @@ -177,7 +230,7 @@ where AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, Token::User((user_id, device_id)), ) => (Some(user_id), Some(device_id), None, None), - (AuthScheme::ServerSignatures, Token::None) => { + (AuthScheme::ServerSignatures, Token::None | Token::AuthRateLimited(_)) => { let TypedHeader(Authorization(x_matrix)) = parts .extract::>>() .await @@ -309,7 +362,8 @@ where | AuthScheme::AppserviceTokenOptional | AuthScheme::AccessTokenOptional, Token::None, - ) => (None, None, None, None), + ) + | (AuthScheme::None, Token::AuthRateLimited(_)) => (None, None, None, None), (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { return Err(Error::BadRequest( ErrorKind::Unauthorized, @@ -327,6 +381,16 @@ where } }; + let target = if let Some(server_name) = sender_servername.clone() { + Some(Target::Server(server_name)) + } else if let Some(user) = &sender_user { + Some(Target::from_client_request(appservice_info.clone(), user)) + } else { + sender_ip_address.map(Target::Ip) + }; + + services().rate_limiting.check(target, metadata).await?; + let mut http_request = Request::builder().uri(parts.uri).method(parts.method); *http_request.headers_mut().unwrap() = parts.headers; @@ -377,6 +441,7 @@ where sender_servername, appservice_info, json_body, + sender_ip_address, }) } } diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index 862da1dcff7f382112e36d91d9b43cf3baaa479f..a741676cde7617a22ca53237ac790ac15f02d046 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -3,7 +3,7 @@ use ruma::{ api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, }; -use std::ops::Deref; +use std::{net::IpAddr, ops::Deref}; #[cfg(feature = "conduit_bin")] mod axum; @@ -14,6 +14,7 @@ pub struct Ruma { pub sender_user: Option, pub sender_device: Option, pub sender_servername: Option, + pub sender_ip_address: Option, // This is None when body is not a valid string pub json_body: Option, pub appservice_info: Option, diff --git a/src/api/server_server.rs b/src/api/server_server.rs index adc764ff11a9db6c6807b3d3a6a1bcc282111300..10f4fb8c5c62d54c18d5d3857aa1826a03644af9 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -6,6 +6,7 @@ use crate::{ globals::SigningKeys, media::FileMeta, pdu::{gen_event_id_canonical_json, PduBuilder}, + rate_limiting::Target, }, services, utils, Error, PduEvent, Result, Ruma, SUPPORTED_VERSIONS, }; @@ -2237,6 +2238,13 @@ pub async fn create_invite_route( pub async fn get_content_route( body: Ruma, ) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let target = Some(Target::Server(sender_servername.to_owned())); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2247,7 +2255,11 @@ pub async fn get_content_route( file, }) = services() .media - .get(services().globals.server_name(), &body.media_id, true) + .get( + services().globals.server_name(), + &body.media_id, + target.clone(), + ) .await? { Ok(get_content::v1::Response::new( @@ -2269,6 +2281,13 @@ pub async fn get_content_route( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_servername, + .. + } = body; + let sender_servername = sender_servername.expect("server is authenticated"); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2288,7 +2307,7 @@ pub async fn get_content_thumbnail_route( body.height .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - true, + Some(Target::Server(sender_servername)), ) .await? else { diff --git a/src/config/mod.rs b/src/config/mod.rs index 098dc20ddb59b9475baa7c747bbe5545e3312e4d..36c5631e79deb49d30165cbf48b5ea8d9c1671d0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,7 +17,9 @@ use url::Url; use crate::Error; mod proxy; -use self::proxy::ProxyConfig; +pub mod rate_limiting; + +use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig}; const SHA256_HEX_LENGTH: u8 = 64; @@ -80,6 +82,8 @@ pub struct IncompleteConfig { pub trusted_servers: Vec, #[serde(default = "default_log")] pub log: String, + #[serde(default)] + pub ip_address_detection: IpAddrDetection, pub turn_username: Option, pub turn_password: Option, pub turn_uris: Option>, @@ -92,6 +96,8 @@ pub struct IncompleteConfig { #[serde(default)] pub media: IncompleteMediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, #[serde(flatten)] @@ -133,11 +139,14 @@ pub struct Config { pub jwt_secret: Option, pub trusted_servers: Vec, pub log: String, + pub ip_address_detection: IpAddrDetection, pub turn: Option, pub media: MediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, pub catchall: BTreeMap, @@ -177,6 +186,7 @@ impl From for Config { jwt_secret, trusted_servers, log, + ip_address_detection, turn_username, turn_password, turn_uris, @@ -184,6 +194,7 @@ impl From for Config { turn_ttl, turn, media, + rate_limiting, emergency_password, catchall, } = val; @@ -279,8 +290,10 @@ impl From for Config { jwt_secret, trusted_servers, log, + ip_address_detection, turn, media, + rate_limiting, emergency_password, catchall, } @@ -609,6 +622,19 @@ pub struct S3MediaBackend { pub directory_structure: DirectoryStructure, } +#[derive(Deserialize, Debug, Clone)] +pub enum IpAddrDetection { + None, + Header(String), + Socket, +} + +impl Default for IpAddrDetection { + fn default() -> Self { + Self::Header("X-Forwarded-For".to_owned()) + } +} + const DEPRECATED_KEYS: &[&str] = &[ "cache_capacity", "turn_username", diff --git a/src/config/rate_limiting.rs b/src/config/rate_limiting.rs new file mode 100644 index 0000000000000000000000000000000000000000..9ebb4f04494c5cc01dff06f7fb0b5909c3ea733d --- /dev/null +++ b/src/config/rate_limiting.rs @@ -0,0 +1,117 @@ +use std::{collections::HashMap, num::NonZeroU64}; + +use bytesize::ByteSize; +use serde::Deserialize; + +use crate::service::rate_limiting::{ClientRestriction, FederationRestriction, Restriction}; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + #[serde(flatten)] + pub target: ConfigFragment, + pub global: ConfigFragment, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigFragment { + pub client: ConfigClientFragment, + pub federation: ConfigFederationFragment, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigClientFragment { + pub map: HashMap, + pub media: ClientMediaConfig, + // TODO: Only have available on target, not global (same with most authenticated endpoints too maybe)? + pub authentication_failures: RequestLimitation, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigFederationFragment { + pub map: HashMap, + pub media: FederationMediaConfig, +} + +impl ConfigFragment { + pub fn get(&self, restriction: &Restriction) -> &RequestLimitation { + // Maybe look into https://github.com/moriyoshi-kasuga/enum-table + match restriction { + Restriction::Client(client_restriction) => { + self.client.map.get(client_restriction).unwrap() + } + Restriction::Federation(federation_restriction) => { + self.federation.map.get(federation_restriction).unwrap() + } + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct RequestLimitation { + #[serde(flatten)] + pub timeframe: Timeframe, + pub burst_capacity: NonZeroU64, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct ClientMediaConfig { + pub download: MediaLimitation, + pub upload: MediaLimitation, + pub fetch: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct FederationMediaConfig { + pub download: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct MediaLimitation { + #[serde(flatten)] + pub timeframe: MediaTimeframe, + pub burst_capacity: ByteSize, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum MediaTimeframe { + PerSecond(ByteSize), + PerMinute(ByteSize), + PerHour(ByteSize), + PerDay(ByteSize), +} + +impl MediaTimeframe { + pub fn bytes_per_sec(&self) -> u64 { + match self { + MediaTimeframe::PerSecond(t) => t.as_u64(), + MediaTimeframe::PerMinute(t) => t.as_u64() / 60, + MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60), + MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24), + } + } +} diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 695c7d3ca16282902b8a606486797d0ae1eadcbd..831d6c3052c586d2756f34168156880eb79befbf 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -203,19 +203,7 @@ impl service::media::Data for KeyValueDatabase { let is_blocked_via_filehash = self.is_blocked_filehash(&sha256_digest)?; - let time_info = if let Some(filehash_meta) = self - .filehash_metadata - .get(&sha256_digest)? - .map(FilehashMetadata::from_vec) - { - Some(FileInfo { - creation: filehash_meta.creation(&sha256_digest)?, - last_access: filehash_meta.last_access(&sha256_digest)?, - size: filehash_meta.size(&sha256_digest)?, - }) - } else { - None - }; + let file_info = self.file_info(&sha256_digest)?; Some(MediaQueryFileInfo { uploader_localpart, @@ -224,7 +212,7 @@ impl service::media::Data for KeyValueDatabase { content_type, unauthenticated_access_permitted, is_blocked_via_filehash, - file_info: time_info, + file_info, }) } else { None @@ -1353,6 +1341,24 @@ impl service::media::Data for KeyValueDatabase { Ok(()) } } + + fn file_info(&self, sha256_digest: &[u8]) -> Result, Error> { + Ok( + if let Some(filehash_meta) = self + .filehash_metadata + .get(sha256_digest)? + .map(FilehashMetadata::from_vec) + { + Some(FileInfo { + creation: filehash_meta.creation(sha256_digest)?, + last_access: filehash_meta.last_access(sha256_digest)?, + size: filehash_meta.size(sha256_digest)?, + }) + } else { + None + }, + ) + } } impl KeyValueDatabase { diff --git a/src/main.rs b/src/main.rs index 809852dd937045d27f3bd7347a2d8ab1d6806fde..543464e91282d6143ee46ddb15e5202b2dd79762 100644 --- a/src/main.rs +++ b/src/main.rs @@ -242,7 +242,9 @@ async fn run_server() -> io::Result<()> { ) .layer(map_response(set_csp_header)); - let app = routes(config).layer(middlewares).into_make_service(); + let app = routes(config) + .layer(middlewares) + .into_make_service_with_connect_info::(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 0a3f87b95e9349d05cbdb65697963cc20d32109e..a7ef3d8361f125e8f06ad8b46f0e0eb9bd6133d9 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -41,6 +41,7 @@ use tokio::sync::{mpsc, Mutex, RwLock}; use crate::{ api::client_server::{self, leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, + service::rate_limiting::Target, services, utils::{self, HtmlEscape}, Error, PduEvent, Result, @@ -1174,8 +1175,12 @@ impl Service { file, content_type, content_disposition, - } = client_server::media::get_content(server_name, media_id.to_owned(), true, true) - .await?; + } = client_server::media::get_content( + server_name, + media_id.to_owned(), + Some(Target::User(services().globals.server_user().to_owned())), + ) + .await?; if let Ok(image) = image::load_from_memory(&file) { let filename = content_disposition.and_then(|cd| cd.filename); diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 444f5f9a9cdabd123c7372ee78ec23da3eb1d0c3..539348374ce72dcf3c01434e3b9b4fe4f43c5a3d 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,7 +1,7 @@ use ruma::{OwnedServerName, ServerName, UserId}; use sha2::{digest::Output, Sha256}; -use crate::{config::MediaRetentionConfig, Error, Result}; +use crate::{config::MediaRetentionConfig, service::media::FileInfo, Error, Result}; use super::{ BlockedMediaInfo, DbFileMeta, MediaListItem, MediaQuery, MediaType, ServerNameOrUserId, @@ -124,4 +124,7 @@ pub trait Data: Send + Sync { fn update_last_accessed(&self, server_name: &ServerName, media_id: &str) -> Result<()>; fn update_last_accessed_filehash(&self, sha256_digest: &[u8]) -> Result<()>; + + /// Returns the known information about a file + fn file_info(&self, sha256_digest: &[u8]) -> Result>; } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 2f5c814df6f3ae3af42272c2ae1af212fdf8b827..5d3283adf6fb17acad21c11db5bb0a2fd5ef2663 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -17,6 +17,7 @@ use tracing::{error, info, warn}; use crate::{ config::{DirectoryStructure, MediaBackendConfig, S3MediaBackend}, + service::rate_limiting::Target, services, utils, Error, Result, }; use image::imageops::FilterType; @@ -237,7 +238,7 @@ impl Service { &self, servername: &ServerName, media_id: &str, - authenticated: bool, + target: Option, ) -> Result> { let DbFileMeta { sha256_digest, @@ -246,12 +247,19 @@ impl Service { unauthenticated_access_permitted, } = self.db.search_file_metadata(servername, media_id)?; - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } let file = self.get_file(&sha256_digest, None).await?; + services() + .rate_limiting + .check_media_download(target, size(&file)?) + .await?; + Ok(Some(FileMeta { content_disposition: content_disposition(filename, &content_type), content_type, @@ -288,7 +296,7 @@ impl Service { media_id: &str, width: u32, height: u32, - authenticated: bool, + target: Option, ) -> Result> { if let Some((width, height, crop)) = self.thumbnail_properties(width, height) { if let Ok(DbFileMeta { @@ -300,10 +308,19 @@ impl Service { .db .search_thumbnail_metadata(servername, media_id, width, height) { - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } + let file_info = self.file_info(&sha256_digest)?; + + services() + .rate_limiting + .check_media_download(target, file_info.size) + .await?; + // Using saved thumbnail let file = self .get_file(&sha256_digest, Some((servername, media_id))) @@ -314,19 +331,15 @@ impl Service { content_type, file, })) - } else if !authenticated { + } else if !target.as_ref().is_some_and(Target::is_authenticated) { return Ok(None); } else if let Ok(DbFileMeta { sha256_digest, filename, content_type, - unauthenticated_access_permitted, + .. }) = self.db.search_file_metadata(servername, media_id) { - if !(authenticated || unauthenticated_access_permitted) { - return Ok(None); - } - let content_disposition = content_disposition(filename.clone(), &content_type); // Generate a thumbnail let file = self.get_file(&sha256_digest, None).await?; @@ -426,7 +439,9 @@ impl Service { return Ok(None); }; - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } @@ -662,6 +677,13 @@ impl Service { .update_last_accessed_filehash(sha256_digest) .map(|_| file) } + + fn file_info(&self, sha256_digest: &[u8]) -> Result { + self.db + .file_info(sha256_digest) + .transpose() + .unwrap_or_else(|| Err(Error::BadRequest(ErrorKind::NotFound, "Fi)le not found"))) + } } /// Creates the media file, using the configured media backend diff --git a/src/service/mod.rs b/src/service/mod.rs index 432c0e7acab5d457748e71d8855745f753d992fd..6c511391517f90c29b46c0df11f1082b0fafa28c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -36,6 +37,7 @@ pub struct Services { pub key_backups: key_backups::Service, pub media: Arc, pub sending: Arc, + pub rate_limiting: Arc, } impl Services { @@ -123,6 +125,8 @@ impl Services { media: Arc::new(media::Service { db }), sending: sending::Service::build(db, &config), + rate_limiting: rate_limiting::Service::build(&config), + globals: globals::Service::load(db, config)?, }) } diff --git a/src/service/rate_limiting/mod.rs b/src/service/rate_limiting/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..bf393067275f94578db43c7543b3ad85655ac532 --- /dev/null +++ b/src/service/rate_limiting/mod.rs @@ -0,0 +1,713 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + net::IpAddr, + sync::Arc, + time::Duration, +}; + +use ruma::{ + api::{ + client::error::{ErrorKind, RetryAfter}, + federation::membership::create_knock_event, + Metadata, + }, + OwnedServerName, OwnedUserId, UserId, +}; +use serde::Deserialize; +use tokio::{ + sync::{Mutex, MutexGuard, RwLock}, + time::Instant, +}; + +use crate::{ + config::rate_limiting::{MediaLimitation, RequestLimitation}, + service::appservice::RegistrationInfo, + services, Config, Error, Result, +}; + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis + Server(OwnedServerName), + Appservice { id: String, rate_limited: bool }, + Ip(IpAddr), +} + +impl Target { + pub fn from_client_request( + registration_info: Option, + sender_user: &UserId, + ) -> Self { + if let Some(info) = registration_info { + // `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded" + return Target::Appservice { + id: info.registration.id, + rate_limited: info.registration.rate_limited.unwrap_or(true) + && !(sender_user.server_name() == services().globals.server_name() + && info.registration.sender_localpart == sender_user.localpart()), + }; + } + + Target::User(sender_user.to_owned()) + } + + pub fn from_client_request_optional_auth( + registration_info: Option, + sender_user: &Option, + ip_addr: Option, + ) -> Option { + if let Some(sender_user) = sender_user.as_ref() { + Some(Self::from_client_request(registration_info, sender_user)) + } else { + ip_addr.map(Self::Ip) + } + } + + fn rate_limited(&self) -> bool { + match self { + Target::User(user_id) => user_id != services().globals.server_user(), + Target::Appservice { + id: _, + rate_limited, + } => *rate_limited, + _ => true, + } + } + + pub fn is_authenticated(&self) -> bool { + !matches!(self, Target::Ip(_)) + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Restriction { + Client(ClientRestriction), + Federation(FederationRestriction), +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum ClientRestriction { + Registration, + Login, + RegistrationTokenValidity, + + SendEvent, + + Join, + Invite, + Knock, + + SendReport, + CreateAlias, + + MediaDownload, + MediaCreate, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum FederationRestriction { + Join, + Knock, + Invite, + + // Transactions should be handled by a completely dedicated rate-limiter + Transaction, + + MediaDownload, +} + +impl TryFrom for Restriction { + type Error = (); + + fn try_from(value: Metadata) -> Result { + use ruma::api::{ + client::{ + account::{check_registration_token_validity, register}, + alias::create_alias, + authenticated_media::{ + get_content, get_content_as_filename, get_content_thumbnail, get_media_preview, + }, + knock::knock_room, + media::{self, create_content, create_mxc_uri}, + membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + reporting::report_user, + room::{report_content, report_room}, + session::login, + state::send_state_event, + }, + federation::{ + authenticated_media::{ + get_content as federation_get_content, + get_content_thumbnail as federation_get_content_thumbnail, + }, + membership::{create_invite, create_join_event}, + }, + IncomingRequest, + }; + use Restriction::*; + + Ok(match value { + register::v3::Request::METADATA => Client(ClientRestriction::Registration), + check_registration_token_validity::v1::Request::METADATA => { + Client(ClientRestriction::RegistrationTokenValidity) + } + login::v3::Request::METADATA => Client(ClientRestriction::Login), + send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { + Client(ClientRestriction::SendEvent) + } + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join), + invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite), + knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock), + report_user::v3::Request::METADATA + | report_content::v3::Request::METADATA + | report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport), + create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias), + // NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once + create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => { + Client(ClientRestriction::MediaCreate) + } + // Unauthenticate media is deprecated + #[allow(deprecated)] + media::get_content::v3::Request::METADATA + | media::get_content_as_filename::v3::Request::METADATA + | media::get_content_thumbnail::v3::Request::METADATA + | media::get_media_preview::v3::Request::METADATA + | get_content::v1::Request::METADATA + | get_content_as_filename::v1::Request::METADATA + | get_content_thumbnail::v1::Request::METADATA + | get_media_preview::v1::Request::METADATA => Client(ClientRestriction::MediaDownload), + federation_get_content::v1::Request::METADATA + | federation_get_content_thumbnail::v1::Request::METADATA => { + Federation(FederationRestriction::MediaDownload) + } + // v1 is deprecated + #[allow(deprecated)] + create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { + Federation(FederationRestriction::Join) + } + create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock), + create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { + Federation(FederationRestriction::Invite) + } + + _ => return Err(()), + }) + } +} + +type MediaBucket = Mutex>>>; +type GlobalMediaBucket = Arc>; + +pub struct Service { + buckets: Mutex>>>, + global_bucket: Mutex>>>, + + media_upload: MediaBucket, + media_fetch: MediaBucket, + media_download: MediaBucket, + + global_media_upload: GlobalMediaBucket, + global_media_fetch: GlobalMediaBucket, + global_media_download_client: GlobalMediaBucket, + global_media_download_federation: GlobalMediaBucket, + + authentication_failures: RwLock>>>, +} + +impl Service { + pub fn build(config: &Config) -> Arc { + let now = Instant::now(); + let global_media_config = &config.rate_limiting.global; + + Arc::new(Self { + buckets: Mutex::new(HashMap::new()), + global_bucket: Mutex::new(HashMap::new()), + + media_upload: Mutex::new(HashMap::new()), + media_fetch: Mutex::new(HashMap::new()), + media_download: Mutex::new(HashMap::new()), + + global_media_upload: default_media_entry(global_media_config.client.media.upload, now), + global_media_fetch: default_media_entry(global_media_config.client.media.fetch, now), + global_media_download_client: default_media_entry( + global_media_config.client.media.download, + now, + ), + global_media_download_federation: default_media_entry( + global_media_config.federation.media.download, + now, + ), + + authentication_failures: RwLock::new(HashMap::new()), + }) + } + + //TODO: use checked and saturating arithmetic + + /// Takes the target and request, and either accepts the request while adding to the + /// bucket, or rejects the request, returning the duration that should be waited until + /// the request should be retried. + pub async fn check(&self, target: Option, request: Metadata) -> Result<()> { + let Ok(restriction) = request.try_into() else { + // Endpoint has no associated restriction + return Ok(()); + }; + let arrival = Instant::now(); + + { + let map = self.global_bucket.lock().await; + + if let Some(value) = map.get(&restriction) { + let value = value.lock().await; + + if arrival.checked_duration_since(*value).is_none() { + instant_to_err(&value)?; + } + } + } + + if let Some(target) = target { + let config = services() + .globals + .config + .rate_limiting + .target + .get(&restriction); + + let mut map = self.buckets.lock().await; + let entry = map.entry((target, restriction)); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + { + let config = services() + .globals + .config + .rate_limiting + .global + .get(&restriction); + + let mut map = self.global_bucket.lock().await; + + let entry = map.entry(restriction); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + Ok(()) + } + + pub async fn check_media_download(&self, target: Option, size: u64) -> Result<()> { + // All targets besides servers use the client-server API + let (target_limitation, global_limitation, global_bucket) = + if let Some(Target::Server(_)) = &target { + ( + services() + .globals + .config + .rate_limiting + .target + .federation + .media + .download, + services() + .globals + .config + .rate_limiting + .global + .federation + .media + .download, + &self.global_media_download_federation, + ) + } else { + ( + services() + .globals + .config + .rate_limiting + .target + .client + .media + .download, + services() + .globals + .config + .rate_limiting + .global + .client + .media + .download, + &self.global_media_download_client, + ) + }; + + check_media( + target, + size, + target_limitation, + global_limitation, + &self.media_download, + global_bucket, + ) + .await + } + + pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> { + let target_limitation = services() + .globals + .config + .rate_limiting + .target + // Media can only be uploaded on the client-server API + .client + .media + .upload; + + let global_limitation = services() + .globals + .config + .rate_limiting + .global + // Media can only be uploaded on the client-server API + .client + .media + .upload; + + check_media( + Some(target), + size, + target_limitation, + global_limitation, + &self.media_upload, + &self.global_media_upload, + ) + .await + } + + pub async fn check_media_pre_fetch(&self, target: &Target) -> Result<()> { + if !target.rate_limited() { + return Ok(()); + } + + let arrival = Instant::now(); + + let check = async |map: &MediaBucket, global_bucket: &GlobalMediaBucket| { + let map = map.lock().await; + if let Some(mutex) = map.get(target) { + let mutex = mutex.lock().await; + + if arrival.checked_duration_since(*mutex).is_none() { + return instant_to_err(&mutex); + } + } + + let global_bucket = global_bucket.lock().await; + + if arrival.checked_duration_since(*global_bucket).is_none() { + return instant_to_err(&global_bucket); + } + + Ok(()) + }; + + // checking fetch + check(&self.media_fetch, &self.global_media_fetch).await?; + + // checking download as well + check(&self.media_download, &self.global_media_download_client).await + } + + /// Checks whether the ip address is has been rate limited due to too many bad access tokens being sent. + pub async fn pre_auth_check(&self, ip_addr: IpAddr) -> Result<()> { + let arrival = Instant::now(); + + if let Some(instant) = self.authentication_failures.read().await.get(&ip_addr) { + let instant = instant.read().await; + + if arrival.checked_duration_since(*instant).is_none() { + return instant_to_err(&instant); + } + } + + Ok(()) + } + + /// Updates the bad auth rate limiter when a bad access token is sent where access tokens auth is an option. + pub async fn update_post_auth_failure(&self, ip_addr: IpAddr) { + let arrival = Instant::now(); + + let RequestLimitation { + timeframe, + burst_capacity, + } = services() + .globals + .config + .rate_limiting + .target + .client + .authentication_failures; + + let mut map = self.authentication_failures.write().await; + let entry = map.entry(ip_addr); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.write().await; + + let min_instant = + arrival - Duration::from_nanos(timeframe.nano_gap() * burst_capacity.get()); + *entry = entry.max(min_instant) + Duration::from_nanos(timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(RwLock::new( + arrival - Duration::from_nanos(burst_capacity.get() / timeframe.nano_gap()), + ))); + } + } + } + + pub async fn update_media_post_fetch(&self, target: Target, size: u64) { + if !target.rate_limited() { + return; + } + + let arrival = Instant::now(); + + let update = async |map: &MediaBucket, + target_limitation: &MediaLimitation, + global_bucket: &GlobalMediaBucket, + global_limitation: &MediaLimitation| { + let mut map = map.lock().await; + let entry = map.entry(target.clone()); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + + let _ = + update_media_entry(size, target_limitation, &arrival, entry, false).await; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + target_limitation.burst_capacity.as_u64() + / target_limitation.timeframe.bytes_per_sec(), + ), + ))); + } + } + + let _ = update_media_entry( + size, + global_limitation, + &arrival, + Arc::clone(global_bucket), + false, + ) + .await; + }; + + // updating fetch + update( + &self.media_fetch, + &services() + .globals + .config + .rate_limiting + .target + .client + .media + .fetch, + &self.global_media_fetch, + &services() + .globals + .config + .rate_limiting + .global + .client + .media + .fetch, + ) + .await; + + // updating download as well + update( + &self.media_download, + &services() + .globals + .config + .rate_limiting + .target + .client + .media + .download, + &self.global_media_download_client, + &services() + .globals + .config + .rate_limiting + .global + .client + .media + .download, + ) + .await; + } +} + +async fn update_media_entry( + size: u64, + limitation: &MediaLimitation, + arrival: &Instant, + entry: Arc>, + and_check: bool, +) -> Result<()> { + let mut entry = entry.lock().await; + + //TODO: use more precise conversion than secs + let proposed_entry = get_proposed_entry(size, limitation, arrival, &entry, and_check)?; + + *entry = proposed_entry; + + Ok(()) +} + +fn get_proposed_entry( + size: u64, + limitation: &MediaLimitation, + arrival: &Instant, + entry: &MutexGuard<'_, Instant>, + and_check: bool, +) -> Result { + let min_instant = *arrival + - Duration::from_secs( + limitation.burst_capacity.as_u64() / limitation.timeframe.bytes_per_sec(), + ); + + let proposed_entry = + entry.max(min_instant) + Duration::from_secs(size / limitation.timeframe.bytes_per_sec()); + + if and_check && arrival.checked_duration_since(proposed_entry).is_none() { + return instant_to_err(&proposed_entry).map(|_| proposed_entry); + } + + Ok(proposed_entry) +} + +async fn check_media( + target: Option, + size: u64, + target_limitation: MediaLimitation, + global_limitation: MediaLimitation, + target_map: &MediaBucket, + global_bucket: &GlobalMediaBucket, +) -> Result<()> { + if !target.as_ref().is_some_and(Target::rate_limited) { + return Ok(()); + } + + let arrival = Instant::now(); + + let mut global_bucket = global_bucket.lock().await; + let proposed = get_proposed_entry(size, &global_limitation, &arrival, &global_bucket, true)?; + + if let Some(target) = target { + let mut map = target_map.lock().await; + let entry = map.entry(target); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + + update_media_entry(size, &target_limitation, &arrival, entry, true).await?; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(default_media_entry(target_limitation, arrival)); + } + } + } + + *global_bucket = proposed; + + Ok(()) +} + +fn default_media_entry( + target_limitation: MediaLimitation, + arrival: Instant, +) -> Arc> { + Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + target_limitation.burst_capacity.as_u64() + / target_limitation.timeframe.bytes_per_sec(), + ), + )) +} + +fn instant_to_err(instant: &Instant) -> Result<()> { + let now = Instant::now(); + + Err(Error::BadRequest( + ErrorKind::LimitExceeded { + // Not using ::DateTime because conversion from Instant to SystemTime is convoluted + retry_after: Some(RetryAfter::Delay(instant.duration_since(now))), + }, + "Rate limit exceeded", + )) +}