bazarr/libs/waitress/task.py
2022-11-07 13:08:27 -05:00

571 lines
21 KiB
Python

##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
from collections import deque
import socket
import sys
import threading
import time
from .buffers import ReadOnlyFileBasedBuffer
from .utilities import build_http_date, logger, queue_logger
rename_headers = { # or keep them without the HTTP_ prefix added
"CONTENT_LENGTH": "CONTENT_LENGTH",
"CONTENT_TYPE": "CONTENT_TYPE",
}
hop_by_hop = frozenset(
(
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
)
)
class ThreadedTaskDispatcher:
"""A Task Dispatcher that creates a thread for each task."""
stop_count = 0 # Number of threads that will stop soon.
active_count = 0 # Number of currently active threads
logger = logger
queue_logger = queue_logger
def __init__(self):
self.threads = set()
self.queue = deque()
self.lock = threading.Lock()
self.queue_cv = threading.Condition(self.lock)
self.thread_exit_cv = threading.Condition(self.lock)
def start_new_thread(self, target, thread_no):
t = threading.Thread(
target=target, name=f"waitress-{thread_no}", args=(thread_no,)
)
t.daemon = True
t.start()
def handler_thread(self, thread_no):
while True:
with self.lock:
while not self.queue and self.stop_count == 0:
# Mark ourselves as idle before waiting to be
# woken up, then we will once again be active
self.active_count -= 1
self.queue_cv.wait()
self.active_count += 1
if self.stop_count > 0:
self.active_count -= 1
self.stop_count -= 1
self.threads.discard(thread_no)
self.thread_exit_cv.notify()
break
task = self.queue.popleft()
try:
task.service()
except BaseException:
self.logger.exception("Exception when servicing %r", task)
def set_thread_count(self, count):
with self.lock:
threads = self.threads
thread_no = 0
running = len(threads) - self.stop_count
while running < count:
# Start threads.
while thread_no in threads:
thread_no = thread_no + 1
threads.add(thread_no)
running += 1
self.start_new_thread(self.handler_thread, thread_no)
self.active_count += 1
thread_no = thread_no + 1
if running > count:
# Stop threads.
self.stop_count += running - count
self.queue_cv.notify_all()
def add_task(self, task):
with self.lock:
self.queue.append(task)
self.queue_cv.notify()
queue_size = len(self.queue)
idle_threads = len(self.threads) - self.stop_count - self.active_count
if queue_size > idle_threads:
self.queue_logger.warning(
"Task queue depth is %d", queue_size - idle_threads
)
def shutdown(self, cancel_pending=True, timeout=5):
self.set_thread_count(0)
# Ensure the threads shut down.
threads = self.threads
expiration = time.time() + timeout
with self.lock:
while threads:
if time.time() >= expiration:
self.logger.warning("%d thread(s) still running", len(threads))
break
self.thread_exit_cv.wait(0.1)
if cancel_pending:
# Cancel remaining tasks.
queue = self.queue
if len(queue) > 0:
self.logger.warning("Canceling %d pending task(s)", len(queue))
while queue:
task = queue.popleft()
task.cancel()
self.queue_cv.notify_all()
return True
return False
class Task:
close_on_finish = False
status = "200 OK"
wrote_header = False
start_time = 0
content_length = None
content_bytes_written = 0
logged_write_excess = False
logged_write_no_body = False
complete = False
chunked_response = False
logger = logger
def __init__(self, channel, request):
self.channel = channel
self.request = request
self.response_headers = []
version = request.version
if version not in ("1.0", "1.1"):
# fall back to a version we support.
version = "1.0"
self.version = version
def service(self):
try:
self.start()
self.execute()
self.finish()
except OSError:
self.close_on_finish = True
if self.channel.adj.log_socket_errors:
raise
@property
def has_body(self):
return not (
self.status.startswith("1")
or self.status.startswith("204")
or self.status.startswith("304")
)
def build_response_header(self):
version = self.version
# Figure out whether the connection should be closed.
connection = self.request.headers.get("CONNECTION", "").lower()
response_headers = []
content_length_header = None
date_header = None
server_header = None
connection_close_header = None
for (headername, headerval) in self.response_headers:
headername = "-".join([x.capitalize() for x in headername.split("-")])
if headername == "Content-Length":
if self.has_body:
content_length_header = headerval
else:
continue # pragma: no cover
if headername == "Date":
date_header = headerval
if headername == "Server":
server_header = headerval
if headername == "Connection":
connection_close_header = headerval.lower()
# replace with properly capitalized version
response_headers.append((headername, headerval))
if (
content_length_header is None
and self.content_length is not None
and self.has_body
):
content_length_header = str(self.content_length)
response_headers.append(("Content-Length", content_length_header))
def close_on_finish():
if connection_close_header is None:
response_headers.append(("Connection", "close"))
self.close_on_finish = True
if version == "1.0":
if connection == "keep-alive":
if not content_length_header:
close_on_finish()
else:
response_headers.append(("Connection", "Keep-Alive"))
else:
close_on_finish()
elif version == "1.1":
if connection == "close":
close_on_finish()
if not content_length_header:
# RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
# for any response with a status code of 1xx, 204 or 304.
if self.has_body:
response_headers.append(("Transfer-Encoding", "chunked"))
self.chunked_response = True
if not self.close_on_finish:
close_on_finish()
# under HTTP 1.1 keep-alive is default, no need to set the header
else:
raise AssertionError("neither HTTP/1.0 or HTTP/1.1")
# Set the Server and Date field, if not yet specified. This is needed
# if the server is used as a proxy.
ident = self.channel.server.adj.ident
if not server_header:
if ident:
response_headers.append(("Server", ident))
else:
response_headers.append(("Via", ident or "waitress"))
if not date_header:
response_headers.append(("Date", build_http_date(self.start_time)))
self.response_headers = response_headers
first_line = f"HTTP/{self.version} {self.status}"
# NB: sorting headers needs to preserve same-named-header order
# as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
# rely on stable sort to keep relative position of same-named headers
next_lines = [
"%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0])
]
lines = [first_line] + next_lines
res = "%s\r\n\r\n" % "\r\n".join(lines)
return res.encode("latin-1")
def remove_content_length_header(self):
response_headers = []
for header_name, header_value in self.response_headers:
if header_name.lower() == "content-length":
continue # pragma: nocover
response_headers.append((header_name, header_value))
self.response_headers = response_headers
def start(self):
self.start_time = time.time()
def finish(self):
if not self.wrote_header:
self.write(b"")
if self.chunked_response:
# not self.write, it will chunk it!
self.channel.write_soon(b"0\r\n\r\n")
def write(self, data):
if not self.complete:
raise RuntimeError("start_response was not called before body written")
channel = self.channel
if not self.wrote_header:
rh = self.build_response_header()
channel.write_soon(rh)
self.wrote_header = True
if data and self.has_body:
towrite = data
cl = self.content_length
if self.chunked_response:
# use chunked encoding response
towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n"
towrite += data + b"\r\n"
elif cl is not None:
towrite = data[: cl - self.content_bytes_written]
self.content_bytes_written += len(towrite)
if towrite != data and not self.logged_write_excess:
self.logger.warning(
"application-written content exceeded the number of "
"bytes specified by Content-Length header (%s)" % cl
)
self.logged_write_excess = True
if towrite:
channel.write_soon(towrite)
elif data:
# Cheat, and tell the application we have written all of the bytes,
# even though the response shouldn't have a body and we are
# ignoring it entirely.
self.content_bytes_written += len(data)
if not self.logged_write_no_body:
self.logger.warning(
"application-written content was ignored due to HTTP "
"response that may not contain a message-body: (%s)" % self.status
)
self.logged_write_no_body = True
class ErrorTask(Task):
"""An error task produces an error response"""
complete = True
def execute(self):
e = self.request.error
status, headers, body = e.to_response()
self.status = status
self.response_headers.extend(headers)
# We need to explicitly tell the remote client we are closing the
# connection, because self.close_on_finish is set, and we are going to
# slam the door in the clients face.
self.response_headers.append(("Connection", "close"))
self.close_on_finish = True
self.content_length = len(body)
self.write(body)
class WSGITask(Task):
"""A WSGI task produces a response from a WSGI application."""
environ = None
def execute(self):
environ = self.get_environment()
def start_response(status, headers, exc_info=None):
if self.complete and not exc_info:
raise AssertionError(
"start_response called a second time without providing exc_info."
)
if exc_info:
try:
if self.wrote_header:
# higher levels will catch and handle raised exception:
# 1. "service" method in task.py
# 2. "service" method in channel.py
# 3. "handler_thread" method in task.py
raise exc_info[1]
else:
# As per WSGI spec existing headers must be cleared
self.response_headers = []
finally:
exc_info = None
self.complete = True
if not status.__class__ is str:
raise AssertionError("status %s is not a string" % status)
if "\n" in status or "\r" in status:
raise ValueError(
"carriage return/line feed character present in status"
)
self.status = status
# Prepare the headers for output
for k, v in headers:
if not k.__class__ is str:
raise AssertionError(
f"Header name {k!r} is not a string in {(k, v)!r}"
)
if not v.__class__ is str:
raise AssertionError(
f"Header value {v!r} is not a string in {(k, v)!r}"
)
if "\n" in v or "\r" in v:
raise ValueError(
"carriage return/line feed character present in header value"
)
if "\n" in k or "\r" in k:
raise ValueError(
"carriage return/line feed character present in header name"
)
kl = k.lower()
if kl == "content-length":
self.content_length = int(v)
elif kl in hop_by_hop:
raise AssertionError(
'%s is a "hop-by-hop" header; it cannot be used by '
"a WSGI application (see PEP 3333)" % k
)
self.response_headers.extend(headers)
# Return a method used to write the response data.
return self.write
# Call the application to handle the request and write a response
app_iter = self.channel.server.application(environ, start_response)
can_close_app_iter = True
try:
if app_iter.__class__ is ReadOnlyFileBasedBuffer:
cl = self.content_length
size = app_iter.prepare(cl)
if size:
if cl != size:
if cl is not None:
self.remove_content_length_header()
self.content_length = size
self.write(b"") # generate headers
# if the write_soon below succeeds then the channel will
# take over closing the underlying file via the channel's
# _flush_some or handle_close so we intentionally avoid
# calling close in the finally block
self.channel.write_soon(app_iter)
can_close_app_iter = False
return
first_chunk_len = None
for chunk in app_iter:
if first_chunk_len is None:
first_chunk_len = len(chunk)
# Set a Content-Length header if one is not supplied.
# start_response may not have been called until first
# iteration as per PEP, so we must reinterrogate
# self.content_length here
if self.content_length is None:
app_iter_len = None
if hasattr(app_iter, "__len__"):
app_iter_len = len(app_iter)
if app_iter_len == 1:
self.content_length = first_chunk_len
# transmit headers only after first iteration of the iterable
# that returns a non-empty bytestring (PEP 3333)
if chunk:
self.write(chunk)
cl = self.content_length
if cl is not None:
if self.content_bytes_written != cl:
# close the connection so the client isn't sitting around
# waiting for more data when there are too few bytes
# to service content-length
self.close_on_finish = True
if self.request.command != "HEAD":
self.logger.warning(
"application returned too few bytes (%s) "
"for specified Content-Length (%s) via app_iter"
% (self.content_bytes_written, cl),
)
finally:
if can_close_app_iter and hasattr(app_iter, "close"):
app_iter.close()
def get_environment(self):
"""Returns a WSGI environment."""
environ = self.environ
if environ is not None:
# Return the cached copy.
return environ
request = self.request
path = request.path
channel = self.channel
server = channel.server
url_prefix = server.adj.url_prefix
if path.startswith("/"):
# strip extra slashes at the beginning of a path that starts
# with any number of slashes
path = "/" + path.lstrip("/")
if url_prefix:
# NB: url_prefix is guaranteed by the configuration machinery to
# be either the empty string or a string that starts with a single
# slash and ends without any slashes
if path == url_prefix:
# if the path is the same as the url prefix, the SCRIPT_NAME
# should be the url_prefix and PATH_INFO should be empty
path = ""
else:
# if the path starts with the url prefix plus a slash,
# the SCRIPT_NAME should be the url_prefix and PATH_INFO should
# the value of path from the slash until its end
url_prefix_with_trailing_slash = url_prefix + "/"
if path.startswith(url_prefix_with_trailing_slash):
path = path[len(url_prefix) :]
environ = {
"REMOTE_ADDR": channel.addr[0],
# Nah, we aren't actually going to look up the reverse DNS for
# REMOTE_ADDR, but we will happily set this environment variable
# for the WSGI application. Spec says we can just set this to
# REMOTE_ADDR, so we do.
"REMOTE_HOST": channel.addr[0],
# try and set the REMOTE_PORT to something useful, but maybe None
"REMOTE_PORT": str(channel.addr[1]),
"REQUEST_METHOD": request.command.upper(),
"SERVER_PORT": str(server.effective_port),
"SERVER_NAME": server.server_name,
"SERVER_SOFTWARE": server.adj.ident,
"SERVER_PROTOCOL": "HTTP/%s" % self.version,
"SCRIPT_NAME": url_prefix,
"PATH_INFO": path,
"REQUEST_URI": request.request_uri,
"QUERY_STRING": request.query,
"wsgi.url_scheme": request.url_scheme,
# the following environment variables are required by the WSGI spec
"wsgi.version": (1, 0),
# apps should use the logging module
"wsgi.errors": sys.stderr,
"wsgi.multithread": True,
"wsgi.multiprocess": False,
"wsgi.run_once": False,
"wsgi.input": request.get_body_stream(),
"wsgi.file_wrapper": ReadOnlyFileBasedBuffer,
"wsgi.input_terminated": True, # wsgi.input is EOF terminated
}
for key, value in dict(request.headers).items():
value = value.strip()
mykey = rename_headers.get(key, None)
if mykey is None:
mykey = "HTTP_" + key
if mykey not in environ:
environ[mykey] = value
# Insert a callable into the environment that allows the application to
# check if the client disconnected. Only works with
# channel_request_lookahead larger than 0.
environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
# cache the environ for this request
self.environ = environ
return environ