added MySQL support

This commit is contained in:
Eugene Pankov 2022-07-20 10:02:37 +02:00
parent 6be92356c2
commit 7682e17080
No known key found for this signature in database
GPG key ID: 5896FCBBDD1CF4F4
84 changed files with 5377 additions and 277 deletions

View file

@ -27,6 +27,10 @@ replace = version = "{new_version}"
search = version = "{current_version}"
replace = version = "{new_version}"
[bumpversion:file:warpgate-protocol-mysql/Cargo.toml]
search = version = "{current_version}"
replace = version = "{new_version}"
[bumpversion:file:warpgate-protocol-ssh/Cargo.toml]
search = version = "{current_version}"
replace = version = "{new_version}"

449
Cargo.lock generated
View file

@ -118,6 +118,12 @@ dependencies = [
"password-hash 0.4.1",
]
[[package]]
name = "arrayvec"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "async-attributes"
version = "1.1.2"
@ -396,9 +402,9 @@ checksum = "8a32fd6af2b5827bce66c29053ba0e7c42b9dcab01835835058558c10851a46b"
[[package]]
name = "bcrypt-pbkdf"
version = "0.6.2"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c38c03b9506bd92bf1ef50665a81eda156f615438f7654bffba58907e6149d7"
checksum = "12621b8e87feb183a6e5dbb315e49026b2229c4398797ee0ae2d1bc00aef41b9"
dependencies = [
"blowfish",
"crypto-mac",
@ -407,12 +413,42 @@ dependencies = [
"zeroize",
]
[[package]]
name = "bigdecimal"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6aaf33151a6429fe9211d1b276eafdf70cdff28b071e76c0b0e1503221ea3744"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc0455254eb5c6964c4545d8bac815e1a1be4f3afe0ae695ea539c12d728d44b"
[[package]]
name = "bindgen"
version = "0.59.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bd2a9a458e8f4304c52c43ebb0cfbd520289f8379a52e329a38afda99bf8eb8"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"lazy_static",
"lazycell",
"peeking_take_while",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
@ -425,6 +461,18 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "blake2"
version = "0.10.4"
@ -529,6 +577,15 @@ version = "1.0.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
@ -564,6 +621,17 @@ dependencies = [
"generic-array",
]
[[package]]
name = "clang-sys"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a050e2153c5be08febd6734e29298e844fdb0fa21aeddd63b4eb7baa106c69b"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "3.2.2"
@ -603,6 +671,15 @@ dependencies = [
"os_str_bytes",
]
[[package]]
name = "cmake"
version = "0.1.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8ad8cef104ac57b68b89df3208164d228503abbdce70f6880ffa3d970e7443a"
dependencies = [
"cc",
]
[[package]]
name = "color_quant"
version = "1.1.0"
@ -1031,6 +1108,7 @@ dependencies = [
"cfg-if",
"crc32fast",
"libc",
"libz-sys",
"miniz_oxide",
]
@ -1077,6 +1155,70 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "frunk"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cd67cf7d54b7e72d0ea76f3985c3747d74aee43e0218ad993b7903ba7a5395e"
dependencies = [
"frunk_core",
"frunk_derives",
"frunk_proc_macros",
]
[[package]]
name = "frunk_core"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1246cf43ec80bf8b2505b5c360b8fb999c97dabd17dbb604d85558d5cbc25482"
[[package]]
name = "frunk_derives"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dbc4f084ec5a3f031d24ccedeb87ab2c3189a2f33b8d070889073837d5ea09e"
dependencies = [
"frunk_proc_macro_helpers",
"quote",
"syn",
]
[[package]]
name = "frunk_proc_macro_helpers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99f11257f106c6753f5ffcb8e601fb39c390a088017aaa55b70c526bff15f63e"
dependencies = [
"frunk_core",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "frunk_proc_macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a078bd8459eccbb85e0b007b8f756585762a72a9efc53f359b371c3b6351dbcc"
dependencies = [
"frunk_core",
"frunk_proc_macros_impl",
"proc-macro-hack",
]
[[package]]
name = "frunk_proc_macros_impl"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ffba99f0fa4f57e42f57388fbb9a0ca863bc2b4261f3c5570fed579d5df6c32"
dependencies = [
"frunk_core",
"frunk_proc_macro_helpers",
"proc-macro-hack",
"quote",
"syn",
]
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@ -1086,6 +1228,12 @@ dependencies = [
"libc",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "futures"
version = "0.3.21"
@ -1238,6 +1386,12 @@ version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4"
[[package]]
name = "glob"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "gloo-timers"
version = "0.2.4"
@ -1648,12 +1802,101 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "lexical"
version = "6.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7aefb36fd43fef7003334742cbf77b243fcd36418a1d1bdd480d613a67968f6"
dependencies = [
"lexical-core",
]
[[package]]
name = "lexical-core"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46"
dependencies = [
"lexical-parse-float",
"lexical-parse-integer",
"lexical-util",
"lexical-write-float",
"lexical-write-integer",
]
[[package]]
name = "lexical-parse-float"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f"
dependencies = [
"lexical-parse-integer",
"lexical-util",
"static_assertions",
]
[[package]]
name = "lexical-parse-integer"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9"
dependencies = [
"lexical-util",
"static_assertions",
]
[[package]]
name = "lexical-util"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc"
dependencies = [
"static_assertions",
]
[[package]]
name = "lexical-write-float"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862"
dependencies = [
"lexical-util",
"lexical-write-integer",
"static_assertions",
]
[[package]]
name = "lexical-write-integer"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446"
dependencies = [
"lexical-util",
"static_assertions",
]
[[package]]
name = "libc"
version = "0.2.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21a41fed9d98f27ab1c6d161da622a4fa35e8a54a8adc24bbf3ddd0ef70b0e50"
[[package]]
name = "libloading"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efbc0f03f9a775e9f6aed295c6a1ba2253c5757a9e03d55c6caa46a681abcddd"
dependencies = [
"cfg-if",
"winapi",
]
[[package]]
name = "libsodium-sys"
version = "0.2.7"
@ -1677,6 +1920,17 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "libz-sys"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9702761c3935f8cc2f101793272e202c72b99da8f4224a19ddcf1279a6450bbf"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linked-hash-map"
version = "0.5.4"
@ -1818,6 +2072,43 @@ dependencies = [
"version_check",
]
[[package]]
name = "mysql_common"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20ce6fdcef94a8e87fea3f9402de4d7e403f10e8c20b6e2bd207e6b6f8a4eab0"
dependencies = [
"base64",
"bigdecimal",
"bindgen",
"bitflags",
"bitvec",
"byteorder",
"bytes",
"cc",
"cmake",
"crc32fast",
"flate2",
"frunk",
"lazy_static",
"lexical",
"num-bigint",
"num-traits",
"rand",
"regex",
"rust_decimal",
"saturating",
"serde",
"serde_json",
"sha-1",
"sha2 0.10.2",
"smallvec",
"subprocess",
"thiserror",
"time 0.3.11",
"uuid",
]
[[package]]
name = "native-tls"
version = "0.2.10"
@ -2163,6 +2454,12 @@ dependencies = [
"sha2 0.9.9",
]
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pem"
version = "1.0.2"
@ -2444,6 +2741,12 @@ dependencies = [
"version_check",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5"
[[package]]
name = "proc-macro2"
version = "1.0.39"
@ -2505,6 +2808,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.8.5"
@ -2783,6 +3092,17 @@ dependencies = [
"walkdir",
]
[[package]]
name = "rust_decimal"
version = "1.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34a3bb58e85333f1ab191bf979104b586ebd77475bc6681882825f4532dfe87c"
dependencies = [
"arrayvec",
"num-traits",
"serde",
]
[[package]]
name = "rustc-demangle"
version = "0.1.21"
@ -2858,6 +3178,12 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "saturating"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71"
[[package]]
name = "schannel"
version = "0.1.19"
@ -3139,6 +3465,17 @@ dependencies = [
"digest 0.10.3",
]
[[package]]
name = "sha1"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77f4e7f65455545c2153c1253d25056825e77ee2533f0e41deb65a93a34852f"
dependencies = [
"cfg-if",
"cpufeatures",
"digest 0.10.3",
]
[[package]]
name = "sha2"
version = "0.9.9"
@ -3172,6 +3509,12 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3"
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
@ -3329,6 +3672,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stringprep"
version = "0.1.2"
@ -3345,6 +3694,16 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subprocess"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "subtle"
version = "2.4.1"
@ -3368,6 +3727,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "tempfile"
version = "3.3.0"
@ -3488,10 +3853,11 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokio"
version = "1.19.2"
version = "1.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439"
checksum = "57aec3cfa4c296db7255446efb4928a6be304b431a806216105542a67b6ca82e"
dependencies = [
"autocfg",
"bytes",
"libc",
"memchr",
@ -3561,9 +3927,9 @@ dependencies = [
[[package]]
name = "tokio-tungstenite"
version = "0.17.1"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06cda1232a49558c46f8a504d5b93101d42c0bf7f911f12a105ba48168f821ae"
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
dependencies = [
"futures-util",
"log",
@ -3778,9 +4144,9 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.17.2"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96a2dea40e7570482f28eb57afbe42d97551905da6a9400acc5c328d24004f5"
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
dependencies = [
"base64",
"byteorder",
@ -3995,6 +4361,7 @@ dependencies = [
"warpgate-admin",
"warpgate-common",
"warpgate-protocol-http",
"warpgate-protocol-mysql",
"warpgate-protocol-ssh",
]
@ -4011,6 +4378,7 @@ dependencies = [
"mime_guess",
"poem",
"poem-openapi",
"regex",
"russh-keys",
"rust-embed",
"sea-orm",
@ -4041,6 +4409,7 @@ dependencies = [
"once_cell",
"packet",
"password-hash 0.4.1",
"poem",
"poem-openapi",
"rand",
"rand_chacha",
@ -4060,6 +4429,19 @@ dependencies = [
"warpgate-db-migrations",
]
[[package]]
name = "warpgate-database-protocols"
version = "0.3.0"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"memchr",
"thiserror",
"tokio",
]
[[package]]
name = "warpgate-db-entities"
version = "0.3.0"
@ -4098,6 +4480,7 @@ dependencies = [
"percent-encoding",
"poem",
"poem-openapi",
"regex",
"reqwest",
"serde",
"serde_json",
@ -4111,6 +4494,33 @@ dependencies = [
"warpgate-web",
]
[[package]]
name = "warpgate-protocol-mysql"
version = "0.3.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"delegate",
"mysql_common",
"password-hash 0.2.3",
"rand",
"rustls",
"rustls-pemfile",
"sha1",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",
"uuid",
"warpgate-admin",
"warpgate-common",
"warpgate-database-protocols",
"warpgate-db-entities",
"webpki",
"webpki-roots",
]
[[package]]
name = "warpgate-protocol-ssh"
version = "0.3.0"
@ -4132,6 +4542,7 @@ dependencies = [
"uuid",
"warpgate-common",
"warpgate-db-entities",
"zeroize",
]
[[package]]
@ -4242,6 +4653,15 @@ dependencies = [
"untrusted",
]
[[package]]
name = "webpki-roots"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1c760f0d366a6c24a02ed7816e23e691f5d92291f94d15e836006fd11b04daf"
dependencies = [
"webpki",
]
[[package]]
name = "wepoll-ffi"
version = "0.1.2"
@ -4334,6 +4754,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "wyz"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30b31594f29d27036c383b53b59ed3476874d518f0efb151b27a4c275141390e"
dependencies = [
"tap",
]
[[package]]
name = "yaml-rust"
version = "0.4.5"
@ -4364,6 +4793,6 @@ dependencies = [
[[package]]
name = "zeroize"
version = "1.3.0"
version = "1.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4756f7db3f7b5574938c3eb1c117038b8e07f95ee6718c0efad4ac21508f1efd"
checksum = "20b578acffd8516a6c3f2a1bdefc1ec37e547bb4e0fb8b6b01a4cafc886b4442"

View file

@ -5,7 +5,9 @@ members = [
"warpgate-common",
"warpgate-db-migrations",
"warpgate-db-entities",
"warpgate-database-protocols",
"warpgate-protocol-http",
"warpgate-protocol-mysql",
"warpgate-protocol-ssh",
"warpgate-web",
]

View file

@ -1,4 +1,4 @@
projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-protocol-ssh"
projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-database-protocols warpgate-protocol-ssh warpgate-protocol-mysql"
run *ARGS:
RUST_BACKTRACE=1 RUST_LOG=warpgate cd warpgate && cargo run -- --config ../config.yaml {{ARGS}}

View file

@ -5,25 +5,40 @@ name = "warpgate-admin"
version = "0.3.0"
[dependencies]
anyhow = {version = "1.0", features = ["std"]}
anyhow = { version = "1.0", features = ["std"] }
async-trait = "0.1"
bytes = "1.1"
chrono = "0.4"
futures = "0.3"
hex = "0.4"
mime_guess = {version = "2.0", default_features = false}
poem = {version = "^1.3.30", features = ["cookie", "session", "anyhow", "websocket"]}
poem-openapi = {version = "^2.0.4", features = ["swagger-ui", "chrono", "uuid", "static-files"]}
russh-keys = {version = "0.22.0-beta.3", features = ["openssl"]}
mime_guess = { version = "2.0", default_features = false }
poem = { version = "^1.3.30", features = [
"cookie",
"session",
"anyhow",
"websocket",
] }
poem-openapi = { version = "^2.0.4", features = [
"swagger-ui",
"chrono",
"uuid",
"static-files",
] }
russh-keys = { version = "0.22.0-beta.3", features = ["openssl"] }
rust-embed = "6.3"
sea-orm = {version = "^0.9", features = ["sqlx-sqlite", "runtime-tokio-native-tls", "macros"], default-features = false}
sea-orm = { version = "^0.9", features = [
"sqlx-sqlite",
"runtime-tokio-native-tls",
"macros",
], default-features = false }
serde = "1.0"
serde_json = "1.0"
thiserror = "1.0"
tokio = {version = "1.19", features = ["tracing"]}
tokio = {version = "1.20", features = ["tracing"]}
tracing = "0.1"
uuid = {version = "1.0", features = ["v4", "serde"]}
warpgate-common = {version = "*", path = "../warpgate-common"}
warpgate-db-entities = {version = "*", path = "../warpgate-db-entities"}
warpgate-protocol-ssh = {version = "*", path = "../warpgate-protocol-ssh"}
warpgate-web = {version = "*", path = "../warpgate-web"}
uuid = { version = "1.0", features = ["v4", "serde"] }
warpgate-common = { version = "*", path = "../warpgate-common" }
warpgate-db-entities = { version = "*", path = "../warpgate-db-entities" }
warpgate-protocol-ssh = { version = "*", path = "../warpgate-protocol-ssh" }
warpgate-web = { version = "*", path = "../warpgate-web" }
regex = "1.5"

View file

@ -1,10 +1,18 @@
#![feature(type_alias_impl_trait, let_else, try_blocks)]
mod api;
use poem_openapi::OpenApiService;
use regex::Regex;
pub fn main() {
let api_service =
OpenApiService::new(api::get(), "Warpgate Web Admin", env!("CARGO_PKG_VERSION"))
.server("/@warpgate/admin/api");
println!("{}", api_service.spec());
let spec = api_service.spec();
let re = Regex::new(r"TargetOptions\[(?P<name>\w+)\]").unwrap();
let spec = re.replace_all(&spec, "TargetOptions$name");
let re = Regex::new(r"PaginatedResponse<(?P<name>\w+)>").unwrap();
let spec = re.replace_all(&spec, "Paginated$name");
println!("{}", spec);
}

View file

@ -16,7 +16,8 @@ lazy_static = "1.4"
once_cell = "1.10"
packet = "0.1"
password-hash = "0.4"
poem-openapi = {version = "^2.0.4", features = ["swagger-ui", "chrono", "uuid", "static-files"]}
poem = "^1.3.30"
poem-openapi = {version = "2.0.4", features = ["swagger-ui", "chrono", "uuid", "static-files"]}
rand = "0.8"
rand_chacha = "0.3"
rand_core = {version = "0.6", features = ["std"]}
@ -24,7 +25,7 @@ sea-orm = {version = "^0.9", features = ["sqlx-sqlite", "runtime-tokio-native-tl
serde = "1.0"
serde_json = "1.0"
thiserror = "1.0"
tokio = {version = "1.19", features = ["tracing"]}
tokio = {version = "1.20", features = ["tracing"]}
totp-rs = {version = "2.0", features = ["otpauth"]}
tracing = "0.1"
tracing-core = "0.1"

View file

@ -20,7 +20,9 @@ impl From<&String> for AuthSelector {
return AuthSelector::Ticket { secret };
}
let mut parts = selector.splitn(2, ':');
let separator = if selector.contains('#') { '#' } else { ':' };
let mut parts = selector.splitn(2, separator);
let username = parts.next().unwrap_or("").to_string();
let target_name = parts.next().unwrap_or("").to_string();
AuthSelector::User {

View file

@ -1,12 +1,13 @@
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use std::path::PathBuf;
use std::time::Duration;
use poem_openapi::{Object, Union};
use poem_openapi::{Enum, Object, Union};
use serde::{Deserialize, Serialize};
use crate::helpers::otp::OtpSecretKey;
use crate::Secret;
use crate::{ListenEndpoint, Secret};
const fn _default_true() -> bool {
true
@ -16,10 +17,14 @@ const fn _default_false() -> bool {
false
}
const fn _default_port() -> u16 {
const fn _default_ssh_port() -> u16 {
22
}
const fn _default_mysql_port() -> u16 {
3306
}
#[inline]
fn _default_username() -> String {
"root".to_owned()
@ -41,8 +46,13 @@ fn _default_database_url() -> Secret<String> {
}
#[inline]
fn _default_http_listen() -> String {
"0.0.0.0:8888".to_owned()
fn _default_http_listen() -> ListenEndpoint {
ListenEndpoint("0.0.0.0:8888".to_socket_addrs().unwrap().next().unwrap())
}
#[inline]
fn _default_mysql_listen() -> ListenEndpoint {
ListenEndpoint("0.0.0.0:33306".to_socket_addrs().unwrap().next().unwrap())
}
#[inline]
@ -58,7 +68,7 @@ fn _default_empty_vec<T>() -> Vec<T> {
#[derive(Debug, Deserialize, Serialize, Clone, Object)]
pub struct TargetSSHOptions {
pub host: String,
#[serde(default = "_default_port")]
#[serde(default = "_default_ssh_port")]
pub port: u16,
#[serde(default = "_default_username")]
pub username: String,
@ -91,6 +101,59 @@ pub struct TargetHTTPOptions {
pub headers: Option<HashMap<String, String>>,
}
#[derive(Debug, Deserialize, Serialize, Clone, Enum, PartialEq, Eq)]
pub enum TlsMode {
Disabled,
Preferred,
Required,
}
impl Default for TlsMode {
fn default() -> Self {
TlsMode::Preferred
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Object)]
pub struct Tls {
#[serde(default)]
pub mode: TlsMode,
#[serde(default)]
pub verify: bool,
}
#[allow(clippy::derivable_impls)]
impl Default for Tls {
fn default() -> Self {
Self {
mode: TlsMode::default(),
verify: false,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Object)]
pub struct TargetMySqlOptions {
#[serde(default = "_default_empty_string")]
pub host: String,
#[serde(default = "_default_mysql_port")]
pub port: u16,
#[serde(default = "_default_username")]
pub username: String,
#[serde(default)]
pub password: Option<String>,
#[serde(default)]
pub tls: Tls,
#[serde(default)]
pub verify_tls: bool,
}
#[derive(Debug, Deserialize, Serialize, Clone, Object, Default)]
pub struct TargetWebAdminOptions {}
@ -104,12 +167,14 @@ pub struct Target {
}
#[derive(Debug, Deserialize, Serialize, Clone, Union)]
#[oai(discriminator_name = "kind")]
#[oai(discriminator_name = "kind", one_of)]
pub enum TargetOptions {
#[serde(rename = "ssh")]
Ssh(TargetSSHOptions),
#[serde(rename = "http")]
Http(TargetHTTPOptions),
#[serde(rename = "mysql")]
MySql(TargetMySqlOptions),
#[serde(rename = "web_admin")]
WebAdmin(TargetWebAdminOptions),
}
@ -134,6 +199,8 @@ pub struct UserRequireCredentialsPolicy {
pub http: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssh: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mysql: Option<Vec<String>>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
@ -150,8 +217,8 @@ pub struct Role {
pub name: String,
}
fn _default_ssh_listen() -> String {
"0.0.0.0:2222".to_owned()
fn _default_ssh_listen() -> ListenEndpoint {
ListenEndpoint("0.0.0.0:2222".to_socket_addrs().unwrap().next().unwrap())
}
fn _default_ssh_client_key() -> String {
@ -164,8 +231,11 @@ fn _default_ssh_keys_path() -> String {
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SSHConfig {
#[serde(default = "_default_false")]
pub enable: bool,
#[serde(default = "_default_ssh_listen")]
pub listen: String,
pub listen: ListenEndpoint,
#[serde(default = "_default_ssh_keys_path")]
pub keys: String,
@ -177,6 +247,7 @@ pub struct SSHConfig {
impl Default for SSHConfig {
fn default() -> Self {
SSHConfig {
enable: true,
listen: _default_ssh_listen(),
keys: _default_ssh_keys_path(),
client_key: _default_ssh_client_key(),
@ -186,11 +257,11 @@ impl Default for SSHConfig {
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct HTTPConfig {
#[serde(default = "_default_false")]
#[serde(default = "_default_true")]
pub enable: bool,
#[serde(default = "_default_http_listen")]
pub listen: String,
pub listen: ListenEndpoint,
#[serde(default)]
pub certificate: String,
@ -210,6 +281,32 @@ impl Default for HTTPConfig {
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct MySQLConfig {
#[serde(default = "_default_false")]
pub enable: bool,
#[serde(default = "_default_mysql_listen")]
pub listen: ListenEndpoint,
#[serde(default)]
pub certificate: String,
#[serde(default)]
pub key: String,
}
impl Default for MySQLConfig {
fn default() -> Self {
MySQLConfig {
enable: true,
listen: _default_http_listen(),
certificate: "".to_owned(),
key: "".to_owned(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RecordingsConfig {
#[serde(default = "_default_false")]
@ -267,6 +364,9 @@ pub struct WarpgateConfigStore {
#[serde(default)]
pub http: HTTPConfig,
#[serde(default)]
pub mysql: MySQLConfig,
#[serde(default)]
pub log: LogConfig,
}
@ -282,6 +382,7 @@ impl Default for WarpgateConfigStore {
database_url: _default_database_url(),
ssh: SSHConfig::default(),
http: HTTPConfig::default(),
mysql: MySQLConfig::default(),
log: LogConfig::default(),
}
}

View file

@ -1,7 +1,6 @@
use std::collections::HashSet;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use data_encoding::BASE64_MIME;
use sea_orm::ActiveValue::Set;
@ -16,7 +15,7 @@ use crate::helpers::hash::verify_password_hash;
use crate::helpers::otp::verify_totp;
use crate::{
AuthCredential, AuthResult, ProtocolName, Target, User, UserAuthCredential, UserSnapshot,
WarpgateConfig,
WarpgateConfig, WarpgateError,
};
pub struct FileConfigProvider {
@ -46,7 +45,7 @@ fn credential_is_type(c: &UserAuthCredential, k: &str) -> bool {
#[async_trait]
impl ConfigProvider for FileConfigProvider {
async fn list_users(&mut self) -> Result<Vec<UserSnapshot>> {
async fn list_users(&mut self) -> Result<Vec<UserSnapshot>, WarpgateError> {
Ok(self
.config
.lock()
@ -58,7 +57,7 @@ impl ConfigProvider for FileConfigProvider {
.collect::<Vec<_>>())
}
async fn list_targets(&mut self) -> Result<Vec<Target>> {
async fn list_targets(&mut self) -> Result<Vec<Target>, WarpgateError> {
Ok(self
.config
.lock()
@ -75,7 +74,7 @@ impl ConfigProvider for FileConfigProvider {
username: &str,
credentials: &[AuthCredential],
protocol: ProtocolName,
) -> Result<AuthResult> {
) -> Result<AuthResult, WarpgateError> {
if credentials.is_empty() {
return Ok(AuthResult::Rejected);
}
@ -167,6 +166,7 @@ impl ConfigProvider for FileConfigProvider {
let required_kinds = match protocol {
"SSH" => &policy.ssh,
"HTTP" => &policy.http,
"MySQL" => &policy.mysql,
_ => {
error!(%protocol, "Unkown protocol");
return Ok(AuthResult::Rejected);
@ -204,7 +204,11 @@ impl ConfigProvider for FileConfigProvider {
})
}
async fn authorize_target(&mut self, username: &str, target_name: &str) -> Result<bool> {
async fn authorize_target(
&mut self,
username: &str,
target_name: &str,
) -> Result<bool, WarpgateError> {
let config = self.config.lock().await;
let user = config
.store
@ -220,7 +224,7 @@ impl ConfigProvider for FileConfigProvider {
};
let Some(target) = target else {
error!("Selected target not found: {}", target_name);
warn!("Selected target not found: {}", target_name);
return Ok(false);
};
@ -244,11 +248,11 @@ impl ConfigProvider for FileConfigProvider {
Ok(intersect)
}
async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<()> {
async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<(), WarpgateError> {
let db = self.db.lock().await;
let ticket = Ticket::Entity::find_by_id(*ticket_id).one(&*db).await?;
let Some(ticket) = ticket else {
anyhow::bail!("Ticket not found: {}", ticket_id);
return Err(WarpgateError::InvalidTicket(*ticket_id));
};
if let Some(uses_left) = ticket.uses_left {

View file

@ -1,7 +1,6 @@
mod file;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
pub use file::FileConfigProvider;
@ -11,7 +10,7 @@ use tracing::*;
use uuid::Uuid;
use warpgate_db_entities::Ticket;
use crate::{ProtocolName, Secret, Target, UserSnapshot};
use crate::{ProtocolName, Secret, Target, UserSnapshot, WarpgateError};
pub enum AuthResult {
Accepted { username: String },
@ -30,27 +29,31 @@ pub enum AuthCredential {
#[async_trait]
pub trait ConfigProvider {
async fn list_users(&mut self) -> Result<Vec<UserSnapshot>>;
async fn list_users(&mut self) -> Result<Vec<UserSnapshot>, WarpgateError>;
async fn list_targets(&mut self) -> Result<Vec<Target>>;
async fn list_targets(&mut self) -> Result<Vec<Target>, WarpgateError>;
async fn authorize(
&mut self,
username: &str,
credentials: &[AuthCredential],
protocol: ProtocolName,
) -> Result<AuthResult>;
) -> Result<AuthResult, WarpgateError>;
async fn authorize_target(&mut self, username: &str, target: &str) -> Result<bool>;
async fn authorize_target(
&mut self,
username: &str,
target: &str,
) -> Result<bool, WarpgateError>;
async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<()>;
async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<(), WarpgateError>;
}
//TODO: move this somewhere
pub async fn authorize_ticket(
db: &Arc<Mutex<DatabaseConnection>>,
secret: &Secret<String>,
) -> Result<Option<Ticket::Model>> {
) -> Result<Option<Ticket::Model>, WarpgateError> {
let ticket = {
let db = db.lock().await;
Ticket::Entity::find()

View file

@ -0,0 +1,26 @@
use std::error::Error;
use poem::error::ResponseError;
use uuid::Uuid;
#[derive(thiserror::Error, Debug)]
pub enum WarpgateError {
#[error("database error: {0}")]
DatabaseError(#[from] sea_orm::DbErr),
#[error("ticket not found: {0}")]
InvalidTicket(Uuid),
#[error(transparent)]
Other(Box<dyn Error + Send + Sync>),
}
impl ResponseError for WarpgateError {
fn status(&self) -> poem::http::StatusCode {
poem::http::StatusCode::INTERNAL_SERVER_ERROR
}
}
impl WarpgateError {
pub fn other<E: Error + Send + Sync + 'static>(err: E) -> Self {
Self::Other(Box::new(err))
}
}

View file

@ -5,6 +5,7 @@ mod config_providers;
pub mod consts;
mod data;
pub mod db;
mod error;
pub mod eventhub;
pub mod helpers;
pub mod logging;
@ -18,6 +19,7 @@ mod types;
pub use config::*;
pub use config_providers::*;
pub use data::*;
pub use error::WarpgateError;
pub use protocols::*;
pub use services::*;
pub use state::{SessionState, SessionStateInit, State};

View file

@ -1,11 +1,10 @@
use std::sync::Arc;
use anyhow::{Context, Result};
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
use tokio::sync::Mutex;
use warpgate_db_entities::Session;
use crate::{SessionId, SessionState, State, Target};
use crate::{SessionId, SessionState, State, Target, WarpgateError};
pub trait SessionHandle {
fn close(&mut self);
@ -41,7 +40,7 @@ impl WarpgateServerHandle {
&self.session_state
}
pub async fn set_username(&mut self, username: String) -> Result<()> {
pub async fn set_username(&self, username: String) -> Result<(), WarpgateError> {
use sea_orm::ActiveValue::Set;
{
@ -64,7 +63,7 @@ impl WarpgateServerHandle {
Ok(())
}
pub async fn set_target(&self, target: &Target) -> Result<()> {
pub async fn set_target(&self, target: &Target) -> Result<(), WarpgateError> {
use sea_orm::ActiveValue::Set;
{
let mut state = self.session_state.lock().await;
@ -77,7 +76,7 @@ impl WarpgateServerHandle {
Session::Entity::update_many()
.set(Session::ActiveModel {
target_snapshot: Set(Some(
serde_json::to_string(&target).context("Error serializing target")?,
serde_json::to_string(&target).map_err(WarpgateError::other)?,
)),
..Default::default()
})

View file

@ -1,4 +1,6 @@
use std::fmt::Debug;
use std::net::{SocketAddr, ToSocketAddrs};
use std::ops::Deref;
use bytes::Bytes;
use data_encoding::HEXLOWER;
@ -66,3 +68,51 @@ impl<T> Debug for Secret<T> {
write!(f, "<secret>")
}
}
#[derive(Clone)]
pub struct ListenEndpoint(pub SocketAddr);
impl Deref for ListenEndpoint {
type Target = SocketAddr;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de> Deserialize<'de> for ListenEndpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let v: String = Deserialize::deserialize::<D>(deserializer)?;
let v = v
.to_socket_addrs()
.map_err(|e| {
serde::de::Error::custom(format!(
"failed to resolve {v} into a TCP endpoint: {e:?}"
))
})?
.next()
.ok_or_else(|| {
return serde::de::Error::custom(format!(
"failed to resolve {v} into a TCP endpoint"
));
})?;
Ok(Self(v))
}
}
impl Serialize for ListenEndpoint {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.serialize(serializer)
}
}
impl Debug for ListenEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

View file

@ -0,0 +1,24 @@
[package]
name = "warpgate-database-protocols"
version = "0.3.0"
description = "Core of SQLx, the rust SQL toolkit. Just the database protocol parts."
license = "MIT OR Apache-2.0"
edition = "2021"
authors = [
"Ryan Leckey <leckey.ryan@gmail.com>",
"Austin Bonander <austin.bonander@gmail.com>",
"Chloe Ross <orangesnowfox@gmail.com>",
"Daniel Akhterov <akhterovd@gmail.com>",
]
[dependencies]
tokio = { version = "1.20", features = ["io-util"] }
bitflags = { version = "1.3", default-features = false }
bytes = "1.1"
futures-core = { version = "0.3", default-features = false }
futures-util = { version = "0.3", default-features = false, features = [
"alloc",
"sink",
] }
memchr = { version = "2.4.1", default-features = false }
thiserror = "1.0"

View file

@ -0,0 +1 @@
This is an extract from sqlx-core with Encode/Decode impls added for server-side packet flow

View file

@ -0,0 +1,241 @@
//! Types for working with errors produced by SQLx.
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt::Display;
use std::io;
use std::result::Result as StdResult;
/// A specialized `Result` type for SQLx.
pub type Result<T> = StdResult<T, Error>;
// Convenience type alias for usage within SQLx.
// Do not make this type public.
pub type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
/// An unexpected `NULL` was encountered during decoding.
///
/// Returned from [`Row::get`](crate::row::Row::get) if the value from the database is `NULL`,
/// and you are not decoding into an `Option`.
#[derive(thiserror::Error, Debug)]
#[error("unexpected null; try decoding as an `Option`")]
pub struct UnexpectedNullError;
/// Represents all the ways a method can fail within SQLx.
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
/// Error occurred while parsing a connection string.
#[error("error with configuration: {0}")]
Configuration(#[source] BoxDynError),
/// Error returned from the database.
#[error("error returned from database: {0}")]
Database(#[source] Box<dyn DatabaseError>),
/// Error communicating with the database backend.
#[error("error communicating with database: {0}")]
Io(#[from] io::Error),
/// Error occurred while attempting to establish a TLS connection.
#[error("error occurred while attempting to establish a TLS connection: {0}")]
Tls(#[source] BoxDynError),
/// Unexpected or invalid data encountered while communicating with the database.
///
/// This should indicate there is a programming error in a SQLx driver or there
/// is something corrupted with the connection to the database itself.
#[error("encountered unexpected or invalid data: {0}")]
Protocol(String),
/// No rows returned by a query that expected to return at least one row.
#[error("no rows returned by a query that expected to return at least one row")]
RowNotFound,
/// Type in query doesn't exist. Likely due to typo or missing user type.
#[error("type named {type_name} not found")]
TypeNotFound { type_name: String },
/// Column index was out of bounds.
#[error("column index out of bounds: the len is {len}, but the index is {index}")]
ColumnIndexOutOfBounds { index: usize, len: usize },
/// No column found for the given name.
#[error("no column found for name: {0}")]
ColumnNotFound(String),
/// Error occurred while decoding a value from a specific column.
#[error("error occurred while decoding column {index}: {source}")]
ColumnDecode {
index: String,
#[source]
source: BoxDynError,
},
/// Error occurred while decoding a value.
#[error("error occurred while decoding: {0}")]
Decode(#[source] BoxDynError),
/// A [`Pool::acquire`] timed out due to connections not becoming available or
/// because another task encountered too many errors while trying to open a new connection.
///
/// [`Pool::acquire`]: crate::pool::Pool::acquire
#[error("pool timed out while waiting for an open connection")]
PoolTimedOut,
/// [`Pool::close`] was called while we were waiting in [`Pool::acquire`].
///
/// [`Pool::acquire`]: crate::pool::Pool::acquire
/// [`Pool::close`]: crate::pool::Pool::close
#[error("attempted to acquire a connection on a closed pool")]
PoolClosed,
/// A background worker has crashed.
#[error("attempted to communicate with a crashed background worker")]
WorkerCrashed,
#[cfg(feature = "migrate")]
#[error("{0}")]
Migrate(#[source] Box<crate::migrate::MigrateError>),
}
impl StdError for Box<dyn DatabaseError> {}
impl Error {
#[allow(dead_code)]
#[inline]
pub(crate) fn protocol(err: impl Display) -> Self {
Error::Protocol(err.to_string())
}
#[allow(dead_code)]
#[inline]
pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self {
Error::Configuration(err.into())
}
}
/// An error that was returned from the database.
pub trait DatabaseError: 'static + Send + Sync + StdError {
/// The primary, human-readable error message.
fn message(&self) -> &str;
/// The (SQLSTATE) code for the error.
fn code(&self) -> Option<Cow<'_, str>> {
None
}
#[doc(hidden)]
fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static);
#[doc(hidden)]
fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static);
#[doc(hidden)]
fn into_error(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static>;
#[doc(hidden)]
fn is_transient_in_connect_phase(&self) -> bool {
false
}
/// Returns the name of the constraint that triggered the error, if applicable.
/// If the error was caused by a conflict of a unique index, this will be the index name.
///
/// ### Note
/// Currently only populated by the Postgres driver.
fn constraint(&self) -> Option<&str> {
None
}
}
impl dyn DatabaseError {
/// Downcast a reference to this generic database error to a specific
/// database error type.
///
/// # Panics
///
/// Panics if the database error type is not `E`. This is a deliberate contrast from
/// `Error::downcast_ref` which returns `Option<&E>`. In normal usage, you should know the
/// specific error type. In other cases, use `try_downcast_ref`.
pub fn downcast_ref<E: DatabaseError>(&self) -> &E {
self.try_downcast_ref().unwrap_or_else(|| {
panic!(
"downcast to wrong DatabaseError type; original error: {}",
self
)
})
}
/// Downcast this generic database error to a specific database error type.
///
/// # Panics
///
/// Panics if the database error type is not `E`. This is a deliberate contrast from
/// `Error::downcast` which returns `Option<E>`. In normal usage, you should know the
/// specific error type. In other cases, use `try_downcast`.
pub fn downcast<E: DatabaseError>(self: Box<Self>) -> Box<E> {
self.try_downcast().unwrap_or_else(|e| {
panic!(
"downcast to wrong DatabaseError type; original error: {}",
e
)
})
}
/// Downcast a reference to this generic database error to a specific
/// database error type.
#[inline]
pub fn try_downcast_ref<E: DatabaseError>(&self) -> Option<&E> {
self.as_error().downcast_ref()
}
/// Downcast this generic database error to a specific database error type.
#[inline]
pub fn try_downcast<E: DatabaseError>(self: Box<Self>) -> StdResult<Box<E>, Box<Self>> {
if self.as_error().is::<E>() {
Ok(self.into_error().downcast().unwrap())
} else {
Err(self)
}
}
}
impl<E> From<E> for Error
where
E: DatabaseError,
{
#[inline]
fn from(error: E) -> Self {
Error::Database(Box::new(error))
}
}
#[cfg(feature = "migrate")]
impl From<crate::migrate::MigrateError> for Error {
#[inline]
fn from(error: crate::migrate::MigrateError) -> Self {
Error::Migrate(Box::new(error))
}
}
#[cfg(feature = "_tls-native-tls")]
impl From<sqlx_rt::native_tls::Error> for Error {
#[inline]
fn from(error: sqlx_rt::native_tls::Error) -> Self {
Error::Tls(Box::new(error))
}
}
// Format an error message as a `Protocol` error
#[macro_export]
macro_rules! err_protocol {
($expr:expr) => {
$crate::error::Error::Protocol($expr.into())
};
($fmt:expr, $($arg:tt)*) => {
$crate::error::Error::Protocol(format!($fmt, $($arg)*))
};
}

View file

@ -0,0 +1,59 @@
use std::str::from_utf8;
use bytes::{Buf, Bytes};
use memchr::memchr;
use crate::err_protocol;
use crate::error::Error;
pub trait BufExt: Buf {
// Read a nul-terminated byte sequence
fn get_bytes_nul(&mut self) -> Result<Bytes, Error>;
// Read a byte sequence of the exact length
fn get_bytes(&mut self, len: usize) -> Bytes;
// Read a nul-terminated string
fn get_str_nul(&mut self) -> Result<String, Error>;
// Read a string of the exact length
fn get_str(&mut self, len: usize) -> Result<String, Error>;
}
impl BufExt for Bytes {
fn get_bytes_nul(&mut self) -> Result<Bytes, Error> {
let nul =
memchr(b'\0', self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?;
let v = self.slice(0..nul);
self.advance(nul + 1);
Ok(v)
}
fn get_bytes(&mut self, len: usize) -> Bytes {
let v = self.slice(..len);
self.advance(len);
v
}
fn get_str_nul(&mut self) -> Result<String, Error> {
self.get_bytes_nul().and_then(|bytes| {
from_utf8(&*bytes)
.map(ToOwned::to_owned)
.map_err(|err| err_protocol!("{}", err))
})
}
fn get_str(&mut self, len: usize) -> Result<String, Error> {
let v = from_utf8(&self[..len])
.map_err(|err| err_protocol!("{}", err))
.map(ToOwned::to_owned)?;
self.advance(len);
Ok(v)
}
}

View file

@ -0,0 +1,12 @@
use bytes::BufMut;
pub trait BufMutExt: BufMut {
fn put_str_nul(&mut self, s: &str);
}
impl BufMutExt for Vec<u8> {
fn put_str_nul(&mut self, s: &str) {
self.extend(s.as_bytes());
self.push(0);
}
}

View file

@ -0,0 +1,166 @@
#![allow(dead_code)]
use std::io;
use std::io::Cursor;
use std::ops::{Deref, DerefMut};
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use crate::error::Error;
use crate::io::decode::Decode;
use crate::io::encode::Encode;
use crate::io::write_and_flush::WriteAndFlush;
pub struct BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) stream: S,
// writes with `write` to the underlying stream are buffered
// this can be flushed with `flush`
pub(crate) wbuf: Vec<u8>,
// we read into the read buffer using 100% safe code
rbuf: BytesMut,
}
impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(stream: S) -> Self {
Self {
stream,
wbuf: Vec::with_capacity(512),
rbuf: BytesMut::with_capacity(4096),
}
}
pub fn write<'en, T>(&mut self, value: T)
where
T: Encode<'en, ()>,
{
self.write_with(value, ())
}
pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
where
T: Encode<'en, C>,
{
value.encode_with(&mut self.wbuf, context);
}
pub fn flush(&mut self) -> WriteAndFlush<'_, S> {
WriteAndFlush {
stream: &mut self.stream,
buf: Cursor::new(&mut self.wbuf),
}
}
pub async fn read<'de, T>(&mut self, cnt: usize) -> Result<T, Error>
where
T: Decode<'de, ()>,
{
self.read_with(cnt, ()).await
}
pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
{
T::decode_with(self.read_raw(cnt).await?.freeze(), context)
}
pub async fn read_raw(&mut self, cnt: usize) -> Result<BytesMut, Error> {
read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?;
let buf = self.rbuf.split_to(cnt);
Ok(buf)
}
pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> {
read_raw_into(&mut self.stream, buf, cnt).await
}
pub fn take(self) -> S {
self.stream
}
}
impl<S> Deref for BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Target = S;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S> DerefMut for BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
// Holds a buffer which has been temporarily extended, so that
// we can read into it. Automatically shrinks the buffer back
// down if the read is cancelled.
struct BufTruncator<'a> {
buf: &'a mut BytesMut,
filled_len: usize,
}
impl<'a> BufTruncator<'a> {
fn new(buf: &'a mut BytesMut) -> Self {
let filled_len = buf.len();
Self { buf, filled_len }
}
fn reserve(&mut self, space: usize) {
self.buf.resize(self.filled_len + space, 0);
}
async fn read<S: AsyncRead + Unpin>(&mut self, stream: &mut S) -> Result<usize, Error> {
let n = stream.read(&mut self.buf[self.filled_len..]).await?;
self.filled_len += n;
Ok(n)
}
fn is_full(&self) -> bool {
self.filled_len >= self.buf.len()
}
}
impl Drop for BufTruncator<'_> {
fn drop(&mut self) {
self.buf.truncate(self.filled_len);
}
}
async fn read_raw_into<S: AsyncRead + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
cnt: usize,
) -> Result<(), Error> {
let mut buf = BufTruncator::new(buf);
buf.reserve(cnt);
while !buf.is_full() {
let n = buf.read(stream).await?;
if n == 0 {
// a zero read when we had space in the read buffer
// should be treated as an EOF
// and an unexpected EOF means the server told us to go away
return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
}
}
Ok(())
}

View file

@ -0,0 +1,29 @@
use bytes::Bytes;
use crate::error::Error;
pub trait Decode<'de, Context = ()>
where
Self: Sized,
{
fn decode(buf: Bytes) -> Result<Self, Error>
where
Self: Decode<'de, ()>,
{
Self::decode_with(buf, ())
}
fn decode_with(buf: Bytes, context: Context) -> Result<Self, Error>;
}
impl Decode<'_> for Bytes {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
Ok(buf)
}
}
impl Decode<'_> for () {
fn decode_with(_: Bytes, _: ()) -> Result<(), Error> {
Ok(())
}
}

View file

@ -0,0 +1,16 @@
pub trait Encode<'en, Context = ()> {
fn encode(&self, buf: &mut Vec<u8>)
where
Self: Encode<'en, ()>,
{
self.encode_with(buf, ());
}
fn encode_with(&self, buf: &mut Vec<u8>, context: Context);
}
impl<'en, C> Encode<'en, C> for &'_ [u8] {
fn encode_with(&self, buf: &mut Vec<u8>, _: C) {
buf.extend_from_slice(self);
}
}

View file

@ -0,0 +1,12 @@
mod buf;
mod buf_mut;
mod buf_stream;
mod decode;
mod encode;
mod write_and_flush;
pub use buf::BufExt;
pub use buf_mut::BufMutExt;
pub use buf_stream::BufStream;
pub use decode::Decode;
pub use encode::Encode;

View file

@ -0,0 +1,47 @@
use std::io::{BufRead, Cursor};
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Future;
use futures_util::ready;
use tokio::io::AsyncWrite;
use crate::error::Error;
// Atomic operation that writes the full buffer to the stream, flushes the stream, and then
// clears the buffer (even if either of the two previous operations failed).
pub struct WriteAndFlush<'a, S> {
pub(super) stream: &'a mut S,
pub(super) buf: Cursor<&'a mut Vec<u8>>,
}
impl<S: AsyncWrite + Unpin> Future for WriteAndFlush<'_, S> {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
ref mut stream,
ref mut buf,
} = *self;
loop {
let read = buf.fill_buf()?;
if !read.is_empty() {
let written = ready!(Pin::new(&mut *stream).poll_write(cx, read)?);
buf.consume(written);
} else {
break;
}
}
Pin::new(stream).poll_flush(cx).map_err(Error::Io)
}
}
impl<'a, S> Drop for WriteAndFlush<'a, S> {
fn drop(&mut self) {
// clear the buffer regardless of whether the flush succeeded or not
self.buf.get_mut().clear();
}
}

View file

@ -0,0 +1,6 @@
#![allow(dead_code)]
pub mod io;
pub mod mysql;
#[macro_use]
pub mod error;

View file

@ -0,0 +1,901 @@
use std::str::FromStr;
use crate::error::Error;
#[allow(non_camel_case_types)]
#[derive(Copy, Clone)]
pub(crate) enum CharSet {
armscii8,
ascii,
big5,
binary,
cp1250,
cp1251,
cp1256,
cp1257,
cp850,
cp852,
cp866,
cp932,
dec8,
eucjpms,
euckr,
gb18030,
gb2312,
gbk,
geostd8,
greek,
hebrew,
hp8,
keybcs2,
koi8r,
koi8u,
latin1,
latin2,
latin5,
latin7,
macce,
macroman,
sjis,
swe7,
tis620,
ucs2,
ujis,
utf16,
utf16le,
utf32,
utf8,
utf8mb4,
}
impl CharSet {
pub(crate) fn as_str(&self) -> &'static str {
match self {
CharSet::armscii8 => "armscii8",
CharSet::ascii => "ascii",
CharSet::big5 => "big5",
CharSet::binary => "binary",
CharSet::cp1250 => "cp1250",
CharSet::cp1251 => "cp1251",
CharSet::cp1256 => "cp1256",
CharSet::cp1257 => "cp1257",
CharSet::cp850 => "cp850",
CharSet::cp852 => "cp852",
CharSet::cp866 => "cp866",
CharSet::cp932 => "cp932",
CharSet::dec8 => "dec8",
CharSet::eucjpms => "eucjpms",
CharSet::euckr => "euckr",
CharSet::gb18030 => "gb18030",
CharSet::gb2312 => "gb2312",
CharSet::gbk => "gbk",
CharSet::geostd8 => "geostd8",
CharSet::greek => "greek",
CharSet::hebrew => "hebrew",
CharSet::hp8 => "hp8",
CharSet::keybcs2 => "keybcs2",
CharSet::koi8r => "koi8r",
CharSet::koi8u => "koi8u",
CharSet::latin1 => "latin1",
CharSet::latin2 => "latin2",
CharSet::latin5 => "latin5",
CharSet::latin7 => "latin7",
CharSet::macce => "macce",
CharSet::macroman => "macroman",
CharSet::sjis => "sjis",
CharSet::swe7 => "swe7",
CharSet::tis620 => "tis620",
CharSet::ucs2 => "ucs2",
CharSet::ujis => "ujis",
CharSet::utf16 => "utf16",
CharSet::utf16le => "utf16le",
CharSet::utf32 => "utf32",
CharSet::utf8 => "utf8",
CharSet::utf8mb4 => "utf8mb4",
}
}
pub(crate) fn default_collation(&self) -> Collation {
match self {
CharSet::armscii8 => Collation::armscii8_general_ci,
CharSet::ascii => Collation::ascii_general_ci,
CharSet::big5 => Collation::big5_chinese_ci,
CharSet::binary => Collation::binary,
CharSet::cp1250 => Collation::cp1250_general_ci,
CharSet::cp1251 => Collation::cp1251_general_ci,
CharSet::cp1256 => Collation::cp1256_general_ci,
CharSet::cp1257 => Collation::cp1257_general_ci,
CharSet::cp850 => Collation::cp850_general_ci,
CharSet::cp852 => Collation::cp852_general_ci,
CharSet::cp866 => Collation::cp866_general_ci,
CharSet::cp932 => Collation::cp932_japanese_ci,
CharSet::dec8 => Collation::dec8_swedish_ci,
CharSet::eucjpms => Collation::eucjpms_japanese_ci,
CharSet::euckr => Collation::euckr_korean_ci,
CharSet::gb18030 => Collation::gb18030_chinese_ci,
CharSet::gb2312 => Collation::gb2312_chinese_ci,
CharSet::gbk => Collation::gbk_chinese_ci,
CharSet::geostd8 => Collation::geostd8_general_ci,
CharSet::greek => Collation::greek_general_ci,
CharSet::hebrew => Collation::hebrew_general_ci,
CharSet::hp8 => Collation::hp8_english_ci,
CharSet::keybcs2 => Collation::keybcs2_general_ci,
CharSet::koi8r => Collation::koi8r_general_ci,
CharSet::koi8u => Collation::koi8u_general_ci,
CharSet::latin1 => Collation::latin1_swedish_ci,
CharSet::latin2 => Collation::latin2_general_ci,
CharSet::latin5 => Collation::latin5_turkish_ci,
CharSet::latin7 => Collation::latin7_general_ci,
CharSet::macce => Collation::macce_general_ci,
CharSet::macroman => Collation::macroman_general_ci,
CharSet::sjis => Collation::sjis_japanese_ci,
CharSet::swe7 => Collation::swe7_swedish_ci,
CharSet::tis620 => Collation::tis620_thai_ci,
CharSet::ucs2 => Collation::ucs2_general_ci,
CharSet::ujis => Collation::ujis_japanese_ci,
CharSet::utf16 => Collation::utf16_general_ci,
CharSet::utf16le => Collation::utf16le_general_ci,
CharSet::utf32 => Collation::utf32_general_ci,
CharSet::utf8 => Collation::utf8_unicode_ci,
CharSet::utf8mb4 => Collation::utf8mb4_unicode_ci,
}
}
}
impl FromStr for CharSet {
type Err = Error;
fn from_str(char_set: &str) -> Result<Self, Self::Err> {
Ok(match char_set {
"armscii8" => CharSet::armscii8,
"ascii" => CharSet::ascii,
"big5" => CharSet::big5,
"binary" => CharSet::binary,
"cp1250" => CharSet::cp1250,
"cp1251" => CharSet::cp1251,
"cp1256" => CharSet::cp1256,
"cp1257" => CharSet::cp1257,
"cp850" => CharSet::cp850,
"cp852" => CharSet::cp852,
"cp866" => CharSet::cp866,
"cp932" => CharSet::cp932,
"dec8" => CharSet::dec8,
"eucjpms" => CharSet::eucjpms,
"euckr" => CharSet::euckr,
"gb18030" => CharSet::gb18030,
"gb2312" => CharSet::gb2312,
"gbk" => CharSet::gbk,
"geostd8" => CharSet::geostd8,
"greek" => CharSet::greek,
"hebrew" => CharSet::hebrew,
"hp8" => CharSet::hp8,
"keybcs2" => CharSet::keybcs2,
"koi8r" => CharSet::koi8r,
"koi8u" => CharSet::koi8u,
"latin1" => CharSet::latin1,
"latin2" => CharSet::latin2,
"latin5" => CharSet::latin5,
"latin7" => CharSet::latin7,
"macce" => CharSet::macce,
"macroman" => CharSet::macroman,
"sjis" => CharSet::sjis,
"swe7" => CharSet::swe7,
"tis620" => CharSet::tis620,
"ucs2" => CharSet::ucs2,
"ujis" => CharSet::ujis,
"utf16" => CharSet::utf16,
"utf16le" => CharSet::utf16le,
"utf32" => CharSet::utf32,
"utf8" => CharSet::utf8,
"utf8mb4" => CharSet::utf8mb4,
_ => {
return Err(Error::Configuration(
format!("unsupported MySQL charset: {}", char_set).into(),
));
}
})
}
}
#[derive(Copy, Clone)]
#[allow(non_camel_case_types)]
#[repr(u8)]
pub(crate) enum Collation {
armscii8_bin = 64,
armscii8_general_ci = 32,
ascii_bin = 65,
ascii_general_ci = 11,
big5_bin = 84,
big5_chinese_ci = 1,
binary = 63,
cp1250_bin = 66,
cp1250_croatian_ci = 44,
cp1250_czech_cs = 34,
cp1250_general_ci = 26,
cp1250_polish_ci = 99,
cp1251_bin = 50,
cp1251_bulgarian_ci = 14,
cp1251_general_ci = 51,
cp1251_general_cs = 52,
cp1251_ukrainian_ci = 23,
cp1256_bin = 67,
cp1256_general_ci = 57,
cp1257_bin = 58,
cp1257_general_ci = 59,
cp1257_lithuanian_ci = 29,
cp850_bin = 80,
cp850_general_ci = 4,
cp852_bin = 81,
cp852_general_ci = 40,
cp866_bin = 68,
cp866_general_ci = 36,
cp932_bin = 96,
cp932_japanese_ci = 95,
dec8_bin = 69,
dec8_swedish_ci = 3,
eucjpms_bin = 98,
eucjpms_japanese_ci = 97,
euckr_bin = 85,
euckr_korean_ci = 19,
gb18030_bin = 249,
gb18030_chinese_ci = 248,
gb18030_unicode_520_ci = 250,
gb2312_bin = 86,
gb2312_chinese_ci = 24,
gbk_bin = 87,
gbk_chinese_ci = 28,
geostd8_bin = 93,
geostd8_general_ci = 92,
greek_bin = 70,
greek_general_ci = 25,
hebrew_bin = 71,
hebrew_general_ci = 16,
hp8_bin = 72,
hp8_english_ci = 6,
keybcs2_bin = 73,
keybcs2_general_ci = 37,
koi8r_bin = 74,
koi8r_general_ci = 7,
koi8u_bin = 75,
koi8u_general_ci = 22,
latin1_bin = 47,
latin1_danish_ci = 15,
latin1_general_ci = 48,
latin1_general_cs = 49,
latin1_german1_ci = 5,
latin1_german2_ci = 31,
latin1_spanish_ci = 94,
latin1_swedish_ci = 8,
latin2_bin = 77,
latin2_croatian_ci = 27,
latin2_czech_cs = 2,
latin2_general_ci = 9,
latin2_hungarian_ci = 21,
latin5_bin = 78,
latin5_turkish_ci = 30,
latin7_bin = 79,
latin7_estonian_cs = 20,
latin7_general_ci = 41,
latin7_general_cs = 42,
macce_bin = 43,
macce_general_ci = 38,
macroman_bin = 53,
macroman_general_ci = 39,
sjis_bin = 88,
sjis_japanese_ci = 13,
swe7_bin = 82,
swe7_swedish_ci = 10,
tis620_bin = 89,
tis620_thai_ci = 18,
ucs2_bin = 90,
ucs2_croatian_ci = 149,
ucs2_czech_ci = 138,
ucs2_danish_ci = 139,
ucs2_esperanto_ci = 145,
ucs2_estonian_ci = 134,
ucs2_general_ci = 35,
ucs2_general_mysql500_ci = 159,
ucs2_german2_ci = 148,
ucs2_hungarian_ci = 146,
ucs2_icelandic_ci = 129,
ucs2_latvian_ci = 130,
ucs2_lithuanian_ci = 140,
ucs2_persian_ci = 144,
ucs2_polish_ci = 133,
ucs2_roman_ci = 143,
ucs2_romanian_ci = 131,
ucs2_sinhala_ci = 147,
ucs2_slovak_ci = 141,
ucs2_slovenian_ci = 132,
ucs2_spanish_ci = 135,
ucs2_spanish2_ci = 142,
ucs2_swedish_ci = 136,
ucs2_turkish_ci = 137,
ucs2_unicode_520_ci = 150,
ucs2_unicode_ci = 128,
ucs2_vietnamese_ci = 151,
ujis_bin = 91,
ujis_japanese_ci = 12,
utf16_bin = 55,
utf16_croatian_ci = 122,
utf16_czech_ci = 111,
utf16_danish_ci = 112,
utf16_esperanto_ci = 118,
utf16_estonian_ci = 107,
utf16_general_ci = 54,
utf16_german2_ci = 121,
utf16_hungarian_ci = 119,
utf16_icelandic_ci = 102,
utf16_latvian_ci = 103,
utf16_lithuanian_ci = 113,
utf16_persian_ci = 117,
utf16_polish_ci = 106,
utf16_roman_ci = 116,
utf16_romanian_ci = 104,
utf16_sinhala_ci = 120,
utf16_slovak_ci = 114,
utf16_slovenian_ci = 105,
utf16_spanish_ci = 108,
utf16_spanish2_ci = 115,
utf16_swedish_ci = 109,
utf16_turkish_ci = 110,
utf16_unicode_520_ci = 123,
utf16_unicode_ci = 101,
utf16_vietnamese_ci = 124,
utf16le_bin = 62,
utf16le_general_ci = 56,
utf32_bin = 61,
utf32_croatian_ci = 181,
utf32_czech_ci = 170,
utf32_danish_ci = 171,
utf32_esperanto_ci = 177,
utf32_estonian_ci = 166,
utf32_general_ci = 60,
utf32_german2_ci = 180,
utf32_hungarian_ci = 178,
utf32_icelandic_ci = 161,
utf32_latvian_ci = 162,
utf32_lithuanian_ci = 172,
utf32_persian_ci = 176,
utf32_polish_ci = 165,
utf32_roman_ci = 175,
utf32_romanian_ci = 163,
utf32_sinhala_ci = 179,
utf32_slovak_ci = 173,
utf32_slovenian_ci = 164,
utf32_spanish_ci = 167,
utf32_spanish2_ci = 174,
utf32_swedish_ci = 168,
utf32_turkish_ci = 169,
utf32_unicode_520_ci = 182,
utf32_unicode_ci = 160,
utf32_vietnamese_ci = 183,
utf8_bin = 83,
utf8_croatian_ci = 213,
utf8_czech_ci = 202,
utf8_danish_ci = 203,
utf8_esperanto_ci = 209,
utf8_estonian_ci = 198,
utf8_general_ci = 33,
utf8_general_mysql500_ci = 223,
utf8_german2_ci = 212,
utf8_hungarian_ci = 210,
utf8_icelandic_ci = 193,
utf8_latvian_ci = 194,
utf8_lithuanian_ci = 204,
utf8_persian_ci = 208,
utf8_polish_ci = 197,
utf8_roman_ci = 207,
utf8_romanian_ci = 195,
utf8_sinhala_ci = 211,
utf8_slovak_ci = 205,
utf8_slovenian_ci = 196,
utf8_spanish_ci = 199,
utf8_spanish2_ci = 206,
utf8_swedish_ci = 200,
utf8_tolower_ci = 76,
utf8_turkish_ci = 201,
utf8_unicode_520_ci = 214,
utf8_unicode_ci = 192,
utf8_vietnamese_ci = 215,
utf8mb4_0900_ai_ci = 255,
utf8mb4_bin = 46,
utf8mb4_croatian_ci = 245,
utf8mb4_czech_ci = 234,
utf8mb4_danish_ci = 235,
utf8mb4_esperanto_ci = 241,
utf8mb4_estonian_ci = 230,
utf8mb4_general_ci = 45,
utf8mb4_german2_ci = 244,
utf8mb4_hungarian_ci = 242,
utf8mb4_icelandic_ci = 225,
utf8mb4_latvian_ci = 226,
utf8mb4_lithuanian_ci = 236,
utf8mb4_persian_ci = 240,
utf8mb4_polish_ci = 229,
utf8mb4_roman_ci = 239,
utf8mb4_romanian_ci = 227,
utf8mb4_sinhala_ci = 243,
utf8mb4_slovak_ci = 237,
utf8mb4_slovenian_ci = 228,
utf8mb4_spanish_ci = 231,
utf8mb4_spanish2_ci = 238,
utf8mb4_swedish_ci = 232,
utf8mb4_turkish_ci = 233,
utf8mb4_unicode_520_ci = 246,
utf8mb4_unicode_ci = 224,
utf8mb4_vietnamese_ci = 247,
}
impl Collation {
pub(crate) fn as_str(&self) -> &'static str {
match self {
Collation::armscii8_bin => "armscii8_bin",
Collation::armscii8_general_ci => "armscii8_general_ci",
Collation::ascii_bin => "ascii_bin",
Collation::ascii_general_ci => "ascii_general_ci",
Collation::big5_bin => "big5_bin",
Collation::big5_chinese_ci => "big5_chinese_ci",
Collation::binary => "binary",
Collation::cp1250_bin => "cp1250_bin",
Collation::cp1250_croatian_ci => "cp1250_croatian_ci",
Collation::cp1250_czech_cs => "cp1250_czech_cs",
Collation::cp1250_general_ci => "cp1250_general_ci",
Collation::cp1250_polish_ci => "cp1250_polish_ci",
Collation::cp1251_bin => "cp1251_bin",
Collation::cp1251_bulgarian_ci => "cp1251_bulgarian_ci",
Collation::cp1251_general_ci => "cp1251_general_ci",
Collation::cp1251_general_cs => "cp1251_general_cs",
Collation::cp1251_ukrainian_ci => "cp1251_ukrainian_ci",
Collation::cp1256_bin => "cp1256_bin",
Collation::cp1256_general_ci => "cp1256_general_ci",
Collation::cp1257_bin => "cp1257_bin",
Collation::cp1257_general_ci => "cp1257_general_ci",
Collation::cp1257_lithuanian_ci => "cp1257_lithuanian_ci",
Collation::cp850_bin => "cp850_bin",
Collation::cp850_general_ci => "cp850_general_ci",
Collation::cp852_bin => "cp852_bin",
Collation::cp852_general_ci => "cp852_general_ci",
Collation::cp866_bin => "cp866_bin",
Collation::cp866_general_ci => "cp866_general_ci",
Collation::cp932_bin => "cp932_bin",
Collation::cp932_japanese_ci => "cp932_japanese_ci",
Collation::dec8_bin => "dec8_bin",
Collation::dec8_swedish_ci => "dec8_swedish_ci",
Collation::eucjpms_bin => "eucjpms_bin",
Collation::eucjpms_japanese_ci => "eucjpms_japanese_ci",
Collation::euckr_bin => "euckr_bin",
Collation::euckr_korean_ci => "euckr_korean_ci",
Collation::gb18030_bin => "gb18030_bin",
Collation::gb18030_chinese_ci => "gb18030_chinese_ci",
Collation::gb18030_unicode_520_ci => "gb18030_unicode_520_ci",
Collation::gb2312_bin => "gb2312_bin",
Collation::gb2312_chinese_ci => "gb2312_chinese_ci",
Collation::gbk_bin => "gbk_bin",
Collation::gbk_chinese_ci => "gbk_chinese_ci",
Collation::geostd8_bin => "geostd8_bin",
Collation::geostd8_general_ci => "geostd8_general_ci",
Collation::greek_bin => "greek_bin",
Collation::greek_general_ci => "greek_general_ci",
Collation::hebrew_bin => "hebrew_bin",
Collation::hebrew_general_ci => "hebrew_general_ci",
Collation::hp8_bin => "hp8_bin",
Collation::hp8_english_ci => "hp8_english_ci",
Collation::keybcs2_bin => "keybcs2_bin",
Collation::keybcs2_general_ci => "keybcs2_general_ci",
Collation::koi8r_bin => "koi8r_bin",
Collation::koi8r_general_ci => "koi8r_general_ci",
Collation::koi8u_bin => "koi8u_bin",
Collation::koi8u_general_ci => "koi8u_general_ci",
Collation::latin1_bin => "latin1_bin",
Collation::latin1_danish_ci => "latin1_danish_ci",
Collation::latin1_general_ci => "latin1_general_ci",
Collation::latin1_general_cs => "latin1_general_cs",
Collation::latin1_german1_ci => "latin1_german1_ci",
Collation::latin1_german2_ci => "latin1_german2_ci",
Collation::latin1_spanish_ci => "latin1_spanish_ci",
Collation::latin1_swedish_ci => "latin1_swedish_ci",
Collation::latin2_bin => "latin2_bin",
Collation::latin2_croatian_ci => "latin2_croatian_ci",
Collation::latin2_czech_cs => "latin2_czech_cs",
Collation::latin2_general_ci => "latin2_general_ci",
Collation::latin2_hungarian_ci => "latin2_hungarian_ci",
Collation::latin5_bin => "latin5_bin",
Collation::latin5_turkish_ci => "latin5_turkish_ci",
Collation::latin7_bin => "latin7_bin",
Collation::latin7_estonian_cs => "latin7_estonian_cs",
Collation::latin7_general_ci => "latin7_general_ci",
Collation::latin7_general_cs => "latin7_general_cs",
Collation::macce_bin => "macce_bin",
Collation::macce_general_ci => "macce_general_ci",
Collation::macroman_bin => "macroman_bin",
Collation::macroman_general_ci => "macroman_general_ci",
Collation::sjis_bin => "sjis_bin",
Collation::sjis_japanese_ci => "sjis_japanese_ci",
Collation::swe7_bin => "swe7_bin",
Collation::swe7_swedish_ci => "swe7_swedish_ci",
Collation::tis620_bin => "tis620_bin",
Collation::tis620_thai_ci => "tis620_thai_ci",
Collation::ucs2_bin => "ucs2_bin",
Collation::ucs2_croatian_ci => "ucs2_croatian_ci",
Collation::ucs2_czech_ci => "ucs2_czech_ci",
Collation::ucs2_danish_ci => "ucs2_danish_ci",
Collation::ucs2_esperanto_ci => "ucs2_esperanto_ci",
Collation::ucs2_estonian_ci => "ucs2_estonian_ci",
Collation::ucs2_general_ci => "ucs2_general_ci",
Collation::ucs2_general_mysql500_ci => "ucs2_general_mysql500_ci",
Collation::ucs2_german2_ci => "ucs2_german2_ci",
Collation::ucs2_hungarian_ci => "ucs2_hungarian_ci",
Collation::ucs2_icelandic_ci => "ucs2_icelandic_ci",
Collation::ucs2_latvian_ci => "ucs2_latvian_ci",
Collation::ucs2_lithuanian_ci => "ucs2_lithuanian_ci",
Collation::ucs2_persian_ci => "ucs2_persian_ci",
Collation::ucs2_polish_ci => "ucs2_polish_ci",
Collation::ucs2_roman_ci => "ucs2_roman_ci",
Collation::ucs2_romanian_ci => "ucs2_romanian_ci",
Collation::ucs2_sinhala_ci => "ucs2_sinhala_ci",
Collation::ucs2_slovak_ci => "ucs2_slovak_ci",
Collation::ucs2_slovenian_ci => "ucs2_slovenian_ci",
Collation::ucs2_spanish_ci => "ucs2_spanish_ci",
Collation::ucs2_spanish2_ci => "ucs2_spanish2_ci",
Collation::ucs2_swedish_ci => "ucs2_swedish_ci",
Collation::ucs2_turkish_ci => "ucs2_turkish_ci",
Collation::ucs2_unicode_520_ci => "ucs2_unicode_520_ci",
Collation::ucs2_unicode_ci => "ucs2_unicode_ci",
Collation::ucs2_vietnamese_ci => "ucs2_vietnamese_ci",
Collation::ujis_bin => "ujis_bin",
Collation::ujis_japanese_ci => "ujis_japanese_ci",
Collation::utf16_bin => "utf16_bin",
Collation::utf16_croatian_ci => "utf16_croatian_ci",
Collation::utf16_czech_ci => "utf16_czech_ci",
Collation::utf16_danish_ci => "utf16_danish_ci",
Collation::utf16_esperanto_ci => "utf16_esperanto_ci",
Collation::utf16_estonian_ci => "utf16_estonian_ci",
Collation::utf16_general_ci => "utf16_general_ci",
Collation::utf16_german2_ci => "utf16_german2_ci",
Collation::utf16_hungarian_ci => "utf16_hungarian_ci",
Collation::utf16_icelandic_ci => "utf16_icelandic_ci",
Collation::utf16_latvian_ci => "utf16_latvian_ci",
Collation::utf16_lithuanian_ci => "utf16_lithuanian_ci",
Collation::utf16_persian_ci => "utf16_persian_ci",
Collation::utf16_polish_ci => "utf16_polish_ci",
Collation::utf16_roman_ci => "utf16_roman_ci",
Collation::utf16_romanian_ci => "utf16_romanian_ci",
Collation::utf16_sinhala_ci => "utf16_sinhala_ci",
Collation::utf16_slovak_ci => "utf16_slovak_ci",
Collation::utf16_slovenian_ci => "utf16_slovenian_ci",
Collation::utf16_spanish_ci => "utf16_spanish_ci",
Collation::utf16_spanish2_ci => "utf16_spanish2_ci",
Collation::utf16_swedish_ci => "utf16_swedish_ci",
Collation::utf16_turkish_ci => "utf16_turkish_ci",
Collation::utf16_unicode_520_ci => "utf16_unicode_520_ci",
Collation::utf16_unicode_ci => "utf16_unicode_ci",
Collation::utf16_vietnamese_ci => "utf16_vietnamese_ci",
Collation::utf16le_bin => "utf16le_bin",
Collation::utf16le_general_ci => "utf16le_general_ci",
Collation::utf32_bin => "utf32_bin",
Collation::utf32_croatian_ci => "utf32_croatian_ci",
Collation::utf32_czech_ci => "utf32_czech_ci",
Collation::utf32_danish_ci => "utf32_danish_ci",
Collation::utf32_esperanto_ci => "utf32_esperanto_ci",
Collation::utf32_estonian_ci => "utf32_estonian_ci",
Collation::utf32_general_ci => "utf32_general_ci",
Collation::utf32_german2_ci => "utf32_german2_ci",
Collation::utf32_hungarian_ci => "utf32_hungarian_ci",
Collation::utf32_icelandic_ci => "utf32_icelandic_ci",
Collation::utf32_latvian_ci => "utf32_latvian_ci",
Collation::utf32_lithuanian_ci => "utf32_lithuanian_ci",
Collation::utf32_persian_ci => "utf32_persian_ci",
Collation::utf32_polish_ci => "utf32_polish_ci",
Collation::utf32_roman_ci => "utf32_roman_ci",
Collation::utf32_romanian_ci => "utf32_romanian_ci",
Collation::utf32_sinhala_ci => "utf32_sinhala_ci",
Collation::utf32_slovak_ci => "utf32_slovak_ci",
Collation::utf32_slovenian_ci => "utf32_slovenian_ci",
Collation::utf32_spanish_ci => "utf32_spanish_ci",
Collation::utf32_spanish2_ci => "utf32_spanish2_ci",
Collation::utf32_swedish_ci => "utf32_swedish_ci",
Collation::utf32_turkish_ci => "utf32_turkish_ci",
Collation::utf32_unicode_520_ci => "utf32_unicode_520_ci",
Collation::utf32_unicode_ci => "utf32_unicode_ci",
Collation::utf32_vietnamese_ci => "utf32_vietnamese_ci",
Collation::utf8_bin => "utf8_bin",
Collation::utf8_croatian_ci => "utf8_croatian_ci",
Collation::utf8_czech_ci => "utf8_czech_ci",
Collation::utf8_danish_ci => "utf8_danish_ci",
Collation::utf8_esperanto_ci => "utf8_esperanto_ci",
Collation::utf8_estonian_ci => "utf8_estonian_ci",
Collation::utf8_general_ci => "utf8_general_ci",
Collation::utf8_general_mysql500_ci => "utf8_general_mysql500_ci",
Collation::utf8_german2_ci => "utf8_german2_ci",
Collation::utf8_hungarian_ci => "utf8_hungarian_ci",
Collation::utf8_icelandic_ci => "utf8_icelandic_ci",
Collation::utf8_latvian_ci => "utf8_latvian_ci",
Collation::utf8_lithuanian_ci => "utf8_lithuanian_ci",
Collation::utf8_persian_ci => "utf8_persian_ci",
Collation::utf8_polish_ci => "utf8_polish_ci",
Collation::utf8_roman_ci => "utf8_roman_ci",
Collation::utf8_romanian_ci => "utf8_romanian_ci",
Collation::utf8_sinhala_ci => "utf8_sinhala_ci",
Collation::utf8_slovak_ci => "utf8_slovak_ci",
Collation::utf8_slovenian_ci => "utf8_slovenian_ci",
Collation::utf8_spanish_ci => "utf8_spanish_ci",
Collation::utf8_spanish2_ci => "utf8_spanish2_ci",
Collation::utf8_swedish_ci => "utf8_swedish_ci",
Collation::utf8_tolower_ci => "utf8_tolower_ci",
Collation::utf8_turkish_ci => "utf8_turkish_ci",
Collation::utf8_unicode_520_ci => "utf8_unicode_520_ci",
Collation::utf8_unicode_ci => "utf8_unicode_ci",
Collation::utf8_vietnamese_ci => "utf8_vietnamese_ci",
Collation::utf8mb4_0900_ai_ci => "utf8mb4_0900_ai_ci",
Collation::utf8mb4_bin => "utf8mb4_bin",
Collation::utf8mb4_croatian_ci => "utf8mb4_croatian_ci",
Collation::utf8mb4_czech_ci => "utf8mb4_czech_ci",
Collation::utf8mb4_danish_ci => "utf8mb4_danish_ci",
Collation::utf8mb4_esperanto_ci => "utf8mb4_esperanto_ci",
Collation::utf8mb4_estonian_ci => "utf8mb4_estonian_ci",
Collation::utf8mb4_general_ci => "utf8mb4_general_ci",
Collation::utf8mb4_german2_ci => "utf8mb4_german2_ci",
Collation::utf8mb4_hungarian_ci => "utf8mb4_hungarian_ci",
Collation::utf8mb4_icelandic_ci => "utf8mb4_icelandic_ci",
Collation::utf8mb4_latvian_ci => "utf8mb4_latvian_ci",
Collation::utf8mb4_lithuanian_ci => "utf8mb4_lithuanian_ci",
Collation::utf8mb4_persian_ci => "utf8mb4_persian_ci",
Collation::utf8mb4_polish_ci => "utf8mb4_polish_ci",
Collation::utf8mb4_roman_ci => "utf8mb4_roman_ci",
Collation::utf8mb4_romanian_ci => "utf8mb4_romanian_ci",
Collation::utf8mb4_sinhala_ci => "utf8mb4_sinhala_ci",
Collation::utf8mb4_slovak_ci => "utf8mb4_slovak_ci",
Collation::utf8mb4_slovenian_ci => "utf8mb4_slovenian_ci",
Collation::utf8mb4_spanish_ci => "utf8mb4_spanish_ci",
Collation::utf8mb4_spanish2_ci => "utf8mb4_spanish2_ci",
Collation::utf8mb4_swedish_ci => "utf8mb4_swedish_ci",
Collation::utf8mb4_turkish_ci => "utf8mb4_turkish_ci",
Collation::utf8mb4_unicode_520_ci => "utf8mb4_unicode_520_ci",
Collation::utf8mb4_unicode_ci => "utf8mb4_unicode_ci",
Collation::utf8mb4_vietnamese_ci => "utf8mb4_vietnamese_ci",
}
}
}
// Handshake packet have only 1 byte for collation_id.
// So we can't use collations with ID > 255.
impl FromStr for Collation {
type Err = Error;
fn from_str(collation: &str) -> Result<Self, Self::Err> {
Ok(match collation {
"big5_chinese_ci" => Collation::big5_chinese_ci,
"swe7_swedish_ci" => Collation::swe7_swedish_ci,
"utf16_unicode_ci" => Collation::utf16_unicode_ci,
"utf16_icelandic_ci" => Collation::utf16_icelandic_ci,
"utf16_latvian_ci" => Collation::utf16_latvian_ci,
"utf16_romanian_ci" => Collation::utf16_romanian_ci,
"utf16_slovenian_ci" => Collation::utf16_slovenian_ci,
"utf16_polish_ci" => Collation::utf16_polish_ci,
"utf16_estonian_ci" => Collation::utf16_estonian_ci,
"utf16_spanish_ci" => Collation::utf16_spanish_ci,
"utf16_swedish_ci" => Collation::utf16_swedish_ci,
"ascii_general_ci" => Collation::ascii_general_ci,
"utf16_turkish_ci" => Collation::utf16_turkish_ci,
"utf16_czech_ci" => Collation::utf16_czech_ci,
"utf16_danish_ci" => Collation::utf16_danish_ci,
"utf16_lithuanian_ci" => Collation::utf16_lithuanian_ci,
"utf16_slovak_ci" => Collation::utf16_slovak_ci,
"utf16_spanish2_ci" => Collation::utf16_spanish2_ci,
"utf16_roman_ci" => Collation::utf16_roman_ci,
"utf16_persian_ci" => Collation::utf16_persian_ci,
"utf16_esperanto_ci" => Collation::utf16_esperanto_ci,
"utf16_hungarian_ci" => Collation::utf16_hungarian_ci,
"ujis_japanese_ci" => Collation::ujis_japanese_ci,
"utf16_sinhala_ci" => Collation::utf16_sinhala_ci,
"utf16_german2_ci" => Collation::utf16_german2_ci,
"utf16_croatian_ci" => Collation::utf16_croatian_ci,
"utf16_unicode_520_ci" => Collation::utf16_unicode_520_ci,
"utf16_vietnamese_ci" => Collation::utf16_vietnamese_ci,
"ucs2_unicode_ci" => Collation::ucs2_unicode_ci,
"ucs2_icelandic_ci" => Collation::ucs2_icelandic_ci,
"sjis_japanese_ci" => Collation::sjis_japanese_ci,
"ucs2_latvian_ci" => Collation::ucs2_latvian_ci,
"ucs2_romanian_ci" => Collation::ucs2_romanian_ci,
"ucs2_slovenian_ci" => Collation::ucs2_slovenian_ci,
"ucs2_polish_ci" => Collation::ucs2_polish_ci,
"ucs2_estonian_ci" => Collation::ucs2_estonian_ci,
"ucs2_spanish_ci" => Collation::ucs2_spanish_ci,
"ucs2_swedish_ci" => Collation::ucs2_swedish_ci,
"ucs2_turkish_ci" => Collation::ucs2_turkish_ci,
"ucs2_czech_ci" => Collation::ucs2_czech_ci,
"ucs2_danish_ci" => Collation::ucs2_danish_ci,
"cp1251_bulgarian_ci" => Collation::cp1251_bulgarian_ci,
"ucs2_lithuanian_ci" => Collation::ucs2_lithuanian_ci,
"ucs2_slovak_ci" => Collation::ucs2_slovak_ci,
"ucs2_spanish2_ci" => Collation::ucs2_spanish2_ci,
"ucs2_roman_ci" => Collation::ucs2_roman_ci,
"ucs2_persian_ci" => Collation::ucs2_persian_ci,
"ucs2_esperanto_ci" => Collation::ucs2_esperanto_ci,
"ucs2_hungarian_ci" => Collation::ucs2_hungarian_ci,
"ucs2_sinhala_ci" => Collation::ucs2_sinhala_ci,
"ucs2_german2_ci" => Collation::ucs2_german2_ci,
"ucs2_croatian_ci" => Collation::ucs2_croatian_ci,
"latin1_danish_ci" => Collation::latin1_danish_ci,
"ucs2_unicode_520_ci" => Collation::ucs2_unicode_520_ci,
"ucs2_vietnamese_ci" => Collation::ucs2_vietnamese_ci,
"ucs2_general_mysql500_ci" => Collation::ucs2_general_mysql500_ci,
"hebrew_general_ci" => Collation::hebrew_general_ci,
"utf32_unicode_ci" => Collation::utf32_unicode_ci,
"utf32_icelandic_ci" => Collation::utf32_icelandic_ci,
"utf32_latvian_ci" => Collation::utf32_latvian_ci,
"utf32_romanian_ci" => Collation::utf32_romanian_ci,
"utf32_slovenian_ci" => Collation::utf32_slovenian_ci,
"utf32_polish_ci" => Collation::utf32_polish_ci,
"utf32_estonian_ci" => Collation::utf32_estonian_ci,
"utf32_spanish_ci" => Collation::utf32_spanish_ci,
"utf32_swedish_ci" => Collation::utf32_swedish_ci,
"utf32_turkish_ci" => Collation::utf32_turkish_ci,
"utf32_czech_ci" => Collation::utf32_czech_ci,
"utf32_danish_ci" => Collation::utf32_danish_ci,
"utf32_lithuanian_ci" => Collation::utf32_lithuanian_ci,
"utf32_slovak_ci" => Collation::utf32_slovak_ci,
"utf32_spanish2_ci" => Collation::utf32_spanish2_ci,
"utf32_roman_ci" => Collation::utf32_roman_ci,
"utf32_persian_ci" => Collation::utf32_persian_ci,
"utf32_esperanto_ci" => Collation::utf32_esperanto_ci,
"utf32_hungarian_ci" => Collation::utf32_hungarian_ci,
"utf32_sinhala_ci" => Collation::utf32_sinhala_ci,
"tis620_thai_ci" => Collation::tis620_thai_ci,
"utf32_german2_ci" => Collation::utf32_german2_ci,
"utf32_croatian_ci" => Collation::utf32_croatian_ci,
"utf32_unicode_520_ci" => Collation::utf32_unicode_520_ci,
"utf32_vietnamese_ci" => Collation::utf32_vietnamese_ci,
"euckr_korean_ci" => Collation::euckr_korean_ci,
"utf8_unicode_ci" => Collation::utf8_unicode_ci,
"utf8_icelandic_ci" => Collation::utf8_icelandic_ci,
"utf8_latvian_ci" => Collation::utf8_latvian_ci,
"utf8_romanian_ci" => Collation::utf8_romanian_ci,
"utf8_slovenian_ci" => Collation::utf8_slovenian_ci,
"utf8_polish_ci" => Collation::utf8_polish_ci,
"utf8_estonian_ci" => Collation::utf8_estonian_ci,
"utf8_spanish_ci" => Collation::utf8_spanish_ci,
"latin2_czech_cs" => Collation::latin2_czech_cs,
"latin7_estonian_cs" => Collation::latin7_estonian_cs,
"utf8_swedish_ci" => Collation::utf8_swedish_ci,
"utf8_turkish_ci" => Collation::utf8_turkish_ci,
"utf8_czech_ci" => Collation::utf8_czech_ci,
"utf8_danish_ci" => Collation::utf8_danish_ci,
"utf8_lithuanian_ci" => Collation::utf8_lithuanian_ci,
"utf8_slovak_ci" => Collation::utf8_slovak_ci,
"utf8_spanish2_ci" => Collation::utf8_spanish2_ci,
"utf8_roman_ci" => Collation::utf8_roman_ci,
"utf8_persian_ci" => Collation::utf8_persian_ci,
"utf8_esperanto_ci" => Collation::utf8_esperanto_ci,
"latin2_hungarian_ci" => Collation::latin2_hungarian_ci,
"utf8_hungarian_ci" => Collation::utf8_hungarian_ci,
"utf8_sinhala_ci" => Collation::utf8_sinhala_ci,
"utf8_german2_ci" => Collation::utf8_german2_ci,
"utf8_croatian_ci" => Collation::utf8_croatian_ci,
"utf8_unicode_520_ci" => Collation::utf8_unicode_520_ci,
"utf8_vietnamese_ci" => Collation::utf8_vietnamese_ci,
"koi8u_general_ci" => Collation::koi8u_general_ci,
"utf8_general_mysql500_ci" => Collation::utf8_general_mysql500_ci,
"utf8mb4_unicode_ci" => Collation::utf8mb4_unicode_ci,
"utf8mb4_icelandic_ci" => Collation::utf8mb4_icelandic_ci,
"utf8mb4_latvian_ci" => Collation::utf8mb4_latvian_ci,
"utf8mb4_romanian_ci" => Collation::utf8mb4_romanian_ci,
"utf8mb4_slovenian_ci" => Collation::utf8mb4_slovenian_ci,
"utf8mb4_polish_ci" => Collation::utf8mb4_polish_ci,
"cp1251_ukrainian_ci" => Collation::cp1251_ukrainian_ci,
"utf8mb4_estonian_ci" => Collation::utf8mb4_estonian_ci,
"utf8mb4_spanish_ci" => Collation::utf8mb4_spanish_ci,
"utf8mb4_swedish_ci" => Collation::utf8mb4_swedish_ci,
"utf8mb4_turkish_ci" => Collation::utf8mb4_turkish_ci,
"utf8mb4_czech_ci" => Collation::utf8mb4_czech_ci,
"utf8mb4_danish_ci" => Collation::utf8mb4_danish_ci,
"utf8mb4_lithuanian_ci" => Collation::utf8mb4_lithuanian_ci,
"utf8mb4_slovak_ci" => Collation::utf8mb4_slovak_ci,
"utf8mb4_spanish2_ci" => Collation::utf8mb4_spanish2_ci,
"utf8mb4_roman_ci" => Collation::utf8mb4_roman_ci,
"gb2312_chinese_ci" => Collation::gb2312_chinese_ci,
"utf8mb4_persian_ci" => Collation::utf8mb4_persian_ci,
"utf8mb4_esperanto_ci" => Collation::utf8mb4_esperanto_ci,
"utf8mb4_hungarian_ci" => Collation::utf8mb4_hungarian_ci,
"utf8mb4_sinhala_ci" => Collation::utf8mb4_sinhala_ci,
"utf8mb4_german2_ci" => Collation::utf8mb4_german2_ci,
"utf8mb4_croatian_ci" => Collation::utf8mb4_croatian_ci,
"utf8mb4_unicode_520_ci" => Collation::utf8mb4_unicode_520_ci,
"utf8mb4_vietnamese_ci" => Collation::utf8mb4_vietnamese_ci,
"gb18030_chinese_ci" => Collation::gb18030_chinese_ci,
"gb18030_bin" => Collation::gb18030_bin,
"greek_general_ci" => Collation::greek_general_ci,
"gb18030_unicode_520_ci" => Collation::gb18030_unicode_520_ci,
"utf8mb4_0900_ai_ci" => Collation::utf8mb4_0900_ai_ci,
"cp1250_general_ci" => Collation::cp1250_general_ci,
"latin2_croatian_ci" => Collation::latin2_croatian_ci,
"gbk_chinese_ci" => Collation::gbk_chinese_ci,
"cp1257_lithuanian_ci" => Collation::cp1257_lithuanian_ci,
"dec8_swedish_ci" => Collation::dec8_swedish_ci,
"latin5_turkish_ci" => Collation::latin5_turkish_ci,
"latin1_german2_ci" => Collation::latin1_german2_ci,
"armscii8_general_ci" => Collation::armscii8_general_ci,
"utf8_general_ci" => Collation::utf8_general_ci,
"cp1250_czech_cs" => Collation::cp1250_czech_cs,
"ucs2_general_ci" => Collation::ucs2_general_ci,
"cp866_general_ci" => Collation::cp866_general_ci,
"keybcs2_general_ci" => Collation::keybcs2_general_ci,
"macce_general_ci" => Collation::macce_general_ci,
"macroman_general_ci" => Collation::macroman_general_ci,
"cp850_general_ci" => Collation::cp850_general_ci,
"cp852_general_ci" => Collation::cp852_general_ci,
"latin7_general_ci" => Collation::latin7_general_ci,
"latin7_general_cs" => Collation::latin7_general_cs,
"macce_bin" => Collation::macce_bin,
"cp1250_croatian_ci" => Collation::cp1250_croatian_ci,
"utf8mb4_general_ci" => Collation::utf8mb4_general_ci,
"utf8mb4_bin" => Collation::utf8mb4_bin,
"latin1_bin" => Collation::latin1_bin,
"latin1_general_ci" => Collation::latin1_general_ci,
"latin1_general_cs" => Collation::latin1_general_cs,
"latin1_german1_ci" => Collation::latin1_german1_ci,
"cp1251_bin" => Collation::cp1251_bin,
"cp1251_general_ci" => Collation::cp1251_general_ci,
"cp1251_general_cs" => Collation::cp1251_general_cs,
"macroman_bin" => Collation::macroman_bin,
"utf16_general_ci" => Collation::utf16_general_ci,
"utf16_bin" => Collation::utf16_bin,
"utf16le_general_ci" => Collation::utf16le_general_ci,
"cp1256_general_ci" => Collation::cp1256_general_ci,
"cp1257_bin" => Collation::cp1257_bin,
"cp1257_general_ci" => Collation::cp1257_general_ci,
"hp8_english_ci" => Collation::hp8_english_ci,
"utf32_general_ci" => Collation::utf32_general_ci,
"utf32_bin" => Collation::utf32_bin,
"utf16le_bin" => Collation::utf16le_bin,
"binary" => Collation::binary,
"armscii8_bin" => Collation::armscii8_bin,
"ascii_bin" => Collation::ascii_bin,
"cp1250_bin" => Collation::cp1250_bin,
"cp1256_bin" => Collation::cp1256_bin,
"cp866_bin" => Collation::cp866_bin,
"dec8_bin" => Collation::dec8_bin,
"koi8r_general_ci" => Collation::koi8r_general_ci,
"greek_bin" => Collation::greek_bin,
"hebrew_bin" => Collation::hebrew_bin,
"hp8_bin" => Collation::hp8_bin,
"keybcs2_bin" => Collation::keybcs2_bin,
"koi8r_bin" => Collation::koi8r_bin,
"koi8u_bin" => Collation::koi8u_bin,
"utf8_tolower_ci" => Collation::utf8_tolower_ci,
"latin2_bin" => Collation::latin2_bin,
"latin5_bin" => Collation::latin5_bin,
"latin7_bin" => Collation::latin7_bin,
"latin1_swedish_ci" => Collation::latin1_swedish_ci,
"cp850_bin" => Collation::cp850_bin,
"cp852_bin" => Collation::cp852_bin,
"swe7_bin" => Collation::swe7_bin,
"utf8_bin" => Collation::utf8_bin,
"big5_bin" => Collation::big5_bin,
"euckr_bin" => Collation::euckr_bin,
"gb2312_bin" => Collation::gb2312_bin,
"gbk_bin" => Collation::gbk_bin,
"sjis_bin" => Collation::sjis_bin,
"tis620_bin" => Collation::tis620_bin,
"latin2_general_ci" => Collation::latin2_general_ci,
"ucs2_bin" => Collation::ucs2_bin,
"ujis_bin" => Collation::ujis_bin,
"geostd8_general_ci" => Collation::geostd8_general_ci,
"geostd8_bin" => Collation::geostd8_bin,
"latin1_spanish_ci" => Collation::latin1_spanish_ci,
"cp932_japanese_ci" => Collation::cp932_japanese_ci,
"cp932_bin" => Collation::cp932_bin,
"eucjpms_japanese_ci" => Collation::eucjpms_japanese_ci,
"eucjpms_bin" => Collation::eucjpms_bin,
"cp1250_polish_ci" => Collation::cp1250_polish_ci,
_ => {
return Err(Error::Configuration(
format!("unsupported MySQL collation: {}", collation).into(),
));
}
})
}
}

View file

@ -0,0 +1,40 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::BufExt;
pub trait MySqlBufExt: Buf {
// Read a length-encoded integer.
// NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL.
// NOTE: 0xff is only returned during a result set to indicate ERR.
// <https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger>
fn get_uint_lenenc(&mut self) -> u64;
// Read a length-encoded string.
fn get_str_lenenc(&mut self) -> Result<String, Error>;
// Read a length-encoded byte sequence.
fn get_bytes_lenenc(&mut self) -> Bytes;
}
impl MySqlBufExt for Bytes {
fn get_uint_lenenc(&mut self) -> u64 {
match self.get_u8() {
0xfc => u64::from(self.get_u16_le()),
0xfd => self.get_uint_le(3),
0xfe => self.get_u64_le(),
v => u64::from(v),
}
}
fn get_str_lenenc(&mut self) -> Result<String, Error> {
let size = self.get_uint_lenenc();
self.get_str(size as usize)
}
fn get_bytes_lenenc(&mut self) -> Bytes {
let size = self.get_uint_lenenc();
self.split_to(size as usize)
}
}

View file

@ -0,0 +1,126 @@
use bytes::BufMut;
pub trait MySqlBufMutExt: BufMut {
fn put_uint_lenenc(&mut self, v: u64);
fn put_str_lenenc(&mut self, v: &str);
fn put_bytes_lenenc(&mut self, v: &[u8]);
}
impl MySqlBufMutExt for Vec<u8> {
fn put_uint_lenenc(&mut self, v: u64) {
// https://dev.mysql.com/doc/internals/en/integer.html
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
if v < 251 {
self.push(v as u8);
} else if v < 0x1_00_00 {
self.push(0xfc);
self.extend(&(v as u16).to_le_bytes());
} else if v < 0x1_00_00_00 {
self.push(0xfd);
self.extend(&(v as u32).to_le_bytes()[..3]);
} else {
self.push(0xfe);
self.extend(&v.to_le_bytes());
}
}
fn put_str_lenenc(&mut self, v: &str) {
self.put_bytes_lenenc(v.as_bytes());
}
fn put_bytes_lenenc(&mut self, v: &[u8]) {
self.put_uint_lenenc(v.len() as u64);
self.extend(v);
}
}
#[test]
fn test_encodes_int_lenenc_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFA_u64);
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn test_encodes_int_lenenc_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(std::u16::MAX as u64);
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFF_FF_FF_u64);
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(std::u64::MAX);
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_fb() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFB_u64);
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn test_encodes_int_lenenc_fc() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFC_u64);
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn test_encodes_int_lenenc_fd() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFD_u64);
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn test_encodes_int_lenenc_fe() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFE_u64);
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
#[test]
fn test_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFF_u64);
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn test_encodes_string_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_lenenc("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn test_encodes_byte_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_bytes_lenenc(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}

View file

@ -0,0 +1,5 @@
mod buf;
mod buf_mut;
pub use buf::MySqlBufExt;
pub use buf_mut::MySqlBufMutExt;

View file

@ -0,0 +1,5 @@
//! **MySQL** database driver.
pub mod collation;
pub mod io;
pub mod protocol;

View file

@ -0,0 +1,38 @@
use std::str::FromStr;
use crate::err_protocol;
use crate::error::Error;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AuthPlugin {
MySqlClearPassword,
MySqlNativePassword,
CachingSha2Password,
Sha256Password,
}
impl AuthPlugin {
pub(crate) fn name(self) -> &'static str {
match self {
AuthPlugin::MySqlClearPassword => "mysql_clear_password",
AuthPlugin::MySqlNativePassword => "mysql_native_password",
AuthPlugin::CachingSha2Password => "caching_sha2_password",
AuthPlugin::Sha256Password => "sha256_password",
}
}
}
impl FromStr for AuthPlugin {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"mysql_clear_password" => Ok(AuthPlugin::MySqlClearPassword),
"mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword),
"caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password),
"sha256_password" => Ok(AuthPlugin::Sha256Password),
_ => Err(err_protocol!("unknown authentication plugin: {}", s)),
}
}
}

View file

@ -0,0 +1,86 @@
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html
// https://mariadb.com/kb/en/library/connection/#capabilities
bitflags::bitflags! {
pub struct Capabilities: u64 {
// [MariaDB] MySQL compatibility
const MYSQL = 1;
// [*] Send found rows instead of affected rows in EOF_Packet.
const FOUND_ROWS = 2;
// Get all column flags.
const LONG_FLAG = 4;
// [*] Database (schema) name can be specified on connect in Handshake Response Packet.
const CONNECT_WITH_DB = 8;
// Don't allow database.table.column
const NO_SCHEMA = 16;
// [*] Compression protocol supported
const COMPRESS = 32;
// Special handling of ODBC behavior.
const ODBC = 64;
// Can use LOAD DATA LOCAL
const LOCAL_FILES = 128;
// [*] Ignore spaces before '('
const IGNORE_SPACE = 256;
// [*] New 4.1+ protocol
const PROTOCOL_41 = 512;
// This is an interactive client
const INTERACTIVE = 1024;
// Use SSL encryption for this session
const SSL = 2048;
// Client knows about transactions
const TRANSACTIONS = 8192;
// 4.1+ authentication
const SECURE_CONNECTION = (1 << 15);
// Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE
const MULTI_STATEMENTS = (1 << 16);
// Enable/disable multi-results for COM_QUERY
const MULTI_RESULTS = (1 << 17);
// Enable/disable multi-results for COM_STMT_PREPARE
const PS_MULTI_RESULTS = (1 << 18);
// Client supports plugin authentication
const PLUGIN_AUTH = (1 << 19);
// Client supports connection attributes
const CONNECT_ATTRS = (1 << 20);
// Enable authentication response packet to be larger than 255 bytes.
const PLUGIN_AUTH_LENENC_DATA = (1 << 21);
// Don't close the connection for a user account with expired password.
const CAN_HANDLE_EXPIRED_PASSWORDS = (1 << 22);
// Capable of handling server state change information.
const SESSION_TRACK = (1 << 23);
// Client no longer needs EOF_Packet and will use OK_Packet instead.
const DEPRECATE_EOF = (1 << 24);
// Support ZSTD protocol compression
const ZSTD_COMPRESSION_ALGORITHM = (1 << 26);
// Verify server certificate
const SSL_VERIFY_SERVER_CERT = (1 << 30);
// The client can handle optional metadata information in the resultset
const OPTIONAL_RESULTSET_METADATA = (1 << 25);
// Don't reset the options after an unsuccessful connect
const REMEMBER_OPTIONS = (1 << 31);
}
}

View file

@ -0,0 +1,58 @@
use bytes::{Buf, BufMut, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
#[derive(Debug)]
pub struct AuthSwitchRequest {
pub plugin: AuthPlugin,
pub data: Bytes,
}
impl Decode<'_> for AuthSwitchRequest {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
return Err(err_protocol!(
"expected 0xfe (AUTH_SWITCH) but found 0x{:x}",
header
));
}
let plugin = buf.get_str_nul()?.parse()?;
// See: https://github.com/mysql/mysql-server/blob/ea7d2e2d16ac03afdd9cb72a972a95981107bf51/sql/auth/sha2_password.cc#L942
if buf.len() != 21 {
return Err(err_protocol!(
"expected 21 bytes but found {} bytes",
buf.len()
));
}
let data = buf.get_bytes(20);
buf.advance(1); // NUL-terminator
Ok(Self { plugin, data })
}
}
impl Encode<'_, ()> for AuthSwitchRequest {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(0xfe);
buf.put_str_nul(self.plugin.name());
buf.extend(&self.data);
}
}
#[derive(Debug)]
pub struct AuthSwitchResponse(pub Vec<u8>);
impl Encode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.extend_from_slice(&self.0);
}
}

View file

@ -0,0 +1,233 @@
use bytes::buf::Chain;
use bytes::{Buf, BufMut, Bytes};
use crate::error::Error;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
#[derive(Debug)]
pub struct Handshake {
#[allow(unused)]
pub protocol_version: u8,
pub server_version: String,
#[allow(unused)]
pub connection_id: u32,
pub server_capabilities: Capabilities,
#[allow(unused)]
pub server_default_collation: u8,
#[allow(unused)]
pub status: Status,
pub auth_plugin: Option<AuthPlugin>,
pub auth_plugin_data: Chain<Bytes, Bytes>,
}
impl Decode<'_> for Handshake {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let protocol_version = buf.get_u8(); // int<1>
let server_version = buf.get_str_nul()?; // string<NUL>
let connection_id = buf.get_u32_le(); // int<4>
let auth_plugin_data_1 = buf.get_bytes(8); // string<8>
buf.advance(1); // reserved: string<1>
let capabilities_1 = buf.get_u16_le(); // int<2>
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
let collation = buf.get_u8(); // int<1>
let status = Status::from_bits_truncate(buf.get_u16_le());
let capabilities_2 = buf.get_u16_le(); // int<2>
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.get_u8()
} else {
buf.advance(1); // int<1>
0
};
buf.advance(6); // reserved: string<6>
if capabilities.contains(Capabilities::MYSQL) {
buf.advance(4); // reserved: string<4>
} else {
let capabilities_3 = buf.get_u32_le(); // int<4>
capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
}
let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize;
let v = buf.get_bytes(len);
buf.advance(1); // NUL-terminator
v
} else {
Bytes::new()
};
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
Some(buf.get_str_nul()?.parse()?)
} else {
None
};
Ok(Self {
protocol_version,
server_version,
connection_id,
server_default_collation: collation,
status,
server_capabilities: capabilities,
auth_plugin,
auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
})
}
}
impl Encode<'_, ()> for Handshake {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(self.protocol_version);
buf.put_str_nul(&self.server_version);
buf.put_u32_le(self.connection_id);
buf.put_slice(self.auth_plugin_data.first_ref());
buf.put_u8(0x00);
buf.put_u16_le((self.server_capabilities.bits() & 0x0000_FFFF) as u16);
buf.put_u8(self.server_default_collation);
buf.put_u16_le(self.status.bits());
buf.put_u16_le(((self.server_capabilities.bits() & 0xFFFF_0000) >> 16) as u16);
if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.put_u8((self.auth_plugin_data.last_ref().len() + 8 + 1) as u8);
} else {
buf.put_u8(0);
}
buf.put_slice(&[0_u8; 10][..]);
if self
.server_capabilities
.contains(Capabilities::SECURE_CONNECTION)
{
buf.put_slice(self.auth_plugin_data.last_ref());
buf.put_u8(0);
}
if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) {
if let Some(auth_plugin) = self.auth_plugin {
buf.put_str_nul(auth_plugin.name());
}
}
}
}
#[test]
fn test_decode_handshake_mysql_8_0_18() {
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
assert_eq!(p.protocol_version, 10);
p.server_capabilities.toggle(
Capabilities::MYSQL
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::SSL
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
| Capabilities::SSL_VERIFY_SERVER_CERT
| Capabilities::OPTIONAL_RESULTSET_METADATA
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 255);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(
p.auth_plugin,
Some(AuthPlugin::CachingSha2Password)
));
assert_eq!(
&*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
);
}
#[test]
fn test_decode_handshake_mariadb_10_4_7() {
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
let mut p = Handshake::decode(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
assert_eq!(p.protocol_version, 10);
assert_eq!(
&*p.server_version,
"5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
);
p.server_capabilities.toggle(
Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 8);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(
p.auth_plugin,
Some(AuthPlugin::MySqlNativePassword)
));
assert_eq!(
&*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
&[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,]
);
}

View file

@ -0,0 +1,147 @@
use std::str::FromStr;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::connect::ssl_request::SslRequest;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
// https://mariadb.com/kb/en/connection/#client-handshake-response
#[derive(Debug)]
pub struct HandshakeResponse {
pub database: Option<String>,
/// Max size of a command packet that the client wants to send to the server
pub max_packet_size: u32,
/// Default collation for the connection
pub collation: u8,
/// Name of the SQL account which client wants to log in
pub username: String,
/// Authentication method used by the client
pub auth_plugin: Option<AuthPlugin>,
/// Opaque authentication response
pub auth_response: Option<Bytes>,
}
impl Encode<'_, Capabilities> for HandshakeResponse {
fn encode_with(&self, buf: &mut Vec<u8>, mut capabilities: Capabilities) {
if self.auth_plugin.is_none() {
// ensure PLUGIN_AUTH is set *only* if we have a defined plugin
capabilities.remove(Capabilities::PLUGIN_AUTH);
}
// NOTE: Half of this packet is identical to the SSL Request packet
SslRequest {
max_packet_size: self.max_packet_size,
collation: self.collation,
}
.encode_with(buf, capabilities);
buf.put_str_nul(&self.username);
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
if let Some(response) = &self.auth_response {
buf.put_bytes_lenenc(response);
} else {
buf.put_bytes_lenenc(&[]);
}
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
if let Some(response) = &self.auth_response {
buf.push(response.len() as u8);
buf.extend(response);
} else {
buf.push(0);
}
} else {
buf.push(0);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if let Some(database) = &self.database {
buf.put_str_nul(database);
} else {
buf.push(0);
}
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
if let Some(plugin) = &self.auth_plugin {
buf.put_str_nul(plugin.name());
} else {
buf.push(0);
}
}
}
}
impl Decode<'_, &mut Capabilities> for HandshakeResponse {
fn decode_with(mut buf: Bytes, server_capabilities: &mut Capabilities) -> Result<Self, Error> {
let mut capabilities = buf.get_u32_le() as u64;
let max_packet_size = buf.get_u32_le();
let collation = buf.get_u8();
buf.advance(19);
let partial_cap = Capabilities::from_bits_truncate(capabilities);
if partial_cap.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.advance(4);
} else {
capabilities += (buf.get_u32_le() as u64) << 32;
}
let partial_cap = Capabilities::from_bits_truncate(capabilities);
if partial_cap.contains(Capabilities::SSL) && buf.is_empty() {
return Ok(HandshakeResponse {
collation,
max_packet_size,
username: "".to_string(),
auth_response: None,
auth_plugin: None,
database: None,
});
}
let username = buf.get_str_nul()?;
let auth_response = if partial_cap.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
Some(buf.get_bytes_lenenc())
} else if partial_cap.contains(Capabilities::SECURE_CONNECTION) {
let len = buf.get_u8();
Some(buf.get_bytes(len as usize))
} else {
Some(buf.get_bytes_nul()?)
};
let database = if partial_cap.contains(Capabilities::CONNECT_WITH_DB) {
Some(buf.get_str_nul()?)
} else {
None
};
let auth_plugin: Option<AuthPlugin> = if partial_cap.contains(Capabilities::PLUGIN_AUTH) {
Some(AuthPlugin::from_str(&buf.get_str_nul()?)?)
} else {
None
};
*server_capabilities &= Capabilities::from_bits_truncate(capabilities);
Ok(HandshakeResponse {
collation,
max_packet_size,
username,
auth_response,
auth_plugin,
database,
})
}
}

View file

@ -0,0 +1,13 @@
//! Connection Phase
//!
//! <https://dev.mysql.com/doc/internals/en/connection-phase.html>
mod auth_switch;
mod handshake;
mod handshake_response;
mod ssl_request;
pub use auth_switch::{AuthSwitchRequest, AuthSwitchResponse};
pub use handshake::Handshake;
pub use handshake_response::HandshakeResponse;
pub use ssl_request::SslRequest;

View file

@ -0,0 +1,30 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
#[derive(Debug)]
pub struct SslRequest {
pub max_packet_size: u32,
pub collation: u8,
}
impl Encode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
buf.extend(&(capabilities.bits() as u32).to_le_bytes());
buf.extend(&self.max_packet_size.to_le_bytes());
buf.push(self.collation);
// reserved: string<19>
buf.extend(&[0_u8; 19]);
if capabilities.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.extend(&[0_u8; 4]);
} else {
// extended client capabilities (MariaDB-specified): int<4>
buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes());
}
}
}

View file

@ -0,0 +1,11 @@
pub mod auth;
pub mod capabilities;
pub mod connect;
pub mod packet;
pub mod response;
pub mod row;
pub mod text;
pub use capabilities::Capabilities;
pub use packet::Packet;
pub use row::Row;

View file

@ -0,0 +1,89 @@
use std::ops::{Deref, DerefMut};
use bytes::Bytes;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::mysql::protocol::response::{EofPacket, OkPacket};
use crate::mysql::protocol::Capabilities;
#[derive(Debug)]
pub struct Packet<T>(pub T);
impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
where
T: Encode<'en, Capabilities>,
{
fn encode_with(
&self,
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
// reserve space to write the prefixed length
let offset = buf.len();
buf.extend(&[0_u8; 4]);
// encode the payload
self.0.encode_with(buf, capabilities);
// determine the length of the encoded payload
// and write to our reserved space
let len = buf.len() - offset - 4;
let header = &mut buf[offset..];
// FIXME: Support larger packets
assert!(len < 0xFF_FF_FF);
header[..4].copy_from_slice(&(len as u32).to_le_bytes());
header[3] = *sequence_id;
*sequence_id = sequence_id.wrapping_add(1);
}
}
impl Packet<Bytes> {
pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
where
T: Decode<'de, ()>,
{
self.decode_with(())
}
pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
{
T::decode_with(self.0, context)
}
pub(crate) fn ok(self) -> Result<OkPacket, Error> {
self.decode()
}
pub(crate) fn eof(self, capabilities: Capabilities) -> Result<EofPacket, Error> {
if capabilities.contains(Capabilities::DEPRECATE_EOF) {
let ok = self.ok()?;
Ok(EofPacket {
warnings: ok.warnings,
status: ok.status,
})
} else {
self.decode_with(capabilities)
}
}
}
impl Deref for Packet<Bytes> {
type Target = Bytes;
fn deref(&self) -> &Bytes {
&self.0
}
}
impl DerefMut for Packet<Bytes> {
fn deref_mut(&mut self) -> &mut Bytes {
&mut self.0
}
}

View file

@ -0,0 +1,36 @@
use bytes::{Buf, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::Capabilities;
/// Marks the end of a result set, returning status and warnings.
///
/// # Note
///
/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL
/// prior MySQL versions.
#[derive(Debug)]
pub struct EofPacket {
pub warnings: u16,
pub status: Status,
}
impl Decode<'_, Capabilities> for EofPacket {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
return Err(err_protocol!(
"expected 0xfe (EOF_Packet) but found 0x{:x}",
header
));
}
let warnings = buf.get_u16_le();
let status = Status::from_bits_truncate(buf.get_u16_le());
Ok(Self { status, warnings })
}
}

View file

@ -0,0 +1,81 @@
use bytes::{Buf, BufMut, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::{BufExt, Decode, Encode};
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
// https://mariadb.com/kb/en/err_packet/
/// Indicates that an error occurred.
#[derive(Debug)]
pub struct ErrPacket {
pub error_code: u16,
pub sql_state: Option<String>,
pub error_message: String,
}
impl Decode<'_, Capabilities> for ErrPacket {
fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xff {
return Err(err_protocol!(
"expected 0xff (ERR_Packet) but found 0x{:x}",
header
));
}
let error_code = buf.get_u16_le();
let mut sql_state = None;
if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) {
buf.advance(1);
sql_state = Some(buf.get_str(5)?);
}
}
let error_message = buf.get_str(buf.len())?;
Ok(Self {
error_code,
sql_state,
error_message,
})
}
}
impl Encode<'_, ()> for ErrPacket {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(0xff);
buf.put_u16_le(self.error_code);
buf.extend_from_slice(self.error_message.as_bytes())
//TODO: sql_state
}
}
#[test]
fn test_decode_err_packet_out_of_order() {
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
let p =
ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap();
assert_eq!(&p.error_message, "Got packets out of order");
assert_eq!(p.error_code, 1156);
assert_eq!(p.sql_state, None);
}
#[test]
fn test_decode_err_packet_unknown_database() {
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
let p =
ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap();
assert_eq!(p.error_code, 1049);
assert_eq!(p.sql_state.as_deref(), Some("42000"));
assert_eq!(&p.error_message, "Unknown database \'unknown\'");
}

View file

@ -0,0 +1,14 @@
//! Generic Response Packets
//!
//! <https://dev.mysql.com/doc/internals/en/generic-response-packets.html>
//! <https://mariadb.com/kb/en/4-server-response-packets/>
mod eof;
mod err;
mod ok;
mod status;
pub use eof::EofPacket;
pub use err::ErrPacket;
pub use ok::OkPacket;
pub use status::Status;

View file

@ -0,0 +1,63 @@
use bytes::{Buf, BufMut, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt};
use crate::mysql::protocol::response::Status;
/// Indicates successful completion of a previous command sent by the client.
#[derive(Debug)]
pub struct OkPacket {
pub affected_rows: u64,
pub last_insert_id: u64,
pub status: Status,
pub warnings: u16,
}
impl Decode<'_> for OkPacket {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 && header != 0xfe {
return Err(err_protocol!(
"expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}",
header
));
}
let affected_rows = buf.get_uint_lenenc();
let last_insert_id = buf.get_uint_lenenc();
let status = Status::from_bits_truncate(buf.get_u16_le());
let warnings = buf.get_u16_le();
Ok(Self {
affected_rows,
last_insert_id,
status,
warnings,
})
}
}
impl Encode<'_, ()> for OkPacket {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(0);
buf.put_uint_lenenc(self.affected_rows);
buf.put_uint_lenenc(self.last_insert_id);
buf.put_u16_le(self.status.bits());
buf.put_u16_le(self.warnings);
}
}
#[test]
fn test_decode_ok_packet() {
const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00";
let p = OkPacket::decode(DATA.into()).unwrap();
assert_eq!(p.affected_rows, 0);
assert_eq!(p.last_insert_id, 0);
assert_eq!(p.warnings, 0);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
}

View file

@ -0,0 +1,49 @@
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad
// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status
bitflags::bitflags! {
pub struct Status: u16 {
// Is raised when a multi-statement transaction has been started, either explicitly,
// by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first
// transactional statement, when autocommit=off.
const SERVER_STATUS_IN_TRANS = 1;
// Autocommit mode is set
const SERVER_STATUS_AUTOCOMMIT = 2;
// Multi query - next query exists.
const SERVER_MORE_RESULTS_EXISTS = 8;
const SERVER_QUERY_NO_GOOD_INDEX_USED = 16;
const SERVER_QUERY_NO_INDEX_USED = 32;
// When using COM_STMT_FETCH, indicate that current cursor still has result
const SERVER_STATUS_CURSOR_EXISTS = 64;
// When using COM_STMT_FETCH, indicate that current cursor has finished to send results
const SERVER_STATUS_LAST_ROW_SENT = 128;
// Database has been dropped
const SERVER_STATUS_DB_DROPPED = (1 << 8);
// Current escape mode is "no backslash escape"
const SERVER_STATUS_NO_BACKSLASH_ESCAPES = (1 << 9);
// A DDL change did have an impact on an existing PREPARE (an automatic
// re-prepare has been executed)
const SERVER_STATUS_METADATA_CHANGED = (1 << 10);
// Last statement took more than the time value specified
// in server variable long_query_time.
const SERVER_QUERY_WAS_SLOW = (1 << 11);
// This result-set contain stored procedure output parameter.
const SERVER_PS_OUT_PARAMS = (1 << 12);
// Current transaction is a read-only transaction.
const SERVER_STATUS_IN_TRANS_READONLY = (1 << 13);
// This status flag, when on, implies that one of the state information has changed
// on the server because of the execution of the last statement.
const SERVER_SESSION_STATE_CHANGED = (1 << 14);
}
}

