bazarr/libs/waitress/task.py
2018-09-16 20:33:04 -04:00

529 lines
19 KiB
Python

##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
import socket
import sys
import threading
import time
from waitress.buffers import ReadOnlyFileBasedBuffer
from waitress.compat import (
tobytes,
Queue,
Empty,
reraise,
)
from waitress.utilities import (
build_http_date,
logger,
)
rename_headers = { # or keep them without the HTTP_ prefix added
'CONTENT_LENGTH': 'CONTENT_LENGTH',
'CONTENT_TYPE': 'CONTENT_TYPE',
}
hop_by_hop = frozenset((
'connection',
'keep-alive',
'proxy-authenticate',
'proxy-authorization',
'te',
'trailers',
'transfer-encoding',
'upgrade'
))
class JustTesting(Exception):
pass
class ThreadedTaskDispatcher(object):
"""A Task Dispatcher that creates a thread for each task.
"""
stop_count = 0 # Number of threads that will stop soon.
logger = logger
def __init__(self):
self.threads = {} # { thread number -> 1 }
self.queue = Queue()
self.thread_mgmt_lock = threading.Lock()
def start_new_thread(self, target, args):
t = threading.Thread(target=target, name='waitress', args=args)
t.daemon = True
t.start()
def handler_thread(self, thread_no):
threads = self.threads
try:
while threads.get(thread_no):
task = self.queue.get()
if task is None:
# Special value: kill this thread.
break
try:
task.service()
except Exception as e:
self.logger.exception(
'Exception when servicing %r' % task)
if isinstance(e, JustTesting):
break
finally:
with self.thread_mgmt_lock:
self.stop_count -= 1
threads.pop(thread_no, None)
def set_thread_count(self, count):
with self.thread_mgmt_lock:
threads = self.threads
thread_no = 0
running = len(threads) - self.stop_count
while running < count:
# Start threads.
while thread_no in threads:
thread_no = thread_no + 1
threads[thread_no] = 1
running += 1
self.start_new_thread(self.handler_thread, (thread_no,))
thread_no = thread_no + 1
if running > count:
# Stop threads.
to_stop = running - count
self.stop_count += to_stop
for n in range(to_stop):
self.queue.put(None)
running -= 1
def add_task(self, task):
try:
task.defer()
self.queue.put(task)
except:
task.cancel()
raise
def shutdown(self, cancel_pending=True, timeout=5):
self.set_thread_count(0)
# Ensure the threads shut down.
threads = self.threads
expiration = time.time() + timeout
while threads:
if time.time() >= expiration:
self.logger.warning(
"%d thread(s) still running" %
len(threads))
break
time.sleep(0.1)
if cancel_pending:
# Cancel remaining tasks.
try:
queue = self.queue
while not queue.empty():
task = queue.get()
if task is not None:
task.cancel()
except Empty: # pragma: no cover
pass
return True
return False
class Task(object):
close_on_finish = False
status = '200 OK'
wrote_header = False
start_time = 0
content_length = None
content_bytes_written = 0
logged_write_excess = False
complete = False
chunked_response = False
logger = logger
def __init__(self, channel, request):
self.channel = channel
self.request = request
self.response_headers = []
version = request.version
if version not in ('1.0', '1.1'):
# fall back to a version we support.
version = '1.0'
self.version = version
def service(self):
try:
try:
self.start()
self.execute()
self.finish()
except socket.error:
self.close_on_finish = True
if self.channel.adj.log_socket_errors:
raise
finally:
pass
def cancel(self):
self.close_on_finish = True
def defer(self):
pass
def build_response_header(self):
version = self.version
# Figure out whether the connection should be closed.
connection = self.request.headers.get('CONNECTION', '').lower()
response_headers = self.response_headers
content_length_header = None
date_header = None
server_header = None
connection_close_header = None
for i, (headername, headerval) in enumerate(response_headers):
headername = '-'.join(
[x.capitalize() for x in headername.split('-')]
)
if headername == 'Content-Length':
content_length_header = headerval
if headername == 'Date':
date_header = headerval
if headername == 'Server':
server_header = headerval
if headername == 'Connection':
connection_close_header = headerval.lower()
# replace with properly capitalized version
response_headers[i] = (headername, headerval)
if content_length_header is None and self.content_length is not None:
content_length_header = str(self.content_length)
self.response_headers.append(
('Content-Length', content_length_header)
)
def close_on_finish():
if connection_close_header is None:
response_headers.append(('Connection', 'close'))
self.close_on_finish = True
if version == '1.0':
if connection == 'keep-alive':
if not content_length_header:
close_on_finish()
else:
response_headers.append(('Connection', 'Keep-Alive'))
else:
close_on_finish()
elif version == '1.1':
if connection == 'close':
close_on_finish()
if not content_length_header:
response_headers.append(('Transfer-Encoding', 'chunked'))
self.chunked_response = True
if not self.close_on_finish:
close_on_finish()
# under HTTP 1.1 keep-alive is default, no need to set the header
else:
raise AssertionError('neither HTTP/1.0 or HTTP/1.1')
# Set the Server and Date field, if not yet specified. This is needed
# if the server is used as a proxy.
ident = self.channel.server.adj.ident
if not server_header:
response_headers.append(('Server', ident))
else:
response_headers.append(('Via', ident))
if not date_header:
response_headers.append(('Date', build_http_date(self.start_time)))
first_line = 'HTTP/%s %s' % (self.version, self.status)
# NB: sorting headers needs to preserve same-named-header order
# as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
# rely on stable sort to keep relative position of same-named headers
next_lines = ['%s: %s' % hv for hv in sorted(
self.response_headers, key=lambda x: x[0])]
lines = [first_line] + next_lines
res = '%s\r\n\r\n' % '\r\n'.join(lines)
return tobytes(res)
def remove_content_length_header(self):
for i, (header_name, header_value) in enumerate(self.response_headers):
if header_name.lower() == 'content-length':
del self.response_headers[i]
def start(self):
self.start_time = time.time()
def finish(self):
if not self.wrote_header:
self.write(b'')
if self.chunked_response:
# not self.write, it will chunk it!
self.channel.write_soon(b'0\r\n\r\n')
def write(self, data):
if not self.complete:
raise RuntimeError('start_response was not called before body '
'written')
channel = self.channel
if not self.wrote_header:
rh = self.build_response_header()
channel.write_soon(rh)
self.wrote_header = True
if data:
towrite = data
cl = self.content_length
if self.chunked_response:
# use chunked encoding response
towrite = tobytes(hex(len(data))[2:].upper()) + b'\r\n'
towrite += data + b'\r\n'
elif cl is not None:
towrite = data[:cl - self.content_bytes_written]
self.content_bytes_written += len(towrite)
if towrite != data and not self.logged_write_excess:
self.logger.warning(
'application-written content exceeded the number of '
'bytes specified by Content-Length header (%s)' % cl)
self.logged_write_excess = True
if towrite:
channel.write_soon(towrite)
class ErrorTask(Task):
""" An error task produces an error response
"""
complete = True
def execute(self):
e = self.request.error
body = '%s\r\n\r\n%s' % (e.reason, e.body)
tag = '\r\n\r\n(generated by waitress)'
body = body + tag
self.status = '%s %s' % (e.code, e.reason)
cl = len(body)
self.content_length = cl
self.response_headers.append(('Content-Length', str(cl)))
self.response_headers.append(('Content-Type', 'text/plain'))
if self.version == '1.1':
connection = self.request.headers.get('CONNECTION', '').lower()
if connection == 'close':
self.response_headers.append(('Connection', 'close'))
# under HTTP 1.1 keep-alive is default, no need to set the header
else:
# HTTP 1.0
self.response_headers.append(('Connection', 'close'))
self.close_on_finish = True
self.write(tobytes(body))
class WSGITask(Task):
"""A WSGI task produces a response from a WSGI application.
"""
environ = None
def execute(self):
env = self.get_environment()
def start_response(status, headers, exc_info=None):
if self.complete and not exc_info:
raise AssertionError("start_response called a second time "
"without providing exc_info.")
if exc_info:
try:
if self.wrote_header:
# higher levels will catch and handle raised exception:
# 1. "service" method in task.py
# 2. "service" method in channel.py
# 3. "handler_thread" method in task.py
reraise(exc_info[0], exc_info[1], exc_info[2])
else:
# As per WSGI spec existing headers must be cleared
self.response_headers = []
finally:
exc_info = None
self.complete = True
if not status.__class__ is str:
raise AssertionError('status %s is not a string' % status)
if '\n' in status or '\r' in status:
raise ValueError("carriage return/line "
"feed character present in status")
self.status = status
# Prepare the headers for output
for k, v in headers:
if not k.__class__ is str:
raise AssertionError(
'Header name %r is not a string in %r' % (k, (k, v))
)
if not v.__class__ is str:
raise AssertionError(
'Header value %r is not a string in %r' % (v, (k, v))
)
if '\n' in v or '\r' in v:
raise ValueError("carriage return/line "
"feed character present in header value")
if '\n' in k or '\r' in k:
raise ValueError("carriage return/line "
"feed character present in header name")
kl = k.lower()
if kl == 'content-length':
self.content_length = int(v)
elif kl in hop_by_hop:
raise AssertionError(
'%s is a "hop-by-hop" header; it cannot be used by '
'a WSGI application (see PEP 3333)' % k)
self.response_headers.extend(headers)
# Return a method used to write the response data.
return self.write
# Call the application to handle the request and write a response
app_iter = self.channel.server.application(env, start_response)
if app_iter.__class__ is ReadOnlyFileBasedBuffer:
# NB: do not put this inside the below try: finally: which closes
# the app_iter; we need to defer closing the underlying file. It's
# intention that we don't want to call ``close`` here if the
# app_iter is a ROFBB; the buffer (and therefore the file) will
# eventually be closed within channel.py's _flush_some or
# handle_close instead.
cl = self.content_length
size = app_iter.prepare(cl)
if size:
if cl != size:
if cl is not None:
self.remove_content_length_header()
self.content_length = size
self.write(b'') # generate headers
self.channel.write_soon(app_iter)
return
try:
first_chunk_len = None
for chunk in app_iter:
if first_chunk_len is None:
first_chunk_len = len(chunk)
# Set a Content-Length header if one is not supplied.
# start_response may not have been called until first
# iteration as per PEP, so we must reinterrogate
# self.content_length here
if self.content_length is None:
app_iter_len = None
if hasattr(app_iter, '__len__'):
app_iter_len = len(app_iter)
if app_iter_len == 1:
self.content_length = first_chunk_len
# transmit headers only after first iteration of the iterable
# that returns a non-empty bytestring (PEP 3333)
if chunk:
self.write(chunk)
cl = self.content_length
if cl is not None:
if self.content_bytes_written != cl:
# close the connection so the client isn't sitting around
# waiting for more data when there are too few bytes
# to service content-length
self.close_on_finish = True
if self.request.command != 'HEAD':
self.logger.warning(
'application returned too few bytes (%s) '
'for specified Content-Length (%s) via app_iter' % (
self.content_bytes_written, cl),
)
finally:
if hasattr(app_iter, 'close'):
app_iter.close()
def get_environment(self):
"""Returns a WSGI environment."""
environ = self.environ
if environ is not None:
# Return the cached copy.
return environ
request = self.request
path = request.path
channel = self.channel
server = channel.server
url_prefix = server.adj.url_prefix
if path.startswith('/'):
# strip extra slashes at the beginning of a path that starts
# with any number of slashes
path = '/' + path.lstrip('/')
if url_prefix:
# NB: url_prefix is guaranteed by the configuration machinery to
# be either the empty string or a string that starts with a single
# slash and ends without any slashes
if path == url_prefix:
# if the path is the same as the url prefix, the SCRIPT_NAME
# should be the url_prefix and PATH_INFO should be empty
path = ''
else:
# if the path starts with the url prefix plus a slash,
# the SCRIPT_NAME should be the url_prefix and PATH_INFO should
# the value of path from the slash until its end
url_prefix_with_trailing_slash = url_prefix + '/'
if path.startswith(url_prefix_with_trailing_slash):
path = path[len(url_prefix):]
environ = {}
environ['REQUEST_METHOD'] = request.command.upper()
environ['SERVER_PORT'] = str(server.effective_port)
environ['SERVER_NAME'] = server.server_name
environ['SERVER_SOFTWARE'] = server.adj.ident
environ['SERVER_PROTOCOL'] = 'HTTP/%s' % self.version
environ['SCRIPT_NAME'] = url_prefix
environ['PATH_INFO'] = path
environ['QUERY_STRING'] = request.query
host = environ['REMOTE_ADDR'] = channel.addr[0]
headers = dict(request.headers)
if host == server.adj.trusted_proxy:
wsgi_url_scheme = headers.pop('X_FORWARDED_PROTO',
request.url_scheme)
else:
wsgi_url_scheme = request.url_scheme
if wsgi_url_scheme not in ('http', 'https'):
raise ValueError('Invalid X_FORWARDED_PROTO value')
for key, value in headers.items():
value = value.strip()
mykey = rename_headers.get(key, None)
if mykey is None:
mykey = 'HTTP_%s' % key
if mykey not in environ:
environ[mykey] = value
# the following environment variables are required by the WSGI spec
environ['wsgi.version'] = (1, 0)
environ['wsgi.url_scheme'] = wsgi_url_scheme
environ['wsgi.errors'] = sys.stderr # apps should use the logging module
environ['wsgi.multithread'] = True
environ['wsgi.multiprocess'] = False
environ['wsgi.run_once'] = False
environ['wsgi.input'] = request.get_body_stream()
environ['wsgi.file_wrapper'] = ReadOnlyFileBasedBuffer
self.environ = environ
return environ