diff --git a/Cargo.toml b/Cargo.toml index ec1a8979..caeefefa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ tokio_io = ["tokio"] [dependencies] byteorder = "^1.3" -chrono-tz = "0.5" +chrono-tz = "0.6" crossbeam = "0.8.0" thiserror = "1.0.20" futures-core = "0.3.5" @@ -33,6 +33,7 @@ pin-project = "1.0.4" url="^2" uuid = "0.8.1" combine = "4.2.3" +percent-encoding = "2.1.0" [dependencies.futures-util] version = "0.3.12" diff --git a/src/types/options.rs b/src/types/options.rs index f2d84c0f..cbd746a3 100644 --- a/src/types/options.rs +++ b/src/types/options.rs @@ -14,6 +14,7 @@ use crate::errors::{Error, Result, UrlError}; use std::fmt::Formatter; #[cfg(feature = "tls")] use native_tls; +use percent_encoding::percent_decode; use url::Url; const DEFAULT_MIN_CONNS: usize = 10; @@ -432,11 +433,11 @@ fn from_url(url_str: &str) -> Result { let mut options = Options::default(); - if let Some(username) = get_username_from_url(&url)? { + if let Some(username) = get_username_from_url(&url) { options.username = username.into(); } - if let Some(password) = get_password_from_url(&url)? { + if let Some(password) = get_password_from_url(&url) { options.password = password.into() } @@ -507,7 +508,8 @@ fn parse_param<'a, F, T, E>( where F: Fn(&str) -> std::result::Result, { - match parse(value.as_ref()) { + let source = percent_decode(value.as_bytes()).decode_utf8_lossy(); + match parse(source.as_ref()) { Ok(value) => Ok(value), Err(_) => Err(UrlError::InvalidParamValue { param: param.into(), @@ -516,19 +518,19 @@ where } } -fn get_username_from_url(url: &Url) -> Result> { +fn get_username_from_url(url: &Url) -> Option> { let user = url.username(); if user.is_empty() { - return Ok(None); + return None; } - Ok(Some(user)) + Some(percent_decode(user.as_bytes()) + .decode_utf8_lossy()) } -fn get_password_from_url(url: &Url) -> Result> { - match url.password() { - None => Ok(None), - Some(password) => Ok(Some(password)), - } +fn get_password_from_url(url: &Url) -> Option> { + let password = url.password()?; + Some(percent_decode(password.as_bytes()) + .decode_utf8_lossy()) } fn get_database_from_url(url: &Url) -> Result> { @@ -653,6 +655,25 @@ mod test { ); } + #[test] + fn test_parse_encoded_creds() { + let url = "tcp://user%20%3Cbar%3E:password%20%3Cbar%3E@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s"; + assert_eq!( + Options { + username: "user ".into(), + password: "password ".into(), + addr: Url::parse("tcp://user%20%3Cbar%3E:password%20%3Cbar%3E@host1:9001").unwrap(), + database: "database".into(), + keepalive: Some(Duration::from_secs(99)), + ping_timeout: Duration::from_millis(42), + connection_timeout: Duration::from_secs(10), + compression: true, + ..Options::default() + }, + from_url(url).unwrap(), + ); + } + #[test] fn test_parse_options() { let url = "tcp://username:password@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s";