View file

@ -0,0 +1,17 @@
use std::ops::Range;
use bytes::Bytes;
#[derive(Debug)]
pub struct Row {
pub(crate) storage: Bytes,
pub(crate) values: Vec<Option<Range<usize>>>,
}
impl Row {
pub(crate) fn get(&self, index: usize) -> Option<&[u8]> {
self.values[index]
.as_ref()
.map(|col| &self.storage[(col.start as usize)..(col.end as usize)])
}
}

View file

@ -0,0 +1,265 @@
use std::str::from_utf8;
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
bitflags! {
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct ColumnFlags: u16 {
/// Field can't be `NULL`.
const NOT_NULL = 1;
/// Field is part of a primary key.
const PRIMARY_KEY = 2;
/// Field is part of a unique key.
const UNIQUE_KEY = 4;
/// Field is part of a multi-part unique or primary key.
const MULTIPLE_KEY = 8;
/// Field is a blob.
const BLOB = 16;
/// Field is unsigned.
const UNSIGNED = 32;
/// Field is zero filled.
const ZEROFILL = 64;
/// Field is binary.
const BINARY = 128;
/// Field is an enumeration.
const ENUM = 256;
/// Field is an auto-incement field.
const AUTO_INCREMENT = 512;
/// Field is a timestamp.
const TIMESTAMP = 1024;
/// Field is a set.
const SET = 2048;
/// Field does not have a default value.
const NO_DEFAULT_VALUE = 4096;
/// Field is set to NOW on UPDATE.
const ON_UPDATE_NOW = 8192;
/// Field is a number.
const NUM = 32768;
}
}
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub enum ColumnType {
Decimal = 0x00,
Tiny = 0x01,
Short = 0x02,
Long = 0x03,
Float = 0x04,
Double = 0x05,
Null = 0x06,
Timestamp = 0x07,
LongLong = 0x08,
Int24 = 0x09,
Date = 0x0a,
Time = 0x0b,
Datetime = 0x0c,
Year = 0x0d,
VarChar = 0x0f,
Bit = 0x10,
Json = 0xf5,
NewDecimal = 0xf6,
Enum = 0xf7,
Set = 0xf8,
TinyBlob = 0xf9,
MediumBlob = 0xfa,
LongBlob = 0xfb,
Blob = 0xfc,
VarString = 0xfd,
String = 0xfe,
Geometry = 0xff,
}
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html
// https://mariadb.com/kb/en/resultset/#column-definition-packet
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
#[derive(Debug)]
pub struct ColumnDefinition {
#[allow(unused)]
catalog: Bytes,
#[allow(unused)]
schema: Bytes,
#[allow(unused)]
table_alias: Bytes,
#[allow(unused)]
table: Bytes,
alias: Bytes,
name: Bytes,
pub(crate) char_set: u16,
pub(crate) max_size: u32,
pub(crate) r#type: ColumnType,
pub(crate) flags: ColumnFlags,
#[allow(unused)]
decimals: u8,
}
impl ColumnDefinition {
// NOTE: strings in-protocol are transmitted according to the client character set
// as this is UTF-8, all these strings should be UTF-8
pub(crate) fn name(&self) -> Result<&str, Error> {
from_utf8(&self.name).map_err(Error::protocol)
}
pub(crate) fn alias(&self) -> Result<&str, Error> {
from_utf8(&self.alias).map_err(Error::protocol)
}
}
impl Decode<'_, Capabilities> for ColumnDefinition {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let catalog = buf.get_bytes_lenenc();
let schema = buf.get_bytes_lenenc();
let table_alias = buf.get_bytes_lenenc();
let table = buf.get_bytes_lenenc();
let alias = buf.get_bytes_lenenc();
let name = buf.get_bytes_lenenc();
let _next_len = buf.get_uint_lenenc(); // always 0x0c
let char_set = buf.get_u16_le();
let max_size = buf.get_u32_le();
let type_id = buf.get_u8();
let flags = buf.get_u16_le();
let decimals = buf.get_u8();
Ok(Self {
catalog,
schema,
table_alias,
table,
alias,
name,
char_set,
max_size,
r#type: ColumnType::try_from_u16(type_id)?,
flags: ColumnFlags::from_bits_truncate(flags),
decimals,
})
}
}
impl ColumnType {
pub(crate) fn name(
self,
char_set: u16,
flags: ColumnFlags,
max_size: Option<u32>,
) -> &'static str {
let is_binary = char_set == 63;
let is_unsigned = flags.contains(ColumnFlags::UNSIGNED);
let is_enum = flags.contains(ColumnFlags::ENUM);
match self {
ColumnType::Tiny if max_size == Some(1) => "BOOLEAN",
ColumnType::Tiny if is_unsigned => "TINYINT UNSIGNED",
ColumnType::Short if is_unsigned => "SMALLINT UNSIGNED",
ColumnType::Long if is_unsigned => "INT UNSIGNED",
ColumnType::Int24 if is_unsigned => "MEDIUMINT UNSIGNED",
ColumnType::LongLong if is_unsigned => "BIGINT UNSIGNED",
ColumnType::Tiny => "TINYINT",
ColumnType::Short => "SMALLINT",
ColumnType::Long => "INT",
ColumnType::Int24 => "MEDIUMINT",
ColumnType::LongLong => "BIGINT",
ColumnType::Float => "FLOAT",
ColumnType::Double => "DOUBLE",
ColumnType::Null => "NULL",
ColumnType::Timestamp => "TIMESTAMP",
ColumnType::Date => "DATE",
ColumnType::Time => "TIME",
ColumnType::Datetime => "DATETIME",
ColumnType::Year => "YEAR",
ColumnType::Bit => "BIT",
ColumnType::Enum => "ENUM",
ColumnType::Set => "SET",
ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL",
ColumnType::Geometry => "GEOMETRY",
ColumnType::Json => "JSON",
ColumnType::String if is_binary => "BINARY",
ColumnType::String if is_enum => "ENUM",
ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY",
ColumnType::String => "CHAR",
ColumnType::VarChar | ColumnType::VarString => "VARCHAR",
ColumnType::TinyBlob if is_binary => "TINYBLOB",
ColumnType::TinyBlob => "TINYTEXT",
ColumnType::Blob if is_binary => "BLOB",
ColumnType::Blob => "TEXT",
ColumnType::MediumBlob if is_binary => "MEDIUMBLOB",
ColumnType::MediumBlob => "MEDIUMTEXT",
ColumnType::LongBlob if is_binary => "LONGBLOB",
ColumnType::LongBlob => "LONGTEXT",
}
}
pub(crate) fn try_from_u16(id: u8) -> Result<Self, Error> {
Ok(match id {
0x00 => ColumnType::Decimal,
0x01 => ColumnType::Tiny,
0x02 => ColumnType::Short,
0x03 => ColumnType::Long,
0x04 => ColumnType::Float,
0x05 => ColumnType::Double,
0x06 => ColumnType::Null,
0x07 => ColumnType::Timestamp,
0x08 => ColumnType::LongLong,
0x09 => ColumnType::Int24,
0x0a => ColumnType::Date,
0x0b => ColumnType::Time,
0x0c => ColumnType::Datetime,
0x0d => ColumnType::Year,
// [internal] 0x0e => ColumnType::NewDate,
0x0f => ColumnType::VarChar,
0x10 => ColumnType::Bit,
// [internal] 0x11 => ColumnType::Timestamp2,
// [internal] 0x12 => ColumnType::Datetime2,
// [internal] 0x13 => ColumnType::Time2,
0xf5 => ColumnType::Json,
0xf6 => ColumnType::NewDecimal,
0xf7 => ColumnType::Enum,
0xf8 => ColumnType::Set,
0xf9 => ColumnType::TinyBlob,
0xfa => ColumnType::MediumBlob,
0xfb => ColumnType::LongBlob,
0xfc => ColumnType::Blob,
0xfd => ColumnType::VarString,
0xfe => ColumnType::String,
0xff => ColumnType::Geometry,
_ => {
return Err(err_protocol!("unknown column type 0x{:02x}", id));
}
})
}
}

