mirror of
https://github.com/warp-tech/warpgate.git
synced 2024-09-20 06:46:17 +08:00
tls
This commit is contained in:
parent
8b143934f1
commit
659deb0d75
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -4667,9 +4667,12 @@ dependencies = [
|
|||
"mysql_common",
|
||||
"password-hash 0.2.3",
|
||||
"rand",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"sha1",
|
||||
"sqlx-core-guts",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tracing",
|
||||
"uuid 0.8.2",
|
||||
"warpgate-admin",
|
||||
|
|
|
@ -23,3 +23,6 @@ rand = "0.8"
|
|||
sha1 = "0.10.1"
|
||||
password-hash = { version = "0.2", features = ["std"] }
|
||||
delegate = "0.6"
|
||||
rustls = "0.20"
|
||||
rustls-pemfile = "1.0"
|
||||
tokio-rustls = "0.23"
|
||||
|
|
|
@ -33,7 +33,7 @@ impl MySQLClient {
|
|||
let handshake = Handshake::decode(payload)?;
|
||||
|
||||
options.capabilities &= handshake.server_capabilities;
|
||||
// options.capabilities.remove(Capabilities::CONNECT_ATTRS);
|
||||
options.capabilities |= Capabilities::SSL;
|
||||
|
||||
debug!(?handshake, "Received handshake");
|
||||
debug!(capabilities=?options.capabilities, "Capabilities");
|
||||
|
|
|
@ -7,6 +7,9 @@ use async_trait::async_trait;
|
|||
use bytes::{Buf, BytesMut};
|
||||
use common::compute_auth_challenge_response;
|
||||
use rand::Rng;
|
||||
use rustls::server::{ClientHello, NoClientAuth, ResolvesServerCert};
|
||||
use rustls::sign::CertifiedKey;
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use sqlx_core_guts::io::Decode;
|
||||
use sqlx_core_guts::mysql::protocol::auth::AuthPlugin;
|
||||
use sqlx_core_guts::mysql::protocol::connect::{AuthSwitchRequest, Handshake, HandshakeResponse};
|
||||
|
@ -15,6 +18,7 @@ use sqlx_core_guts::mysql::protocol::text::Query;
|
|||
use sqlx_core_guts::mysql::protocol::Capabilities;
|
||||
use std::fmt::Debug;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use stream::MySQLStream;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tracing::*;
|
||||
|
@ -35,21 +39,87 @@ impl MySQLProtocolServer {
|
|||
}
|
||||
}
|
||||
|
||||
struct ResolveServerCert(Arc<CertifiedKey>);
|
||||
|
||||
impl ResolvesServerCert for ResolveServerCert {
|
||||
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
|
||||
Some(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolServer for MySQLProtocolServer {
|
||||
async fn run(self, address: SocketAddr) -> Result<()> {
|
||||
let (certificates, key_bytes) = {
|
||||
let config = self.services.config.lock().await;
|
||||
let certificate_path = config
|
||||
.paths_relative_to
|
||||
.join(&config.store.mysql.certificate);
|
||||
let key_path = config.paths_relative_to.join(&config.store.mysql.key);
|
||||
|
||||
(
|
||||
rustls_pemfile::certs(
|
||||
&mut std::fs::read(&certificate_path)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"reading SSL certificate from '{}'",
|
||||
certificate_path.display()
|
||||
)
|
||||
})?
|
||||
.as_slice(),
|
||||
)
|
||||
.map(|mut certs| {
|
||||
certs
|
||||
.drain(..)
|
||||
.map(Certificate)
|
||||
.collect::<Vec<Certificate>>()
|
||||
})
|
||||
.context("failed to parse tls certificates")?,
|
||||
std::fs::read(&key_path).with_context(|| {
|
||||
format!("reading SSL private key from '{}'", key_path.display())
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
let mut key = rustls_pemfile::pkcs8_private_keys(&mut key_bytes.as_slice())?
|
||||
.drain(..)
|
||||
.next()
|
||||
.map(PrivateKey);
|
||||
|
||||
if key.is_none() {
|
||||
key = rustls_pemfile::rsa_private_keys(&mut key_bytes.as_slice())?
|
||||
.drain(..)
|
||||
.next()
|
||||
.map(PrivateKey);
|
||||
}
|
||||
|
||||
let key = key.context("no private keys in file")?;
|
||||
let key = rustls::sign::any_supported_type(&key)?;
|
||||
|
||||
let cert_key = Arc::new(CertifiedKey {
|
||||
cert: certificates,
|
||||
key,
|
||||
ocsp: None,
|
||||
sct_list: None,
|
||||
});
|
||||
|
||||
let tls_config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_client_cert_verifier(NoClientAuth::new())
|
||||
.with_cert_resolver(Arc::new(ResolveServerCert(cert_key)));
|
||||
|
||||
info!(?address, "Listening");
|
||||
let listener = TcpListener::bind(address).await?;
|
||||
loop {
|
||||
let (stream, addr) = listener.accept().await?;
|
||||
let tls_config = tls_config.clone();
|
||||
tokio::spawn(async move {
|
||||
match Session::new(stream).run().await {
|
||||
match Session::new(stream, tls_config).run().await {
|
||||
Ok(_) => info!(?addr, "Session finished"),
|
||||
Err(e) => error!(?addr, ?e, "Session failed"),
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_target(self, _target: Target) -> Result<(), TargetTestError> {
|
||||
|
@ -67,10 +137,11 @@ struct Session {
|
|||
stream: MySQLStream,
|
||||
capabilities: Capabilities,
|
||||
challenge: [u8; 20],
|
||||
tls_config: Arc<ServerConfig>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(stream: TcpStream) -> Self {
|
||||
pub fn new(stream: TcpStream, tls_config: ServerConfig) -> Self {
|
||||
Self {
|
||||
stream: MySQLStream::new(stream),
|
||||
capabilities: Capabilities::PROTOCOL_41
|
||||
|
@ -90,8 +161,10 @@ impl Session {
|
|||
| Capabilities::TRANSACTIONS
|
||||
// | Capabilities::MULTI_FACTOR_AUTHENTICATION
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::SECURE_CONNECTION,
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::SSL,
|
||||
challenge: get_crypto_rng().gen(),
|
||||
tls_config: Arc::new(tls_config),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,12 +215,24 @@ impl Session {
|
|||
self.stream.push(&handshake, ())?;
|
||||
self.stream.flush().await?;
|
||||
|
||||
let payload = self.stream.recv().await?;
|
||||
let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities)
|
||||
.context("Failed to parse packet")?;
|
||||
let resp = loop {
|
||||
let payload = self.stream.recv().await?;
|
||||
let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities)
|
||||
.context("Failed to parse packet")?;
|
||||
|
||||
trace!(?resp, "Handshake response");
|
||||
info!(capabilities=?self.capabilities, username=%resp.username, "handshake complete");
|
||||
trace!(?resp, "Handshake response");
|
||||
info!(capabilities=?self.capabilities, username=%resp.username, "handshake complete");
|
||||
|
||||
if self.capabilities.contains(Capabilities::SSL) {
|
||||
if self.stream.is_tls() {
|
||||
break resp
|
||||
}
|
||||
self.stream.upgrade(self.tls_config.clone()).await?;
|
||||
continue
|
||||
} else {
|
||||
break resp
|
||||
}
|
||||
};
|
||||
|
||||
if resp.auth_plugin == Some(AuthPlugin::MySqlNativePassword) {
|
||||
if let Some(response) = resp.auth_response.as_ref() {
|
||||
|
@ -213,8 +298,10 @@ impl Session {
|
|||
let payload = self.stream.recv().await?;
|
||||
trace!(?payload, "server got packet");
|
||||
|
||||
let com = payload.get(0);
|
||||
|
||||
// COM_QUERY
|
||||
if payload.get(0) == Some(&0x03) {
|
||||
if com == Some(&0x03) {
|
||||
let query = Query::decode(payload)?;
|
||||
trace!(?query, "server got query");
|
||||
|
||||
|
@ -227,23 +314,42 @@ impl Session {
|
|||
trace!(?response, "client got packet");
|
||||
self.stream.push(&&response[..], ())?;
|
||||
self.stream.flush().await?;
|
||||
if let Some(b) = response.get(0) {
|
||||
if b == &0xfe {
|
||||
if let Some(com) = response.get(0) {
|
||||
if com == &0xfe {
|
||||
eof_ctr += 1;
|
||||
if eof_ctr == 2 && !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
// tood check multiple results
|
||||
if eof_ctr == 2
|
||||
&& !self.capabilities.contains(Capabilities::DEPRECATE_EOF)
|
||||
{
|
||||
// todo check multiple results
|
||||
break;
|
||||
}
|
||||
}
|
||||
if b == &0 || b == &0xff {
|
||||
if com == &0 || com == &0xff {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// COM_QUIT
|
||||
} else if com == Some(&0x01) {
|
||||
break;
|
||||
// COM_FIELD_LIST
|
||||
} else if com == Some(&0x04) {
|
||||
client.stream.push(&&payload[..], ())?;
|
||||
client.stream.flush().await?;
|
||||
loop {
|
||||
let response = client.stream.recv().await?;
|
||||
trace!(?response, "client got packet");
|
||||
self.stream.push(&&response[..], ())?;
|
||||
self.stream.flush().await?;
|
||||
if let Some(com) = response.get(0) {
|
||||
if com == &0 || com == &0xff || com == &0xfe {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
warn!("Unknown packet type {:?}", payload.get(0));
|
||||
self.send_error(999, "Not implemented").await?;
|
||||
self.send_error(1047, "Not implemented").await?;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,12 +2,15 @@ use anyhow::{Context, Result};
|
|||
use bytes::{Bytes, BytesMut};
|
||||
use mysql_common::proto::codec::PacketCodec;
|
||||
use sqlx_core_guts::io::{BufStream, Encode};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::*;
|
||||
|
||||
pub struct MySQLStream {
|
||||
stream: BufStream<TcpStream>,
|
||||
stream: BufStream<MaybeTlsStream<TcpStream>>,
|
||||
codec: PacketCodec,
|
||||
inbound_buffer: BytesMut,
|
||||
outbound_buffer: BytesMut,
|
||||
|
@ -16,7 +19,7 @@ pub struct MySQLStream {
|
|||
impl MySQLStream {
|
||||
pub fn new(stream: TcpStream) -> Self {
|
||||
Self {
|
||||
stream: BufStream::new(stream),
|
||||
stream: BufStream::new(MaybeTlsStream::Raw(stream)),
|
||||
codec: PacketCodec::default(),
|
||||
inbound_buffer: BytesMut::new(),
|
||||
outbound_buffer: BytesMut::new(),
|
||||
|
@ -50,7 +53,7 @@ impl MySQLStream {
|
|||
let got_full_packet = self.codec.decode(&mut self.inbound_buffer, &mut payload)?;
|
||||
if got_full_packet {
|
||||
trace!(?payload, "received");
|
||||
return Ok(payload.freeze())
|
||||
return Ok(payload.freeze());
|
||||
}
|
||||
}
|
||||
let read_bytes = self.stream.read_buf(&mut self.inbound_buffer).await?;
|
||||
|
@ -61,7 +64,102 @@ impl MySQLStream {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn reset_sequence_id (&mut self) {
|
||||
pub fn reset_sequence_id(&mut self) {
|
||||
self.codec.reset_seq_id();
|
||||
}
|
||||
|
||||
pub async fn upgrade(&mut self, tls_config: Arc<rustls::ServerConfig>) -> Result<()> {
|
||||
if let MaybeTlsStream::Raw(stream) =
|
||||
std::mem::replace(&mut self.stream, BufStream::new(MaybeTlsStream::Upgrading)).take()
|
||||
{
|
||||
let acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
|
||||
|
||||
let accept = acceptor.accept(stream).await.context("TLS setup failed")?;
|
||||
|
||||
self.stream = BufStream::new(MaybeTlsStream::ServerTls(accept));
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("bad state")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tls(&self) -> bool {
|
||||
match *self.stream {
|
||||
MaybeTlsStream::Raw(_) => false,
|
||||
MaybeTlsStream::ServerTls(_) => true,
|
||||
MaybeTlsStream::ClientTls(_) => true,
|
||||
MaybeTlsStream::Upgrading => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum MaybeTlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
ClientTls(tokio_rustls::client::TlsStream<S>),
|
||||
ServerTls(tokio_rustls::server::TlsStream<S>),
|
||||
Raw(S),
|
||||
Upgrading,
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for MaybeTlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_read(cx, buf),
|
||||
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_read(cx, buf),
|
||||
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for MaybeTlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_write(cx, buf),
|
||||
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_write(cx, buf),
|
||||
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_flush(cx),
|
||||
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_flush(cx),
|
||||
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_flush(cx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::ClientTls(tls) => Pin::new(tls).poll_shutdown(cx),
|
||||
MaybeTlsStream::ServerTls(tls) => Pin::new(tls).poll_shutdown(cx),
|
||||
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue