bazarr/libs/dns/query.py

1579 lines
54 KiB
Python
Raw Normal View History

# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
2018-11-01 00:08:29 +08:00
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Talk to a DNS server."""
import base64
import contextlib
import enum
2018-11-01 00:08:29 +08:00
import errno
import os
import os.path
import selectors
2018-11-01 00:08:29 +08:00
import socket
import struct
import time
from typing import Any, Dict, Optional, Tuple, Union
2018-11-01 00:08:29 +08:00
import dns._features
2018-11-01 00:08:29 +08:00
import dns.exception
import dns.inet
import dns.message
import dns.name
import dns.quic
import dns.rcode
2018-11-01 00:08:29 +08:00
import dns.rdataclass
import dns.rdatatype
import dns.serial
import dns.transaction
import dns.tsig
import dns.xfr
2018-11-01 00:08:29 +08:00
def _remaining(expiration):
if expiration is None:
return None
timeout = expiration - time.time()
if timeout <= 0.0:
raise dns.exception.Timeout
return timeout
def _expiration_for_this_attempt(timeout, expiration):
if expiration is None:
return None
return min(time.time() + timeout, expiration)
_have_httpx = dns._features.have("doh")
if _have_httpx:
import httpcore._backends.sync
import httpx
_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream
class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = dns.inet.low_level_address_tuple(
(local_address, self._local_port), af
)
else:
source = None
sock = _make_socket(af, socket.SOCK_STREAM, source)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
try:
_connect(
sock,
dns.inet.low_level_address_tuple((address, port), af),
attempt_expiration,
)
return _CoreSyncStream(sock)
except Exception:
pass
raise httpcore.ConnectError
def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
class _HTTPTransport(httpx.HTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
resolver = dns.resolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
else:
class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
have_doh = _have_httpx
try:
import ssl
except ImportError: # pragma: no cover
class ssl: # type: ignore
CERT_NONE = 0
class WantReadException(Exception):
pass
class WantWriteException(Exception):
pass
class SSLContext:
pass
class SSLSocket:
pass
@classmethod
def create_default_context(cls, *args, **kwargs):
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
2018-11-01 00:08:29 +08:00
# Function used to create a socket. Can be overridden if needed in special
# situations.
socket_factory = socket.socket
2018-11-01 00:08:29 +08:00
class UnexpectedSource(dns.exception.DNSException):
"""A DNS query response came from an unexpected address or port."""
class BadResponse(dns.exception.FormError):
"""A DNS query response does not respond to the question asked."""
class NoDOH(dns.exception.DNSException):
"""DNS over HTTPS (DOH) was requested but the httpx module is not
available."""
class NoDOQ(dns.exception.DNSException):
"""DNS over QUIC (DOQ) was requested but the aioquic module is not
available."""
2018-11-01 00:08:29 +08:00
# for backwards compatibility
TransferError = dns.xfr.TransferError
2018-11-01 00:08:29 +08:00
def _compute_times(timeout):
now = time.time()
2018-11-01 00:08:29 +08:00
if timeout is None:
return (now, None)
2018-11-01 00:08:29 +08:00
else:
return (now, now + timeout)
2018-11-01 00:08:29 +08:00
def _wait_for(fd, readable, writable, _, expiration):
# Use the selected selector class to wait for any of the specified
# events. An "expiration" absolute time is converted into a relative
# timeout.
#
# The unused parameter is 'error', which is always set when
# selecting for read or write, and we have no error-only selects.
2018-11-01 00:08:29 +08:00
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
return True
sel = _selector_class()
events = 0
if readable:
events |= selectors.EVENT_READ
if writable:
events |= selectors.EVENT_WRITE
if events:
sel.register(fd, events)
if expiration is None:
timeout = None
else:
timeout = expiration - time.time()
if timeout <= 0.0:
raise dns.exception.Timeout
if not sel.select(timeout):
raise dns.exception.Timeout
2018-11-01 00:08:29 +08:00
def _set_selector_class(selector_class):
# Internal API. Do not use.
2018-11-01 00:08:29 +08:00
global _selector_class
2018-11-01 00:08:29 +08:00
_selector_class = selector_class
if hasattr(selectors, "PollSelector"):
2018-11-01 00:08:29 +08:00
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
#
# We ignore typing here as we can't say _selector_class is Any
# on python < 3.8 due to a bug.
_selector_class = selectors.PollSelector # type: ignore
2018-11-01 00:08:29 +08:00
else:
_selector_class = selectors.SelectSelector # type: ignore
2018-11-01 00:08:29 +08:00
def _wait_for_readable(s, expiration):
_wait_for(s, True, False, True, expiration)
def _wait_for_writable(s, expiration):
_wait_for(s, False, True, True, expiration)
def _addresses_equal(af, a1, a2):
# Convert the first value of the tuple, which is a textual format
# address into binary form, so that we are not confused by different
# textual representations of the same address
try:
n1 = dns.inet.inet_pton(af, a1[0])
n2 = dns.inet.inet_pton(af, a2[0])
except dns.exception.SyntaxError:
return False
2018-11-01 00:08:29 +08:00
return n1 == n2 and a1[1:] == a2[1:]
def _matches_destination(af, from_address, destination, ignore_unexpected):
# Check that from_address is appropriate for a response to a query
# sent to destination.
if not destination:
return True
if _addresses_equal(af, from_address, destination) or (
dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
):
return True
elif ignore_unexpected:
return False
raise UnexpectedSource(
f"got a response from {from_address} instead of " f"{destination}"
)
def _destination_and_source(
where, port, source, source_port, where_must_be_address=True
):
2018-11-01 00:08:29 +08:00
# Apply defaults and compute destination and source tuples
# suitable for use in connect(), sendto(), or bind().
af = None
destination = None
try:
af = dns.inet.af_for_address(where)
destination = where
except Exception:
if where_must_be_address:
raise
# URLs are ok so eat the exception
if source:
saf = dns.inet.af_for_address(source)
if af:
# We know the destination af, so source had better agree!
if saf != af:
raise ValueError(
"different address families for source and destination"
)
else:
# We didn't know the destination af, but we know the source,
# so that's our af.
af = saf
if source_port and not source:
# Caller has specified a source_port but not an address, so we
# need to return a source, and we need to use the appropriate
# wildcard address as the address.
try:
source = dns.inet.any_for_af(af)
except Exception:
# we catch this and raise ValueError for backwards compatibility
raise ValueError("source_port specified but address family is unknown")
# Convert high-level (address, port) tuples into low-level address
# tuples.
if destination:
destination = dns.inet.low_level_address_tuple((destination, port), af)
if source:
source = dns.inet.low_level_address_tuple((source, source_port), af)
2018-11-01 00:08:29 +08:00
return (af, destination, source)
def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
s = socket_factory(af, type)
try:
s.setblocking(False)
if source is not None:
s.bind(source)
if ssl_context:
# LGTM gets a false positive here, as our default context is OK
return ssl_context.wrap_socket(
s,
do_handshake_on_connect=False, # lgtm[py/insecure-protocol]
server_hostname=server_hostname,
)
else:
return s
except Exception:
s.close()
raise
2018-11-01 00:08:29 +08:00
def https(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 443,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
session: Optional[Any] = None,
path: str = "/dns-query",
post: bool = True,
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
*where*, a ``str``, the nameserver IP address or the full URL. If an IP address is
given, the URL will be constructed using the following schema:
https://<IP-address>:<port>/<path>.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
times out. If ``None``, the default, wait forever.
2018-11-01 00:08:29 +08:00
*port*, a ``int``, the port to send the query to. The default is 443.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message. The default is
0.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
*session*, an ``httpx.Client``. If provided, the client session to use to send the
queries.
*path*, a ``str``. If *where* is an IP address, then *path* will be used to
construct the URL to send the DNS query to.
*post*, a ``bool``. If ``True``, the default, POST method will be used.
*bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames in URLs. If not specified, a new resolver with a default
configuration will be used; note this is *not* the default resolver as that resolver
might have been configured to use DoH causing a chicken-and-egg problem. This
parameter only has an effect if the HTTP library is httpx.
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
Returns a ``dns.message.Message``.
2018-11-01 00:08:29 +08:00
"""
if not have_doh:
raise NoDOH # pragma: no cover
if session and not isinstance(session, httpx.Client):
raise ValueError("session parameter must be an httpx.Client")
2018-11-01 00:08:29 +08:00
wire = q.to_wire()
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
transport = None
headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
# set source port and source address
if the_source is None:
local_address = None
local_port = 0
else:
local_address = the_source[0]
local_port = the_source[1]
transport = _HTTPTransport(
local_address=local_address,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
headers.update(
{
"content-type": "application/dns-message",
"content-length": str(len(wire)),
}
)
response = session.post(url, headers=headers, content=wire, timeout=timeout)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = session.get(
url, headers=headers, timeout=timeout, params={"dns": twire}
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
"{} responded with status code {}"
"\nResponse body: {}".format(where, response.status_code, response.content)
)
r = dns.message.from_wire(
response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = response.elapsed.total_seconds()
2018-11-01 00:08:29 +08:00
if not q.is_response(r):
raise BadResponse
return r
def _udp_recv(sock, max_size, expiration):
"""Reads a datagram from the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
while True:
try:
return sock.recvfrom(max_size)
except BlockingIOError:
_wait_for_readable(sock, expiration)
def _udp_send(sock, data, destination, expiration):
"""Sends the specified datagram to destination over the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
while True:
try:
if destination:
return sock.sendto(data, destination)
else:
return sock.send(data)
except BlockingIOError: # pragma: no cover
_wait_for_writable(sock, expiration)
def send_udp(
sock: Any,
what: Union[dns.message.Message, bytes],
destination: Any,
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``socket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
*destination*, a destination tuple appropriate for the address family
of the socket, specifying where to send the query.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
if isinstance(what, dns.message.Message):
what = what.to_wire()
sent_time = time.time()
n = _udp_send(sock, what, destination, expiration)
return (n, sent_time)
def receive_udp(
sock: Any,
destination: Optional[Any] = None,
expiration: Optional[float] = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: Optional[dns.message.Message] = None,
) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``socket``.
*destination*, a destination tuple appropriate for the address family
of the socket, specifying where the message is expected to arrive from.
When receiving a response, this would be where the associated query was
sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
unexpected sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*keyring*, a ``dict``, the keyring to use for TSIG.
*request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
tuple of the received message and the received time.
If *destination* is ``None``, returns a
``(dns.message.Message, float, tuple)``
tuple of the received message, the received time, and the address where
the message arrived from.
*ignore_errors*, a ``bool``. If various format errors or response
mismatches occur, ignore them and keep listening for a valid response.
The default is ``False``.
*query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
*ignore_errors* is ``True``, check that the received message is a response
to this query, and if not keep listening for a valid response.
"""
wire = b""
while True:
(wire, from_address) = _udp_recv(sock, 65535, expiration)
if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
continue
received_time = time.time()
try:
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# If we got Truncated and not FORMERR, we at least got the header with TC
# set, and very likely the question section, so we'll re-raise if the
# message seems to be a response as we need to know when truncation happens.
# We need to check that it seems to be a response as we don't want a random
# injected message with TC set to cause us to bail out.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
if destination:
return (r, received_time)
else:
return (r, received_time, from_address)
def udp(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: Optional[Any] = None,
ignore_errors: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
unexpected sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
*sock*, a ``socket.socket``, or ``None``, the socket to use for the
query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking datagram socket,
and the *source* and *source_port* are ignored.
*ignore_errors*, a ``bool``. If various format errors or response
mismatches occur, ignore them and keep listening for a valid response.
The default is ``False``.
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
(begin_time, expiration) = _compute_times(timeout)
if sock:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
else:
cm = _make_socket(af, socket.SOCK_DGRAM, source)
with cm as s:
send_udp(s, wire, destination, expiration)
(r, received_time) = receive_udp(
s,
destination,
expiration,
ignore_unexpected,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
raise_on_truncation,
ignore_errors,
q,
)
r.time = received_time - begin_time
# We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse
return r
assert (
False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
)
def udp_with_fallback(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: Optional[Any] = None,
tcp_sock: Optional[Any] = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message. The default is
0.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected
sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
If ``None``, the default, a socket is created. Note that if a socket is provided,
it must be a nonblocking datagram socket, and the *source* and *source_port* are
ignored for the UDP query.
*tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
TCP query. If ``None``, the default, a socket is created. Note that if a socket is
provided, it must be a nonblocking connected stream socket, and *where*, *source*
and *source_port* are ignored for the TCP query.
*ignore_errors*, a ``bool``. If various format errors or response mismatches occur
while listening for UDP, ignore them and keep listening for a valid response. The
default is ``False``.
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
TCP was used.
"""
try:
response = udp(
q,
where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
ignore_errors,
)
return (response, False)
except dns.message.Truncated:
response = tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
)
return (response, True)
2018-11-01 00:08:29 +08:00
2018-11-01 00:08:29 +08:00
def _net_read(sock, count, expiration):
"""Read the specified number of bytes from sock. Keep trying until we
either get the desired amount, or we hit EOF.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
s = b""
2018-11-01 00:08:29 +08:00
while count > 0:
try:
n = sock.recv(count)
if n == b"":
raise EOFError
count -= len(n)
s += n
except (BlockingIOError, ssl.SSLWantReadError):
_wait_for_readable(sock, expiration)
except ssl.SSLWantWriteError: # pragma: no cover
_wait_for_writable(sock, expiration)
2018-11-01 00:08:29 +08:00
return s
def _net_write(sock, data, expiration):
"""Write the specified data to the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
current = 0
l = len(data)
while current < l:
try:
current += sock.send(data[current:])
except (BlockingIOError, ssl.SSLWantWriteError):
_wait_for_writable(sock, expiration)
except ssl.SSLWantReadError: # pragma: no cover
_wait_for_readable(sock, expiration)
2018-11-01 00:08:29 +08:00
def send_tcp(
sock: Any,
what: Union[dns.message.Message, bytes],
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
2018-11-01 00:08:29 +08:00
*sock*, a ``socket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
if isinstance(what, dns.message.Message):
tcpmsg = what.to_wire(prepend_length=True)
else:
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time()
_net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time)
def receive_tcp(
sock: Any,
expiration: Optional[float] = None,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``socket``.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*keyring*, a ``dict``, the keyring to use for TSIG.
*request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
Returns a ``(dns.message.Message, float)`` tuple of the received message
and the received time.
"""
ldata = _net_read(sock, 2, expiration)
(l,) = struct.unpack("!H", ldata)
wire = _net_read(sock, l, expiration)
received_time = time.time()
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return (r, received_time)
2018-11-01 00:08:29 +08:00
def _connect(s, address, expiration):
err = s.connect_ex(address)
if err == 0:
return
if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
_wait_for_writable(s, expiration)
err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
raise OSError(err, os.strerror(err))
2018-11-01 00:08:29 +08:00
def tcp(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[Any] = None,
) -> dns.message.Message:
2018-11-01 00:08:29 +08:00
"""Return the response obtained after sending a query via TCP.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
2018-11-01 00:08:29 +08:00
The default is 0.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking connected stream
socket, and *where*, *port*, *source* and *source_port* are ignored.
Returns a ``dns.message.Message``.
2018-11-01 00:08:29 +08:00
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
if sock:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
else:
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
cm = _make_socket(af, socket.SOCK_STREAM, source)
with cm as s:
if not sock:
_connect(s, destination, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
assert (
False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
)
2018-11-01 00:08:29 +08:00
def _tls_handshake(s, expiration):
while True:
try:
s.do_handshake()
return
except ssl.SSLWantReadError:
_wait_for_readable(s, expiration)
except ssl.SSLWantWriteError: # pragma: no cover
_wait_for_writable(s, expiration)
def _make_dot_ssl_context(
server_hostname: Optional[str], verify: Union[bool, str]
) -> ssl.SSLContext:
cafile: Optional[str] = None
capath: Optional[str] = None
if isinstance(verify, str):
if os.path.isfile(verify):
cafile = verify
elif os.path.isdir(verify):
capath = verify
else:
raise ValueError("invalid verify string")
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context.set_alpn_protocols(["dot"])
if verify is False:
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context
def tls(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[ssl.SSLSocket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 853.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for
the query. If ``None``, the default, a socket is created. Note
that if a socket is provided, it must be a nonblocking connected
SSL stream socket, and *where*, *port*, *source*, *source_port*,
and *ssl_context* are ignored.
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
a TLS connection. If ``None``, the default, creates one with the default
configuration.
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
Returns a ``dns.message.Message``.
"""
if sock:
#
# If a socket was provided, there's no special TLS handling needed.
#
return tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
sock,
)
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
if ssl_context is None and not sock:
ssl_context = _make_dot_ssl_context(server_hostname, verify)
with _make_socket(
af,
socket.SOCK_STREAM,
source,
ssl_context=ssl_context,
server_hostname=server_hostname,
) as s:
_connect(s, destination, expiration)
_tls_handshake(s, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
assert (
False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
)
def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC.
*q*, a ``dns.message.Message``, the query to send.
*where*, a ``str``, the nameserver IP address.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
times out. If ``None``, the default, wait forever.
*port*, a ``int``, the port to send the query to. The default is 853.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message. The default is
0.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
connection to use to send the query.
2018-11-01 00:08:29 +08:00
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
2018-11-01 00:08:29 +08:00
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.SyncQuicConnection
the_manager: dns.quic.SyncQuicManager
if connection:
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection
else:
manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=server_hostname
)
the_manager = manager # for type checking happiness
with manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
with the_connection.make_stream(timeout) as stream:
stream.send(wire, True)
wire = stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
def xfr(
where: str,
zone: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
timeout: Optional[float] = None,
port: int = 53,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
keyname: Optional[Union[dns.name.Name, str]] = None,
relativize: bool = True,
lifetime: Optional[float] = None,
source: Optional[str] = None,
source_port: int = 0,
serial: int = 0,
use_udp: bool = False,
keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
) -> Any:
2018-11-01 00:08:29 +08:00
"""Return a generator for the responses to a zone transfer.
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer.
*rdtype*, an ``int`` or ``str``, the type of zone transfer. The
default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be
used to do an incremental transfer instead.
*rdclass*, an ``int`` or ``str``, the class of the zone transfer.
The default is ``dns.rdataclass.IN``.
*timeout*, a ``float``, the number of seconds to wait for each
response message. If None, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*keyring*, a ``dict``, the keyring to use for TSIG.
*keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG
key to use.
*relativize*, a ``bool``. If ``True``, all names in the zone will be
relativized to the zone origin. It is essential that the
relativize setting matches the one specified to
``dns.zone.from_xfr()`` if using this generator to make a zone.
*lifetime*, a ``float``, the total number of seconds to spend
doing the transfer. If ``None``, the default, then there is no
limit on the time the transfer may take.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
2018-11-01 00:08:29 +08:00
The default is 0.
*serial*, an ``int``, the SOA serial number to use as the base for
an IXFR diff sequence (only meaningful if *rdtype* is
``dns.rdatatype.IXFR``).
*use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR).
*keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.
Raises on errors, and so does the generator.
Returns a generator of ``dns.message.Message`` objects.
2018-11-01 00:08:29 +08:00
"""
if isinstance(zone, str):
2018-11-01 00:08:29 +08:00
zone = dns.name.from_text(zone)
rdtype = dns.rdatatype.RdataType.make(rdtype)
2018-11-01 00:08:29 +08:00
q = dns.message.make_query(zone, rdtype, rdclass)
if rdtype == dns.rdatatype.IXFR:
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
2018-11-01 00:08:29 +08:00
q.authority.append(rrset)
if keyring is not None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
wire = q.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError("cannot do a UDP AXFR")
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
with _make_socket(af, sock_type, source) as s:
(_, expiration) = _compute_times(lifetime)
_connect(s, destination, expiration)
l = len(wire)
2018-11-01 00:08:29 +08:00
if use_udp:
_udp_send(s, wire, None, expiration)
2018-11-01 00:08:29 +08:00
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
done = False
delete_mode = True
expecting_SOA = False
soa_rrset = None
if relativize:
origin = zone
oname = dns.name.empty
else:
origin = None
oname = zone
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if use_udp:
(wire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
wire = _net_read(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=True,
one_rr_per_rrset=is_ixfr,
)
rcode = r.rcode()
if rcode != dns.rcode.NOERROR:
raise TransferError(rcode)
tsig_ctx = r.tsig_ctx
answer_index = 0
if soa_rrset is None:
if not r.answer or r.answer[0].name != oname:
raise dns.exception.FormError("No answer or RRset not for qname")
rrset = r.answer[0]
if rrset.rdtype != dns.rdatatype.SOA:
raise dns.exception.FormError("first RRset is not an SOA")
answer_index = 1
soa_rrset = rrset.copy()
if rdtype == dns.rdatatype.IXFR:
if dns.serial.Serial(soa_rrset[0].serial) <= serial:
#
# We're already up-to-date.
#
done = True
else:
expecting_SOA = True
#
# Process SOAs in the answer section (other than the initial
# SOA in the first message).
#
for rrset in r.answer[answer_index:]:
if done:
raise dns.exception.FormError("answers after final SOA")
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
if expecting_SOA:
if rrset[0].serial != serial:
raise dns.exception.FormError("IXFR base serial mismatch")
expecting_SOA = False
elif rdtype == dns.rdatatype.IXFR:
delete_mode = not delete_mode
2018-11-01 00:08:29 +08:00
#
# If this SOA RRset is equal to the first we saw then we're
# finished. If this is an IXFR we also check that we're
# seeing the record in the expected part of the response.
2018-11-01 00:08:29 +08:00
#
if rrset == soa_rrset and (
rdtype == dns.rdatatype.AXFR
or (rdtype == dns.rdatatype.IXFR and delete_mode)
):
done = True
elif expecting_SOA:
#
# We made an IXFR request and are expecting another
# SOA RR, but saw something else, so this must be an
# AXFR response.
#
rdtype = dns.rdatatype.AXFR
2018-11-01 00:08:29 +08:00
expecting_SOA = False
if done and q.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
yield r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
query: Optional[dns.message.Message] = None,
port: int = 53,
timeout: Optional[float] = None,
lifetime: Optional[float] = None,
source: Optional[str] = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
for this transfer (typically a ``dns.zone.Zone``).
*query*, the query to send. If not supplied, a default query is
constructed using information from the *txn_manager*.
*port*, an ``int``, the port send the message to. The default is 53.
*timeout*, a ``float``, the number of seconds to wait for each
response message. If None, the default, wait forever.
*lifetime*, a ``float``, the total number of seconds to spend
doing the transfer. If ``None``, the default, then there is no
limit on the time the transfer may take.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use
2022-11-08 02:06:49 +08:00
TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which
means "try UDP but fallback to TCP if needed", and
``dns.UDPMode.ONLY``, which means "try UDP and raise
2022-11-08 02:06:49 +08:00
``dns.xfr.UseTCP`` if it does not succeed.
Raises on errors.
"""
if query is None:
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
(_, expiration) = _compute_times(lifetime)
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
with _make_socket(af, sock_type, source) as s:
_connect(s, destination, expiration)
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")