View file

@ -0,0 +1,9 @@
mod column;
mod ping;
mod query;
mod quit;
pub use column::{ColumnDefinition, ColumnFlags, ColumnType};
pub use ping::Ping;
pub use query::Query;
pub use quit::Quit;

View file

@ -0,0 +1,13 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-ping.html
#[derive(Debug)]
pub struct Ping;
impl Encode<'_, Capabilities> for Ping {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x0e); // COM_PING
}
}

View file

@ -0,0 +1,32 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode, Encode};
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-query.html
#[derive(Debug)]
pub struct Query(pub String);
impl Encode<'_, ()> for Query {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(0x03); // COM_QUERY
buf.extend(self.0.as_bytes())
}
}
impl Encode<'_, Capabilities> for Query {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x03); // COM_QUERY
buf.extend(self.0.as_bytes())
}
}
impl Decode<'_> for Query {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
buf.advance(1);
let q = buf.get_str(buf.len())?;
Ok(Query(q))
}
}

View file

@ -0,0 +1,13 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-quit.html
#[derive(Debug)]
pub struct Quit;
impl Encode<'_, Capabilities> for Quit {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x01); // COM_QUIT
}
}

View file

@ -18,7 +18,7 @@ poem-openapi = {version = "^2.0.4", features = ["swagger-ui"]}
reqwest = {version = "0.11", features = ["rustls-tls-native-roots", "stream"]}
serde = "1.0"
serde_json = "1.0"
tokio = {version = "1.19", features = ["tracing", "signal"]}
tokio = {version = "1.20", features = ["tracing", "signal"]}
tokio-tungstenite = {version = "0.17", features = ["rustls-tls-native-roots"]}
tracing = "0.1"
warpgate-admin = {version = "*", path = "../warpgate-admin"}
@ -27,3 +27,4 @@ warpgate-db-entities = {version = "*", path = "../warpgate-db-entities"}
warpgate-web = {version = "*", path = "../warpgate-web"}
percent-encoding = "2.1"
uuid = {version = "1.0", features = ["v4"]}
regex = "1.5"

