This commit is contained in:
Eugene Pankov 2022-07-14 21:22:03 +02:00
parent 8b143934f1
commit 659deb0d75
No known key found for this signature in database
GPG key ID: 5896FCBBDD1CF4F4
5 changed files with 233 additions and 23 deletions

3
Cargo.lock generated
View file

@ -4667,9 +4667,12 @@ dependencies = [
"mysql_common",
"password-hash 0.2.3",
"rand",
"rustls",
"rustls-pemfile",
"sha1",
"sqlx-core-guts",
"tokio",
"tokio-rustls",
"tracing",
"uuid 0.8.2",
"warpgate-admin",

View file

@ -23,3 +23,6 @@ rand = "0.8"
sha1 = "0.10.1"
password-hash = { version = "0.2", features = ["std"] }
delegate = "0.6"
rustls = "0.20"
rustls-pemfile = "1.0"
tokio-rustls = "0.23"

View file

@ -33,7 +33,7 @@ impl MySQLClient {
let handshake = Handshake::decode(payload)?;
options.capabilities &= handshake.server_capabilities;
// options.capabilities.remove(Capabilities::CONNECT_ATTRS);
options.capabilities |= Capabilities::SSL;
debug!(?handshake, "Received handshake");
debug!(capabilities=?options.capabilities, "Capabilities");

View file

@ -7,6 +7,9 @@ use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use common::compute_auth_challenge_response;
use rand::Rng;
use rustls::server::{ClientHello, NoClientAuth, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::{Certificate, PrivateKey, ServerConfig};
use sqlx_core_guts::io::Decode;
use sqlx_core_guts::mysql::protocol::auth::AuthPlugin;
use sqlx_core_guts::mysql::protocol::connect::{AuthSwitchRequest, Handshake, HandshakeResponse};
@ -15,6 +18,7 @@ use sqlx_core_guts::mysql::protocol::text::Query;
use sqlx_core_guts::mysql::protocol::Capabilities;
use std::fmt::Debug;
use std::net::SocketAddr;
use std::sync::Arc;
use stream::MySQLStream;
use tokio::net::{TcpListener, TcpStream};
use tracing::*;
@ -35,21 +39,87 @@ impl MySQLProtocolServer {
}
}
struct ResolveServerCert(Arc<CertifiedKey>);
impl ResolvesServerCert for ResolveServerCert {
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.0.clone())
}
}
#[async_trait]
impl ProtocolServer for MySQLProtocolServer {
async fn run(self, address: SocketAddr) -> Result<()> {
let (certificates, key_bytes) = {
let config = self.services.config.lock().await;
let certificate_path = config
.paths_relative_to
.join(&config.store.mysql.certificate);
let key_path = config.paths_relative_to.join(&config.store.mysql.key);
(
rustls_pemfile::certs(
&mut std::fs::read(&certificate_path)
.with_context(|| {
format!(
"reading SSL certificate from '{}'",
certificate_path.display()
)
})?
.as_slice(),
)
.map(|mut certs| {
certs
.drain(..)
.map(Certificate)
.collect::<Vec<Certificate>>()
})
.context("failed to parse tls certificates")?,
std::fs::read(&key_path).with_context(|| {
format!("reading SSL private key from '{}'", key_path.display())
})?,
)
};
let mut key = rustls_pemfile::pkcs8_private_keys(&mut key_bytes.as_slice())?
.drain(..)
.next()
.map(PrivateKey);
if key.is_none() {
key = rustls_pemfile::rsa_private_keys(&mut key_bytes.as_slice())?
.drain(..)
.next()
.map(PrivateKey);
}
let key = key.context("no private keys in file")?;
let key = rustls::sign::any_supported_type(&key)?;
let cert_key = Arc::new(CertifiedKey {
cert: certificates,
key,
ocsp: None,
sct_list: None,
});
let tls_config = ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(NoClientAuth::new())
.with_cert_resolver(Arc::new(ResolveServerCert(cert_key)));
info!(?address, "Listening");
let listener = TcpListener::bind(address).await?;
loop {
let (stream, addr) = listener.accept().await?;
let tls_config = tls_config.clone();
tokio::spawn(async move {
match Session::new(stream).run().await {
match Session::new(stream, tls_config).run().await {
Ok(_) => info!(?addr, "Session finished"),
Err(e) => error!(?addr, ?e, "Session failed"),
}
});
}
Ok(())
}
async fn test_target(self, _target: Target) -> Result<(), TargetTestError> {
@ -67,10 +137,11 @@ struct Session {
stream: MySQLStream,
capabilities: Capabilities,
challenge: [u8; 20],
tls_config: Arc<ServerConfig>,
}
impl Session {
pub fn new(stream: TcpStream) -> Self {
pub fn new(stream: TcpStream, tls_config: ServerConfig) -> Self {
Self {
stream: MySQLStream::new(stream),
capabilities: Capabilities::PROTOCOL_41
@ -90,8 +161,10 @@ impl Session {
| Capabilities::TRANSACTIONS
// | Capabilities::MULTI_FACTOR_AUTHENTICATION
| Capabilities::DEPRECATE_EOF
| Capabilities::SECURE_CONNECTION,
| Capabilities::SECURE_CONNECTION
| Capabilities::SSL,
challenge: get_crypto_rng().gen(),
tls_config: Arc::new(tls_config),
}
}
@ -142,12 +215,24 @@ impl Session {
self.stream.push(&handshake, ())?;
self.stream.flush().await?;
let payload = self.stream.recv().await?;
let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities)
.context("Failed to parse packet")?;
let resp = loop {
let payload = self.stream.recv().await?;
let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities)
.context("Failed to parse packet")?;
trace!(?resp, "Handshake response");
info!(capabilities=?self.capabilities, username=%resp.username, "handshake complete");
trace!(?resp, "Handshake response");
info!(capabilities=?self.capabilities, username=%resp.username, "handshake complete");
if self.capabilities.contains(Capabilities::SSL) {
if self.stream.is_tls() {
break resp
}
self.stream.upgrade(self.tls_config.clone()).await?;
continue
} else {
break resp
}
};
if resp.auth_plugin == Some(AuthPlugin::MySqlNativePassword) {
if let Some(response) = resp.auth_response.as_ref() {
@ -213,8 +298,10 @@ impl Session {
let payload = self.stream.recv().await?;
trace!(?payload, "server got packet");
let com = payload.get(0);
// COM_QUERY
if payload.get(0) == Some(&0x03) {
if com == Some(&0x03) {
let query = Query::decode(payload)?;
trace!(?query, "server got query");
@ -227,23 +314,42 @@ impl Session {
trace!(?response, "client got packet");
self.stream.push(&&response[..], ())?;
self.stream.flush().await?;
if let Some(b) = response.get(0) {
if b == &0xfe {
if let Some(com) = response.get(0) {
if com == &0xfe {
eof_ctr += 1;
if eof_ctr == 2 && !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
// tood check multiple results
if eof_ctr == 2
&& !self.capabilities.contains(Capabilities::DEPRECATE_EOF)
{
// todo check multiple results
break;
}
}
if b == &0 || b == &0xff {
if com == &0 || com == &0xff {
break;
}
}
}
// COM_QUIT
} else if com == Some(&0x01) {
break;
// COM_FIELD_LIST
} else if com == Some(&0x04) {
client.stream.push(&&payload[..], ())?;
client.stream.flush().await?;
loop {
let response = client.stream.recv().await?;
trace!(?response, "client got packet");
self.stream.push(&&response[..], ())?;
self.stream.flush().await?;
if let Some(com) = response.get(0) {
if com == &0 || com == &0xff || com == &0xfe {
break;
}
}
}
} else {
warn!("Unknown packet type {:?}", payload.get(0));
self.send_error(999, "Not implemented").await?;
self.send_error(1047, "Not implemented").await?;
}
}

View file

@ -2,12 +2,15 @@ use anyhow::{Context, Result};
use bytes::{Bytes, BytesMut};
use mysql_common::proto::codec::PacketCodec;
use sqlx_core_guts::io::{BufStream, Encode};
use tokio::io::AsyncReadExt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tracing::*;
pub struct MySQLStream {
stream: BufStream<TcpStream>,
stream: BufStream<MaybeTlsStream<TcpStream>>,
codec: PacketCodec,
inbound_buffer: BytesMut,
outbound_buffer: BytesMut,
@ -16,7 +19,7 @@ pub struct MySQLStream {
impl MySQLStream {
pub fn new(stream: TcpStream) -> Self {
Self {
stream: BufStream::new(stream),
stream: BufStream::new(MaybeTlsStream::Raw(stream)),
codec: PacketCodec::default(),
inbound_buffer: BytesMut::new(),
outbound_buffer: BytesMut::new(),
@ -50,7 +53,7 @@ impl MySQLStream {
let got_full_packet = self.codec.decode(&mut self.inbound_buffer, &mut payload)?;
if got_full_packet {
trace!(?payload, "received");
return Ok(payload.freeze())
return Ok(payload.freeze());
}
}
let read_bytes = self.stream.read_buf(&mut self.inbound_buffer).await?;
@ -61,7 +64,102 @@ impl MySQLStream {
}
}
pub fn reset_sequence_id (&mut self) {
pub fn reset_sequence_id(&mut self) {
self.codec.reset_seq_id();
}
pub async fn upgrade(&mut self, tls_config: Arc<rustls::ServerConfig>) -> Result<()> {
if let MaybeTlsStream::Raw(stream) =
std::mem::replace(&mut self.stream, BufStream::new(MaybeTlsStream::Upgrading)).take()
{
let acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
let accept = acceptor.accept(stream).await.context("TLS setup failed")?;
self.stream = BufStream::new(MaybeTlsStream::ServerTls(accept));
Ok(())
} else {
anyhow::bail!("bad state")
}
}
pub fn is_tls(&self) -> bool {
match *self.stream {
MaybeTlsStream::Raw(_) => false,
MaybeTlsStream::ServerTls(_) => true,
MaybeTlsStream::ClientTls(_) => true,
MaybeTlsStream::Upgrading => false,
}
}
}
enum MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
ClientTls(tokio_rustls::client::TlsStream<S>),
ServerTls(tokio_rustls::server::TlsStream<S>),
Raw(S),
Upgrading,
}
impl<S> AsyncRead for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_read(cx, buf),
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_read(cx, buf),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
_ => unreachable!(),
}
}
}
impl<S> AsyncWrite for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_write(cx, buf),
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_write(cx, buf),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
_ => unreachable!(),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_flush(cx),
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_flush(cx),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_flush(cx),
_ => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_shutdown(cx),
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_shutdown(cx),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
_ => unreachable!(),
}
}
}