2019-12-03 05:46:54 +08:00
|
|
|
import base64
|
|
|
|
import string
|
|
|
|
import struct
|
2022-01-24 12:07:52 +08:00
|
|
|
import typing as _t
|
2019-12-03 05:46:54 +08:00
|
|
|
|
|
|
|
from .exc import BadData
|
|
|
|
|
2022-01-24 12:07:52 +08:00
|
|
|
_t_str_bytes = _t.Union[str, bytes]
|
2019-12-03 05:46:54 +08:00
|
|
|
|
2022-01-24 12:07:52 +08:00
|
|
|
|
|
|
|
def want_bytes(
|
|
|
|
s: _t_str_bytes, encoding: str = "utf-8", errors: str = "strict"
|
|
|
|
) -> bytes:
|
|
|
|
if isinstance(s, str):
|
2019-12-03 05:46:54 +08:00
|
|
|
s = s.encode(encoding, errors)
|
2022-01-24 12:07:52 +08:00
|
|
|
|
2019-12-03 05:46:54 +08:00
|
|
|
return s
|
|
|
|
|
|
|
|
|
2022-01-24 12:07:52 +08:00
|
|
|
def base64_encode(string: _t_str_bytes) -> bytes:
|
2019-12-03 05:46:54 +08:00
|
|
|
"""Base64 encode a string of bytes or text. The resulting bytes are
|
|
|
|
safe to use in URLs.
|
|
|
|
"""
|
|
|
|
string = want_bytes(string)
|
|
|
|
return base64.urlsafe_b64encode(string).rstrip(b"=")
|
|
|
|
|
|
|
|
|
2022-01-24 12:07:52 +08:00
|
|
|
def base64_decode(string: _t_str_bytes) -> bytes:
|
2019-12-03 05:46:54 +08:00
|
|
|
"""Base64 decode a URL-safe string of bytes or text. The result is
|
|
|
|
bytes.
|
|
|
|
"""
|
|
|
|
string = want_bytes(string, encoding="ascii", errors="ignore")
|
|
|
|
string += b"=" * (-len(string) % 4)
|
|
|
|
|
|
|
|
try:
|
|
|
|
return base64.urlsafe_b64decode(string)
|
2022-11-08 02:06:49 +08:00
|
|
|
except (TypeError, ValueError) as e:
|
|
|
|
raise BadData("Invalid base64-encoded data") from e
|
2019-12-03 05:46:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
# The alphabet used by base64.urlsafe_*
|
2022-01-24 12:07:52 +08:00
|
|
|
_base64_alphabet = f"{string.ascii_letters}{string.digits}-_=".encode("ascii")
|
2019-12-03 05:46:54 +08:00
|
|
|
|
|
|
|
_int64_struct = struct.Struct(">Q")
|
|
|
|
_int_to_bytes = _int64_struct.pack
|
2022-01-24 12:07:52 +08:00
|
|
|
_bytes_to_int = _t.cast("_t.Callable[[bytes], _t.Tuple[int]]", _int64_struct.unpack)
|
2019-12-03 05:46:54 +08:00
|
|
|
|
|
|
|
|
2022-01-24 12:07:52 +08:00
|
|
|
def int_to_bytes(num: int) -> bytes:
|
2019-12-03 05:46:54 +08:00
|
|
|
return _int_to_bytes(num).lstrip(b"\x00")
|
|
|
|
|
|
|
|
|
2022-01-24 12:07:52 +08:00
|
|
|
def bytes_to_int(bytestr: bytes) -> int:
|
2019-12-03 05:46:54 +08:00
|
|
|
return _bytes_to_int(bytestr.rjust(8, b"\x00"))[0]
|