View file

@ -1,5 +1,6 @@
use crate::common::SessionExt;
use crate::session::SessionMiddleware;
use anyhow::Context;
use poem::session::Session;
use poem::web::Data;
use poem::Request;
@ -66,7 +67,7 @@ impl Api {
config_provider
.authorize(&body.username, &credentials, crate::common::PROTOCOL_NAME)
.await
.map_err(|e| e.context("Failed to authorize user"))?
.context("Failed to authorize user")?
};
match result {

View file

@ -1,6 +1,3 @@
use std::net::ToSocketAddrs;
use crate::common::SessionExt;
use poem::session::Session;
use poem::web::Data;
use poem::Request;
@ -9,11 +6,14 @@ use poem_openapi::{ApiResponse, Object, OpenApi};
use serde::Serialize;
use warpgate_common::Services;
use crate::common::SessionExt;
pub struct Api;
#[derive(Serialize, Object)]
pub struct PortsInfo {
ssh: u16,
ssh: Option<u16>,
mysql: Option<u16>,
}
#[derive(Serialize, Object)]
@ -54,15 +54,22 @@ impl Api {
external_host: external_host.map(&str::to_string),
ports: if session.is_authenticated() {
PortsInfo {
ssh: config
.store
.ssh
.listen
.to_socket_addrs()
.map_or(0, |mut x| x.next().map(|x| x.port()).unwrap_or(0)),
ssh: if config.store.ssh.enable {
Some(config.store.ssh.listen.port())
} else {
None
},
mysql: if config.store.mysql.enable {
Some(config.store.mysql.listen.port())
} else {
None
},
}
} else {
PortsInfo { ssh: 0 }
PortsInfo {
ssh: None,
mysql: None,
}
},
})))
}

