bazarr/libs/waitress/server.py
Louis Vézina 1b0e721a9d WIP
2020-04-15 00:02:44 -04:00

436 lines
14 KiB
Python

##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
import os
import os.path
import socket
import time
from waitress import trigger
from waitress.adjustments import Adjustments
from waitress.channel import HTTPChannel
from waitress.task import ThreadedTaskDispatcher
from waitress.utilities import cleanup_unix_socket
from waitress.compat import (
IPPROTO_IPV6,
IPV6_V6ONLY,
)
from . import wasyncore
from .proxy_headers import proxy_headers_middleware
def create_server(
application,
map=None,
_start=True, # test shim
_sock=None, # test shim
_dispatcher=None, # test shim
**kw # adjustments
):
"""
if __name__ == '__main__':
server = create_server(app)
server.run()
"""
if application is None:
raise ValueError(
'The "app" passed to ``create_server`` was ``None``. You forgot '
"to return a WSGI app within your application."
)
adj = Adjustments(**kw)
if map is None: # pragma: nocover
map = {}
dispatcher = _dispatcher
if dispatcher is None:
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(adj.threads)
if adj.unix_socket and hasattr(socket, "AF_UNIX"):
sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None)
return UnixWSGIServer(
application,
map,
_start,
_sock,
dispatcher=dispatcher,
adj=adj,
sockinfo=sockinfo,
)
effective_listen = []
last_serv = None
if not adj.sockets:
for sockinfo in adj.listen:
# When TcpWSGIServer is called, it registers itself in the map. This
# side-effect is all we need it for, so we don't store a reference to
# or return it to the user.
last_serv = TcpWSGIServer(
application,
map,
_start,
_sock,
dispatcher=dispatcher,
adj=adj,
sockinfo=sockinfo,
)
effective_listen.append(
(last_serv.effective_host, last_serv.effective_port)
)
for sock in adj.sockets:
sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname())
if sock.family == socket.AF_INET or sock.family == socket.AF_INET6:
last_serv = TcpWSGIServer(
application,
map,
_start,
sock,
dispatcher=dispatcher,
adj=adj,
bind_socket=False,
sockinfo=sockinfo,
)
effective_listen.append(
(last_serv.effective_host, last_serv.effective_port)
)
elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
last_serv = UnixWSGIServer(
application,
map,
_start,
sock,
dispatcher=dispatcher,
adj=adj,
bind_socket=False,
sockinfo=sockinfo,
)
effective_listen.append(
(last_serv.effective_host, last_serv.effective_port)
)
# We are running a single server, so we can just return the last server,
# saves us from having to create one more object
if len(effective_listen) == 1:
# In this case we have no need to use a MultiSocketServer
return last_serv
# Return a class that has a utility function to print out the sockets it's
# listening on, and has a .run() function. All of the TcpWSGIServers
# registered themselves in the map above.
return MultiSocketServer(map, adj, effective_listen, dispatcher)
# This class is only ever used if we have multiple listen sockets. It allows
# the serve() API to call .run() which starts the wasyncore loop, and catches
# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down.
class MultiSocketServer(object):
asyncore = wasyncore # test shim
def __init__(
self, map=None, adj=None, effective_listen=None, dispatcher=None,
):
self.adj = adj
self.map = map
self.effective_listen = effective_listen
self.task_dispatcher = dispatcher
def print_listen(self, format_str): # pragma: nocover
for l in self.effective_listen:
l = list(l)
if ":" in l[0]:
l[0] = "[{}]".format(l[0])
print(format_str.format(*l))
def run(self):
try:
self.asyncore.loop(
timeout=self.adj.asyncore_loop_timeout,
map=self.map,
use_poll=self.adj.asyncore_use_poll,
)
except (SystemExit, KeyboardInterrupt):
self.close()
def close(self):
self.task_dispatcher.shutdown()
wasyncore.close_all(self.map)
class BaseWSGIServer(wasyncore.dispatcher, object):
channel_class = HTTPChannel
next_channel_cleanup = 0
socketmod = socket # test shim
asyncore = wasyncore # test shim
def __init__(
self,
application,
map=None,
_start=True, # test shim
_sock=None, # test shim
dispatcher=None, # dispatcher
adj=None, # adjustments
sockinfo=None, # opaque object
bind_socket=True,
**kw
):
if adj is None:
adj = Adjustments(**kw)
if adj.trusted_proxy or adj.clear_untrusted_proxy_headers:
# wrap the application to deal with proxy headers
# we wrap it here because webtest subclasses the TcpWSGIServer
# directly and thus doesn't run any code that's in create_server
application = proxy_headers_middleware(
application,
trusted_proxy=adj.trusted_proxy,
trusted_proxy_count=adj.trusted_proxy_count,
trusted_proxy_headers=adj.trusted_proxy_headers,
clear_untrusted=adj.clear_untrusted_proxy_headers,
log_untrusted=adj.log_untrusted_proxy_headers,
logger=self.logger,
)
if map is None:
# use a nonglobal socket map by default to hopefully prevent
# conflicts with apps and libs that use the wasyncore global socket
# map ala https://github.com/Pylons/waitress/issues/63
map = {}
if sockinfo is None:
sockinfo = adj.listen[0]
self.sockinfo = sockinfo
self.family = sockinfo[0]
self.socktype = sockinfo[1]
self.application = application
self.adj = adj
self.trigger = trigger.trigger(map)
if dispatcher is None:
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(self.adj.threads)
self.task_dispatcher = dispatcher
self.asyncore.dispatcher.__init__(self, _sock, map=map)
if _sock is None:
self.create_socket(self.family, self.socktype)
if self.family == socket.AF_INET6: # pragma: nocover
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
self.set_reuse_addr()
if bind_socket:
self.bind_server_socket()
self.effective_host, self.effective_port = self.getsockname()
self.server_name = self.get_server_name(self.effective_host)
self.active_channels = {}
if _start:
self.accept_connections()
def bind_server_socket(self):
raise NotImplementedError # pragma: no cover
def get_server_name(self, ip):
"""Given an IP or hostname, try to determine the server name."""
if not ip:
raise ValueError("Requires an IP to get the server name")
server_name = str(ip)
# If we are bound to all IP's, just return the current hostname, only
# fall-back to "localhost" if we fail to get the hostname
if server_name == "0.0.0.0" or server_name == "::":
try:
return str(self.socketmod.gethostname())
except (socket.error, UnicodeDecodeError): # pragma: no cover
# We also deal with UnicodeDecodeError in case of Windows with
# non-ascii hostname
return "localhost"
# Now let's try and convert the IP address to a proper hostname
try:
server_name = self.socketmod.gethostbyaddr(server_name)[0]
except (socket.error, UnicodeDecodeError): # pragma: no cover
# We also deal with UnicodeDecodeError in case of Windows with
# non-ascii hostname
pass
# If it contains an IPv6 literal, make sure to surround it with
# brackets
if ":" in server_name and "[" not in server_name:
server_name = "[{}]".format(server_name)
return server_name
def getsockname(self):
raise NotImplementedError # pragma: no cover
def accept_connections(self):
self.accepting = True
self.socket.listen(self.adj.backlog) # Get around asyncore NT limit
def add_task(self, task):
self.task_dispatcher.add_task(task)
def readable(self):
now = time.time()
if now >= self.next_channel_cleanup:
self.next_channel_cleanup = now + self.adj.cleanup_interval
self.maintenance(now)
return self.accepting and len(self._map) < self.adj.connection_limit
def writable(self):
return False
def handle_read(self):
pass
def handle_connect(self):
pass
def handle_accept(self):
try:
v = self.accept()
if v is None:
return
conn, addr = v
except socket.error:
# Linux: On rare occasions we get a bogus socket back from
# accept. socketmodule.c:makesockaddr complains that the
# address family is unknown. We don't want the whole server
# to shut down because of this.
if self.adj.log_socket_errors:
self.logger.warning("server accept() threw an exception", exc_info=True)
return
self.set_socket_options(conn)
addr = self.fix_addr(addr)
self.channel_class(self, conn, addr, self.adj, map=self._map)
def run(self):
try:
self.asyncore.loop(
timeout=self.adj.asyncore_loop_timeout,
map=self._map,
use_poll=self.adj.asyncore_use_poll,
)
except (SystemExit, KeyboardInterrupt):
self.task_dispatcher.shutdown()
def pull_trigger(self):
self.trigger.pull_trigger()
def set_socket_options(self, conn):
pass
def fix_addr(self, addr):
return addr
def maintenance(self, now):
"""
Closes channels that have not had any activity in a while.
The timeout is configured through adj.channel_timeout (seconds).
"""
cutoff = now - self.adj.channel_timeout
for channel in self.active_channels.values():
if (not channel.requests) and channel.last_activity < cutoff:
channel.will_close = True
def print_listen(self, format_str): # pragma: nocover
print(format_str.format(self.effective_host, self.effective_port))
def close(self):
self.trigger.close()
return wasyncore.dispatcher.close(self)
class TcpWSGIServer(BaseWSGIServer):
def bind_server_socket(self):
(_, _, _, sockaddr) = self.sockinfo
self.bind(sockaddr)
def getsockname(self):
try:
return self.socketmod.getnameinfo(
self.socket.getsockname(), self.socketmod.NI_NUMERICSERV
)
except: # pragma: no cover
# This only happens on Linux because a DNS issue is considered a
# temporary failure that will raise (even when NI_NAMEREQD is not
# set). Instead we try again, but this time we just ask for the
# numerichost and the numericserv (port) and return those. It is
# better than nothing.
return self.socketmod.getnameinfo(
self.socket.getsockname(),
self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV,
)
def set_socket_options(self, conn):
for (level, optname, value) in self.adj.socket_options:
conn.setsockopt(level, optname, value)
if hasattr(socket, "AF_UNIX"):
class UnixWSGIServer(BaseWSGIServer):
def __init__(
self,
application,
map=None,
_start=True, # test shim
_sock=None, # test shim
dispatcher=None, # dispatcher
adj=None, # adjustments
sockinfo=None, # opaque object
**kw
):
if sockinfo is None:
sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None)
super(UnixWSGIServer, self).__init__(
application,
map=map,
_start=_start,
_sock=_sock,
dispatcher=dispatcher,
adj=adj,
sockinfo=sockinfo,
**kw
)
def bind_server_socket(self):
cleanup_unix_socket(self.adj.unix_socket)
self.bind(self.adj.unix_socket)
if os.path.exists(self.adj.unix_socket):
os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms)
def getsockname(self):
return ("unix", self.socket.getsockname())
def fix_addr(self, addr):
return ("localhost", None)
def get_server_name(self, ip):
return "localhost"
# Compatibility alias.
WSGIServer = TcpWSGIServer