From 59b6a5e45297a25eaaf0379c5e2a067ab2494ac0 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Fri, 9 Apr 2021 13:01:39 -0600 Subject: [PATCH 1/2] add tor support --- Cargo.lock | 13 +++++++++++ Cargo.toml | 3 ++- src/database.rs | 37 ++++++++++++++++++++++++++++--- src/database/globals.rs | 49 ++++++++++++++++++++++++++++++++++------- src/server_server.rs | 43 +++++++++++++++++++----------------- src/utils.rs | 28 +++++++++++++++++++++++ 6 files changed, 141 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 51ccff7b..6cbc2173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1508,6 +1508,7 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-native-tls", + "tokio-socks", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -2336,6 +2337,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-socks" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-util" version = "0.6.3" diff --git a/Cargo.toml b/Cargo.toml index 8addf501..6fdeff48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,9 +71,10 @@ tracing-opentelemetry = "0.11.0" opentelemetry-jaeger = "0.11.0" [features] -default = ["conduit_bin"] +default = ["conduit_bin", "tor"] conduit_bin = [] # TODO: add rocket to this when it is optional tls_vendored = ["reqwest/native-tls-vendored"] +tor = ["reqwest/socks"] [[bin]] name = "conduit" diff --git a/src/database.rs b/src/database.rs index 6dc9c70e..4349c65f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -17,9 +17,11 @@ use log::info; use rocket::futures::{self, channel::mpsc}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; -use std::collections::HashMap; -use std::fs::remove_dir_all; -use std::sync::{Arc, RwLock}; +use std::{ + collections::HashMap, + fs::remove_dir_all, + sync::{Arc, RwLock}, +}; use tokio::sync::Semaphore; #[derive(Clone, Deserialize)] @@ -40,6 +42,10 @@ pub struct Config { allow_federation: bool, #[serde(default = "false_fn")] pub allow_jaeger: bool, + #[cfg(feature = "tor")] + #[serde(default)] + #[serde(flatten)] + tor_federation: TorFederation, jwt_secret: Option, } @@ -63,6 +69,31 @@ fn default_max_concurrent_requests() -> u16 { 4 } +#[cfg(feature = "tor")] +#[derive(Clone, Deserialize)] +#[serde(rename = "snake_case")] +#[serde(tag = "tor_federation")] +pub enum TorFederation { + Disabled, + Enabled { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + tor_proxy: reqwest::Url, + tor_only: bool, + }, +} +#[cfg(feature = "tor")] +impl TorFederation { + pub fn enabled(&self) -> bool { + matches!(self, &TorFederation::Enabled { .. }) + } +} +#[cfg(feature = "tor")] +impl Default for TorFederation { + fn default() -> Self { + TorFederation::Disabled + } +} + #[derive(Clone)] pub struct Database { pub globals: globals::Globals, diff --git a/src/database/globals.rs b/src/database/globals.rs index 6004c10a..7075dbeb 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,10 +1,11 @@ use crate::{database::Config, utils, Error, Result}; use log::error; use ruma::ServerName; -use std::collections::HashMap; -use std::sync::Arc; -use std::sync::RwLock; -use std::time::Duration; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::Duration, +}; use trust_dns_resolver::TokioAsyncResolver; pub const COUNTER: &str = "c"; @@ -57,12 +58,39 @@ impl Globals { } }; - let reqwest_client = reqwest::Client::builder() + let mut reqwest_client = reqwest::Client::builder(); + reqwest_client = reqwest_client .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(60 * 3)) - .pool_max_idle_per_host(1) - .build() - .unwrap(); + .pool_max_idle_per_host(1); + #[cfg(feature = "tor")] + { + use crate::database::TorFederation; + + if let TorFederation::Enabled { + tor_proxy, + tor_only, + } = config.tor_federation.clone() + { + let proxy = if tor_only { + reqwest::Proxy::all(tor_proxy).unwrap() + } else { + reqwest::Proxy::custom(move |url| { + if url + .host_str() + .map_or(false, |host| host.ends_with(".onion")) + { + Some(tor_proxy.clone()) + } else { + None + } + }) + }; + reqwest_client = reqwest_client.proxy(proxy); + } + } + + let reqwest_client = reqwest_client.build().unwrap(); let jwt_decoding_key = config .jwt_secret @@ -129,6 +157,11 @@ impl Globals { self.config.allow_federation } + #[cfg(feature = "tor")] + pub fn tor_federation_enabled(&self) -> bool { + self.config.tor_federation.enabled() + } + pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.dns_resolver } diff --git a/src/server_server.rs b/src/server_server.rs index b9b8e316..fb0a4bfe 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -3,6 +3,8 @@ use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION, HOST}; use log::{info, warn}; use regex::Regex; +#[cfg(feature = "conduit_bin")] +use rocket::{get, post, put}; use rocket::{response::content::Json, State}; use ruma::{ api::{ @@ -29,10 +31,6 @@ use std::{ net::{IpAddr, SocketAddr}, time::{Duration, SystemTime}, }; -#[cfg(feature = "conduit_bin")] -use { - rocket::{get, post, put} -}; #[tracing::instrument(skip(globals))] pub async fn send_request( @@ -231,7 +229,15 @@ async fn find_actual_destination( let mut host = None; let destination_str = destination.as_str().to_owned(); - let actual_destination = "https://".to_owned() + #[cfg(not(feature = "tor"))] + let protocol = "https://"; + #[cfg(feature = "tor")] + let protocol = if globals.tor_federation_enabled() && destination_str.ends_with(".onion") { + "http://" + } else { + "https://" + }; + let actual_destination = protocol.to_owned() + &match get_ip_with_port(destination_str.clone()) { Some(host_port) => { // 1: IP literal with provided or default port @@ -600,21 +606,18 @@ pub async fn send_transaction_message_route<'a>( let users = namespaces .get("users") .and_then(|users| users.as_sequence()) - .map_or_else( - Vec::new, - |users| { - users - .iter() - .map(|users| { - users - .get("regex") - .and_then(|regex| regex.as_str()) - .and_then(|regex| Regex::new(regex).ok()) - }) - .filter_map(|o| o) - .collect::>() - }, - ); + .map_or_else(Vec::new, |users| { + users + .iter() + .map(|users| { + users + .get("regex") + .and_then(|regex| regex.as_str()) + .and_then(|regex| Regex::new(regex).ok()) + }) + .filter_map(|o| o) + .collect::>() + }); let aliases = namespaces .get("aliases") .and_then(|users| users.get("regex")) diff --git a/src/utils.rs b/src/utils.rs index 0783567e..81769fe1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -6,6 +6,7 @@ use sled::IVec; use std::{ cmp, convert::TryInto, + str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; @@ -112,3 +113,30 @@ pub fn to_canonical_object( ))), } } + +#[allow(dead_code)] +pub fn deserialize_from_str< + 'de, + D: serde::de::Deserializer<'de>, + T: FromStr, + E: std::fmt::Display, +>( + deserializer: D, +) -> std::result::Result { + struct Visitor, E>(std::marker::PhantomData); + impl<'de, T: FromStr, Err: std::fmt::Display> serde::de::Visitor<'de> + for Visitor + { + type Value = T; + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a parsable string") + } + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(|e| serde::de::Error::custom(e)) + } + } + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) +} -- GitLab From dd7159399ac1de608e3b25018bdaebc2324fbb1c Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Mon, 12 Apr 2021 17:06:23 -0600 Subject: [PATCH 2/2] fix tor config --- src/database.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/database.rs b/src/database.rs index 4349c65f..4b8da687 100644 --- a/src/database.rs +++ b/src/database.rs @@ -71,13 +71,15 @@ fn default_max_concurrent_requests() -> u16 { #[cfg(feature = "tor")] #[derive(Clone, Deserialize)] -#[serde(rename = "snake_case")] #[serde(tag = "tor_federation")] pub enum TorFederation { + #[serde(rename = "disabled")] Disabled, + #[serde(rename = "enabled")] Enabled { #[serde(deserialize_with = "crate::utils::deserialize_from_str")] tor_proxy: reqwest::Url, + #[serde(default)] tor_only: bool, }, } -- GitLab