View file

@ -12,6 +12,7 @@ pub struct Api;
#[derive(Debug, Serialize, Clone, Enum)]
pub enum TargetKind {
Http,
MySql,
Ssh,
WebAdmin,
}
@ -68,6 +69,7 @@ impl Api {
kind: match t.options {
TargetOptions::Ssh(_) => TargetKind::Ssh,
TargetOptions::Http(_) => TargetKind::Http,
TargetOptions::MySql(_) => TargetKind::MySql,
TargetOptions::WebAdmin(_) => TargetKind::WebAdmin,
},
})

View file

@ -1,5 +1,6 @@
#![feature(type_alias_impl_trait, let_else, try_blocks)]
mod api;
use regex::Regex;
mod common;
mod session;
mod session_handle;
@ -12,5 +13,10 @@ pub fn main() {
env!("CARGO_PKG_VERSION"),
)
.server("/@warpgate/api");
println!("{}", api_service.spec());
let spec = api_service.spec();
let re = Regex::new(r"PaginatedResponse<(?P<name>\w+)>").unwrap();
let spec = re.replace_all(&spec, "Paginated$name");
println!("{}", spec);
}

View file

@ -0,0 +1,28 @@
[package]
edition = "2021"
license = "Apache-2.0"
name = "warpgate-protocol-mysql"
version = "0.3.0"
[dependencies]
warpgate-admin = { version = "*", path = "../warpgate-admin" }
warpgate-common = { version = "*", path = "../warpgate-common" }
warpgate-db-entities = { version = "*", path = "../warpgate-db-entities" }
warpgate-database-protocols = { version = "*", path = "../warpgate-database-protocols" }
anyhow = { version = "1.0", features = ["std"] }
async-trait = "0.1"
tokio = { version = "1.20", features = ["tracing", "signal"] }
tracing = "0.1"
uuid = { version = "1.0", features = ["v4"] }
bytes = "1.1"
mysql_common = "0.29"
rand = "0.8"
sha1 = "0.10.1"
password-hash = { version = "0.2", features = ["std"] }
delegate = "0.6"
rustls = { version = "0.20", features = ["dangerous_configuration"] }
rustls-pemfile = "1.0"
tokio-rustls = "0.23"
thiserror = "1.0"
webpki = "0.22"
webpki-roots = "0.22"

View file

@ -0,0 +1,150 @@
use std::sync::Arc;
use bytes::BytesMut;
use tokio::net::TcpStream;
use tracing::*;
use warpgate_common::{TargetMySqlOptions, TlsMode};
use warpgate_database_protocols::io::Decode;
use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin;
use warpgate_database_protocols::mysql::protocol::connect::{
Handshake, HandshakeResponse, SslRequest,
};
use warpgate_database_protocols::mysql::protocol::response::ErrPacket;
use warpgate_database_protocols::mysql::protocol::Capabilities;
use crate::common::compute_auth_challenge_response;
use crate::error::MySqlError;
use crate::stream::MySqlStream;
use crate::tls::configure_tls_connector;
pub struct MySqlClient {
pub stream: MySqlStream<tokio_rustls::client::TlsStream<TcpStream>>,
pub capabilities: Capabilities,
}
pub struct ConnectionOptions {
pub collation: u8,
pub database: Option<String>,
pub max_packet_size: u32,
pub capabilities: Capabilities,
}
impl MySqlClient {
pub async fn connect(
target: &TargetMySqlOptions,
mut options: ConnectionOptions,
) -> Result<Self, MySqlError> {
let mut stream =
MySqlStream::new(TcpStream::connect((target.host.clone(), target.port)).await?);
options.capabilities.remove(Capabilities::SSL);
if target.tls.mode != TlsMode::Disabled {
options.capabilities |= Capabilities::SSL;
}
let Some(payload) = stream.recv().await? else {
return Err(MySqlError::Eof)
};
let handshake = Handshake::decode(payload)?;
options.capabilities &= handshake.server_capabilities;
if target.tls.mode == TlsMode::Required && !options.capabilities.contains(Capabilities::SSL)
{
return Err(MySqlError::TlsNotSupported);
}
info!(capabilities=?options.capabilities, "Target handshake");
if options.capabilities.contains(Capabilities::SSL) && target.tls.mode != TlsMode::Disabled
{
let accept_invalid_certs = !target.tls.verify;
let accept_invalid_hostname = false; // ca + hostname verification
let client_config = Arc::new(
configure_tls_connector(accept_invalid_certs, accept_invalid_hostname, None)
.await?,
);
let req = SslRequest {
collation: options.collation,
max_packet_size: options.max_packet_size,
};
stream.push(&req, options.capabilities)?;
stream.flush().await?;
stream = stream
.upgrade((
target
.host
.as_str()
.try_into()
.map_err(|_| MySqlError::InvalidDomainName)?,
client_config,
))
.await?;
info!("Target connection upgraded to TLS");
}
let mut response = HandshakeResponse {
auth_plugin: None,
auth_response: None,
collation: options.collation,
database: options.database,
max_packet_size: options.max_packet_size,
username: target.username.clone(),
};
if handshake.auth_plugin == Some(AuthPlugin::MySqlNativePassword) {
let scramble_bytes = [
&handshake.auth_plugin_data.first_ref()[..],
&handshake.auth_plugin_data.last_ref()[..],
]
.concat();
match scramble_bytes.try_into() as Result<[u8; 20], Vec<u8>> {
Err(scramble_bytes) => {
warn!("Invalid scramble length ({})", scramble_bytes.len());
}
Ok(scramble) => {
response.auth_plugin = Some(AuthPlugin::MySqlNativePassword);
response.auth_response = Some(
BytesMut::from(
compute_auth_challenge_response(
scramble,
target.password.as_deref().unwrap_or(""),
)
.map_err(MySqlError::other)?
.as_bytes(),
)
.freeze(),
);
trace!(response=?response.auth_response, ?scramble, "auth");
}
}
}
stream.push(&response, options.capabilities)?;
stream.flush().await?;
let Some(response) = stream.recv().await? else {
return Err(MySqlError::Eof)
};
if response.get(0) == Some(&0) || response.get(0) == Some(&0xfe) {
debug!("Authorized");
} else if response.get(0) == Some(&0xff) {
let error = ErrPacket::decode_with(response, options.capabilities)?;
return Err(MySqlError::ProtocolError(format!(
"handshake failed: {:?}",
error
)));
} else {
return Err(MySqlError::ProtocolError(format!(
"unknown response type {:?}",
response.get(0)
)));
}
stream.reset_sequence_id();
Ok(Self {
stream,
capabilities: options.capabilities,
})
}
}

View file

@ -0,0 +1,25 @@
use sha1::Digest;
use warpgate_common::ProtocolName;
pub const PROTOCOL_NAME: ProtocolName = "MySQL";
pub fn compute_auth_challenge_response(
challenge: [u8; 20],
password: &str,
) -> Result<password_hash::Output, password_hash::Error> {
password_hash::Output::new(
&{
let password_sha: [u8; 20] = sha1::Sha1::digest(password).into();
let password_sha_sha: [u8; 20] = sha1::Sha1::digest(password_sha).into();
let password_seed_2sha_sha: [u8; 20] =
sha1::Sha1::digest([challenge, password_sha_sha].concat()).into();
let mut result = password_sha;
result
.iter_mut()
.zip(password_seed_2sha_sha.iter())
.for_each(|(x1, x2)| *x1 ^= *x2);
result
}[..],
)
}

View file

