diff --git a/warpgate-protocol-mysql/Cargo.toml b/warpgate-protocol-mysql/Cargo.toml index c827d77..38d8a91 100644 --- a/warpgate-protocol-mysql/Cargo.toml +++ b/warpgate-protocol-mysql/Cargo.toml @@ -21,5 +21,5 @@ bytes = "1.1" mysql_common = "0.29" rand = "0.8" sha1 = "0.10.1" -password-hash="0.2" +password-hash = { version = "0.2", features = ["std"] } delegate = "0.6" diff --git a/warpgate-protocol-mysql/src/client.rs b/warpgate-protocol-mysql/src/client.rs new file mode 100644 index 0000000..e80bb3d --- /dev/null +++ b/warpgate-protocol-mysql/src/client.rs @@ -0,0 +1,98 @@ +use anyhow::{Context, Result}; +use bytes::BytesMut; +use sqlx_core_guts::io::Decode; +use sqlx_core_guts::mysql::options::MySqlConnectOptions; +use sqlx_core_guts::mysql::protocol::auth::AuthPlugin; +use sqlx_core_guts::mysql::protocol::connect::{Handshake, HandshakeResponse}; +use sqlx_core_guts::mysql::protocol::response::ErrPacket; +use sqlx_core_guts::mysql::protocol::Capabilities; +use tokio::net::TcpStream; +use tracing::*; + +use crate::common::compute_auth_challenge_response; +use crate::stream::MySQLStream; + +pub struct MySQLClient { + pub stream: MySQLStream, + pub capabilities: Capabilities, +} + +pub struct ConnectionOptions { + pub collation: u8, + pub database: Option, + pub max_packet_size: u32, + pub capabilities: Capabilities, +} + +impl MySQLClient { + pub async fn connect(uri: &str, mut options: ConnectionOptions) -> Result { + let opts: MySqlConnectOptions = uri.parse()?; + let mut stream = MySQLStream::new(TcpStream::connect((opts.host, opts.port)).await?); + + let payload = stream.recv().await?; + let handshake = Handshake::decode(payload)?; + + options.capabilities &= handshake.server_capabilities; + // options.capabilities.remove(Capabilities::CONNECT_ATTRS); + + debug!(?handshake, "Received handshake"); + debug!(capabilities=?options.capabilities, "Capabilities"); + + let mut response = HandshakeResponse { + auth_plugin: None, + auth_response: None, + collation: options.collation, + database: options.database, + max_packet_size: options.max_packet_size, + username: opts.username, + }; + + if handshake.auth_plugin == Some(AuthPlugin::MySqlNativePassword) { + let scramble_bytes = [ + &handshake.auth_plugin_data.first_ref()[..], + &handshake.auth_plugin_data.last_ref()[..], + ] + .concat(); + match scramble_bytes.try_into() as Result<[u8; 20], Vec> { + Err(scramble_bytes) => { + warn!("Invalid scramble length ({})", scramble_bytes.len()); + } + Ok(scramble) => { + let Some(password) = opts.password else { + error!("Password not set in the connection URI"); + anyhow::bail!("Password not set"); + }; + response.auth_plugin = Some(AuthPlugin::MySqlNativePassword); + response.auth_response = Some( + BytesMut::from( + compute_auth_challenge_response(scramble, &password)?.as_bytes(), + ) + .freeze(), + ); + trace!(response=?response.auth_response, ?scramble, "auth"); + } + } + } + + stream.push(&response, options.capabilities)?; + stream.flush().await?; + + let response = stream.recv().await?; + if response.get(0) == Some(&0) || response.get(0) == Some(&0xfe) { + debug!("Authorized"); + } else if response.get(0) == Some(&0xff) { + let error = ErrPacket::decode_with(response, options.capabilities)?; + error!(?error, "Handshake failed"); + anyhow::bail!("Handshake failed"); + } else { + anyhow::bail!("Unknown response type {:?}", response.get(0)); + } + + stream.reset_sequence_id(); + + Ok(Self { + stream, + capabilities: options.capabilities, + }) + } +} diff --git a/warpgate-protocol-mysql/src/common.rs b/warpgate-protocol-mysql/src/common.rs index 111d958..a80b191 100644 --- a/warpgate-protocol-mysql/src/common.rs +++ b/warpgate-protocol-mysql/src/common.rs @@ -1,3 +1,25 @@ +use sha1::Digest; use warpgate_common::ProtocolName; pub const PROTOCOL_NAME: ProtocolName = "MySQL"; + +pub fn compute_auth_challenge_response( + challenge: [u8; 20], + password: &str, +) -> Result { + password_hash::Output::new( + &{ + let password_sha: [u8; 20] = sha1::Sha1::digest(password).into(); + let password_sha_sha: [u8; 20] = sha1::Sha1::digest(password_sha).into(); + let password_seed_2sha_sha: [u8; 20] = + sha1::Sha1::digest([challenge, password_sha_sha].concat()).into(); + + let mut result = password_sha; + result + .iter_mut() + .zip(password_seed_2sha_sha.iter()) + .for_each(|(x1, x2)| *x1 ^= *x2); + result + }[..], + ) +} diff --git a/warpgate-protocol-mysql/src/lib.rs b/warpgate-protocol-mysql/src/lib.rs index 63f75a7..973aa27 100644 --- a/warpgate-protocol-mysql/src/lib.rs +++ b/warpgate-protocol-mysql/src/lib.rs @@ -1,26 +1,28 @@ #![feature(type_alias_impl_trait, let_else, try_blocks)] +mod client; mod common; +mod stream; use anyhow::{Context, Result}; use async_trait::async_trait; -use bytes::{Buf, Bytes, BytesMut}; -use mysql_common::proto::codec::PacketCodec; +use bytes::{Buf, BytesMut}; +use common::compute_auth_challenge_response; use rand::Rng; -use sha1::Digest; -use sqlx_core_guts::io::{BufStream, Decode, Encode}; +use sqlx_core_guts::io::Decode; use sqlx_core_guts::mysql::protocol::auth::AuthPlugin; use sqlx_core_guts::mysql::protocol::connect::{AuthSwitchRequest, Handshake, HandshakeResponse}; use sqlx_core_guts::mysql::protocol::response::{ErrPacket, OkPacket, Status}; use sqlx_core_guts::mysql::protocol::text::Query; use sqlx_core_guts::mysql::protocol::Capabilities; use std::fmt::Debug; - use std::net::SocketAddr; -use tokio::io::{AsyncReadExt}; +use stream::MySQLStream; use tokio::net::{TcpListener, TcpStream}; use tracing::*; use warpgate_common::helpers::rng::get_crypto_rng; use warpgate_common::{ProtocolServer, Services, Target, TargetTestError}; +use crate::client::{ConnectionOptions, MySQLClient}; + pub struct MySQLProtocolServer { services: Services, } @@ -62,28 +64,23 @@ impl Debug for MySQLProtocolServer { } struct Session { - stream: BufStream, - codec: PacketCodec, + stream: MySQLStream, capabilities: Capabilities, - inbound_buffer: BytesMut, - outbound_buffer: BytesMut, challenge: [u8; 20], } - impl Session { pub fn new(stream: TcpStream) -> Self { Self { - stream: BufStream::new(stream), + stream: MySQLStream::new(stream), capabilities: Capabilities::PROTOCOL_41 | Capabilities::PLUGIN_AUTH | Capabilities::FOUND_ROWS | Capabilities::LONG_FLAG | Capabilities::NO_SCHEMA - | Capabilities::MULTI_RESULTS + // | Capabilities::MULTI_RESULTS | Capabilities::MULTI_STATEMENTS - | Capabilities::PS_MULTI_RESULTS - | Capabilities::CONNECT_ATTRS + // | Capabilities::PS_MULTI_RESULTS | Capabilities::PLUGIN_AUTH_LENENC_DATA | Capabilities::CONNECT_WITH_DB | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS @@ -94,101 +91,42 @@ impl Session { // | Capabilities::MULTI_FACTOR_AUTHENTICATION | Capabilities::DEPRECATE_EOF | Capabilities::SECURE_CONNECTION, - codec: PacketCodec::default(), - inbound_buffer: BytesMut::new(), - outbound_buffer: BytesMut::new(), challenge: get_crypto_rng().gen(), } } - fn push<'a, C, P: Encode<'a, C>>(&mut self, packet: &'a P, context: C) -> Result<()> { - let mut buf = vec![]; - packet.encode_with(&mut buf, context); - self.codec - .encode(&mut &*buf, &mut self.outbound_buffer) - .context("Failed to encode packet")?; - Ok(()) - } - - async fn flush(&mut self) -> Result<()> { - trace!(outbound_buffer=?self.outbound_buffer, "flushing"); - self.stream.write(&self.outbound_buffer[..]); - self.outbound_buffer = BytesMut::new(); - self.stream - .flush() - .await - .context("Failed to flush stream")?; - Ok(()) - } - - async fn recv(&mut self) -> Result { - let mut payload = BytesMut::new(); - loop { - let read_bytes = self.stream.read_buf(&mut self.inbound_buffer).await?; - if read_bytes == 0 { - anyhow::bail!("Unexpected EOF"); - } - trace!(inbound_buffer=?self.inbound_buffer, "chunk"); - { - let got_full_packet = self.codec.decode(&mut self.inbound_buffer, &mut payload)?; - if got_full_packet { - break; - } - } - } - trace!(inbound_buffer=?self.inbound_buffer, "after packet"); - Ok(payload.freeze()) - // let result = P::deserialize(ctx, &mut pb); - // drop(pb); - // return result.context("Failed to deserialize"); - } - async fn check_auth_response(&mut self, response: &[u8]) -> Result { - let expected_response = password_hash::Output::new( - &{ - let true_password = b"123"; - let password_sha: [u8; 20] = sha1::Sha1::digest(true_password).into(); - let password_sha_sha: [u8; 20] = sha1::Sha1::digest(password_sha).into(); - let password_seed_2sha_sha: [u8; 20] = - sha1::Sha1::digest([self.challenge, password_sha_sha].concat()).into(); - - let mut result = password_sha; - result - .iter_mut() - .zip(password_seed_2sha_sha.iter()) - .for_each(|(x1, x2)| *x1 ^= *x2); - result - }[..], - ); - - let client_response = password_hash::Output::new(response); - info!(?client_response, "client_response"); - info!(?expected_response, "exp response"); - - info!("correct {}", client_response == expected_response); + let expected_response = compute_auth_challenge_response(self.challenge, "123")?; + let client_response = password_hash::Output::new(response)?; if client_response == expected_response { - self.push(&OkPacket { - affected_rows: 0, - last_insert_id: 0, - status: Status::empty(), - warnings: 0, - }, ())?; + self.stream.push( + &OkPacket { + affected_rows: 0, + last_insert_id: 0, + status: Status::empty(), + warnings: 0, + }, + (), + )?; } else { - self.push(&ErrPacket { - error_code: 1, - error_message: "Access denied".to_owned(), - sql_state: None, - }, ())?; + self.stream.push( + &ErrPacket { + error_code: 1, + error_message: "Access denied".to_owned(), + sql_state: None, + }, + (), + )?; } - self.flush().await?; + self.stream.flush().await?; Ok(client_response == expected_response) } pub async fn run(mut self) -> Result<()> { let mut challenge_1 = BytesMut::from(&self.challenge[..]); - let mut challenge_2 = challenge_1.split_off(8); + let challenge_2 = challenge_1.split_off(8); let challenge_chain = challenge_1.freeze().chain(challenge_2.freeze()); let handshake = Handshake { @@ -201,19 +139,22 @@ impl Session { status: Status::empty(), auth_plugin: Some(AuthPlugin::MySqlNativePassword), }; - self.push(&handshake, ())?; - self.flush().await?; + self.stream.push(&handshake, ())?; + self.stream.flush().await?; - let mut payload = self.recv().await?; + let payload = self.stream.recv().await?; let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities) .context("Failed to parse packet")?; - info!(?resp, "got response"); + + trace!(?resp, "Handshake response"); + info!(capabilities=?self.capabilities, username=%resp.username, "handshake complete"); if resp.auth_plugin == Some(AuthPlugin::MySqlNativePassword) { if let Some(response) = resp.auth_response.as_ref() { - if self.check_auth_response(response).await? { - return self.run_authorized().await; - }} + if self.check_auth_response(response).await? { + return self.run_authorized(resp).await; + } + } } let challenge = self.challenge.clone(); @@ -221,35 +162,88 @@ impl Session { plugin: AuthPlugin::MySqlNativePassword, data: BytesMut::from(&challenge[..]).freeze(), }; - self.push(&req, ())?; + self.stream.push(&req, ())?; // self.push(&RawBytes::< - self.flush().await?; + self.stream.flush().await?; - let response = &self.recv().await?; + let response = &self.stream.recv().await?; if self.check_auth_response(response).await? { - return self.run_authorized().await; + return self.run_authorized(resp).await; } Ok(()) } - pub async fn run_authorized(mut self) -> Result<()> { + async fn send_error(&mut self, code: u16, message: &str) -> Result<()> { + self.stream.push( + &ErrPacket { + error_code: code, + error_message: message.to_owned(), + sql_state: None, + }, + (), + )?; + self.stream.flush().await + } + + pub async fn run_authorized(mut self, handshake: HandshakeResponse) -> Result<()> { + let mut client = match MySQLClient::connect( + "mysql://dev:123@localhost:3306/elements_web", + ConnectionOptions { + collation: handshake.collation, + database: handshake.database, + max_packet_size: handshake.max_packet_size, + capabilities: self.capabilities.clone(), + }, + ) + .await + { + Ok(c) => c, + Err(error) => { + error!(?error, "Target connection failed"); + self.send_error(1045, "Access denied").await?; + return Err(error); + } + }; + loop { - self.codec.reset_seq_id(); - let payload = self.recv().await?; - trace!(?payload, "got packet"); + self.stream.reset_sequence_id(); + client.stream.reset_sequence_id(); + let payload = self.stream.recv().await?; + trace!(?payload, "server got packet"); // COM_QUERY if payload.get(0) == Some(&0x03) { let query = Query::decode(payload)?; - trace!(?query, "got query"); - self.push(&ErrPacket { - error_code: 1, - error_message: "Whoops".to_owned(), - sql_state: None, - }, ())?; - self.flush().await?; + trace!(?query, "server got query"); + + client.stream.push(&query, ())?; + client.stream.flush().await?; + + let mut eof_ctr = 0; + loop { + let response = client.stream.recv().await?; + trace!(?response, "client got packet"); + self.stream.push(&&response[..], ())?; + self.stream.flush().await?; + if let Some(b) = response.get(0) { + if b == &0xfe { + eof_ctr += 1; + if eof_ctr == 2 && !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { + // tood check multiple results + break; + } + } + if b == &0 || b == &0xff { + break; + } + } + } + + } else { + warn!("Unknown packet type {:?}", payload.get(0)); + self.send_error(999, "Not implemented").await?; } } diff --git a/warpgate-protocol-mysql/src/stream.rs b/warpgate-protocol-mysql/src/stream.rs new file mode 100644 index 0000000..99c955b --- /dev/null +++ b/warpgate-protocol-mysql/src/stream.rs @@ -0,0 +1,67 @@ +use anyhow::{Context, Result}; +use bytes::{Bytes, BytesMut}; +use mysql_common::proto::codec::PacketCodec; +use sqlx_core_guts::io::{BufStream, Encode}; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; +use tracing::*; + +pub struct MySQLStream { + stream: BufStream, + codec: PacketCodec, + inbound_buffer: BytesMut, + outbound_buffer: BytesMut, +} + +impl MySQLStream { + pub fn new(stream: TcpStream) -> Self { + Self { + stream: BufStream::new(stream), + codec: PacketCodec::default(), + inbound_buffer: BytesMut::new(), + outbound_buffer: BytesMut::new(), + } + } + + pub fn push<'a, C, P: Encode<'a, C>>(&mut self, packet: &'a P, context: C) -> Result<()> { + let mut buf = vec![]; + packet.encode_with(&mut buf, context); + self.codec + .encode(&mut &*buf, &mut self.outbound_buffer) + .context("Failed to encode packet")?; + Ok(()) + } + + pub async fn flush(&mut self) -> Result<()> { + trace!(outbound_buffer=?self.outbound_buffer, "sending"); + self.stream.write(&self.outbound_buffer[..]); + self.outbound_buffer = BytesMut::new(); + self.stream + .flush() + .await + .context("Failed to flush stream")?; + Ok(()) + } + + pub async fn recv(&mut self) -> Result { + let mut payload = BytesMut::new(); + loop { + { + let got_full_packet = self.codec.decode(&mut self.inbound_buffer, &mut payload)?; + if got_full_packet { + trace!(?payload, "received"); + return Ok(payload.freeze()) + } + } + let read_bytes = self.stream.read_buf(&mut self.inbound_buffer).await?; + if read_bytes == 0 { + anyhow::bail!("Unexpected EOF"); + } + trace!(inbound_buffer=?self.inbound_buffer, "received chunk"); + } + } + + pub fn reset_sequence_id (&mut self) { + self.codec.reset_seq_id(); + } +}