diff --git a/warpgate-protocol-postgres/Cargo.toml b/warpgate-protocol-postgres/Cargo.toml index f03474d9..ac26d9c1 100644 --- a/warpgate-protocol-postgres/Cargo.toml +++ b/warpgate-protocol-postgres/Cargo.toml @@ -18,7 +18,7 @@ rustls-pemfile = "1.0" tokio-rustls = "0.26" thiserror = "1.0" rustls-native-certs = "0.6" -pgwire = { version = "0.23", default-features = false, features = [ +pgwire = { version = "0.25", default-features = false, features = [ "server-api", ] } rsasl = { version = "2.1.0", default-features = false, features = ["config_builder", "scram-sha-2", "std", "plain", "provider"] } diff --git a/warpgate-protocol-postgres/src/client.rs b/warpgate-protocol-postgres/src/client.rs index a2584846..c1bb40a3 100644 --- a/warpgate-protocol-postgres/src/client.rs +++ b/warpgate-protocol-postgres/src/client.rs @@ -237,8 +237,6 @@ impl PostgresClient { return Err(PostgresError::Eof); }; - dbg!(&payload); - match payload.0 { PgWireBackendMessage::ErrorResponse(response) => return Err(response.into()), PgWireBackendMessage::Authentication( diff --git a/warpgate-protocol-postgres/src/stream.rs b/warpgate-protocol-postgres/src/stream.rs index 636cd6da..496b7174 100644 --- a/warpgate-protocol-postgres/src/stream.rs +++ b/warpgate-protocol-postgres/src/stream.rs @@ -1,9 +1,7 @@ use std::fmt::Debug; -use std::io::Cursor; -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use pgwire::error::{PgWireError, PgWireResult}; -use pgwire::messages::startup::MESSAGE_TYPE_BYTE_AUTHENTICATION; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; @@ -59,25 +57,6 @@ impl PostgresDecode for PgWireGenericFrontendMessage { impl PostgresDecode for PgWireGenericBackendMessage { fn decode(buf: &mut BytesMut) -> PgWireResult> { - let first_byte = { - let mut peeker = Cursor::new(&mut buf[..]); - if peeker.remaining() > 1 { - Some(peeker.get_u8()) - } else { - None - } - }; - - #[allow(clippy::single_match)] - match first_byte { - Some(MESSAGE_TYPE_BYTE_AUTHENTICATION) => { - return Ok(AuthenticationMsgExt::decode(buf)?.map(|x| { - PgWireGenericBackendMessage(PgWireBackendMessage::Authentication(x.0)) - })); - } - _ => (), - } - PgWireBackendMessage::decode(buf).map(|x| x.map(PgWireGenericBackendMessage)) } } @@ -106,50 +85,6 @@ impl PostgresEncode for T { } } -mod authentication_ext { - use std::io::Cursor; - - use bytes::Buf; - use pgwire::messages::startup::Authentication; - use pgwire::messages::Message; - - use super::*; - - /// Workaround for https://github.com/sunng87/pgwire/issues/208 - #[derive(PartialEq, Eq, Debug)] - pub struct AuthenticationMsgExt(pub Authentication); - - impl Message for AuthenticationMsgExt { - #[inline] - fn message_type() -> Option { - Authentication::message_type() - } - - #[inline] - fn message_length(&self) -> usize { - self.0.message_length() - } - - fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { - self.0.encode_body(buf) - } - - fn decode_body(buf: &mut BytesMut, len: usize) -> PgWireResult { - let mut peeker = Cursor::new(&buf[..]); - let code = peeker.get_i32(); - Ok(match code { - 12 => { - buf.advance(4); - Self(Authentication::SASLFinal(buf.split_to(len - 8).freeze())) - } - _ => Self(Authentication::decode_body(buf, len)?), - }) - } - } -} - -pub use authentication_ext::AuthenticationMsgExt; - pub(crate) struct PostgresStream where TcpStream: UpgradableStream,