@ -0,0 +1,50 @@
use std::error::Error;
use warpgate_common::WarpgateError;
use warpgate_database_protocols::error::Error as SqlxError;
use crate::stream::MySqlStreamError;
use crate::tls::{MaybeTlsStreamError, RustlsSetupError};
#[derive(thiserror::Error, Debug)]
pub enum MySqlError {
#[error("protocol error: {0}")]
ProtocolError(String),
#[error("sudden disconnection")]
Eof,
#[error("server doesn't offer TLS")]
TlsNotSupported,
#[error("client doesn't support TLS")]
TlsNotSupportedByClient,
#[error("TLS setup failed: {0}")]
TlsSetup(#[from] RustlsSetupError),
#[error("TLS stream error: {0}")]
Tls(#[from] MaybeTlsStreamError),
#[error("Invalid domain name")]
InvalidDomainName,
#[error("sqlx error: {0}")]
Sqlx(#[from] SqlxError),
#[error("MySQL stream error: {0}")]
MySqlStream(#[from] MySqlStreamError),
#[error("I/O: {0}")]
Io(#[from] std::io::Error),
#[error("packet decode error: {0}")]
Decode(Box<dyn Error + Send + Sync>),
#[error(transparent)]
Warpgate(#[from] WarpgateError),
#[error(transparent)]
Other(Box<dyn Error + Send + Sync>),
}
impl MySqlError {
pub fn other<E: Error + Send + Sync + 'static>(err: E) -> Self {
Self::Other(Box::new(err))
}
pub fn decode(err: SqlxError) -> Self {
match err {
SqlxError::Decode(err) => Self::Decode(err),
_ => Self::Sqlx(err),
}
}
}

View file

@ -0,0 +1,108 @@
#![feature(type_alias_impl_trait, let_else, try_blocks)]
mod client;
mod common;
mod error;
mod session;
mod session_handle;
mod stream;
mod tls;
use std::fmt::Debug;
use std::net::SocketAddr;
use anyhow::{Context, Result};
use async_trait::async_trait;
use rustls::ServerConfig;
use tokio::net::TcpListener;
use tracing::*;
use warpgate_common::{ProtocolServer, Services, SessionStateInit, Target, TargetTestError};
use crate::session::MySqlSession;
use crate::session_handle::MySqlSessionHandle;
use crate::tls::FromCertificateAndKey;
pub struct MySQLProtocolServer {
services: Services,
}
impl MySQLProtocolServer {
pub async fn new(services: &Services) -> Result<Self> {
Ok(MySQLProtocolServer {
services: services.clone(),
})
}
}
#[async_trait]
impl ProtocolServer for MySQLProtocolServer {
async fn run(self, address: SocketAddr) -> Result<()> {
let (certificate, key) = {
let config = self.services.config.lock().await;
let certificate_path = config
.paths_relative_to
.join(&config.store.mysql.certificate);
let key_path = config.paths_relative_to.join(&config.store.mysql.key);
(
std::fs::read(&certificate_path).with_context(|| {
format!(
"reading SSL certificate from '{}'",
certificate_path.display()
)
})?,
std::fs::read(&key_path).with_context(|| {
format!("reading SSL private key from '{}'", key_path.display())
})?,
)
};
let tls_config = ServerConfig::try_from_certificate_and_key(certificate, key)?;
info!(?address, "Listening");
let listener = TcpListener::bind(address).await?;
loop {
let (stream, remote_address) = listener.accept().await?;
let tls_config = tls_config.clone();
let services = self.services.clone();
tokio::spawn(async move {
let (session_handle, mut abort_rx) = MySqlSessionHandle::new();
let server_handle = services
.state
.lock()
.await
.register_session(
&crate::common::PROTOCOL_NAME,
SessionStateInit {
remote_address: Some(remote_address),
handle: Box::new(session_handle),
},
)
.await?;
let session = MySqlSession::new(server_handle, services, stream, tls_config).await;
let span = session.make_logging_span();
tokio::select! {
result = session.run().instrument(span) => match result {
Ok(_) => info!("Session ended"),
Err(e) => error!(error=%e, "Session failed"),
},
_ = abort_rx.recv() => {
warn!("Session aborted by admin");
},
}
Ok::<(), anyhow::Error>(())
});
}
}
async fn test_target(self, _target: Target) -> Result<(), TargetTestError> {
Ok(())
}
}
impl Debug for MySQLProtocolServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MySQLProtocolServer")
}
}

View file

@ -0,0 +1,426 @@
use std::sync::Arc;
use bytes::{Buf, Bytes, BytesMut};
use rand::Rng;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::*;
use uuid::Uuid;
use warpgate_common::auth::AuthSelector;
use warpgate_common::helpers::rng::get_crypto_rng;
use warpgate_common::{
authorize_ticket, AuthCredential, AuthResult, Secret, Services, TargetMySqlOptions,
TargetOptions, WarpgateServerHandle,
};
use warpgate_database_protocols::io::{BufExt, Decode};
use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin;
use warpgate_database_protocols::mysql::protocol::connect::{
AuthSwitchRequest, Handshake, HandshakeResponse,
};
use warpgate_database_protocols::mysql::protocol::response::{ErrPacket, OkPacket, Status};
use warpgate_database_protocols::mysql::protocol::text::Query;
use warpgate_database_protocols::mysql::protocol::Capabilities;
use crate::client::{ConnectionOptions, MySqlClient};
use crate::error::MySqlError;
use crate::stream::MySqlStream;
pub struct MySqlSession {
stream: MySqlStream<tokio_rustls::server::TlsStream<TcpStream>>,
capabilities: Capabilities,
challenge: [u8; 20],
username: Option<String>,
database: Option<String>,
tls_config: Arc<ServerConfig>,
server_handle: Arc<Mutex<WarpgateServerHandle>>,
id: Uuid,
services: Services,
}
impl MySqlSession {
pub async fn new(
server_handle: Arc<Mutex<WarpgateServerHandle>>,
services: Services,
stream: TcpStream,
tls_config: ServerConfig,
) -> Self {
let id = server_handle.lock().await.id();
Self {
services,
stream: MySqlStream::new(stream),
capabilities: Capabilities::PROTOCOL_41
| Capabilities::PLUGIN_AUTH
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::NO_SCHEMA
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CONNECT_WITH_DB
| Capabilities::SESSION_TRACK
| Capabilities::IGNORE_SPACE
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::DEPRECATE_EOF
| Capabilities::SECURE_CONNECTION
| Capabilities::SSL,
challenge: get_crypto_rng().gen(),
tls_config: Arc::new(tls_config),
username: None,
database: None,
server_handle,
id,
}
}
pub fn make_logging_span(&self) -> tracing::Span {
match self.username {
Some(ref username) => info_span!("MySQL", session=%self.id, session_username=%username),
None => info_span!("MySQL", session=%self.id),
}
}
pub async fn run(mut self) -> Result<(), MySqlError> {
let mut challenge_1 = BytesMut::from(&self.challenge[..]);
let challenge_2 = challenge_1.split_off(8);
let challenge_chain = challenge_1.freeze().chain(challenge_2.freeze());
let handshake = Handshake {
protocol_version: 10,
server_version: "8.0.0-Warpgate".to_owned(),
connection_id: 1,
auth_plugin_data: challenge_chain,
server_capabilities: self.capabilities,
server_default_collation: 45,
status: Status::empty(),
auth_plugin: Some(AuthPlugin::MySqlNativePassword),
};
self.stream.push(&handshake, ())?;
self.stream.flush().await?;
let resp = loop {
let Some(payload) = self.stream.recv().await? else {
return Err(MySqlError::Eof);
};
let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities)
.map_err(MySqlError::decode)?;
trace!(?resp, "Handshake response");
info!(capabilities=?self.capabilities, username=%resp.username, "User handshake");
if self.capabilities.contains(Capabilities::SSL) {
if self.stream.is_tls() {
break resp;
}
self.stream = self.stream.upgrade(self.tls_config.clone()).await?;
continue;
} else {
self.send_error(1002, "Warpgate requires TLS - please enable it in your client: add `--ssl` on the CLI or add `?sslMode=PREFERRED` to your database URI").await?;
return Err(MySqlError::TlsNotSupportedByClient);
}
};
if resp.auth_plugin == Some(AuthPlugin::MySqlClearPassword) {
if let Some(mut response) = resp.auth_response.clone() {
let password = Secret::new(response.get_str_nul()?);
return self.run_authorization(resp, password).await;
}
}
let req = AuthSwitchRequest {
plugin: AuthPlugin::MySqlClearPassword,
data: Bytes::new(),
};
self.stream.push(&req, ())?;
// self.push(&RawBytes::<
self.stream.flush().await?;
let Some(response) = &self.stream.recv().await? else {
return Err(MySqlError::Eof);
};
let password = Secret::new(response.clone().get_str_nul()?);
return self.run_authorization(resp, password).await;
}
async fn send_error(&mut self, code: u16, message: &str) -> Result<(), MySqlError> {
self.stream.push(
&ErrPacket {
error_code: code,
error_message: message.to_owned(),
sql_state: None,
},
(),
)?;
self.stream.flush().await?;
Ok(())
}
pub async fn run_authorization(
mut self,
handshake: HandshakeResponse,
password: Secret<String>,
) -> Result<(), MySqlError> {
let selector: AuthSelector = (&handshake.username).into();
async fn fail(this: &mut MySqlSession) -> Result<(), MySqlError> {
this.stream.push(
&ErrPacket {
error_code: 1,
error_message: "Warpgate access denied".to_owned(),
sql_state: None,
},
(),
)?;
this.stream.flush().await?;
Ok(())
}
let credentials = vec![AuthCredential::Password(password)];
match selector {
AuthSelector::User {
username,
target_name,
} => {
let user_auth_result: AuthResult = {
self.services
.config_provider
.lock()
.await
.authorize(&username, &credentials, crate::common::PROTOCOL_NAME)
.await
.map_err(MySqlError::other)?
};
match user_auth_result {
AuthResult::Accepted { username } => {
let target_auth_result = {
self.services
.config_provider
.lock()
.await
.authorize_target(&username, &target_name)
.await
.map_err(MySqlError::other)?
};
if !target_auth_result {
warn!(
"Target {} not authorized for user {}",
target_name, username
);
return fail(&mut self).await;
}
return self.run_authorized(handshake, username, target_name).await;
}
AuthResult::Rejected | AuthResult::OtpNeeded => {
return fail(&mut self).await;
}
}
}
AuthSelector::Ticket { secret } => {
match authorize_ticket(&self.services.db, &secret)
.await
.map_err(MySqlError::other)?
{
Some(ticket) => {
info!("Authorized for {} with a ticket", ticket.target);
self.services
.config_provider
.lock()
.await
.consume_ticket(&ticket.id)
.await
.map_err(MySqlError::other)?;
return self
.run_authorized(handshake, ticket.username, ticket.target)
.await;
}
_ => return fail(&mut self).await,
}
}
}
}
async fn run_authorized(
mut self,
handshake: HandshakeResponse,
username: String,
target_name: String,
) -> Result<(), MySqlError> {
self.stream.push(
&OkPacket {
affected_rows: 0,
last_insert_id: 0,
status: Status::empty(),
warnings: 0,
},
(),
)?;
self.stream.flush().await?;
info!(%username, "Authenticated");
let target = {
self.services
.config
.lock()
.await
.store
.targets
.iter()
.filter_map(|t| match t.options {
TargetOptions::MySql(ref options) => Some((t, options)),
_ => None,
})
.find(|(t, _)| t.name == target_name)
.map(|(t, opt)| (t.clone(), opt.clone()))
};
let Some((target, mysql_options)) = target else {
warn!("Selected target not found");
self.stream.push(
&ErrPacket {
error_code: 1,
error_message: "Warpgate access denied".to_owned(),
sql_state: None,
},
(),
)?;
self.stream.flush().await?;
return Ok(());
};
{
let handle = self.server_handle.lock().await;
handle.set_username(username).await?;
handle.set_target(&target).await?;
}
let span = self.make_logging_span();
return self
.run_authorized_inner(handshake, mysql_options)
.instrument(span)
.await;
}
async fn run_authorized_inner(
mut self,
handshake: HandshakeResponse,
options: TargetMySqlOptions,
) -> Result<(), MySqlError> {
self.database = handshake.database.clone();
self.username = Some(handshake.username);
if let Some(ref database) = handshake.database {
info!("Selected database: {database}");
}
let mut client = match MySqlClient::connect(
&options,
ConnectionOptions {
collation: handshake.collation,
database: handshake.database,
max_packet_size: handshake.max_packet_size,
capabilities: self.capabilities,
},
)
.await
{
Err(error) => {
error!(%error, "Target connection failed");
self.send_error(1045, "Access denied").await?;
Err(error)
}
x => x,
}?;
loop {
self.stream.reset_sequence_id();
client.stream.reset_sequence_id();
let Some(payload) = self.stream.recv().await? else {
break;
};
trace!(?payload, "server got packet");
let com = payload.get(0);
// COM_QUERY
if com == Some(&0x03) {
let query = Query::decode(payload)?;
info!(query=%query.0, "SQL");
client.stream.push(&query, ())?;
client.stream.flush().await?;
let mut eof_ctr = 0;
loop {
let Some(response) = client.stream.recv().await? else {
return Err(MySqlError::Eof);
};
trace!(?response, "client got packet");
self.stream.push(&&response[..], ())?;
self.stream.flush().await?;
if let Some(com) = response.get(0) {
if com == &0xfe {
if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
break;
}
eof_ctr += 1;
if eof_ctr == 2 {
// todo check multiple results
break;
}
}
if com == &0 || com == &0xff {
break;
}
}
}
// COM_QUIT
} else if com == Some(&0x01) {
break;
// COM_INIT_DB
} else if com == Some(&0x02) {
let mut buf = payload.clone();
buf.advance(1);
let db = buf.get_str(buf.len())?;
self.database = Some(db.clone());
info!("Selected database: {db}");
client.stream.push(&&payload[..], ())?;
client.stream.flush().await?;
self.passthrough_until_result(&mut client).await?;
// COM_FIELD_LIST, COM_PING, COM_RESET_CONNECTION
} else if com == Some(&0x04) || com == Some(&0x0e) || com == Some(&0x1f) {
client.stream.push(&&payload[..], ())?;
client.stream.flush().await?;
self.passthrough_until_result(&mut client).await?;
} else if let Some(com) = com {
warn!("Unknown packet type {com}");
self.send_error(1047, "Not implemented").await?;
} else {
break;
}
}
Ok(())
}
async fn passthrough_until_result(
&mut self,
client: &mut MySqlClient,
) -> Result<(), MySqlError> {
loop {
let Some(response) = client.stream.recv().await? else{
return Err(MySqlError::Eof);
};
trace!(?response, "client got packet");
self.stream.push(&&response[..], ())?;
self.stream.flush().await?;
if let Some(com) = response.get(0) {
if com == &0 || com == &0xff || com == &0xfe {
break;
}
}
}
Ok(())
}
}

View file

@ -0,0 +1,19 @@
use tokio::sync::mpsc;
use warpgate_common::SessionHandle;
pub struct MySqlSessionHandle {
abort_tx: mpsc::UnboundedSender<()>,
}
impl MySqlSessionHandle {
pub fn new() -> (Self, mpsc::UnboundedReceiver<()>) {
let (abort_tx, abort_rx) = mpsc::unbounded_channel();
(MySqlSessionHandle { abort_tx }, abort_rx)
}
}
impl SessionHandle for MySqlSessionHandle {
fn close(&mut self) {
let _ = self.abort_tx.send(());
}
}

View file

@ -0,0 +1,100 @@
use bytes::{Bytes, BytesMut};
use mysql_common::proto::codec::error::PacketCodecError;
use mysql_common::proto::codec::PacketCodec;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::*;
use warpgate_database_protocols::io::Encode;
use crate::tls::{MaybeTlsStream, MaybeTlsStreamError, UpgradableStream};
#[derive(thiserror::Error, Debug)]
pub enum MySqlStreamError {
#[error("packet codec error: {0}")]
Codec(#[from] PacketCodecError),
#[error("I/O: {0}")]
Io(#[from] std::io::Error),
}
pub struct MySqlStream<TS>
where
TcpStream: UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
stream: MaybeTlsStream<TcpStream, TS>,
codec: PacketCodec,
inbound_buffer: BytesMut,
outbound_buffer: BytesMut,
}
impl<TS> MySqlStream<TS>
where
TcpStream: UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(stream: TcpStream) -> Self {
Self {
stream: MaybeTlsStream::new(stream),
codec: PacketCodec::default(),
inbound_buffer: BytesMut::new(),
outbound_buffer: BytesMut::new(),
}
}
pub fn push<'a, C, P: Encode<'a, C>>(
&mut self,
packet: &'a P,
context: C,
) -> Result<(), MySqlStreamError> {
let mut buf = vec![];
packet.encode_with(&mut buf, context);
self.codec.encode(&mut &*buf, &mut self.outbound_buffer)?;
Ok(())
}
pub async fn flush(&mut self) -> std::io::Result<()> {
trace!(outbound_buffer=?self.outbound_buffer, "sending");
self.stream.write_all(&self.outbound_buffer[..]).await?;
self.outbound_buffer = BytesMut::new();
self.stream.flush().await?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Option<Bytes>, MySqlStreamError> {
let mut payload = BytesMut::new();
loop {
{
let got_full_packet = self.codec.decode(&mut self.inbound_buffer, &mut payload)?;
if got_full_packet {
trace!(?payload, "received");
return Ok(Some(payload.freeze()));
}
}
let read_bytes = self.stream.read_buf(&mut self.inbound_buffer).await?;
if read_bytes == 0 {
return Ok(None);
}
trace!(inbound_buffer=?self.inbound_buffer, "received chunk");
}
}
pub fn reset_sequence_id(&mut self) {
self.codec.reset_seq_id();
}
pub async fn upgrade(
mut self,
config: <TcpStream as UpgradableStream<TS>>::UpgradeConfig,
) -> Result<Self, MaybeTlsStreamError> {
self.stream = self.stream.upgrade(config).await?;
Ok(self)
}
pub fn is_tls(&self) -> bool {
match self.stream {
MaybeTlsStream::Raw(_) => false,
MaybeTlsStream::Tls(_) => true,
MaybeTlsStream::Upgrading => false,
}
}
}

View file

@ -0,0 +1,155 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use async_trait::async_trait;
use rustls::{ClientConfig, ServerConfig, ServerName};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::*;
#[derive(thiserror::Error, Debug)]
pub enum MaybeTlsStreamError {
#[error("stream is already upgraded")]
AlreadyUpgraded,
#[error("I/O: {0}")]
Io(#[from] std::io::Error),
}
#[async_trait]
pub trait UpgradableStream<T>
where
Self: Sized,
T: AsyncRead + AsyncWrite + Unpin,
{
type UpgradeConfig;
async fn upgrade(self, config: Self::UpgradeConfig) -> Result<T, MaybeTlsStreamError>;
}
pub enum MaybeTlsStream<S, TS>
where
S: AsyncRead + AsyncWrite + Unpin + UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
Tls(TS),
Raw(S),
Upgrading,
}
impl<S, TS> MaybeTlsStream<S, TS>
where
S: AsyncRead + AsyncWrite + Unpin + UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(stream: S) -> Self {
Self::Raw(stream)
}
}
impl<S, TS> MaybeTlsStream<S, TS>
where
S: AsyncRead + AsyncWrite + Unpin + UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
pub async fn upgrade(
mut self,
tls_config: S::UpgradeConfig,
) -> Result<Self, MaybeTlsStreamError> {
if let Self::Raw(stream) = std::mem::replace(&mut self, Self::Upgrading) {
let stream = stream.upgrade(tls_config).await?;
Ok(MaybeTlsStream::Tls(stream))
} else {
Err(MaybeTlsStreamError::AlreadyUpgraded)
}
}
}
impl<S, TS> AsyncRead for MaybeTlsStream<S, TS>
where
S: AsyncRead + AsyncWrite + Unpin + UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
_ => unreachable!(),
}
}
}
impl<S, TS> AsyncWrite for MaybeTlsStream<S, TS>
where
S: AsyncRead + AsyncWrite + Unpin + UpgradableStream<TS>,
TS: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
match self.get_mut() {
MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
_ => unreachable!(),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_flush(cx),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_flush(cx),
_ => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_shutdown(cx),
MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
_ => unreachable!(),
}
}
}
#[async_trait]
impl<S> UpgradableStream<tokio_rustls::client::TlsStream<S>> for S
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
type UpgradeConfig = (ServerName, Arc<ClientConfig>);
async fn upgrade(
mut self,
config: Self::UpgradeConfig,
) -> Result<tokio_rustls::client::TlsStream<S>, MaybeTlsStreamError> {
let (domain, tls_config) = config;
let connector = tokio_rustls::TlsConnector::from(tls_config);
Ok(connector.connect(domain, self).await?)
}
}
#[async_trait]
impl<S> UpgradableStream<tokio_rustls::server::TlsStream<S>> for S
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
type UpgradeConfig = Arc<ServerConfig>;
async fn upgrade(
mut self,
tls_config: Self::UpgradeConfig,
) -> Result<tokio_rustls::server::TlsStream<S>, MaybeTlsStreamError> {
let acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
Ok(acceptor.accept(self).await?)
}
}

View file

@ -0,0 +1,5 @@
mod maybe_tls_stream;
mod rustls_helpers;
pub use maybe_tls_stream::{MaybeTlsStream, MaybeTlsStreamError, UpgradableStream};
pub use rustls_helpers::{configure_tls_connector, FromCertificateAndKey, RustlsSetupError};

View file

@ -0,0 +1,174 @@
use std::io::Cursor;
use std::sync::Arc;
use std::time::SystemTime;
use rustls::client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier};
use rustls::server::{ClientHello, NoClientAuth, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::{
Certificate, ClientConfig, Error as TlsError, OwnedTrustAnchor, PrivateKey, RootCertStore,
ServerConfig, ServerName,
};
#[derive(thiserror::Error, Debug)]
pub enum RustlsSetupError {
#[error("rustls")]
Rustls(#[from] rustls::Error),
#[error("sign")]
Sign(#[from] rustls::sign::SignError),
#[error("no private keys in key file")]
NoKeys,
#[error("I/O")]
Io(#[from] std::io::Error),
#[error("PKI")]
Pki(#[from] webpki::Error),
}
pub trait FromCertificateAndKey<E>
where
Self: Sized,
{
fn try_from_certificate_and_key(cert: Vec<u8>, key: Vec<u8>) -> Result<Self, E>;
}
impl FromCertificateAndKey<RustlsSetupError> for rustls::ServerConfig {
fn try_from_certificate_and_key(
cert: Vec<u8>,
key_bytes: Vec<u8>,
) -> Result<Self, RustlsSetupError> {
let certificates = rustls_pemfile::certs(&mut &cert[..]).map(|mut certs| {
certs
.drain(..)
.map(Certificate)
.collect::<Vec<Certificate>>()
})?;
let mut key = rustls_pemfile::pkcs8_private_keys(&mut key_bytes.as_slice())?
.drain(..)
.next()
.map(PrivateKey);
if key.is_none() {
key = rustls_pemfile::rsa_private_keys(&mut key_bytes.as_slice())?
.drain(..)
.next()
.map(PrivateKey);
}
let key = key.ok_or(RustlsSetupError::NoKeys)?;
let key = rustls::sign::any_supported_type(&key)?;
let cert_key = Arc::new(CertifiedKey {
cert: certificates,
key,
ocsp: None,
sct_list: None,
});
Ok(ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(NoClientAuth::new())
.with_cert_resolver(Arc::new(ResolveServerCert(cert_key))))
}
}
struct ResolveServerCert(Arc<CertifiedKey>);
impl ResolvesServerCert for ResolveServerCert {
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.0.clone())
}
}
pub async fn configure_tls_connector(
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
root_cert: Option<&[u8]>,
) -> Result<ClientConfig, RustlsSetupError> {
let config = ClientConfig::builder().with_safe_defaults();
let config = if accept_invalid_certs {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
} else {
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
if let Some(data) = root_cert {
let mut cursor = Cursor::new(data);
for cert in rustls_pemfile::certs(&mut cursor)? {
cert_store.add(&rustls::Certificate(cert))?;
}
}
if accept_invalid_hostnames {
let verifier = WebPkiVerifier::new(cert_store, None);
config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_no_client_auth()
} else {
config
.with_root_certificates(cert_store)
.with_no_client_auth()
}
};
Ok(config)
}
struct DummyTlsVerifier;
impl ServerCertVerifier for DummyTlsVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}
}
pub struct NoHostnameTlsVerifier {
verifier: WebPkiVerifier,
}
impl ServerCertVerifier for NoHostnameTlsVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls::Certificate,
intermediates: &[rustls::Certificate],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
now: SystemTime,
) -> Result<ServerCertVerified, TlsError> {
match self.verifier.verify_server_cert(
end_entity,
intermediates,
server_name,
scts,
ocsp_response,
now,
) {
Err(TlsError::InvalidCertificateData(reason))
if reason.contains("CertNotValidForName") =>
{
Ok(ServerCertVerified::assertion())
}
res => res,
}
}
}

View file

@ -17,8 +17,9 @@ russh-keys = {version = "0.22.0-beta.3", features = ["openssl"]}
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
thiserror = "1.0"
time = "0.3"
tokio = {version = "1.19", features = ["tracing", "signal"]}
tokio = {version = "1.20", features = ["tracing", "signal"]}
tracing = "0.1"
uuid = {version = "1.0", features = ["v4"]}
warpgate-common = {version = "*", path = "../warpgate-common"}
warpgate-db-entities = {version = "*", path = "../warpgate-db-entities"}
zeroize="^1.5"

View file

@ -1003,7 +1003,7 @@ impl ServerSession {
}
async fn _auth_accept(&mut self, username: &str, target_name: &str) {
info!(username = username, "Authenticated");
info!(%username, "Authenticated");
let _ = self
.server_handle
@ -1031,7 +1031,7 @@ impl ServerSession {
let Some((target, ssh_options)) = target else {
self.target = TargetSelection::NotFound(target_name.to_string());
info!("Selected target not found");
warn!("Selected target not found");
return;
};

View file

@ -13,8 +13,8 @@
"postinstall": "yarn run openapi:client:gateway && yarn run openapi:client:admin",
"openapi:schema:gateway": "cargo run -p warpgate-protocol-http > src/gateway/lib/openapi-schema.json",
"openapi:schema:admin": "cargo run -p warpgate-admin > src/admin/lib/openapi-schema.json",
"openapi:client:gateway": "openapi-generator-cli generate -g typescript-fetch -i src/gateway/lib/openapi-schema.json -o src/gateway/lib/api-client -p npmName=warpgate-gateway-api-client -p useSingleRequestParameter=true && cd src/gateway/lib/api-client && npm i && yarn tsc --target esnext --module esnext && rm -rf src",
"openapi:client:admin": "openapi-generator-cli generate -g typescript-fetch -i src/admin/lib/openapi-schema.json -o src/admin/lib/api-client -p npmName=warpgate-admin-api-client -p useSingleRequestParameter=true && cd src/admin/lib/api-client && npm i && yarn tsc --target esnext --module esnext && rm -rf src",
"openapi:client:gateway": "openapi-generator-cli generate -g typescript-fetch -i src/gateway/lib/openapi-schema.json -o src/gateway/lib/api-client -p npmName=warpgate-gateway-api-client -p useSingleRequestParameter=true && cd src/gateway/lib/api-client && npm i typescript@3.5 && npm i && yarn tsc --target esnext --module esnext && rm -rf src",
"openapi:client:admin": "openapi-generator-cli generate -g typescript-fetch -i src/admin/lib/openapi-schema.json -o src/admin/lib/api-client -p npmName=warpgate-admin-api-client -p useSingleRequestParameter=true && cd src/admin/lib/api-client && npm i typescript@3.5 && npm i && yarn tsc --target esnext --module esnext && rm -rf src",
"openapi": "yarn run openapi:schema:admin && yarn run openapi:schema:gateway && yarn run openapi:client:admin && yarn run openapi:client:gateway"
},
"devDependencies": {

View file

@ -1,6 +1,7 @@
<script lang="ts">
import { api, UserSnapshot, Target, TicketAndSecret } from 'admin/lib/api'
import AsyncButton from 'common/AsyncButton.svelte'
import { TargetKind } from 'gateway/lib/api'
import { link } from 'svelte-spa-router'
import { Alert, FormGroup } from 'sveltestrap'
import { firstBy } from 'thenby'
@ -56,7 +57,7 @@ async function create () {
The secret is only shown once - you won't be able to see it again.
</Alert>
{#if selectedTarget?.options.kind === 'TargetSSHOptions'}
{#if selectedTarget?.options.kind === TargetKind.Ssh}
<h3>Connection instructions</h3>
<FormGroup floating label="SSH username">
@ -68,6 +69,18 @@ async function create () {
</FormGroup>
{/if}
{#if selectedTarget?.options.kind === TargetKind.MySql}
<h3>Connection instructions</h3>
<FormGroup floating label="MySQL username">
<input type="text" class="form-control" readonly value={'ticket-' + result.secret} />
</FormGroup>
<FormGroup floating label="Example command">
<input type="text" class="form-control" readonly value={'mysql -u ticket-' + result.secret + ' --host warpgate-host --port warpgate-port'} />
</FormGroup>
{/if}
<a
class="btn btn-secondary"
href="/tickets"

View file

@ -128,7 +128,7 @@
}
.protocol {
min-width: 3rem;
min-width: 3.5rem;
}
.meta {

View file

@ -1,5 +1,5 @@
<script lang="ts">
import { api, SessionSnapshot, Recording, TargetSSHOptions, TargetHTTPOptions } from 'admin/lib/api'
import { api, SessionSnapshot, Recording, TargetSSHOptions, TargetHTTPOptions, TargetMySqlOptions } from 'admin/lib/api'
import { timeAgo } from 'admin/lib/time'
import AsyncButton from 'common/AsyncButton.svelte'
import moment from 'moment'
@ -27,11 +27,15 @@ async function close () {
function getTargetDescription () {
if (session?.target) {
let address = '<unknown>'
if (session.target.options.kind === 'TargetSSHOptions') {
if (session.target.options.kind === 'Ssh') {
const options = session.target.options as TargetSSHOptions
address = `${options.host}:${options?.port}`
}
if (session.target.options.kind === 'TargetHTTPOptions') {
if (session.target.options.kind === 'MySql') {
const options = session.target.options as TargetMySqlOptions
address = `${options.host}:${options?.port}`
}
if (session.target.options.kind === 'Http') {
const options = session.target.options as unknown as TargetHTTPOptions
address = options.url
}

View file

@ -1,17 +1,14 @@
<script lang="ts">
import { api, Target, UserSnapshot } from 'admin/lib/api'
import { makeExampleSSHCommand, makeSSHUsername } from 'common/ssh'
import ConnectionInstructions from 'common/ConnectionInstructions.svelte'
import { TargetKind } from 'gateway/lib/api'
import { Alert, FormGroup, Modal, ModalBody, ModalHeader } from 'sveltestrap'
import { serverInfo } from 'gateway/lib/store'
import CopyButton from 'common/CopyButton.svelte'
import { makeTargetURL } from 'common/http'
let error: Error|undefined
let targets: Target[]|undefined
let selectedTarget: Target|undefined
let users: UserSnapshot[]|undefined
let selectedUser: UserSnapshot|undefined
let sshUsername = ''
async function load () {
targets = await api.getTargets()
@ -22,11 +19,6 @@ async function load () {
load().catch(e => {
error = e
})
$: sshUsername = makeSSHUsername(selectedTarget?.name, selectedUser?.username)
$: exampleCommand = makeExampleSSHCommand(selectedTarget?.name, selectedUser?.username, $serverInfo)
$: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : ''
</script>
{#if error}
@ -42,7 +34,7 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : ''
{#each targets as target}
<!-- svelte-ignore a11y-missing-attribute -->
<a class="list-group-item list-group-item-action" on:click={() => {
if (target.options.kind !== 'TargetWebAdminOptions') {
if (target.options.kind !== 'WebAdmin') {
selectedTarget = target
}
}}>
@ -50,10 +42,16 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : ''
{target.name}
</strong>
<small class="text-muted ms-auto">
{#if target.options.kind === 'TargetSSHOptions'}
{#if target.options.kind === 'Http'}
HTTP
{/if}
{#if target.options.kind === 'MySql'}
MySQL
{/if}
{#if target.options.kind === 'Ssh'}
SSH
{/if}
{#if target.options.kind === 'TargetWebAdminOptions'}
{#if target.options.kind === 'WebAdmin'}
This web admin interface
{/if}
</small>
@ -67,17 +65,21 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : ''
{selectedTarget?.name}
</div>
<div class="target-type-label">
{#if selectedTarget?.options.kind === 'TargetSSHOptions'}
{#if selectedTarget?.options.kind === 'MySql'}
MySQL target
{/if}
{#if selectedTarget?.options.kind === 'Ssh'}
SSH target
{/if}
{#if selectedTarget?.options.kind === 'TargetWebAdminOptions'}
{#if selectedTarget?.options.kind === 'WebAdmin'}
This web admin interface
{/if}
</div>
</ModalHeader>
<ModalBody>
<h3>Access instructions</h3>
{#if selectedTarget?.options.kind === 'TargetSSHOptions'}
{#if selectedTarget?.options.kind === 'Ssh' || selectedTarget?.options.kind === 'MySql'}
{#if users}
<FormGroup floating label="Select a user">
<select bind:value={selectedUser} class="form-control">
@ -89,24 +91,18 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : ''
</select>
</FormGroup>
{/if}
<FormGroup floating label="SSH username" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={sshUsername} />
<CopyButton text={sshUsername} />
</FormGroup>
<FormGroup floating label="Example command" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={exampleCommand} />
<CopyButton text={exampleCommand} />
</FormGroup>
{/if}
{#if selectedTarget?.options.kind === 'TargetHTTPOptions'}
<FormGroup floating label="Access URL" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={targetURL} />
<CopyButton text={targetURL} />
</FormGroup>
{/if}
<ConnectionInstructions
targetName={selectedTarget?.name}
username={selectedUser?.username}
targetKind={{
Ssh: TargetKind.Ssh,
WebAdmin: TargetKind.WebAdmin,
Http: TargetKind.Http,
MySql: TargetKind.MySql,
}[selectedTarget?.options.kind ?? '']}
/>
</ModalBody>
</Modal>
{/if}

View file

@ -2,7 +2,7 @@
"openapi": "3.0.0",
"info": {
"title": "Warpgate Web Admin",
"version": "0.2.5"
"version": "0.3.0"
},
"servers": [
{
@ -59,28 +59,7 @@
"content": {
"application/json": {
"schema": {
"type": "object",
"required": [
"items",
"offset",
"total"
],
"properties": {
"items": {
"type": "array",
"items": {
"$ref": "#/components/schemas/SessionSnapshot"
}
},
"offset": {
"type": "integer",
"format": "uint64"
},
"total": {
"type": "integer",
"format": "uint64"
}
}
"$ref": "#/components/schemas/PaginatedSessionSnapshot"
}
}
}
@ -509,6 +488,30 @@
}
}
},
"PaginatedSessionSnapshot": {
"type": "object",
"required": [
"items",
"offset",
"total"
],
"properties": {
"items": {
"type": "array",
"items": {
"$ref": "#/components/schemas/SessionSnapshot"
}
},
"offset": {
"type": "integer",
"format": "uint64"
},
"total": {
"type": "integer",
"format": "uint64"
}
}
},
"Recording": {
"type": "object",
"required": [
@ -668,74 +671,139 @@
}
}
},
"TargetMySqlOptions": {
"type": "object",
"required": [
"host",
"port",
"username",
"tls",
"verify_tls"
],
"properties": {
"host": {
"type": "string"
},
"port": {
"type": "integer",
"format": "uint16"
},
"username": {
"type": "string"
},
"password": {
"type": "string"
},
"tls": {
"$ref": "#/components/schemas/Tls"
},
"verify_tls": {
"type": "boolean"
}
}
},
"TargetOptions": {
"type": "object",
"anyOf": [
"oneOf": [
{
"required": [
"kind"
],
"allOf": [
{
"$ref": "#/components/schemas/TargetSSHOptions"
},
{
"type": "object",
"title": "TargetSSHOptions",
"properties": {
"kind": {
"type": "string",
"example": "TargetSSHOptions"
}
}
}
]
"$ref": "#/components/schemas/TargetOptionsTargetSSHOptions"
},
{
"required": [
"kind"
],
"allOf": [
{
"$ref": "#/components/schemas/TargetHTTPOptions"
},
{
"type": "object",
"title": "TargetHTTPOptions",
"properties": {
"kind": {
"type": "string",
"example": "TargetHTTPOptions"
}
}
}
]
"$ref": "#/components/schemas/TargetOptionsTargetHTTPOptions"
},
{
"required": [
"kind"
],
"allOf": [
{
"$ref": "#/components/schemas/TargetWebAdminOptions"
},
{
"type": "object",
"title": "TargetWebAdminOptions",
"properties": {
"kind": {
"type": "string",
"example": "TargetWebAdminOptions"
}
}
}
]
"$ref": "#/components/schemas/TargetOptionsTargetMySqlOptions"
},
{
"$ref": "#/components/schemas/TargetOptionsTargetWebAdminOptions"
}
],
"discriminator": {
"propertyName": "kind"
"propertyName": "kind",
"mapping": {
"Ssh": "#/components/schemas/TargetOptionsTargetSSHOptions",
"Http": "#/components/schemas/TargetOptionsTargetHTTPOptions",
"MySql": "#/components/schemas/TargetOptionsTargetMySqlOptions",
"WebAdmin": "#/components/schemas/TargetOptionsTargetWebAdminOptions"
}
}
},
"TargetOptionsTargetHTTPOptions": {
"allOf": [
{
"type": "object",
"required": [
"kind"
],
"properties": {
"kind": {
"type": "string",
"example": "Http"
}
}
},
{
"$ref": "#/components/schemas/TargetHTTPOptions"
}
]
},
"TargetOptionsTargetMySqlOptions": {
"allOf": [
{
"type": "object",
"required": [
"kind"
],
"properties": {
"kind": {
"type": "string",
"example": "MySql"
}
}
},
{
"$ref": "#/components/schemas/TargetMySqlOptions"
}
]
},
"TargetOptionsTargetSSHOptions": {
"allOf": [
{
"type": "object",
"required": [
"kind"
],
"properties": {
"kind": {
"type": "string",
"example": "Ssh"
}
}
},
{
"$ref": "#/components/schemas/TargetSSHOptions"
}
]
},
"TargetOptionsTargetWebAdminOptions": {
"allOf": [
{
"type": "object",
"required": [
"kind"
],
"properties": {
"kind": {
"type": "string",
"example": "WebAdmin"
}
}
},
{
"$ref": "#/components/schemas/TargetWebAdminOptions"
}
]
},
"TargetSSHOptions": {
"type": "object",
"required": [
@ -807,6 +875,29 @@
}
}
},
"Tls": {
"type": "object",
"required": [
"mode",
"verify"
],
"properties": {
"mode": {
"$ref": "#/components/schemas/TlsMode"
},
"verify": {
"type": "boolean"
}
}
},
"TlsMode": {
"type": "string",
"enum": [
"Disabled",
"Preferred",
"Required"
]
},
"UserSnapshot": {
"type": "object",
"required": [

View file

@ -0,0 +1,60 @@
<script type="ts">
import { Alert, FormGroup } from 'sveltestrap'
import { TargetKind } from 'gateway/lib/api'
import { serverInfo } from 'gateway/lib/store'
import { makeExampleSSHCommand, makeSSHUsername } from 'common/ssh'
import { makeExampleMySQLCommand, makeExampleMySQLURI, makeMySQLUsername } from 'common/mysql'
import CopyButton from 'common/CopyButton.svelte'
import { makeTargetURL } from 'common/http'
export let targetName: string|undefined
export let targetKind: TargetKind
export let username: string|undefined
$: sshUsername = makeSSHUsername(targetName, username)
$: exampleSSHCommand = makeExampleSSHCommand(targetName, username, $serverInfo)
$: mySQLUsername = makeMySQLUsername(targetName, username)
$: exampleMySQLCommand = makeExampleMySQLCommand(targetName, username, $serverInfo)
$: exampleMySQLURI = makeExampleMySQLURI(targetName, username, $serverInfo)
$: targetURL = targetName ? makeTargetURL(targetName) : ''
</script>
{#if targetKind === TargetKind.Ssh}
<FormGroup floating label="SSH username" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={sshUsername} />
<CopyButton text={sshUsername} />
</FormGroup>
<FormGroup floating label="Example command" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={exampleSSHCommand} />
<CopyButton text={exampleSSHCommand} />
</FormGroup>
{/if}
{#if targetKind === TargetKind.Http}
<FormGroup floating label="Access URL" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={targetURL} />
<CopyButton text={targetURL} />
</FormGroup>
{/if}
{#if targetKind === TargetKind.MySql}
<FormGroup floating label="MySQL username" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={mySQLUsername} />
<CopyButton text={mySQLUsername} />
</FormGroup>
<FormGroup floating label="Example command" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={exampleMySQLCommand} />
<CopyButton text={exampleMySQLCommand} />
</FormGroup>
<FormGroup floating label="Example database URL" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={exampleMySQLURI} />
<CopyButton text={exampleMySQLURI} />
</FormGroup>
<Alert color="info">
Make sure you've set your client to require TLS and allowed cleartext password authentication.
</Alert>
{/if}

View file

@ -0,0 +1,13 @@
import type { Info } from 'gateway/lib/api'
export function makeMySQLUsername (targetName?: string, username?: string): string {
return `${username ?? 'username'}#${targetName ?? 'target'}`
}
export function makeExampleMySQLCommand (targetName?: string, username?: string, serverInfo?: Info): string {
return `mysql -u ${makeMySQLUsername(targetName, username)} --host ${serverInfo?.externalHost ?? 'warpgate-host'} --port ${serverInfo?.ports.mysql ?? 'warpgate-mysql-port'} -p --ssl`
}
export function makeExampleMySQLURI (targetName?: string, username?: string, serverInfo?: Info): string {
return `mysql://${makeMySQLUsername(targetName, username)}:<password>@${serverInfo?.externalHost ?? 'warpgate-host'}:${serverInfo?.ports.mysql ?? 'warpgate-mysql-port'}?sslMode=required`
}

View file

@ -5,5 +5,5 @@ export function makeSSHUsername (targetName?: string, username?: string): string
}
export function makeExampleSSHCommand (targetName?: string, username?: string, serverInfo?: Info): string {
return `ssh ${makeSSHUsername(targetName, username)}@${serverInfo?.externalHost ?? 'warpgate-host'} -p ${serverInfo?.ports.ssh ?? 'warpgate-port'}`
return `ssh ${makeSSHUsername(targetName, username)}@${serverInfo?.externalHost ?? 'warpgate-host'} -p ${serverInfo?.ports.ssh ?? 'warpgate-ssh-port'}`
}

View file

@ -1,21 +1,16 @@
<script lang="ts">
import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
import CopyButton from 'common/CopyButton.svelte'
import { makeExampleSSHCommand, makeSSHUsername } from 'common/ssh'
import ConnectionInstructions from 'common/ConnectionInstructions.svelte'
import { api, Target, TargetKind } from 'gateway/lib/api'
import { createEventDispatcher } from 'svelte'
import Fa from 'svelte-fa'
import { FormGroup, Modal, ModalBody, ModalHeader, Spinner } from 'sveltestrap'
import { Modal, ModalBody, ModalHeader, Spinner } from 'sveltestrap'
import { serverInfo } from './lib/store'
const dispatch = createEventDispatcher()
let targets: Target[]|undefined
let selectedTarget: Target|undefined
let sshUsername: string
$: sshUsername = makeSSHUsername(selectedTarget?.name, $serverInfo?.username)
$: exampleCommand = makeExampleSSHCommand(selectedTarget?.name, $serverInfo?.username, $serverInfo)
async function init () {
targets = await api.getTargets()
@ -63,6 +58,9 @@ init()
{#if target.kind === TargetKind.Ssh}
SSH
{/if}
{#if target.kind === TargetKind.MySql}
MySQL
{/if}
</small>
{#if target.kind === TargetKind.Http || target.kind === TargetKind.WebAdmin}
<Fa icon={faArrowRight} fw />
@ -81,19 +79,12 @@ init()
</div>
</ModalHeader>
<ModalBody>
{#if selectedTarget?.kind === TargetKind.Ssh}
<h3>Connection instructions</h3>
<FormGroup floating label="SSH username" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={sshUsername} />
<CopyButton text={sshUsername} />
</FormGroup>
<FormGroup floating label="Example command" class="d-flex align-items-center">
<input type="text" class="form-control" readonly value={exampleCommand} />
<CopyButton text={exampleCommand} />
</FormGroup>
{/if}
<h3>Connection instructions</h3>
<ConnectionInstructions
targetName={selectedTarget?.name}
username={$serverInfo?.username}
targetKind={selectedTarget?.kind ?? TargetKind.Ssh}
/>
</ModalBody>
</Modal>

View file

@ -2,7 +2,7 @@
"openapi": "3.0.0",
"info": {
"title": "Warpgate HTTP proxy",
"version": "0.2.5"
"version": "0.3.0"
},
"servers": [
{
@ -153,13 +153,14 @@
},
"PortsInfo": {
"type": "object",
"required": [
"ssh"
],
"properties": {
"ssh": {
"type": "integer",
"format": "uint16"
},
"mysql": {
"type": "integer",
"format": "uint16"
}
}
},
@ -182,6 +183,7 @@
"type": "string",
"enum": [
"Http",
"MySql",
"Ssh",
"WebAdmin"
]

View file

@ -24,12 +24,13 @@ qrcode = "0.12"
rcgen = {version = "0.9", features = ["zeroize"]}
serde_yaml = "0.8.23"
time = "0.3"
tokio = {version = "1.19", features = ["tracing", "signal", "macros"]}
tokio = {version = "1.20", features = ["tracing", "signal", "macros"]}
tracing = "0.1"
tracing-subscriber = {version = "0.3", features = ["env-filter", "local-time"]}
warpgate-admin = {version = "*", path = "../warpgate-admin"}
warpgate-common = {version = "*", path = "../warpgate-common"}
warpgate-protocol-http = {version = "*", path = "../warpgate-protocol-http"}
warpgate-protocol-mysql = {version = "*", path = "../warpgate-protocol-mysql"}
warpgate-protocol-ssh = {version = "*", path = "../warpgate-protocol-ssh"}
[target.'cfg(target_os = "linux")'.dependencies]

View file

@ -1,24 +1,10 @@
use std::net::ToSocketAddrs;
use anyhow::{Context, Result};
use anyhow::Result;
use tracing::*;
use crate::config::load_config;
pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
let config = load_config(&cli.config, true)?;
config
.store
.ssh
.listen
.to_socket_addrs()
.context("Failed to parse SSH listen address")?;
config
.store
.http
.listen
.to_socket_addrs()
.context("Failed to parse admin server listen address")?;
load_config(&cli.config, true)?;
info!("No problems found");
Ok(())
}

View file

@ -1,5 +1,3 @@
use std::net::ToSocketAddrs;
use anyhow::Result;
use futures::StreamExt;
#[cfg(target_os = "linux")]
@ -9,6 +7,7 @@ use warpgate_common::db::cleanup_db;
use warpgate_common::logging::install_database_logger;
use warpgate_common::{ProtocolServer, Services};
use warpgate_protocol_http::HTTPProtocolServer;
use warpgate_protocol_mysql::MySQLProtocolServer;
use warpgate_protocol_ssh::SSHProtocolServer;
use crate::config::{load_config, watch_config};
@ -24,29 +23,27 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
let mut protocol_futures = futures::stream::FuturesUnordered::new();
protocol_futures.push(
SSHProtocolServer::new(&services).await?.run(
config
.store
.ssh
.listen
.to_socket_addrs()?
.next()
.ok_or_else(|| anyhow::anyhow!("Failed to resolve the listen address"))?,
),
);
if config.store.ssh.enable {
protocol_futures.push(
SSHProtocolServer::new(&services)
.await?
.run(*config.store.ssh.listen),
);
}
if config.store.http.enable {
protocol_futures.push(
HTTPProtocolServer::new(&services).await?.run(
config
.store
.http
.listen
.to_socket_addrs()?
.next()
.ok_or_else(|| anyhow::anyhow!("Failed to resolve the listen address"))?,
),
HTTPProtocolServer::new(&services)
.await?
.run(*config.store.http.listen),
);
}
if config.store.mysql.enable {
protocol_futures.push(
MySQLProtocolServer::new(&services)
.await?
.run(*config.store.mysql.listen),
);
}
@ -68,10 +65,10 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
if console::user_attended() {
info!("--------------------------------------------");
info!("Warpgate is now running.");
info!("Accepting SSH connections on {}", config.store.ssh.listen);
info!("Accepting SSH connections on {:?}", config.store.ssh.listen);
if config.store.http.enable {
info!(
"Accepting HTTP connections on https://{}",
"Accepting HTTP connections on https://{:?}",
config.store.http.listen
);
}
@ -100,6 +97,10 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
drop(config);
if protocol_futures.is_empty() {
anyhow::bail!("No protocols are enabled in the config file, exiting");
}
tokio::spawn(watch_config(cli.config.clone(), services.config.clone()));
loop {

View file

@ -1,5 +1,6 @@
use std::fs::{create_dir_all, File};
use std::io::Write;
use std::net::ToSocketAddrs;
use std::path::{Path, PathBuf};
use anyhow::Result;
@ -9,12 +10,33 @@ use tracing::*;
use warpgate_common::helpers::fs::{secure_directory, secure_file};
use warpgate_common::helpers::hash::hash_password;
use warpgate_common::{
HTTPConfig, Role, SSHConfig, Secret, Services, Target, TargetOptions, TargetWebAdminOptions,
User, UserAuthCredential, WarpgateConfigStore,
HTTPConfig, ListenEndpoint, MySQLConfig, Role, SSHConfig, Secret, Services, Target,
TargetOptions, TargetWebAdminOptions, User, UserAuthCredential, WarpgateConfigStore,
};
use crate::config::load_config;
fn prompt_endpoint(prompt: &str, default: ListenEndpoint) -> ListenEndpoint {
loop {
let v = dialoguer::Input::with_theme(&ColorfulTheme::default())
.default(format!("{:?}", default))
.with_prompt(prompt)
.interact_text()
.and_then(|v| v.to_socket_addrs());
match v {
Ok(mut addr) => match addr.next() {
Some(addr) => return ListenEndpoint(addr),
None => {
error!("No endpoints resolved");
}
},
Err(err) => {
error!("Failed to resolve this endpoint: {err}")
}
}
}
}
pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
let version = env!("CARGO_PKG_VERSION");
info!("Welcome to Warpgate {version}");
@ -57,10 +79,12 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
// ---
info!(
"* Paths can be either absolute or relative to {}.",
config_dir.canonicalize()?.display()
);
if !is_docker {
info!(
"* Paths can be either absolute or relative to {}.",
config_dir.canonicalize()?.display()
);
}
// ---
@ -89,17 +113,41 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
// ---
if !is_docker {
store.ssh.listen = dialoguer::Input::with_theme(&theme)
.default(SSHConfig::default().listen)
.with_prompt("Endpoint to listen for SSH connections on")
.interact_text()?;
store.http.listen = prompt_endpoint(
"Endpoint to listen for HTTP connections on",
HTTPConfig::default().listen,
);
info!("You will now choose specific protocol listeners to be enabled.");
info!("");
info!("NB: Nothing will be exposed by default -");
info!(" you'll set target hosts in the config file later.");
store.ssh.enable = dialoguer::Confirm::with_theme(&theme)
.default(true)
.with_prompt("Accept SSH connections?")
.interact()?;
if store.ssh.enable {
store.ssh.listen = prompt_endpoint(
"Endpoint to listen for SSH connections on",
SSHConfig::default().listen,
);
}
// ---
store.http.listen = dialoguer::Input::with_theme(&theme)
.default(HTTPConfig::default().listen)
.with_prompt("Endpoint to listen for HTTP connections on")
.interact_text()?;
store.mysql.enable = dialoguer::Confirm::with_theme(&theme)
.default(true)
.with_prompt("Accept MySQL connections?")
.interact()?;
if store.mysql.enable {
store.mysql.listen = prompt_endpoint(
"Endpoint to listen for MySQL connections on",
MySQLConfig::default().listen,
);
}
}
if store.http.enable {
@ -120,6 +168,9 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
.to_string_lossy()
.to_string();
store.mysql.certificate = store.http.certificate.clone();
store.mysql.key = store.http.key.clone();
// ---
store.ssh.keys = PathBuf::from(&data_path)
@ -168,7 +219,7 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
warpgate_protocol_ssh::generate_client_keys(&config)?;
{
info!("Generating HTTPS certificate");
info!("Generating a TLS certificate");
let cert = generate_simple_self_signed(vec![
"warpgate.local".to_string(),
"localhost".to_string(),