This commit is contained in:
Louis Vézina 2020-01-29 20:07:26 -05:00
parent 95b8aadb23
commit 83c95cc77d
81 changed files with 24572 additions and 0 deletions

616
libs/bs4/__init__.py Normal file
View file

@ -0,0 +1,616 @@
"""Beautiful Soup
Elixir and Tonic
"The Screen-Scraper's Friend"
http://www.crummy.com/software/BeautifulSoup/
Beautiful Soup uses a pluggable XML or HTML parser to parse a
(possibly invalid) document into a tree representation. Beautiful Soup
provides methods and Pythonic idioms that make it easy to navigate,
search, and modify the parse tree.
Beautiful Soup works with Python 2.7 and up. It works better if lxml
and/or html5lib is installed.
For more than you ever wanted to know about Beautiful Soup, see the
documentation:
http://www.crummy.com/software/BeautifulSoup/bs4/doc/
"""
__author__ = "Leonard Richardson (leonardr@segfault.org)"
__version__ = "4.8.0"
__copyright__ = "Copyright (c) 2004-2019 Leonard Richardson"
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
__all__ = ['BeautifulSoup']
import os
import re
import sys
import traceback
import warnings
from .builder import builder_registry, ParserRejectedMarkup
from .dammit import UnicodeDammit
from .element import (
CData,
Comment,
DEFAULT_OUTPUT_ENCODING,
Declaration,
Doctype,
NavigableString,
PageElement,
ProcessingInstruction,
ResultSet,
SoupStrainer,
Tag,
)
# The very first thing we do is give a useful error if someone is
# running this code under Python 3 without converting it.
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'!='You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
class BeautifulSoup(Tag):
"""
This class defines the basic interface called by the tree builders.
These methods will be called by the parser:
reset()
feed(markup)
The tree builder may call these methods from its feed() implementation:
handle_starttag(name, attrs) # See note about return value
handle_endtag(name)
handle_data(data) # Appends to the current data node
endData(containerClass=NavigableString) # Ends the current data node
No matter how complicated the underlying parser is, you should be
able to build a tree using 'start tag' events, 'end tag' events,
'data' events, and "done with data" events.
If you encounter an empty-element tag (aka a self-closing tag,
like HTML's <br> tag), call handle_starttag and then
handle_endtag.
"""
ROOT_TAG_NAME = '[document]'
# If the end-user gives no indication which tree builder they
# want, look for one with these features.
DEFAULT_BUILDER_FEATURES = ['html', 'fast']
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, pass the additional argument 'features=\"%(parser)s\"' to the BeautifulSoup constructor.\n"
def __init__(self, markup="", features=None, builder=None,
parse_only=None, from_encoding=None, exclude_encodings=None,
**kwargs):
"""Constructor.
:param markup: A string or a file-like object representing
markup to be parsed.
:param features: Desirable features of the parser to be used. This
may be the name of a specific parser ("lxml", "lxml-xml",
"html.parser", or "html5lib") or it may be the type of markup
to be used ("html", "html5", "xml"). It's recommended that you
name a specific parser, so that Beautiful Soup gives you the
same results across platforms and virtual environments.
:param builder: A TreeBuilder subclass to instantiate (or
instance to use) instead of looking one up based on
`features`. You only need to use this if you've implemented a
custom TreeBuilder.
:param parse_only: A SoupStrainer. Only parts of the document
matching the SoupStrainer will be considered. This is useful
when parsing part of a document that would otherwise be too
large to fit into memory.
:param from_encoding: A string indicating the encoding of the
document to be parsed. Pass this in if Beautiful Soup is
guessing wrongly about the document's encoding.
:param exclude_encodings: A list of strings indicating
encodings known to be wrong. Pass this in if you don't know
the document's encoding but you know Beautiful Soup's guess is
wrong.
:param kwargs: For backwards compatibility purposes, the
constructor accepts certain keyword arguments used in
Beautiful Soup 3. None of these arguments do anything in
Beautiful Soup 4; they will result in a warning and then be ignored.
Apart from this, any keyword arguments passed into the BeautifulSoup
constructor are propagated to the TreeBuilder constructor. This
makes it possible to configure a TreeBuilder beyond saying
which one to use.
"""
if 'convertEntities' in kwargs:
del kwargs['convertEntities']
warnings.warn(
"BS4 does not respect the convertEntities argument to the "
"BeautifulSoup constructor. Entities are always converted "
"to Unicode characters.")
if 'markupMassage' in kwargs:
del kwargs['markupMassage']
warnings.warn(
"BS4 does not respect the markupMassage argument to the "
"BeautifulSoup constructor. The tree builder is responsible "
"for any necessary markup massage.")
if 'smartQuotesTo' in kwargs:
del kwargs['smartQuotesTo']
warnings.warn(
"BS4 does not respect the smartQuotesTo argument to the "
"BeautifulSoup constructor. Smart quotes are always converted "
"to Unicode characters.")
if 'selfClosingTags' in kwargs:
del kwargs['selfClosingTags']
warnings.warn(
"BS4 does not respect the selfClosingTags argument to the "
"BeautifulSoup constructor. The tree builder is responsible "
"for understanding self-closing tags.")
if 'isHTML' in kwargs:
del kwargs['isHTML']
warnings.warn(
"BS4 does not respect the isHTML argument to the "
"BeautifulSoup constructor. Suggest you use "
"features='lxml' for HTML and features='lxml-xml' for "
"XML.")
def deprecated_argument(old_name, new_name):
if old_name in kwargs:
warnings.warn(
'The "%s" argument to the BeautifulSoup constructor '
'has been renamed to "%s."' % (old_name, new_name))
value = kwargs[old_name]
del kwargs[old_name]
return value
return None
parse_only = parse_only or deprecated_argument(
"parseOnlyThese", "parse_only")
from_encoding = from_encoding or deprecated_argument(
"fromEncoding", "from_encoding")
if from_encoding and isinstance(markup, str):
warnings.warn("You provided Unicode markup but also provided a value for from_encoding. Your from_encoding will be ignored.")
from_encoding = None
# We need this information to track whether or not the builder
# was specified well enough that we can omit the 'you need to
# specify a parser' warning.
original_builder = builder
original_features = features
if isinstance(builder, type):
# A builder class was passed in; it needs to be instantiated.
builder_class = builder
builder = None
elif builder is None:
if isinstance(features, str):
features = [features]
if features is None or len(features) == 0:
features = self.DEFAULT_BUILDER_FEATURES
builder_class = builder_registry.lookup(*features)
if builder_class is None:
raise FeatureNotFound(
"Couldn't find a tree builder with the features you "
"requested: %s. Do you need to install a parser library?"
% ",".join(features))
# At this point either we have a TreeBuilder instance in
# builder, or we have a builder_class that we can instantiate
# with the remaining **kwargs.
if builder is None:
builder = builder_class(**kwargs)
if not original_builder and not (
original_features == builder.NAME or
original_features in builder.ALTERNATE_NAMES
):
if builder.is_xml:
markup_type = "XML"
else:
markup_type = "HTML"
# This code adapted from warnings.py so that we get the same line
# of code as our warnings.warn() call gets, even if the answer is wrong
# (as it may be in a multithreading situation).
caller = None
try:
caller = sys._getframe(1)
except ValueError:
pass
if caller:
globals = caller.f_globals
line_number = caller.f_lineno
else:
globals = sys.__dict__
line_number= 1
filename = globals.get('__file__')
if filename:
fnl = filename.lower()
if fnl.endswith((".pyc", ".pyo")):
filename = filename[:-1]
if filename:
# If there is no filename at all, the user is most likely in a REPL,
# and the warning is not necessary.
values = dict(
filename=filename,
line_number=line_number,
parser=builder.NAME,
markup_type=markup_type
)
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % values, stacklevel=2)
else:
if kwargs:
warnings.warn("Keyword arguments to the BeautifulSoup constructor will be ignored. These would normally be passed into the TreeBuilder constructor, but a TreeBuilder instance was passed in as `builder`.")
self.builder = builder
self.is_xml = builder.is_xml
self.known_xml = self.is_xml
self._namespaces = dict()
self.parse_only = parse_only
self.builder.initialize_soup(self)
if hasattr(markup, 'read'): # It's a file-type object.
markup = markup.read()
elif len(markup) <= 256 and (
(isinstance(markup, bytes) and not b'<' in markup)
or (isinstance(markup, str) and not '<' in markup)
):
# Print out warnings for a couple beginner problems
# involving passing non-markup to Beautiful Soup.
# Beautiful Soup will still parse the input as markup,
# just in case that's what the user really wants.
if (isinstance(markup, str)
and not os.path.supports_unicode_filenames):
possible_filename = markup.encode("utf8")
else:
possible_filename = markup
is_file = False
try:
is_file = os.path.exists(possible_filename)
except Exception as e:
# This is almost certainly a problem involving
# characters not valid in filenames on this
# system. Just let it go.
pass
if is_file:
if isinstance(markup, str):
markup = markup.encode("utf8")
warnings.warn(
'"%s" looks like a filename, not markup. You should'
' probably open this file and pass the filehandle into'
' Beautiful Soup.' % markup)
self._check_markup_is_url(markup)
for (self.markup, self.original_encoding, self.declared_html_encoding,
self.contains_replacement_characters) in (
self.builder.prepare_markup(
markup, from_encoding, exclude_encodings=exclude_encodings)):
self.reset()
try:
self._feed()
break
except ParserRejectedMarkup:
pass
# Clear out the markup and remove the builder's circular
# reference to this object.
self.markup = None
self.builder.soup = None
def __copy__(self):
copy = type(self)(
self.encode('utf-8'), builder=self.builder, from_encoding='utf-8'
)
# Although we encoded the tree to UTF-8, that may not have
# been the encoding of the original markup. Set the copy's
# .original_encoding to reflect the original object's
# .original_encoding.
copy.original_encoding = self.original_encoding
return copy
def __getstate__(self):
# Frequently a tree builder can't be pickled.
d = dict(self.__dict__)
if 'builder' in d and not self.builder.picklable:
d['builder'] = None
return d
@staticmethod
def _check_markup_is_url(markup):
"""
Check if markup looks like it's actually a url and raise a warning
if so. Markup can be unicode or str (py2) / bytes (py3).
"""
if isinstance(markup, bytes):
space = b' '
cant_start_with = (b"http:", b"https:")
elif isinstance(markup, str):
space = ' '
cant_start_with = ("http:", "https:")
else:
return
if any(markup.startswith(prefix) for prefix in cant_start_with):
if not space in markup:
if isinstance(markup, bytes):
decoded_markup = markup.decode('utf-8', 'replace')
else:
decoded_markup = markup
warnings.warn(
'"%s" looks like a URL. Beautiful Soup is not an'
' HTTP client. You should probably use an HTTP client like'
' requests to get the document behind the URL, and feed'
' that document to Beautiful Soup.' % decoded_markup
)
def _feed(self):
# Convert the document to Unicode.
self.builder.reset()
self.builder.feed(self.markup)
# Close out any unfinished strings and close all the open tags.
self.endData()
while self.currentTag.name != self.ROOT_TAG_NAME:
self.popTag()
def reset(self):
Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
self.hidden = 1
self.builder.reset()
self.current_data = []
self.currentTag = None
self.tagStack = []
self.preserve_whitespace_tag_stack = []
self.pushTag(self)
def new_tag(self, name, namespace=None, nsprefix=None, attrs={}, **kwattrs):
"""Create a new tag associated with this soup."""
kwattrs.update(attrs)
return Tag(None, self.builder, name, namespace, nsprefix, kwattrs)
def new_string(self, s, subclass=NavigableString):
"""Create a new NavigableString associated with this soup."""
return subclass(s)
def insert_before(self, successor):
raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
def insert_after(self, successor):
raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
def popTag(self):
tag = self.tagStack.pop()
if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
self.preserve_whitespace_tag_stack.pop()
#print "Pop", tag.name
if self.tagStack:
self.currentTag = self.tagStack[-1]
return self.currentTag
def pushTag(self, tag):
#print "Push", tag.name
if self.currentTag is not None:
self.currentTag.contents.append(tag)
self.tagStack.append(tag)
self.currentTag = self.tagStack[-1]
if tag.name in self.builder.preserve_whitespace_tags:
self.preserve_whitespace_tag_stack.append(tag)
def endData(self, containerClass=NavigableString):
if self.current_data:
current_data = ''.join(self.current_data)
# If whitespace is not preserved, and this string contains
# nothing but ASCII spaces, replace it with a single space
# or newline.
if not self.preserve_whitespace_tag_stack:
strippable = True
for i in current_data:
if i not in self.ASCII_SPACES:
strippable = False
break
if strippable:
if '\n' in current_data:
current_data = '\n'
else:
current_data = ' '
# Reset the data collector.
self.current_data = []
# Should we add this string to the tree at all?
if self.parse_only and len(self.tagStack) <= 1 and \
(not self.parse_only.text or \
not self.parse_only.search(current_data)):
return
o = containerClass(current_data)
self.object_was_parsed(o)
def object_was_parsed(self, o, parent=None, most_recent_element=None):
"""Add an object to the parse tree."""
if parent is None:
parent = self.currentTag
if most_recent_element is not None:
previous_element = most_recent_element
else:
previous_element = self._most_recent_element
next_element = previous_sibling = next_sibling = None
if isinstance(o, Tag):
next_element = o.next_element
next_sibling = o.next_sibling
previous_sibling = o.previous_sibling
if previous_element is None:
previous_element = o.previous_element
fix = parent.next_element is not None
o.setup(parent, previous_element, next_element, previous_sibling, next_sibling)
self._most_recent_element = o
parent.contents.append(o)
# Check if we are inserting into an already parsed node.
if fix:
self._linkage_fixer(parent)
def _linkage_fixer(self, el):
"""Make sure linkage of this fragment is sound."""
first = el.contents[0]
child = el.contents[-1]
descendant = child
if child is first and el.parent is not None:
# Parent should be linked to first child
el.next_element = child
# We are no longer linked to whatever this element is
prev_el = child.previous_element
if prev_el is not None and prev_el is not el:
prev_el.next_element = None
# First child should be linked to the parent, and no previous siblings.
child.previous_element = el
child.previous_sibling = None
# We have no sibling as we've been appended as the last.
child.next_sibling = None
# This index is a tag, dig deeper for a "last descendant"
if isinstance(child, Tag) and child.contents:
descendant = child._last_descendant(False)
# As the final step, link last descendant. It should be linked
# to the parent's next sibling (if found), else walk up the chain
# and find a parent with a sibling. It should have no next sibling.
descendant.next_element = None
descendant.next_sibling = None
target = el
while True:
if target is None:
break
elif target.next_sibling is not None:
descendant.next_element = target.next_sibling
target.next_sibling.previous_element = child
break
target = target.parent
def _popToTag(self, name, nsprefix=None, inclusivePop=True):
"""Pops the tag stack up to and including the most recent
instance of the given tag. If inclusivePop is false, pops the tag
stack up to but *not* including the most recent instqance of
the given tag."""
#print "Popping to %s" % name
if name == self.ROOT_TAG_NAME:
# The BeautifulSoup object itself can never be popped.
return
most_recently_popped = None
stack_size = len(self.tagStack)
for i in range(stack_size - 1, 0, -1):
t = self.tagStack[i]
if (name == t.name and nsprefix == t.prefix):
if inclusivePop:
most_recently_popped = self.popTag()
break
most_recently_popped = self.popTag()
return most_recently_popped
def handle_starttag(self, name, namespace, nsprefix, attrs):
"""Push a start tag on to the stack.
If this method returns None, the tag was rejected by the
SoupStrainer. You should proceed as if the tag had not occurred
in the document. For instance, if this was a self-closing tag,
don't call handle_endtag.
"""
# print "Start tag %s: %s" % (name, attrs)
self.endData()
if (self.parse_only and len(self.tagStack) <= 1
and (self.parse_only.text
or not self.parse_only.search_tag(name, attrs))):
return None
tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
self.currentTag, self._most_recent_element)
if tag is None:
return tag
if self._most_recent_element is not None:
self._most_recent_element.next_element = tag
self._most_recent_element = tag
self.pushTag(tag)
return tag
def handle_endtag(self, name, nsprefix=None):
#print "End tag: " + name
self.endData()
self._popToTag(name, nsprefix)
def handle_data(self, data):
self.current_data.append(data)
def decode(self, pretty_print=False,
eventual_encoding=DEFAULT_OUTPUT_ENCODING,
formatter="minimal"):
"""Returns a string or Unicode representation of this document.
To get Unicode, pass None for encoding."""
if self.is_xml:
# Print the XML declaration
encoding_part = ''
if eventual_encoding != None:
encoding_part = ' encoding="%s"' % eventual_encoding
prefix = '<?xml version="1.0"%s?>\n' % encoding_part
else:
prefix = ''
if not pretty_print:
indent_level = None
else:
indent_level = 0
return prefix + super(BeautifulSoup, self).decode(
indent_level, eventual_encoding, formatter)
# Alias to make it easier to type import: 'from bs4 import _soup'
_s = BeautifulSoup
_soup = BeautifulSoup
class BeautifulStoneSoup(BeautifulSoup):
"""Deprecated interface to an XML parser."""
def __init__(self, *args, **kwargs):
kwargs['features'] = 'xml'
warnings.warn(
'The BeautifulStoneSoup class is deprecated. Instead of using '
'it, pass features="xml" into the BeautifulSoup constructor.')
super(BeautifulStoneSoup, self).__init__(*args, **kwargs)
class StopParsing(Exception):
pass
class FeatureNotFound(ValueError):
pass
#By default, act as an HTML pretty-printer.
if __name__ == '__main__':
import sys
soup = BeautifulSoup(sys.stdin)
print(soup.prettify())

View file

@ -0,0 +1,367 @@
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
from collections import defaultdict
import itertools
import sys
from bs4.element import (
CharsetMetaAttributeValue,
ContentMetaAttributeValue,
nonwhitespace_re
)
__all__ = [
'HTMLTreeBuilder',
'SAXTreeBuilder',
'TreeBuilder',
'TreeBuilderRegistry',
]
# Some useful features for a TreeBuilder to have.
FAST = 'fast'
PERMISSIVE = 'permissive'
STRICT = 'strict'
XML = 'xml'
HTML = 'html'
HTML_5 = 'html5'
class TreeBuilderRegistry(object):
def __init__(self):
self.builders_for_feature = defaultdict(list)
self.builders = []
def register(self, treebuilder_class):
"""Register a treebuilder based on its advertised features."""
for feature in treebuilder_class.features:
self.builders_for_feature[feature].insert(0, treebuilder_class)
self.builders.insert(0, treebuilder_class)
def lookup(self, *features):
if len(self.builders) == 0:
# There are no builders at all.
return None
if len(features) == 0:
# They didn't ask for any features. Give them the most
# recently registered builder.
return self.builders[0]
# Go down the list of features in order, and eliminate any builders
# that don't match every feature.
features = list(features)
features.reverse()
candidates = None
candidate_set = None
while len(features) > 0:
feature = features.pop()
we_have_the_feature = self.builders_for_feature.get(feature, [])
if len(we_have_the_feature) > 0:
if candidates is None:
candidates = we_have_the_feature
candidate_set = set(candidates)
else:
# Eliminate any candidates that don't have this feature.
candidate_set = candidate_set.intersection(
set(we_have_the_feature))
# The only valid candidates are the ones in candidate_set.
# Go through the original list of candidates and pick the first one
# that's in candidate_set.
if candidate_set is None:
return None
for candidate in candidates:
if candidate in candidate_set:
return candidate
return None
# The BeautifulSoup class will take feature lists from developers and use them
# to look up builders in this registry.
builder_registry = TreeBuilderRegistry()
class TreeBuilder(object):
"""Turn a document into a Beautiful Soup object tree."""
NAME = "[Unknown tree builder]"
ALTERNATE_NAMES = []
features = []
is_xml = False
picklable = False
empty_element_tags = None # A tag will be considered an empty-element
# tag when and only when it has no contents.
# A value for these tag/attribute combinations is a space- or
# comma-separated list of CDATA, rather than a single CDATA.
DEFAULT_CDATA_LIST_ATTRIBUTES = {}
DEFAULT_PRESERVE_WHITESPACE_TAGS = set()
USE_DEFAULT = object()
def __init__(self, multi_valued_attributes=USE_DEFAULT, preserve_whitespace_tags=USE_DEFAULT):
"""Constructor.
:param multi_valued_attributes: If this is set to None, the
TreeBuilder will not turn any values for attributes like
'class' into lists. Setting this do a dictionary will
customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES
for an example.
Internally, these are called "CDATA list attributes", but that
probably doesn't make sense to an end-user, so the argument name
is `multi_valued_attributes`.
:param preserve_whitespace_tags:
"""
self.soup = None
if multi_valued_attributes is self.USE_DEFAULT:
multi_valued_attributes = self.DEFAULT_CDATA_LIST_ATTRIBUTES
self.cdata_list_attributes = multi_valued_attributes
if preserve_whitespace_tags is self.USE_DEFAULT:
preserve_whitespace_tags = self.DEFAULT_PRESERVE_WHITESPACE_TAGS
self.preserve_whitespace_tags = preserve_whitespace_tags
def initialize_soup(self, soup):
"""The BeautifulSoup object has been initialized and is now
being associated with the TreeBuilder.
"""
self.soup = soup
def reset(self):
pass
def can_be_empty_element(self, tag_name):
"""Might a tag with this name be an empty-element tag?
The final markup may or may not actually present this tag as
self-closing.
For instance: an HTMLBuilder does not consider a <p> tag to be
an empty-element tag (it's not in
HTMLBuilder.empty_element_tags). This means an empty <p> tag
will be presented as "<p></p>", not "<p />".
The default implementation has no opinion about which tags are
empty-element tags, so a tag will be presented as an
empty-element tag if and only if it has no contents.
"<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
be left alone.
"""
if self.empty_element_tags is None:
return True
return tag_name in self.empty_element_tags
def feed(self, markup):
raise NotImplementedError()
def prepare_markup(self, markup, user_specified_encoding=None,
document_declared_encoding=None):
return markup, None, None, False
def test_fragment_to_document(self, fragment):
"""Wrap an HTML fragment to make it look like a document.
Different parsers do this differently. For instance, lxml
introduces an empty <head> tag, and html5lib
doesn't. Abstracting this away lets us write simple tests
which run HTML fragments through the parser and compare the
results against other HTML fragments.
This method should not be used outside of tests.
"""
return fragment
def set_up_substitutions(self, tag):
return False
def _replace_cdata_list_attribute_values(self, tag_name, attrs):
"""Replaces class="foo bar" with class=["foo", "bar"]
Modifies its input in place.
"""
if not attrs:
return attrs
if self.cdata_list_attributes:
universal = self.cdata_list_attributes.get('*', [])
tag_specific = self.cdata_list_attributes.get(
tag_name.lower(), None)
for attr in list(attrs.keys()):
if attr in universal or (tag_specific and attr in tag_specific):
# We have a "class"-type attribute whose string
# value is a whitespace-separated list of
# values. Split it into a list.
value = attrs[attr]
if isinstance(value, str):
values = nonwhitespace_re.findall(value)
else:
# html5lib sometimes calls setAttributes twice
# for the same tag when rearranging the parse
# tree. On the second call the attribute value
# here is already a list. If this happens,
# leave the value alone rather than trying to
# split it again.
values = value
attrs[attr] = values
return attrs
class SAXTreeBuilder(TreeBuilder):
"""A Beautiful Soup treebuilder that listens for SAX events."""
def feed(self, markup):
raise NotImplementedError()
def close(self):
pass
def startElement(self, name, attrs):
attrs = dict((key[1], value) for key, value in list(attrs.items()))
#print "Start %s, %r" % (name, attrs)
self.soup.handle_starttag(name, attrs)
def endElement(self, name):
#print "End %s" % name
self.soup.handle_endtag(name)
def startElementNS(self, nsTuple, nodeName, attrs):
# Throw away (ns, nodeName) for now.
self.startElement(nodeName, attrs)
def endElementNS(self, nsTuple, nodeName):
# Throw away (ns, nodeName) for now.
self.endElement(nodeName)
#handler.endElementNS((ns, node.nodeName), node.nodeName)
def startPrefixMapping(self, prefix, nodeValue):
# Ignore the prefix for now.
pass
def endPrefixMapping(self, prefix):
# Ignore the prefix for now.
# handler.endPrefixMapping(prefix)
pass
def characters(self, content):
self.soup.handle_data(content)
def startDocument(self):
pass
def endDocument(self):
pass
class HTMLTreeBuilder(TreeBuilder):
"""This TreeBuilder knows facts about HTML.
Such as which tags are empty-element tags.
"""
empty_element_tags = set([
# These are from HTML5.
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr',
# These are from earlier versions of HTML and are removed in HTML5.
'basefont', 'bgsound', 'command', 'frame', 'image', 'isindex', 'nextid', 'spacer'
])
# The HTML standard defines these as block-level elements. Beautiful
# Soup does not treat these elements differently from other elements,
# but it may do so eventually, and this information is available if
# you need to use it.
block_elements = set(["address", "article", "aside", "blockquote", "canvas", "dd", "div", "dl", "dt", "fieldset", "figcaption", "figure", "footer", "form", "h1", "h2", "h3", "h4", "h5", "h6", "header", "hr", "li", "main", "nav", "noscript", "ol", "output", "p", "pre", "section", "table", "tfoot", "ul", "video"])
# The HTML standard defines these attributes as containing a
# space-separated list of values, not a single value. That is,
# class="foo bar" means that the 'class' attribute has two values,
# 'foo' and 'bar', not the single value 'foo bar'. When we
# encounter one of these attributes, we will parse its value into
# a list of values if possible. Upon output, the list will be
# converted back into a string.
DEFAULT_CDATA_LIST_ATTRIBUTES = {
"*" : ['class', 'accesskey', 'dropzone'],
"a" : ['rel', 'rev'],
"link" : ['rel', 'rev'],
"td" : ["headers"],
"th" : ["headers"],
"td" : ["headers"],
"form" : ["accept-charset"],
"object" : ["archive"],
# These are HTML5 specific, as are *.accesskey and *.dropzone above.
"area" : ["rel"],
"icon" : ["sizes"],
"iframe" : ["sandbox"],
"output" : ["for"],
}
DEFAULT_PRESERVE_WHITESPACE_TAGS = set(['pre', 'textarea'])
def set_up_substitutions(self, tag):
# We are only interested in <meta> tags
if tag.name != 'meta':
return False
http_equiv = tag.get('http-equiv')
content = tag.get('content')
charset = tag.get('charset')
# We are interested in <meta> tags that say what encoding the
# document was originally in. This means HTML 5-style <meta>
# tags that provide the "charset" attribute. It also means
# HTML 4-style <meta> tags that provide the "content"
# attribute and have "http-equiv" set to "content-type".
#
# In both cases we will replace the value of the appropriate
# attribute with a standin object that can take on any
# encoding.
meta_encoding = None
if charset is not None:
# HTML 5 style:
# <meta charset="utf8">
meta_encoding = charset
tag['charset'] = CharsetMetaAttributeValue(charset)
elif (content is not None and http_equiv is not None
and http_equiv.lower() == 'content-type'):
# HTML 4 style:
# <meta http-equiv="content-type" content="text/html; charset=utf8">
tag['content'] = ContentMetaAttributeValue(content)
return (meta_encoding is not None)
def register_treebuilders_from(module):
"""Copy TreeBuilders from the given module into this module."""
# I'm fairly sure this is not the best way to do this.
this_module = sys.modules['bs4.builder']
for name in module.__all__:
obj = getattr(module, name)
if issubclass(obj, TreeBuilder):
setattr(this_module, name, obj)
this_module.__all__.append(name)
# Register the builder while we're at it.
this_module.builder_registry.register(obj)
class ParserRejectedMarkup(Exception):
pass
# Builders are registered in reverse order of priority, so that custom
# builder registrations will take precedence. In general, we want lxml
# to take precedence over html5lib, because it's faster. And we only
# want to use HTMLParser as a last result.
from . import _htmlparser
register_treebuilders_from(_htmlparser)
try:
from . import _html5lib
register_treebuilders_from(_html5lib)
except ImportError:
# They don't have html5lib installed.
pass
try:
from . import _lxml
register_treebuilders_from(_lxml)
except ImportError:
# They don't have lxml installed.
pass

View file

@ -0,0 +1,426 @@
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
__all__ = [
'HTML5TreeBuilder',
]
import warnings
import re
from bs4.builder import (
PERMISSIVE,
HTML,
HTML_5,
HTMLTreeBuilder,
)
from bs4.element import (
NamespacedAttribute,
nonwhitespace_re,
)
import html5lib
from html5lib.constants import (
namespaces,
prefixes,
)
from bs4.element import (
Comment,
Doctype,
NavigableString,
Tag,
)
try:
# Pre-0.99999999
from html5lib.treebuilders import _base as treebuilder_base
new_html5lib = False
except ImportError as e:
# 0.99999999 and up
from html5lib.treebuilders import base as treebuilder_base
new_html5lib = True
class HTML5TreeBuilder(HTMLTreeBuilder):
"""Use html5lib to build a tree."""
NAME = "html5lib"
features = [NAME, PERMISSIVE, HTML_5, HTML]
def prepare_markup(self, markup, user_specified_encoding,
document_declared_encoding=None, exclude_encodings=None):
# Store the user-specified encoding for use later on.
self.user_specified_encoding = user_specified_encoding
# document_declared_encoding and exclude_encodings aren't used
# ATM because the html5lib TreeBuilder doesn't use
# UnicodeDammit.
if exclude_encodings:
warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.")
yield (markup, None, None, False)
# These methods are defined by Beautiful Soup.
def feed(self, markup):
if self.soup.parse_only is not None:
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
extra_kwargs = dict()
if not isinstance(markup, str):
if new_html5lib:
extra_kwargs['override_encoding'] = self.user_specified_encoding
else:
extra_kwargs['encoding'] = self.user_specified_encoding
doc = parser.parse(markup, **extra_kwargs)
# Set the character encoding detected by the tokenizer.
if isinstance(markup, str):
# We need to special-case this because html5lib sets
# charEncoding to UTF-8 if it gets Unicode input.
doc.original_encoding = None
else:
original_encoding = parser.tokenizer.stream.charEncoding[0]
if not isinstance(original_encoding, str):
# In 0.99999999 and up, the encoding is an html5lib
# Encoding object. We want to use a string for compatibility
# with other tree builders.
original_encoding = original_encoding.name
doc.original_encoding = original_encoding
def create_treebuilder(self, namespaceHTMLElements):
self.underlying_builder = TreeBuilderForHtml5lib(
namespaceHTMLElements, self.soup)
return self.underlying_builder
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return '<html><head></head><body>%s</body></html>' % fragment
class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
def __init__(self, namespaceHTMLElements, soup=None):
if soup:
self.soup = soup
else:
from bs4 import BeautifulSoup
self.soup = BeautifulSoup("", "html.parser")
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
def documentClass(self):
self.soup.reset()
return Element(self.soup, self.soup, None)
def insertDoctype(self, token):
name = token["name"]
publicId = token["publicId"]
systemId = token["systemId"]
doctype = Doctype.for_name_and_ids(name, publicId, systemId)
self.soup.object_was_parsed(doctype)
def elementClass(self, name, namespace):
tag = self.soup.new_tag(name, namespace)
return Element(tag, self.soup, namespace)
def commentClass(self, data):
return TextNode(Comment(data), self.soup)
def fragmentClass(self):
from bs4 import BeautifulSoup
self.soup = BeautifulSoup("", "html.parser")
self.soup.name = "[document_fragment]"
return Element(self.soup, self.soup, None)
def appendChild(self, node):
# XXX This code is not covered by the BS4 tests.
self.soup.append(node.element)
def getDocument(self):
return self.soup
def getFragment(self):
return treebuilder_base.TreeBuilder.getFragment(self).element
def testSerializer(self, element):
from bs4 import BeautifulSoup
rv = []
doctype_re = re.compile(r'^(.*?)(?: PUBLIC "(.*?)"(?: "(.*?)")?| SYSTEM "(.*?)")?$')
def serializeElement(element, indent=0):
if isinstance(element, BeautifulSoup):
pass
if isinstance(element, Doctype):
m = doctype_re.match(element)
if m:
name = m.group(1)
if m.lastindex > 1:
publicId = m.group(2) or ""
systemId = m.group(3) or m.group(4) or ""
rv.append("""|%s<!DOCTYPE %s "%s" "%s">""" %
(' ' * indent, name, publicId, systemId))
else:
rv.append("|%s<!DOCTYPE %s>" % (' ' * indent, name))
else:
rv.append("|%s<!DOCTYPE >" % (' ' * indent,))
elif isinstance(element, Comment):
rv.append("|%s<!-- %s -->" % (' ' * indent, element))
elif isinstance(element, NavigableString):
rv.append("|%s\"%s\"" % (' ' * indent, element))
else:
if element.namespace:
name = "%s %s" % (prefixes[element.namespace],
element.name)
else:
name = element.name
rv.append("|%s<%s>" % (' ' * indent, name))
if element.attrs:
attributes = []
for name, value in list(element.attrs.items()):
if isinstance(name, NamespacedAttribute):
name = "%s %s" % (prefixes[name.namespace], name.name)
if isinstance(value, list):
value = " ".join(value)
attributes.append((name, value))
for name, value in sorted(attributes):
rv.append('|%s%s="%s"' % (' ' * (indent + 2), name, value))
indent += 2
for child in element.children:
serializeElement(child, indent)
serializeElement(element, 0)
return "\n".join(rv)
class AttrList(object):
def __init__(self, element):
self.element = element
self.attrs = dict(self.element.attrs)
def __iter__(self):
return list(self.attrs.items()).__iter__()
def __setitem__(self, name, value):
# If this attribute is a multi-valued attribute for this element,
# turn its value into a list.
list_attr = self.element.cdata_list_attributes
if (name in list_attr['*']
or (self.element.name in list_attr
and name in list_attr[self.element.name])):
# A node that is being cloned may have already undergone
# this procedure.
if not isinstance(value, list):
value = nonwhitespace_re.findall(value)
self.element[name] = value
def items(self):
return list(self.attrs.items())
def keys(self):
return list(self.attrs.keys())
def __len__(self):
return len(self.attrs)
def __getitem__(self, name):
return self.attrs[name]
def __contains__(self, name):
return name in list(self.attrs.keys())
class Element(treebuilder_base.Node):
def __init__(self, element, soup, namespace):
treebuilder_base.Node.__init__(self, element.name)
self.element = element
self.soup = soup
self.namespace = namespace
def appendChild(self, node):
string_child = child = None
if isinstance(node, str):
# Some other piece of code decided to pass in a string
# instead of creating a TextElement object to contain the
# string.
string_child = child = node
elif isinstance(node, Tag):
# Some other piece of code decided to pass in a Tag
# instead of creating an Element object to contain the
# Tag.
child = node
elif node.element.__class__ == NavigableString:
string_child = child = node.element
node.parent = self
else:
child = node.element
node.parent = self
if not isinstance(child, str) and child.parent is not None:
node.element.extract()
if (string_child is not None and self.element.contents
and self.element.contents[-1].__class__ == NavigableString):
# We are appending a string onto another string.
# TODO This has O(n^2) performance, for input like
# "a</a>a</a>a</a>..."
old_element = self.element.contents[-1]
new_element = self.soup.new_string(old_element + string_child)
old_element.replace_with(new_element)
self.soup._most_recent_element = new_element
else:
if isinstance(node, str):
# Create a brand new NavigableString from this string.
child = self.soup.new_string(node)
# Tell Beautiful Soup to act as if it parsed this element
# immediately after the parent's last descendant. (Or
# immediately after the parent, if it has no children.)
if self.element.contents:
most_recent_element = self.element._last_descendant(False)
elif self.element.next_element is not None:
# Something from further ahead in the parse tree is
# being inserted into this earlier element. This is
# very annoying because it means an expensive search
# for the last element in the tree.
most_recent_element = self.soup._last_descendant()
else:
most_recent_element = self.element
self.soup.object_was_parsed(
child, parent=self.element,
most_recent_element=most_recent_element)
def getAttributes(self):
if isinstance(self.element, Comment):
return {}
return AttrList(self.element)
def setAttributes(self, attributes):
if attributes is not None and len(attributes) > 0:
converted_attributes = []
for name, value in list(attributes.items()):
if isinstance(name, tuple):
new_name = NamespacedAttribute(*name)
del attributes[name]
attributes[new_name] = value
self.soup.builder._replace_cdata_list_attribute_values(
self.name, attributes)
for name, value in list(attributes.items()):
self.element[name] = value
# The attributes may contain variables that need substitution.
# Call set_up_substitutions manually.
#
# The Tag constructor called this method when the Tag was created,
# but we just set/changed the attributes, so call it again.
self.soup.builder.set_up_substitutions(self.element)
attributes = property(getAttributes, setAttributes)
def insertText(self, data, insertBefore=None):
text = TextNode(self.soup.new_string(data), self.soup)
if insertBefore:
self.insertBefore(text, insertBefore)
else:
self.appendChild(text)
def insertBefore(self, node, refNode):
index = self.element.index(refNode.element)
if (node.element.__class__ == NavigableString and self.element.contents
and self.element.contents[index-1].__class__ == NavigableString):
# (See comments in appendChild)
old_node = self.element.contents[index-1]
new_str = self.soup.new_string(old_node + node.element)
old_node.replace_with(new_str)
else:
self.element.insert(index, node.element)
node.parent = self
def removeChild(self, node):
node.element.extract()
def reparentChildren(self, new_parent):
"""Move all of this tag's children into another tag."""
# print "MOVE", self.element.contents
# print "FROM", self.element
# print "TO", new_parent.element
element = self.element
new_parent_element = new_parent.element
# Determine what this tag's next_element will be once all the children
# are removed.
final_next_element = element.next_sibling
new_parents_last_descendant = new_parent_element._last_descendant(False, False)
if len(new_parent_element.contents) > 0:
# The new parent already contains children. We will be
# appending this tag's children to the end.
new_parents_last_child = new_parent_element.contents[-1]
new_parents_last_descendant_next_element = new_parents_last_descendant.next_element
else:
# The new parent contains no children.
new_parents_last_child = None
new_parents_last_descendant_next_element = new_parent_element.next_element
to_append = element.contents
if len(to_append) > 0:
# Set the first child's previous_element and previous_sibling
# to elements within the new parent
first_child = to_append[0]
if new_parents_last_descendant is not None:
first_child.previous_element = new_parents_last_descendant
else:
first_child.previous_element = new_parent_element
first_child.previous_sibling = new_parents_last_child
if new_parents_last_descendant is not None:
new_parents_last_descendant.next_element = first_child
else:
new_parent_element.next_element = first_child
if new_parents_last_child is not None:
new_parents_last_child.next_sibling = first_child
# Find the very last element being moved. It is now the
# parent's last descendant. It has no .next_sibling and
# its .next_element is whatever the previous last
# descendant had.
last_childs_last_descendant = to_append[-1]._last_descendant(False, True)
last_childs_last_descendant.next_element = new_parents_last_descendant_next_element
if new_parents_last_descendant_next_element is not None:
# TODO: This code has no test coverage and I'm not sure
# how to get html5lib to go through this path, but it's
# just the other side of the previous line.
new_parents_last_descendant_next_element.previous_element = last_childs_last_descendant
last_childs_last_descendant.next_sibling = None
for child in to_append:
child.parent = new_parent_element
new_parent_element.contents.append(child)
# Now that this element has no children, change its .next_element.
element.contents = []
element.next_element = final_next_element
# print "DONE WITH MOVE"
# print "FROM", self.element
# print "TO", new_parent_element
def cloneNode(self):
tag = self.soup.new_tag(self.element.name, self.namespace)
node = Element(tag, self.soup, self.namespace)
for key,value in self.attributes:
node.attributes[key] = value
return node
def hasContent(self):
return self.element.contents
def getNameTuple(self):
if self.namespace == None:
return namespaces["html"], self.name
else:
return self.namespace, self.name
nameTuple = property(getNameTuple)
class TextNode(Element):
def __init__(self, element, soup):
treebuilder_base.Node.__init__(self, None)
self.element = element
self.soup = soup
def cloneNode(self):
raise NotImplementedError

View file

@ -0,0 +1,350 @@
# encoding: utf-8
"""Use the HTMLParser library to parse HTML files that aren't too bad."""
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
__all__ = [
'HTMLParserTreeBuilder',
]
from html.parser import HTMLParser
try:
from html.parser import HTMLParseError
except ImportError as e:
# HTMLParseError is removed in Python 3.5. Since it can never be
# thrown in 3.5, we can just define our own class as a placeholder.
class HTMLParseError(Exception):
pass
import sys
import warnings
# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
# argument, which we'd like to set to False. Unfortunately,
# http://bugs.python.org/issue13273 makes strict=True a better bet
# before Python 3.2.3.
#
# At the end of this file, we monkeypatch HTMLParser so that
# strict=True works well on Python 3.2.2.
major, minor, release = sys.version_info[:3]
CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
from bs4.element import (
CData,
Comment,
Declaration,
Doctype,
ProcessingInstruction,
)
from bs4.dammit import EntitySubstitution, UnicodeDammit
from bs4.builder import (
HTML,
HTMLTreeBuilder,
STRICT,
)
HTMLPARSER = 'html.parser'
class BeautifulSoupHTMLParser(HTMLParser):
def __init__(self, *args, **kwargs):
HTMLParser.__init__(self, *args, **kwargs)
# Keep a list of empty-element tags that were encountered
# without an explicit closing tag. If we encounter a closing tag
# of this type, we'll associate it with one of those entries.
#
# This isn't a stack because we don't care about the
# order. It's a list of closing tags we've already handled and
# will ignore, assuming they ever show up.
self.already_closed_empty_element = []
def error(self, msg):
"""In Python 3, HTMLParser subclasses must implement error(), although this
requirement doesn't appear to be documented.
In Python 2, HTMLParser implements error() as raising an exception.
In any event, this method is called only on very strange markup and our best strategy
is to pretend it didn't happen and keep going.
"""
warnings.warn(msg)
def handle_startendtag(self, name, attrs):
# This is only called when the markup looks like
# <tag/>.
# is_startend() tells handle_starttag not to close the tag
# just because its name matches a known empty-element tag. We
# know that this is an empty-element tag and we want to call
# handle_endtag ourselves.
tag = self.handle_starttag(name, attrs, handle_empty_element=False)
self.handle_endtag(name)
def handle_starttag(self, name, attrs, handle_empty_element=True):
# XXX namespace
attr_dict = {}
for key, value in attrs:
# Change None attribute values to the empty string
# for consistency with the other tree builders.
if value is None:
value = ''
attr_dict[key] = value
attrvalue = '""'
#print "START", name
tag = self.soup.handle_starttag(name, None, None, attr_dict)
if tag and tag.is_empty_element and handle_empty_element:
# Unlike other parsers, html.parser doesn't send separate end tag
# events for empty-element tags. (It's handled in
# handle_startendtag, but only if the original markup looked like
# <tag/>.)
#
# So we need to call handle_endtag() ourselves. Since we
# know the start event is identical to the end event, we
# don't want handle_endtag() to cross off any previous end
# events for tags of this name.
self.handle_endtag(name, check_already_closed=False)
# But we might encounter an explicit closing tag for this tag
# later on. If so, we want to ignore it.
self.already_closed_empty_element.append(name)
def handle_endtag(self, name, check_already_closed=True):
#print "END", name
if check_already_closed and name in self.already_closed_empty_element:
# This is a redundant end tag for an empty-element tag.
# We've already called handle_endtag() for it, so just
# check it off the list.
# print "ALREADY CLOSED", name
self.already_closed_empty_element.remove(name)
else:
self.soup.handle_endtag(name)
def handle_data(self, data):
self.soup.handle_data(data)
def handle_charref(self, name):
# XXX workaround for a bug in HTMLParser. Remove this once
# it's fixed in all supported versions.
# http://bugs.python.org/issue13633
if name.startswith('x'):
real_name = int(name.lstrip('x'), 16)
elif name.startswith('X'):
real_name = int(name.lstrip('X'), 16)
else:
real_name = int(name)
data = None
if real_name < 256:
# HTML numeric entities are supposed to reference Unicode
# code points, but sometimes they reference code points in
# some other encoding (ahem, Windows-1252). E.g. &#147;
# instead of &#201; for LEFT DOUBLE QUOTATION MARK. This
# code tries to detect this situation and compensate.
for encoding in (self.soup.original_encoding, 'windows-1252'):
if not encoding:
continue
try:
data = bytearray([real_name]).decode(encoding)
except UnicodeDecodeError as e:
pass
if not data:
try:
data = chr(real_name)
except (ValueError, OverflowError) as e:
pass
data = data or "\N{REPLACEMENT CHARACTER}"
self.handle_data(data)
def handle_entityref(self, name):
character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
if character is not None:
data = character
else:
# If this were XML, it would be ambiguous whether "&foo"
# was an character entity reference with a missing
# semicolon or the literal string "&foo". Since this is
# HTML, we have a complete list of all character entity references,
# and this one wasn't found, so assume it's the literal string "&foo".
data = "&%s" % name
self.handle_data(data)
def handle_comment(self, data):
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(Comment)
def handle_decl(self, data):
self.soup.endData()
if data.startswith("DOCTYPE "):
data = data[len("DOCTYPE "):]
elif data == 'DOCTYPE':
# i.e. "<!DOCTYPE>"
data = ''
self.soup.handle_data(data)
self.soup.endData(Doctype)
def unknown_decl(self, data):
if data.upper().startswith('CDATA['):
cls = CData
data = data[len('CDATA['):]
else:
cls = Declaration
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(cls)
def handle_pi(self, data):
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(ProcessingInstruction)
class HTMLParserTreeBuilder(HTMLTreeBuilder):
is_xml = False
picklable = True
NAME = HTMLPARSER
features = [NAME, HTML, STRICT]
def __init__(self, parser_args=None, parser_kwargs=None, **kwargs):
super(HTMLParserTreeBuilder, self).__init__(**kwargs)
parser_args = parser_args or []
parser_kwargs = parser_kwargs or {}
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
parser_kwargs['strict'] = False
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
parser_kwargs['convert_charrefs'] = False
self.parser_args = (parser_args, parser_kwargs)
def prepare_markup(self, markup, user_specified_encoding=None,
document_declared_encoding=None, exclude_encodings=None):
"""
:return: A 4-tuple (markup, original encoding, encoding
declared within markup, whether any characters had to be
replaced with REPLACEMENT CHARACTER).
"""
if isinstance(markup, str):
yield (markup, None, None, False)
return
try_encodings = [user_specified_encoding, document_declared_encoding]
dammit = UnicodeDammit(markup, try_encodings, is_html=True,
exclude_encodings=exclude_encodings)
yield (dammit.markup, dammit.original_encoding,
dammit.declared_html_encoding,
dammit.contains_replacement_characters)
def feed(self, markup):
args, kwargs = self.parser_args
parser = BeautifulSoupHTMLParser(*args, **kwargs)
parser.soup = self.soup
try:
parser.feed(markup)
parser.close()
except HTMLParseError as e:
warnings.warn(RuntimeWarning(
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
raise e
parser.already_closed_empty_element = []
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
# string.
#
# XXX This code can be removed once most Python 3 users are on 3.2.3.
if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
import re
attrfind_tolerant = re.compile(
r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
locatestarttagend = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:\s+ # whitespace before attribute name
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
(?:\s*=\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|\"[^\"]*\" # LIT-enclosed value
|[^'\">\s]+ # bare value
)
)?
)
)*
\s* # trailing whitespace
""", re.VERBOSE)
BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
from html.parser import tagfind, attrfind
def parse_starttag(self, i):
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between i+1 and j into a tag and attrs
attrs = []
match = tagfind.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = rawdata[i+1:k].lower()
while k < endpos:
if self.strict:
m = attrfind.match(rawdata, k)
else:
m = attrfind_tolerant.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
if attrvalue:
attrvalue = self.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n")
else:
offset = offset + len(self.__starttag_text)
if self.strict:
self.error("junk characters in start tag: %r"
% (rawdata[k:endpos][:20],))
self.handle_data(rawdata[i:endpos])
return endpos
if end.endswith('/>'):
# XHTML-style empty tag: <span attr="value" />
self.handle_startendtag(tag, attrs)
else:
self.handle_starttag(tag, attrs)
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode(tag)
return endpos
def set_cdata_mode(self, elem):
self.cdata_elem = elem.lower()
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
BeautifulSoupHTMLParser.parse_starttag = parse_starttag
BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
CONSTRUCTOR_TAKES_STRICT = True

296
libs/bs4/builder/_lxml.py Normal file
View file

@ -0,0 +1,296 @@
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
__all__ = [
'LXMLTreeBuilderForXML',
'LXMLTreeBuilder',
]
try:
from collections.abc import Callable # Python 3.6
except ImportError as e:
from collections import Callable
from io import BytesIO
from io import StringIO
from lxml import etree
from bs4.element import (
Comment,
Doctype,
NamespacedAttribute,
ProcessingInstruction,
XMLProcessingInstruction,
)
from bs4.builder import (
FAST,
HTML,
HTMLTreeBuilder,
PERMISSIVE,
ParserRejectedMarkup,
TreeBuilder,
XML)
from bs4.dammit import EncodingDetector
LXML = 'lxml'
def _invert(d):
"Invert a dictionary."
return dict((v,k) for k, v in list(d.items()))
class LXMLTreeBuilderForXML(TreeBuilder):
DEFAULT_PARSER_CLASS = etree.XMLParser
is_xml = True
processing_instruction_class = XMLProcessingInstruction
NAME = "lxml-xml"
ALTERNATE_NAMES = ["xml"]
# Well, it's permissive by XML parser standards.
features = [NAME, LXML, XML, FAST, PERMISSIVE]
CHUNK_SIZE = 512
# This namespace mapping is specified in the XML Namespace
# standard.
DEFAULT_NSMAPS = dict(xml='http://www.w3.org/XML/1998/namespace')
DEFAULT_NSMAPS_INVERTED = _invert(DEFAULT_NSMAPS)
def initialize_soup(self, soup):
"""Let the BeautifulSoup object know about the standard namespace
mapping.
"""
super(LXMLTreeBuilderForXML, self).initialize_soup(soup)
self._register_namespaces(self.DEFAULT_NSMAPS)
def _register_namespaces(self, mapping):
"""Let the BeautifulSoup object know about namespaces encountered
while parsing the document.
This might be useful later on when creating CSS selectors.
"""
for key, value in list(mapping.items()):
if key and key not in self.soup._namespaces:
# Let the BeautifulSoup object know about a new namespace.
# If there are multiple namespaces defined with the same
# prefix, the first one in the document takes precedence.
self.soup._namespaces[key] = value
def default_parser(self, encoding):
# This can either return a parser object or a class, which
# will be instantiated with default arguments.
if self._default_parser is not None:
return self._default_parser
return etree.XMLParser(
target=self, strip_cdata=False, recover=True, encoding=encoding)
def parser_for(self, encoding):
# Use the default parser.
parser = self.default_parser(encoding)
if isinstance(parser, Callable):
# Instantiate the parser with default arguments
parser = parser(target=self, strip_cdata=False, encoding=encoding)
return parser
def __init__(self, parser=None, empty_element_tags=None, **kwargs):
# TODO: Issue a warning if parser is present but not a
# callable, since that means there's no way to create new
# parsers for different encodings.
self._default_parser = parser
if empty_element_tags is not None:
self.empty_element_tags = set(empty_element_tags)
self.soup = None
self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED]
super(LXMLTreeBuilderForXML, self).__init__(**kwargs)
def _getNsTag(self, tag):
# Split the namespace URL out of a fully-qualified lxml tag
# name. Copied from lxml's src/lxml/sax.py.
if tag[0] == '{':
return tuple(tag[1:].split('}', 1))
else:
return (None, tag)
def prepare_markup(self, markup, user_specified_encoding=None,
exclude_encodings=None,
document_declared_encoding=None):
"""
:yield: A series of 4-tuples.
(markup, encoding, declared encoding,
has undergone character replacement)
Each 4-tuple represents a strategy for parsing the document.
"""
# Instead of using UnicodeDammit to convert the bytestring to
# Unicode using different encodings, use EncodingDetector to
# iterate over the encodings, and tell lxml to try to parse
# the document as each one in turn.
is_html = not self.is_xml
if is_html:
self.processing_instruction_class = ProcessingInstruction
else:
self.processing_instruction_class = XMLProcessingInstruction
if isinstance(markup, str):
# We were given Unicode. Maybe lxml can parse Unicode on
# this system?
yield markup, None, document_declared_encoding, False
if isinstance(markup, str):
# No, apparently not. Convert the Unicode to UTF-8 and
# tell lxml to parse it as UTF-8.
yield (markup.encode("utf8"), "utf8",
document_declared_encoding, False)
try_encodings = [user_specified_encoding, document_declared_encoding]
detector = EncodingDetector(
markup, try_encodings, is_html, exclude_encodings)
for encoding in detector.encodings:
yield (detector.markup, encoding, document_declared_encoding, False)
def feed(self, markup):
if isinstance(markup, bytes):
markup = BytesIO(markup)
elif isinstance(markup, str):
markup = StringIO(markup)
# Call feed() at least once, even if the markup is empty,
# or the parser won't be initialized.
data = markup.read(self.CHUNK_SIZE)
try:
self.parser = self.parser_for(self.soup.original_encoding)
self.parser.feed(data)
while len(data) != 0:
# Now call feed() on the rest of the data, chunk by chunk.
data = markup.read(self.CHUNK_SIZE)
if len(data) != 0:
self.parser.feed(data)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
raise ParserRejectedMarkup(str(e))
def close(self):
self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED]
def start(self, name, attrs, nsmap={}):
# Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
attrs = dict(attrs)
nsprefix = None
# Invert each namespace map as it comes in.
if len(nsmap) == 0 and len(self.nsmaps) > 1:
# There are no new namespaces for this tag, but
# non-default namespaces are in play, so we need a
# separate tag stack to know when they end.
self.nsmaps.append(None)
elif len(nsmap) > 0:
# A new namespace mapping has come into play.
# First, Let the BeautifulSoup object know about it.
self._register_namespaces(nsmap)
# Then, add it to our running list of inverted namespace
# mappings.
self.nsmaps.append(_invert(nsmap))
# Also treat the namespace mapping as a set of attributes on the
# tag, so we can recreate it later.
attrs = attrs.copy()
for prefix, namespace in list(nsmap.items()):
attribute = NamespacedAttribute(
"xmlns", prefix, "http://www.w3.org/2000/xmlns/")
attrs[attribute] = namespace
# Namespaces are in play. Find any attributes that came in
# from lxml with namespaces attached to their names, and
# turn then into NamespacedAttribute objects.
new_attrs = {}
for attr, value in list(attrs.items()):
namespace, attr = self._getNsTag(attr)
if namespace is None:
new_attrs[attr] = value
else:
nsprefix = self._prefix_for_namespace(namespace)
attr = NamespacedAttribute(nsprefix, attr, namespace)
new_attrs[attr] = value
attrs = new_attrs
namespace, name = self._getNsTag(name)
nsprefix = self._prefix_for_namespace(namespace)
self.soup.handle_starttag(name, namespace, nsprefix, attrs)
def _prefix_for_namespace(self, namespace):
"""Find the currently active prefix for the given namespace."""
if namespace is None:
return None
for inverted_nsmap in reversed(self.nsmaps):
if inverted_nsmap is not None and namespace in inverted_nsmap:
return inverted_nsmap[namespace]
return None
def end(self, name):
self.soup.endData()
completed_tag = self.soup.tagStack[-1]
namespace, name = self._getNsTag(name)
nsprefix = None
if namespace is not None:
for inverted_nsmap in reversed(self.nsmaps):
if inverted_nsmap is not None and namespace in inverted_nsmap:
nsprefix = inverted_nsmap[namespace]
break
self.soup.handle_endtag(name, nsprefix)
if len(self.nsmaps) > 1:
# This tag, or one of its parents, introduced a namespace
# mapping, so pop it off the stack.
self.nsmaps.pop()
def pi(self, target, data):
self.soup.endData()
self.soup.handle_data(target + ' ' + data)
self.soup.endData(self.processing_instruction_class)
def data(self, content):
self.soup.handle_data(content)
def doctype(self, name, pubid, system):
self.soup.endData()
doctype = Doctype.for_name_and_ids(name, pubid, system)
self.soup.object_was_parsed(doctype)
def comment(self, content):
"Handle comments as Comment objects."
self.soup.endData()
self.soup.handle_data(content)
self.soup.endData(Comment)
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return '<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
NAME = LXML
ALTERNATE_NAMES = ["lxml-html"]
features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE]
is_xml = False
processing_instruction_class = ProcessingInstruction
def default_parser(self, encoding):
return etree.HTMLParser
def feed(self, markup):
encoding = self.soup.original_encoding
try:
self.parser = self.parser_for(encoding)
self.parser.feed(markup)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
raise ParserRejectedMarkup(str(e))
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return '<html><body>%s</body></html>' % fragment

850
libs/bs4/dammit.py Normal file
View file

@ -0,0 +1,850 @@
# -*- coding: utf-8 -*-
"""Beautiful Soup bonus library: Unicode, Dammit
This library converts a bytestream to Unicode through any means
necessary. It is heavily based on code from Mark Pilgrim's Universal
Feed Parser. It works best on XML and HTML, but it does not rewrite the
XML or HTML to reflect a new encoding; that's the tree builder's job.
"""
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
import codecs
from html.entities import codepoint2name
import re
import logging
import string
# Import a library to autodetect character encodings.
chardet_type = None
try:
# First try the fast C implementation.
# PyPI package: cchardet
import cchardet
def chardet_dammit(s):
return cchardet.detect(s)['encoding']
except ImportError:
try:
# Fall back to the pure Python implementation
# Debian package: python-chardet
# PyPI package: chardet
import chardet
def chardet_dammit(s):
return chardet.detect(s)['encoding']
#import chardet.constants
#chardet.constants._debug = 1
except ImportError:
# No chardet available.
def chardet_dammit(s):
return None
# Available from http://cjkpython.i18n.org/.
try:
import iconv_codec
except ImportError:
pass
xml_encoding_re = re.compile(
'^<\\?.*encoding=[\'"](.*?)[\'"].*\\?>'.encode(), re.I)
html_meta_re = re.compile(
'<\\s*meta[^>]+charset\\s*=\\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
class EntitySubstitution(object):
"""Substitute XML or HTML entities for the corresponding characters."""
def _populate_class_variables():
lookup = {}
reverse_lookup = {}
characters_for_re = []
# &apos is an XHTML entity and an HTML 5, but not an HTML 4
# entity. We don't want to use it, but we want to recognize it on the way in.
#
# TODO: Ideally we would be able to recognize all HTML 5 named
# entities, but that's a little tricky.
extra = [(39, 'apos')]
for codepoint, name in list(codepoint2name.items()) + extra:
character = chr(codepoint)
if codepoint not in (34, 39):
# There's no point in turning the quotation mark into
# &quot; or the single quote into &apos;, unless it
# happens within an attribute value, which is handled
# elsewhere.
characters_for_re.append(character)
lookup[character] = name
# But we do want to recognize those entities on the way in and
# convert them to Unicode characters.
reverse_lookup[name] = character
re_definition = "[%s]" % "".join(characters_for_re)
return lookup, reverse_lookup, re.compile(re_definition)
(CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER,
CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables()
CHARACTER_TO_XML_ENTITY = {
"'": "apos",
'"': "quot",
"&": "amp",
"<": "lt",
">": "gt",
}
BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|"
"&(?!#\\d+;|#x[0-9a-fA-F]+;|\\w+;)"
")")
AMPERSAND_OR_BRACKET = re.compile("([<>&])")
@classmethod
def _substitute_html_entity(cls, matchobj):
entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0))
return "&%s;" % entity
@classmethod
def _substitute_xml_entity(cls, matchobj):
"""Used with a regular expression to substitute the
appropriate XML entity for an XML special character."""
entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)]
return "&%s;" % entity
@classmethod
def quoted_attribute_value(self, value):
"""Make a value into a quoted XML attribute, possibly escaping it.
Most strings will be quoted using double quotes.
Bob's Bar -> "Bob's Bar"
If a string contains double quotes, it will be quoted using
single quotes.
Welcome to "my bar" -> 'Welcome to "my bar"'
If a string contains both single and double quotes, the
double quotes will be escaped, and the string will be quoted
using double quotes.
Welcome to "Bob's Bar" -> "Welcome to &quot;Bob's bar&quot;
"""
quote_with = '"'
if '"' in value:
if "'" in value:
# The string contains both single and double
# quotes. Turn the double quotes into
# entities. We quote the double quotes rather than
# the single quotes because the entity name is
# "&quot;" whether this is HTML or XML. If we
# quoted the single quotes, we'd have to decide
# between &apos; and &squot;.
replace_with = "&quot;"
value = value.replace('"', replace_with)
else:
# There are double quotes but no single quotes.
# We can use single quotes to quote the attribute.
quote_with = "'"
return quote_with + value + quote_with
@classmethod
def substitute_xml(cls, value, make_quoted_attribute=False):
"""Substitute XML entities for special XML characters.
:param value: A string to be substituted. The less-than sign
will become &lt;, the greater-than sign will become &gt;,
and any ampersands will become &amp;. If you want ampersands
that appear to be part of an entity definition to be left
alone, use substitute_xml_containing_entities() instead.
:param make_quoted_attribute: If True, then the string will be
quoted, as befits an attribute value.
"""
# Escape angle brackets and ampersands.
value = cls.AMPERSAND_OR_BRACKET.sub(
cls._substitute_xml_entity, value)
if make_quoted_attribute:
value = cls.quoted_attribute_value(value)
return value
@classmethod
def substitute_xml_containing_entities(
cls, value, make_quoted_attribute=False):
"""Substitute XML entities for special XML characters.
:param value: A string to be substituted. The less-than sign will
become &lt;, the greater-than sign will become &gt;, and any
ampersands that are not part of an entity defition will
become &amp;.
:param make_quoted_attribute: If True, then the string will be
quoted, as befits an attribute value.
"""
# Escape angle brackets, and ampersands that aren't part of
# entities.
value = cls.BARE_AMPERSAND_OR_BRACKET.sub(
cls._substitute_xml_entity, value)
if make_quoted_attribute:
value = cls.quoted_attribute_value(value)
return value
@classmethod
def substitute_html(cls, s):
"""Replace certain Unicode characters with named HTML entities.
This differs from data.encode(encoding, 'xmlcharrefreplace')
in that the goal is to make the result more readable (to those
with ASCII displays) rather than to recover from
errors. There's absolutely nothing wrong with a UTF-8 string
containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that
character with "&eacute;" will make it more readable to some
people.
"""
return cls.CHARACTER_TO_HTML_ENTITY_RE.sub(
cls._substitute_html_entity, s)
class EncodingDetector:
"""Suggests a number of possible encodings for a bytestring.
Order of precedence:
1. Encodings you specifically tell EncodingDetector to try first
(the override_encodings argument to the constructor).
2. An encoding declared within the bytestring itself, either in an
XML declaration (if the bytestring is to be interpreted as an XML
document), or in a <meta> tag (if the bytestring is to be
interpreted as an HTML document.)
3. An encoding detected through textual analysis by chardet,
cchardet, or a similar external library.
4. UTF-8.
5. Windows-1252.
"""
def __init__(self, markup, override_encodings=None, is_html=False,
exclude_encodings=None):
self.override_encodings = override_encodings or []
exclude_encodings = exclude_encodings or []
self.exclude_encodings = set([x.lower() for x in exclude_encodings])
self.chardet_encoding = None
self.is_html = is_html
self.declared_encoding = None
# First order of business: strip a byte-order mark.
self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup)
def _usable(self, encoding, tried):
if encoding is not None:
encoding = encoding.lower()
if encoding in self.exclude_encodings:
return False
if encoding not in tried:
tried.add(encoding)
return True
return False
@property
def encodings(self):
"""Yield a number of encodings that might work for this markup."""
tried = set()
for e in self.override_encodings:
if self._usable(e, tried):
yield e
# Did the document originally start with a byte-order mark
# that indicated its encoding?
if self._usable(self.sniffed_encoding, tried):
yield self.sniffed_encoding
# Look within the document for an XML or HTML encoding
# declaration.
if self.declared_encoding is None:
self.declared_encoding = self.find_declared_encoding(
self.markup, self.is_html)
if self._usable(self.declared_encoding, tried):
yield self.declared_encoding
# Use third-party character set detection to guess at the
# encoding.
if self.chardet_encoding is None:
self.chardet_encoding = chardet_dammit(self.markup)
if self._usable(self.chardet_encoding, tried):
yield self.chardet_encoding
# As a last-ditch effort, try utf-8 and windows-1252.
for e in ('utf-8', 'windows-1252'):
if self._usable(e, tried):
yield e
@classmethod
def strip_byte_order_mark(cls, data):
"""If a byte-order mark is present, strip it and return the encoding it implies."""
encoding = None
if isinstance(data, str):
# Unicode data cannot have a byte-order mark.
return data, encoding
if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \
and (data[2:4] != '\x00\x00'):
encoding = 'utf-16be'
data = data[2:]
elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \
and (data[2:4] != '\x00\x00'):
encoding = 'utf-16le'
data = data[2:]
elif data[:3] == b'\xef\xbb\xbf':
encoding = 'utf-8'
data = data[3:]
elif data[:4] == b'\x00\x00\xfe\xff':
encoding = 'utf-32be'
data = data[4:]
elif data[:4] == b'\xff\xfe\x00\x00':
encoding = 'utf-32le'
data = data[4:]
return data, encoding
@classmethod
def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False):
"""Given a document, tries to find its declared encoding.
An XML encoding is declared at the beginning of the document.
An HTML encoding is declared in a <meta> tag, hopefully near the
beginning of the document.
"""
if search_entire_document:
xml_endpos = html_endpos = len(markup)
else:
xml_endpos = 1024
html_endpos = max(2048, int(len(markup) * 0.05))
declared_encoding = None
declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos)
if not declared_encoding_match and is_html:
declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos)
if declared_encoding_match is not None:
declared_encoding = declared_encoding_match.groups()[0].decode(
'ascii', 'replace')
if declared_encoding:
return declared_encoding.lower()
return None
class UnicodeDammit:
"""A class for detecting the encoding of a *ML document and
converting it to a Unicode string. If the source encoding is
windows-1252, can replace MS smart quotes with their HTML or XML
equivalents."""
# This dictionary maps commonly seen values for "charset" in HTML
# meta tags to the corresponding Python codec names. It only covers
# values that aren't in Python's aliases and can't be determined
# by the heuristics in find_codec.
CHARSET_ALIASES = {"macintosh": "mac-roman",
"x-sjis": "shift-jis"}
ENCODINGS_WITH_SMART_QUOTES = [
"windows-1252",
"iso-8859-1",
"iso-8859-2",
]
def __init__(self, markup, override_encodings=[],
smart_quotes_to=None, is_html=False, exclude_encodings=[]):
self.smart_quotes_to = smart_quotes_to
self.tried_encodings = []
self.contains_replacement_characters = False
self.is_html = is_html
self.log = logging.getLogger(__name__)
self.detector = EncodingDetector(
markup, override_encodings, is_html, exclude_encodings)
# Short-circuit if the data is in Unicode to begin with.
if isinstance(markup, str) or markup == '':
self.markup = markup
self.unicode_markup = str(markup)
self.original_encoding = None
return
# The encoding detector may have stripped a byte-order mark.
# Use the stripped markup from this point on.
self.markup = self.detector.markup
u = None
for encoding in self.detector.encodings:
markup = self.detector.markup
u = self._convert_from(encoding)
if u is not None:
break
if not u:
# None of the encodings worked. As an absolute last resort,
# try them again with character replacement.
for encoding in self.detector.encodings:
if encoding != "ascii":
u = self._convert_from(encoding, "replace")
if u is not None:
self.log.warning(
"Some characters could not be decoded, and were "
"replaced with REPLACEMENT CHARACTER."
)
self.contains_replacement_characters = True
break
# If none of that worked, we could at this point force it to
# ASCII, but that would destroy so much data that I think
# giving up is better.
self.unicode_markup = u
if not u:
self.original_encoding = None
def _sub_ms_char(self, match):
"""Changes a MS smart quote character to an XML or HTML
entity, or an ASCII character."""
orig = match.group(1)
if self.smart_quotes_to == 'ascii':
sub = self.MS_CHARS_TO_ASCII.get(orig).encode()
else:
sub = self.MS_CHARS.get(orig)
if type(sub) == tuple:
if self.smart_quotes_to == 'xml':
sub = '&#x'.encode() + sub[1].encode() + ';'.encode()
else:
sub = '&'.encode() + sub[0].encode() + ';'.encode()
else:
sub = sub.encode()
return sub
def _convert_from(self, proposed, errors="strict"):
proposed = self.find_codec(proposed)
if not proposed or (proposed, errors) in self.tried_encodings:
return None
self.tried_encodings.append((proposed, errors))
markup = self.markup
# Convert smart quotes to HTML if coming from an encoding
# that might have them.
if (self.smart_quotes_to is not None
and proposed in self.ENCODINGS_WITH_SMART_QUOTES):
smart_quotes_re = b"([\x80-\x9f])"
smart_quotes_compiled = re.compile(smart_quotes_re)
markup = smart_quotes_compiled.sub(self._sub_ms_char, markup)
try:
#print "Trying to convert document to %s (errors=%s)" % (
# proposed, errors)
u = self._to_unicode(markup, proposed, errors)
self.markup = u
self.original_encoding = proposed
except Exception as e:
#print "That didn't work!"
#print e
return None
#print "Correct encoding: %s" % proposed
return self.markup
def _to_unicode(self, data, encoding, errors="strict"):
'''Given a string and its encoding, decodes the string into Unicode.
%encoding is a string recognized by encodings.aliases'''
return str(data, encoding, errors)
@property
def declared_html_encoding(self):
if not self.is_html:
return None
return self.detector.declared_encoding
def find_codec(self, charset):
value = (self._codec(self.CHARSET_ALIASES.get(charset, charset))
or (charset and self._codec(charset.replace("-", "")))
or (charset and self._codec(charset.replace("-", "_")))
or (charset and charset.lower())
or charset
)
if value:
return value.lower()
return None
def _codec(self, charset):
if not charset:
return charset
codec = None
try:
codecs.lookup(charset)
codec = charset
except (LookupError, ValueError):
pass
return codec
# A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities.
MS_CHARS = {b'\x80': ('euro', '20AC'),
b'\x81': ' ',
b'\x82': ('sbquo', '201A'),
b'\x83': ('fnof', '192'),
b'\x84': ('bdquo', '201E'),
b'\x85': ('hellip', '2026'),
b'\x86': ('dagger', '2020'),
b'\x87': ('Dagger', '2021'),
b'\x88': ('circ', '2C6'),
b'\x89': ('permil', '2030'),
b'\x8A': ('Scaron', '160'),
b'\x8B': ('lsaquo', '2039'),
b'\x8C': ('OElig', '152'),
b'\x8D': '?',
b'\x8E': ('#x17D', '17D'),
b'\x8F': '?',
b'\x90': '?',
b'\x91': ('lsquo', '2018'),
b'\x92': ('rsquo', '2019'),
b'\x93': ('ldquo', '201C'),
b'\x94': ('rdquo', '201D'),
b'\x95': ('bull', '2022'),
b'\x96': ('ndash', '2013'),
b'\x97': ('mdash', '2014'),
b'\x98': ('tilde', '2DC'),
b'\x99': ('trade', '2122'),
b'\x9a': ('scaron', '161'),
b'\x9b': ('rsaquo', '203A'),
b'\x9c': ('oelig', '153'),
b'\x9d': '?',
b'\x9e': ('#x17E', '17E'),
b'\x9f': ('Yuml', ''),}
# A parochial partial mapping of ISO-Latin-1 to ASCII. Contains
# horrors like stripping diacritical marks to turn á into a, but also
# contains non-horrors like turning “ into ".
MS_CHARS_TO_ASCII = {
b'\x80' : 'EUR',
b'\x81' : ' ',
b'\x82' : ',',
b'\x83' : 'f',
b'\x84' : ',,',
b'\x85' : '...',
b'\x86' : '+',
b'\x87' : '++',
b'\x88' : '^',
b'\x89' : '%',
b'\x8a' : 'S',
b'\x8b' : '<',
b'\x8c' : 'OE',
b'\x8d' : '?',
b'\x8e' : 'Z',
b'\x8f' : '?',
b'\x90' : '?',
b'\x91' : "'",
b'\x92' : "'",
b'\x93' : '"',
b'\x94' : '"',
b'\x95' : '*',
b'\x96' : '-',
b'\x97' : '--',
b'\x98' : '~',
b'\x99' : '(TM)',
b'\x9a' : 's',
b'\x9b' : '>',
b'\x9c' : 'oe',
b'\x9d' : '?',
b'\x9e' : 'z',
b'\x9f' : 'Y',
b'\xa0' : ' ',
b'\xa1' : '!',
b'\xa2' : 'c',
b'\xa3' : 'GBP',
b'\xa4' : '$', #This approximation is especially parochial--this is the
#generic currency symbol.
b'\xa5' : 'YEN',
b'\xa6' : '|',
b'\xa7' : 'S',
b'\xa8' : '..',
b'\xa9' : '',
b'\xaa' : '(th)',
b'\xab' : '<<',
b'\xac' : '!',
b'\xad' : ' ',
b'\xae' : '(R)',
b'\xaf' : '-',
b'\xb0' : 'o',
b'\xb1' : '+-',
b'\xb2' : '2',
b'\xb3' : '3',
b'\xb4' : ("'", 'acute'),
b'\xb5' : 'u',
b'\xb6' : 'P',
b'\xb7' : '*',
b'\xb8' : ',',
b'\xb9' : '1',
b'\xba' : '(th)',
b'\xbb' : '>>',
b'\xbc' : '1/4',
b'\xbd' : '1/2',
b'\xbe' : '3/4',
b'\xbf' : '?',
b'\xc0' : 'A',
b'\xc1' : 'A',
b'\xc2' : 'A',
b'\xc3' : 'A',
b'\xc4' : 'A',
b'\xc5' : 'A',
b'\xc6' : 'AE',
b'\xc7' : 'C',
b'\xc8' : 'E',
b'\xc9' : 'E',
b'\xca' : 'E',
b'\xcb' : 'E',
b'\xcc' : 'I',
b'\xcd' : 'I',
b'\xce' : 'I',
b'\xcf' : 'I',
b'\xd0' : 'D',
b'\xd1' : 'N',
b'\xd2' : 'O',
b'\xd3' : 'O',
b'\xd4' : 'O',
b'\xd5' : 'O',
b'\xd6' : 'O',
b'\xd7' : '*',
b'\xd8' : 'O',
b'\xd9' : 'U',
b'\xda' : 'U',
b'\xdb' : 'U',
b'\xdc' : 'U',
b'\xdd' : 'Y',
b'\xde' : 'b',
b'\xdf' : 'B',
b'\xe0' : 'a',
b'\xe1' : 'a',
b'\xe2' : 'a',
b'\xe3' : 'a',
b'\xe4' : 'a',
b'\xe5' : 'a',
b'\xe6' : 'ae',
b'\xe7' : 'c',
b'\xe8' : 'e',
b'\xe9' : 'e',
b'\xea' : 'e',
b'\xeb' : 'e',
b'\xec' : 'i',
b'\xed' : 'i',
b'\xee' : 'i',
b'\xef' : 'i',
b'\xf0' : 'o',
b'\xf1' : 'n',
b'\xf2' : 'o',
b'\xf3' : 'o',
b'\xf4' : 'o',
b'\xf5' : 'o',
b'\xf6' : 'o',
b'\xf7' : '/',
b'\xf8' : 'o',
b'\xf9' : 'u',
b'\xfa' : 'u',
b'\xfb' : 'u',
b'\xfc' : 'u',
b'\xfd' : 'y',
b'\xfe' : 'b',
b'\xff' : 'y',
}
# A map used when removing rogue Windows-1252/ISO-8859-1
# characters in otherwise UTF-8 documents.
#
# Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in
# Windows-1252.
WINDOWS_1252_TO_UTF8 = {
0x80 : b'\xe2\x82\xac', # €
0x82 : b'\xe2\x80\x9a', #
0x83 : b'\xc6\x92', # ƒ
0x84 : b'\xe2\x80\x9e', # „
0x85 : b'\xe2\x80\xa6', # …
0x86 : b'\xe2\x80\xa0', # †
0x87 : b'\xe2\x80\xa1', # ‡
0x88 : b'\xcb\x86', # ˆ
0x89 : b'\xe2\x80\xb0', # ‰
0x8a : b'\xc5\xa0', # Š
0x8b : b'\xe2\x80\xb9', #
0x8c : b'\xc5\x92', # Œ
0x8e : b'\xc5\xbd', # Ž
0x91 : b'\xe2\x80\x98', #
0x92 : b'\xe2\x80\x99', #
0x93 : b'\xe2\x80\x9c', # “
0x94 : b'\xe2\x80\x9d', # ”
0x95 : b'\xe2\x80\xa2', # •
0x96 : b'\xe2\x80\x93', #
0x97 : b'\xe2\x80\x94', # —
0x98 : b'\xcb\x9c', # ˜
0x99 : b'\xe2\x84\xa2', # ™
0x9a : b'\xc5\xa1', # š
0x9b : b'\xe2\x80\xba', #
0x9c : b'\xc5\x93', # œ
0x9e : b'\xc5\xbe', # ž
0x9f : b'\xc5\xb8', # Ÿ
0xa0 : b'\xc2\xa0', #  
0xa1 : b'\xc2\xa1', # ¡
0xa2 : b'\xc2\xa2', # ¢
0xa3 : b'\xc2\xa3', # £
0xa4 : b'\xc2\xa4', # ¤
0xa5 : b'\xc2\xa5', # ¥
0xa6 : b'\xc2\xa6', # ¦
0xa7 : b'\xc2\xa7', # §
0xa8 : b'\xc2\xa8', # ¨
0xa9 : b'\xc2\xa9', # ©
0xaa : b'\xc2\xaa', # ª
0xab : b'\xc2\xab', # «
0xac : b'\xc2\xac', # ¬
0xad : b'\xc2\xad', # ­
0xae : b'\xc2\xae', # ®
0xaf : b'\xc2\xaf', # ¯
0xb0 : b'\xc2\xb0', # °
0xb1 : b'\xc2\xb1', # ±
0xb2 : b'\xc2\xb2', # ²
0xb3 : b'\xc2\xb3', # ³
0xb4 : b'\xc2\xb4', # ´
0xb5 : b'\xc2\xb5', # µ
0xb6 : b'\xc2\xb6', # ¶
0xb7 : b'\xc2\xb7', # ·
0xb8 : b'\xc2\xb8', # ¸
0xb9 : b'\xc2\xb9', # ¹
0xba : b'\xc2\xba', # º
0xbb : b'\xc2\xbb', # »
0xbc : b'\xc2\xbc', # ¼
0xbd : b'\xc2\xbd', # ½
0xbe : b'\xc2\xbe', # ¾
0xbf : b'\xc2\xbf', # ¿
0xc0 : b'\xc3\x80', # À
0xc1 : b'\xc3\x81', # Á
0xc2 : b'\xc3\x82', # Â
0xc3 : b'\xc3\x83', # Ã
0xc4 : b'\xc3\x84', # Ä
0xc5 : b'\xc3\x85', # Å
0xc6 : b'\xc3\x86', # Æ
0xc7 : b'\xc3\x87', # Ç
0xc8 : b'\xc3\x88', # È
0xc9 : b'\xc3\x89', # É
0xca : b'\xc3\x8a', # Ê
0xcb : b'\xc3\x8b', # Ë
0xcc : b'\xc3\x8c', # Ì
0xcd : b'\xc3\x8d', # Í
0xce : b'\xc3\x8e', # Î
0xcf : b'\xc3\x8f', # Ï
0xd0 : b'\xc3\x90', # Ð
0xd1 : b'\xc3\x91', # Ñ
0xd2 : b'\xc3\x92', # Ò
0xd3 : b'\xc3\x93', # Ó
0xd4 : b'\xc3\x94', # Ô
0xd5 : b'\xc3\x95', # Õ
0xd6 : b'\xc3\x96', # Ö
0xd7 : b'\xc3\x97', # ×
0xd8 : b'\xc3\x98', # Ø
0xd9 : b'\xc3\x99', # Ù
0xda : b'\xc3\x9a', # Ú
0xdb : b'\xc3\x9b', # Û
0xdc : b'\xc3\x9c', # Ü
0xdd : b'\xc3\x9d', # Ý
0xde : b'\xc3\x9e', # Þ
0xdf : b'\xc3\x9f', # ß
0xe0 : b'\xc3\xa0', # à
0xe1 : b'\xa1', # á
0xe2 : b'\xc3\xa2', # â
0xe3 : b'\xc3\xa3', # ã
0xe4 : b'\xc3\xa4', # ä
0xe5 : b'\xc3\xa5', # å
0xe6 : b'\xc3\xa6', # æ
0xe7 : b'\xc3\xa7', # ç
0xe8 : b'\xc3\xa8', # è
0xe9 : b'\xc3\xa9', # é
0xea : b'\xc3\xaa', # ê
0xeb : b'\xc3\xab', # ë
0xec : b'\xc3\xac', # ì
0xed : b'\xc3\xad', # í
0xee : b'\xc3\xae', # î
0xef : b'\xc3\xaf', # ï
0xf0 : b'\xc3\xb0', # ð
0xf1 : b'\xc3\xb1', # ñ
0xf2 : b'\xc3\xb2', # ò
0xf3 : b'\xc3\xb3', # ó
0xf4 : b'\xc3\xb4', # ô
0xf5 : b'\xc3\xb5', # õ
0xf6 : b'\xc3\xb6', # ö
0xf7 : b'\xc3\xb7', # ÷
0xf8 : b'\xc3\xb8', # ø
0xf9 : b'\xc3\xb9', # ù
0xfa : b'\xc3\xba', # ú
0xfb : b'\xc3\xbb', # û
0xfc : b'\xc3\xbc', # ü
0xfd : b'\xc3\xbd', # ý
0xfe : b'\xc3\xbe', # þ
}
MULTIBYTE_MARKERS_AND_SIZES = [
(0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF
(0xe0, 0xef, 3), # 3-byte characters start with E0-EF
(0xf0, 0xf4, 4), # 4-byte characters start with F0-F4
]
FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0]
LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1]
@classmethod
def detwingle(cls, in_bytes, main_encoding="utf8",
embedded_encoding="windows-1252"):
"""Fix characters from one encoding embedded in some other encoding.
Currently the only situation supported is Windows-1252 (or its
subset ISO-8859-1), embedded in UTF-8.
The input must be a bytestring. If you've already converted
the document to Unicode, you're too late.
The output is a bytestring in which `embedded_encoding`
characters have been converted to their `main_encoding`
equivalents.
"""
if embedded_encoding.replace('_', '-').lower() not in (
'windows-1252', 'windows_1252'):
raise NotImplementedError(
"Windows-1252 and ISO-8859-1 are the only currently supported "
"embedded encodings.")
if main_encoding.lower() not in ('utf8', 'utf-8'):
raise NotImplementedError(
"UTF-8 is the only currently supported main encoding.")
byte_chunks = []
chunk_start = 0
pos = 0
while pos < len(in_bytes):
byte = in_bytes[pos]
if not isinstance(byte, int):
# Python 2.x
byte = ord(byte)
if (byte >= cls.FIRST_MULTIBYTE_MARKER
and byte <= cls.LAST_MULTIBYTE_MARKER):
# This is the start of a UTF-8 multibyte character. Skip
# to the end.
for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES:
if byte >= start and byte <= end:
pos += size
break
elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8:
# We found a Windows-1252 character!
# Save the string up to this point as a chunk.
byte_chunks.append(in_bytes[chunk_start:pos])
# Now translate the Windows-1252 character into UTF-8
# and add it as another, one-byte chunk.
byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte])
pos += 1
chunk_start = pos
else:
# Go on to the next character.
pos += 1
if chunk_start == 0:
# The string is unchanged.
return in_bytes
else:
# Store the final chunk.
byte_chunks.append(in_bytes[chunk_start:])
return b''.join(byte_chunks)

224
libs/bs4/diagnose.py Normal file
View file

@ -0,0 +1,224 @@
"""Diagnostic functions, mainly for use when doing tech support."""
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
import cProfile
from io import StringIO
from html.parser import HTMLParser
import bs4
from bs4 import BeautifulSoup, __version__
from bs4.builder import builder_registry
import os
import pstats
import random
import tempfile
import time
import traceback
import sys
import cProfile
def diagnose(data):
"""Diagnostic suite for isolating common problems."""
print("Diagnostic running on Beautiful Soup %s" % __version__)
print("Python version %s" % sys.version)
basic_parsers = ["html.parser", "html5lib", "lxml"]
for name in basic_parsers:
for builder in builder_registry.builders:
if name in builder.features:
break
else:
basic_parsers.remove(name)
print((
"I noticed that %s is not installed. Installing it may help." %
name))
if 'lxml' in basic_parsers:
basic_parsers.append("lxml-xml")
try:
from lxml import etree
print("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION)))
except ImportError as e:
print (
"lxml is not installed or couldn't be imported.")
if 'html5lib' in basic_parsers:
try:
import html5lib
print("Found html5lib version %s" % html5lib.__version__)
except ImportError as e:
print (
"html5lib is not installed or couldn't be imported.")
if hasattr(data, 'read'):
data = data.read()
elif data.startswith("http:") or data.startswith("https:"):
print('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data)
print("You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup.")
return
else:
try:
if os.path.exists(data):
print('"%s" looks like a filename. Reading data from the file.' % data)
with open(data) as fp:
data = fp.read()
except ValueError:
# This can happen on some platforms when the 'filename' is
# too long. Assume it's data and not a filename.
pass
print()
for parser in basic_parsers:
print("Trying to parse your markup with %s" % parser)
success = False
try:
soup = BeautifulSoup(data, features=parser)
success = True
except Exception as e:
print("%s could not parse the markup." % parser)
traceback.print_exc()
if success:
print("Here's what %s did with the markup:" % parser)
print(soup.prettify())
print("-" * 80)
def lxml_trace(data, html=True, **kwargs):
"""Print out the lxml events that occur during parsing.
This lets you see how lxml parses a document when no Beautiful
Soup code is running.
"""
from lxml import etree
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
print(("%s, %4s, %s" % (event, element.tag, element.text)))
class AnnouncingParser(HTMLParser):
"""Announces HTMLParser parse events, without doing anything else."""
def _p(self, s):
print(s)
def handle_starttag(self, name, attrs):
self._p("%s START" % name)
def handle_endtag(self, name):
self._p("%s END" % name)
def handle_data(self, data):
self._p("%s DATA" % data)
def handle_charref(self, name):
self._p("%s CHARREF" % name)
def handle_entityref(self, name):
self._p("%s ENTITYREF" % name)
def handle_comment(self, data):
self._p("%s COMMENT" % data)
def handle_decl(self, data):
self._p("%s DECL" % data)
def unknown_decl(self, data):
self._p("%s UNKNOWN-DECL" % data)
def handle_pi(self, data):
self._p("%s PI" % data)
def htmlparser_trace(data):
"""Print out the HTMLParser events that occur during parsing.
This lets you see how HTMLParser parses a document when no
Beautiful Soup code is running.
"""
parser = AnnouncingParser()
parser.feed(data)
_vowels = "aeiou"
_consonants = "bcdfghjklmnpqrstvwxyz"
def rword(length=5):
"Generate a random word-like string."
s = ''
for i in range(length):
if i % 2 == 0:
t = _consonants
else:
t = _vowels
s += random.choice(t)
return s
def rsentence(length=4):
"Generate a random sentence-like string."
return " ".join(rword(random.randint(4,9)) for i in list(range(length)))
def rdoc(num_elements=1000):
"""Randomly generate an invalid HTML document."""
tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table']
elements = []
for i in range(num_elements):
choice = random.randint(0,3)
if choice == 0:
# New tag.
tag_name = random.choice(tag_names)
elements.append("<%s>" % tag_name)
elif choice == 1:
elements.append(rsentence(random.randint(1,4)))
elif choice == 2:
# Close a tag.
tag_name = random.choice(tag_names)
elements.append("</%s>" % tag_name)
return "<html>" + "\n".join(elements) + "</html>"
def benchmark_parsers(num_elements=100000):
"""Very basic head-to-head performance benchmark."""
print("Comparative parser benchmark on Beautiful Soup %s" % __version__)
data = rdoc(num_elements)
print("Generated a large invalid HTML document (%d bytes)." % len(data))
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
success = False
try:
a = time.time()
soup = BeautifulSoup(data, parser)
b = time.time()
success = True
except Exception as e:
print("%s could not parse the markup." % parser)
traceback.print_exc()
if success:
print("BS4+%s parsed the markup in %.2fs." % (parser, b-a))
from lxml import etree
a = time.time()
etree.HTML(data)
b = time.time()
print("Raw lxml parsed the markup in %.2fs." % (b-a))
import html5lib
parser = html5lib.HTMLParser()
a = time.time()
parser.parse(data)
b = time.time()
print("Raw html5lib parsed the markup in %.2fs." % (b-a))
def profile(num_elements=100000, parser="lxml"):
filehandle = tempfile.NamedTemporaryFile()
filename = filehandle.name
data = rdoc(num_elements)
vars = dict(bs4=bs4, data=data, parser=parser)
cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename)
stats = pstats.Stats(filename)
# stats.strip_dirs()
stats.sort_stats("cumulative")
stats.print_stats('_html5lib|bs4', 50)
if __name__ == '__main__':
diagnose(sys.stdin.read())

1579
libs/bs4/element.py Normal file

File diff suppressed because it is too large Load diff

99
libs/bs4/formatter.py Normal file
View file

@ -0,0 +1,99 @@
from bs4.dammit import EntitySubstitution
class Formatter(EntitySubstitution):
"""Describes a strategy to use when outputting a parse tree to a string.
Some parts of this strategy come from the distinction between
HTML4, HTML5, and XML. Others are configurable by the user.
"""
# Registries of XML and HTML formatters.
XML_FORMATTERS = {}
HTML_FORMATTERS = {}
HTML = 'html'
XML = 'xml'
HTML_DEFAULTS = dict(
cdata_containing_tags=set(["script", "style"]),
)
def _default(self, language, value, kwarg):
if value is not None:
return value
if language == self.XML:
return set()
return self.HTML_DEFAULTS[kwarg]
def __init__(
self, language=None, entity_substitution=None,
void_element_close_prefix='/', cdata_containing_tags=None,
):
"""
:param void_element_close_prefix: By default, represent void
elements as <tag/> rather than <tag>
"""
self.language = language
self.entity_substitution = entity_substitution
self.void_element_close_prefix = void_element_close_prefix
self.cdata_containing_tags = self._default(
language, cdata_containing_tags, 'cdata_containing_tags'
)
def substitute(self, ns):
"""Process a string that needs to undergo entity substitution."""
if not self.entity_substitution:
return ns
from .element import NavigableString
if (isinstance(ns, NavigableString)
and ns.parent is not None
and ns.parent.name in self.cdata_containing_tags):
# Do nothing.
return ns
# Substitute.
return self.entity_substitution(ns)
def attribute_value(self, value):
"""Process the value of an attribute."""
return self.substitute(value)
def attributes(self, tag):
"""Reorder a tag's attributes however you want."""
return sorted(tag.attrs.items())
class HTMLFormatter(Formatter):
REGISTRY = {}
def __init__(self, *args, **kwargs):
return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
class XMLFormatter(Formatter):
REGISTRY = {}
def __init__(self, *args, **kwargs):
return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
# Set up aliases for the default formatters.
HTMLFormatter.REGISTRY['html'] = HTMLFormatter(
entity_substitution=EntitySubstitution.substitute_html
)
HTMLFormatter.REGISTRY["html5"] = HTMLFormatter(
entity_substitution=EntitySubstitution.substitute_html,
void_element_close_prefix = None
)
HTMLFormatter.REGISTRY["minimal"] = HTMLFormatter(
entity_substitution=EntitySubstitution.substitute_xml
)
HTMLFormatter.REGISTRY[None] = HTMLFormatter(
entity_substitution=None
)
XMLFormatter.REGISTRY["html"] = XMLFormatter(
entity_substitution=EntitySubstitution.substitute_html
)
XMLFormatter.REGISTRY["minimal"] = XMLFormatter(
entity_substitution=EntitySubstitution.substitute_xml
)
XMLFormatter.REGISTRY[None] = Formatter(
Formatter(Formatter.XML, entity_substitution=None)
)

992
libs/bs4/testing.py Normal file
View file

@ -0,0 +1,992 @@
# encoding: utf-8
"""Helper classes for tests."""
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
import pickle
import copy
import functools
import unittest
from unittest import TestCase
from bs4 import BeautifulSoup
from bs4.element import (
CharsetMetaAttributeValue,
Comment,
ContentMetaAttributeValue,
Doctype,
SoupStrainer,
Tag
)
from bs4.builder import HTMLParserTreeBuilder
default_builder = HTMLParserTreeBuilder
BAD_DOCUMENT = """A bare string
<!DOCTYPE xsl:stylesheet SYSTEM "htmlent.dtd">
<!DOCTYPE xsl:stylesheet PUBLIC "htmlent.dtd">
<div><![CDATA[A CDATA section where it doesn't belong]]></div>
<div><svg><![CDATA[HTML5 does allow CDATA sections in SVG]]></svg></div>
<div>A <meta> tag</div>
<div>A <br> tag that supposedly has contents.</br></div>
<div>AT&T</div>
<div><textarea>Within a textarea, markup like <b> tags and <&<&amp; should be treated as literal</textarea></div>
<div><script>if (i < 2) { alert("<b>Markup within script tags should be treated as literal.</b>"); }</script></div>
<div>This numeric entity is missing the final semicolon: <x t="pi&#241ata"></div>
<div><a href="http://example.com/</a> that attribute value never got closed</div>
<div><a href="foo</a>, </a><a href="bar">that attribute value was closed by the subsequent tag</a></div>
<! This document starts with a bogus declaration ><div>a</div>
<div>This document contains <!an incomplete declaration <div>(do you see it?)</div>
<div>This document ends with <!an incomplete declaration
<div><a style={height:21px;}>That attribute value was bogus</a></div>
<! DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">The doctype is invalid because it contains extra whitespace
<div><table><td nowrap>That boolean attribute had no value</td></table></div>
<div>Here's a nonexistent entity: &#foo; (do you see it?)</div>
<div>This document ends before the entity finishes: &gt
<div><p>Paragraphs shouldn't contain block display elements, but this one does: <dl><dt>you see?</dt></p>
<b b="20" a="1" b="10" a="2" a="3" a="4">Multiple values for the same attribute.</b>
<div><table><tr><td>Here's a table</td></tr></table></div>
<div><table id="1"><tr><td>Here's a nested table:<table id="2"><tr><td>foo</td></tr></table></td></div>
<div>This tag contains nothing but whitespace: <b> </b></div>
<div><blockquote><p><b>This p tag is cut off by</blockquote></p>the end of the blockquote tag</div>
<div><table><div>This table contains bare markup</div></table></div>
<div><div id="1">\n <a href="link1">This link is never closed.\n</div>\n<div id="2">\n <div id="3">\n <a href="link2">This link is closed.</a>\n </div>\n</div></div>
<div>This document contains a <!DOCTYPE surprise>surprise doctype</div>
<div><a><B><Cd><EFG>Mixed case tags are folded to lowercase</efg></CD></b></A></div>
<div><our\u2603>Tag name contains Unicode characters</our\u2603></div>
<div><a \u2603="snowman">Attribute name contains Unicode characters</a></div>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
"""
class SoupTest(unittest.TestCase):
@property
def default_builder(self):
return default_builder
def soup(self, markup, **kwargs):
"""Build a Beautiful Soup object from markup."""
builder = kwargs.pop('builder', self.default_builder)
return BeautifulSoup(markup, builder=builder, **kwargs)
def document_for(self, markup, **kwargs):
"""Turn an HTML fragment into a document.
The details depend on the builder.
"""
return self.default_builder(**kwargs).test_fragment_to_document(markup)
def assertSoupEquals(self, to_parse, compare_parsed_to=None):
builder = self.default_builder
obj = BeautifulSoup(to_parse, builder=builder)
if compare_parsed_to is None:
compare_parsed_to = to_parse
self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
def assertConnectedness(self, element):
"""Ensure that next_element and previous_element are properly
set for all descendants of the given element.
"""
earlier = None
for e in element.descendants:
if earlier:
self.assertEqual(e, earlier.next_element)
self.assertEqual(earlier, e.previous_element)
earlier = e
def linkage_validator(self, el, _recursive_call=False):
"""Ensure proper linkage throughout the document."""
descendant = None
# Document element should have no previous element or previous sibling.
# It also shouldn't have a next sibling.
if el.parent is None:
assert el.previous_element is None,\
"Bad previous_element\nNODE: {}\nPREV: {}\nEXPECTED: {}".format(
el, el.previous_element, None
)
assert el.previous_sibling is None,\
"Bad previous_sibling\nNODE: {}\nPREV: {}\nEXPECTED: {}".format(
el, el.previous_sibling, None
)
assert el.next_sibling is None,\
"Bad next_sibling\nNODE: {}\nNEXT: {}\nEXPECTED: {}".format(
el, el.next_sibling, None
)
idx = 0
child = None
last_child = None
last_idx = len(el.contents) - 1
for child in el.contents:
descendant = None
# Parent should link next element to their first child
# That child should have no previous sibling
if idx == 0:
if el.parent is not None:
assert el.next_element is child,\
"Bad next_element\nNODE: {}\nNEXT: {}\nEXPECTED: {}".format(
el, el.next_element, child
)
assert child.previous_element is el,\
"Bad previous_element\nNODE: {}\nPREV: {}\nEXPECTED: {}".format(
child, child.previous_element, el
)
assert child.previous_sibling is None,\
"Bad previous_sibling\nNODE: {}\nPREV {}\nEXPECTED: {}".format(
child, child.previous_sibling, None
)
# If not the first child, previous index should link as sibling to this index
# Previous element should match the last index or the last bubbled up descendant
else:
assert child.previous_sibling is el.contents[idx - 1],\
"Bad previous_sibling\nNODE: {}\nPREV {}\nEXPECTED {}".format(
child, child.previous_sibling, el.contents[idx - 1]
)
assert el.contents[idx - 1].next_sibling is child,\
"Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
el.contents[idx - 1], el.contents[idx - 1].next_sibling, child
)
if last_child is not None:
assert child.previous_element is last_child,\
"Bad previous_element\nNODE: {}\nPREV {}\nEXPECTED {}\nCONTENTS {}".format(
child, child.previous_element, last_child, child.parent.contents
)
assert last_child.next_element is child,\
"Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
last_child, last_child.next_element, child
)
if isinstance(child, Tag) and child.contents:
descendant = self.linkage_validator(child, True)
# A bubbled up descendant should have no next siblings
assert descendant.next_sibling is None,\
"Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
descendant, descendant.next_sibling, None
)
# Mark last child as either the bubbled up descendant or the current child
if descendant is not None:
last_child = descendant
else:
last_child = child
# If last child, there are non next siblings
if idx == last_idx:
assert child.next_sibling is None,\
"Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
child, child.next_sibling, None
)
idx += 1
child = descendant if descendant is not None else child
if child is None:
child = el
if not _recursive_call and child is not None:
target = el
while True:
if target is None:
assert child.next_element is None, \
"Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
child, child.next_element, None
)
break
elif target.next_sibling is not None:
assert child.next_element is target.next_sibling, \
"Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
child, child.next_element, target.next_sibling
)
break
target = target.parent
# We are done, so nothing to return
return None
else:
# Return the child to the recursive caller
return child
class HTMLTreeBuilderSmokeTest(object):
"""A basic test of a treebuilder's competence.
Any HTML treebuilder, present or future, should be able to pass
these tests. With invalid markup, there's room for interpretation,
and different parsers can handle it differently. But with the
markup in these tests, there's not much room for interpretation.
"""
def test_empty_element_tags(self):
"""Verify that all HTML4 and HTML5 empty element (aka void element) tags
are handled correctly.
"""
for name in [
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr',
'spacer', 'frame'
]:
soup = self.soup("")
new_tag = soup.new_tag(name)
self.assertEqual(True, new_tag.is_empty_element)
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
# to the original.
tree = self.soup("<a><b>foo</a>")
dumped = pickle.dumps(tree, 2)
loaded = pickle.loads(dumped)
self.assertEqual(loaded.__class__, BeautifulSoup)
self.assertEqual(loaded.decode(), tree.decode())
def assertDoctypeHandled(self, doctype_fragment):
"""Assert that a given doctype string is handled correctly."""
doctype_str, soup = self._document_with_doctype(doctype_fragment)
# Make sure a Doctype object was created.
doctype = soup.contents[0]
self.assertEqual(doctype.__class__, Doctype)
self.assertEqual(doctype, doctype_fragment)
self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
# Make sure that the doctype was correctly associated with the
# parse tree and that the rest of the document parsed.
self.assertEqual(soup.p.contents[0], 'foo')
def _document_with_doctype(self, doctype_fragment):
"""Generate and parse a document with the given doctype."""
doctype = '<!DOCTYPE %s>' % doctype_fragment
markup = doctype + '\n<p>foo</p>'
soup = self.soup(markup)
return doctype, soup
def test_normal_doctypes(self):
"""Make sure normal, everyday HTML doctypes are handled correctly."""
self.assertDoctypeHandled("html")
self.assertDoctypeHandled(
'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"')
def test_empty_doctype(self):
soup = self.soup("<!DOCTYPE>")
doctype = soup.contents[0]
self.assertEqual("", doctype.strip())
def test_public_doctype_with_url(self):
doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
self.assertDoctypeHandled(doctype)
def test_system_doctype(self):
self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"')
def test_namespaced_system_doctype(self):
# We can handle a namespaced doctype with a system ID.
self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"')
def test_namespaced_public_doctype(self):
# Test a namespaced doctype with a public id.
self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"')
def test_real_xhtml_document(self):
"""A real XHTML document should come out more or less the same as it went in."""
markup = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
<html xmlns="http://www.w3.org/1999/xhtml">
<head><title>Hello.</title></head>
<body>Goodbye.</body>
</html>"""
soup = self.soup(markup)
self.assertEqual(
soup.encode("utf-8").replace(b"\n", b""),
markup.replace(b"\n", b""))
def test_namespaced_html(self):
"""When a namespaced XML document is parsed as HTML it should
be treated as HTML with weird tag names.
"""
markup = b"""<ns1:foo>content</ns1:foo><ns1:foo/><ns2:foo/>"""
soup = self.soup(markup)
self.assertEqual(2, len(soup.find_all("ns1:foo")))
def test_processing_instruction(self):
# We test both Unicode and bytestring to verify that
# process_markup correctly sets processing_instruction_class
# even when the markup is already Unicode and there is no
# need to process anything.
markup = """<?PITarget PIContent?>"""
soup = self.soup(markup)
self.assertEqual(markup, soup.decode())
markup = b"""<?PITarget PIContent?>"""
soup = self.soup(markup)
self.assertEqual(markup, soup.encode("utf8"))
def test_deepcopy(self):
"""Make sure you can copy the tree builder.
This is important because the builder is part of a
BeautifulSoup object, and we want to be able to copy that.
"""
copy.deepcopy(self.default_builder)
def test_p_tag_is_never_empty_element(self):
"""A <p> tag is never designated as an empty-element tag.
Even if the markup shows it as an empty-element tag, it
shouldn't be presented that way.
"""
soup = self.soup("<p/>")
self.assertFalse(soup.p.is_empty_element)
self.assertEqual(str(soup.p), "<p></p>")
def test_unclosed_tags_get_closed(self):
"""A tag that's not closed by the end of the document should be closed.
This applies to all tags except empty-element tags.
"""
self.assertSoupEquals("<p>", "<p></p>")
self.assertSoupEquals("<b>", "<b></b>")
self.assertSoupEquals("<br>", "<br/>")
def test_br_is_always_empty_element_tag(self):
"""A <br> tag is designated as an empty-element tag.
Some parsers treat <br></br> as one <br/> tag, some parsers as
two tags, but it should always be an empty-element tag.
"""
soup = self.soup("<br></br>")
self.assertTrue(soup.br.is_empty_element)
self.assertEqual(str(soup.br), "<br/>")
def test_nested_formatting_elements(self):
self.assertSoupEquals("<em><em></em></em>")
def test_double_head(self):
html = '''<!DOCTYPE html>
<html>
<head>
<title>Ordinary HEAD element test</title>
</head>
<script type="text/javascript">
alert("Help!");
</script>
<body>
Hello, world!
</body>
</html>
'''
soup = self.soup(html)
self.assertEqual("text/javascript", soup.find('script')['type'])
def test_comment(self):
# Comments are represented as Comment objects.
markup = "<p>foo<!--foobar-->baz</p>"
self.assertSoupEquals(markup)
soup = self.soup(markup)
comment = soup.find(text="foobar")
self.assertEqual(comment.__class__, Comment)
# The comment is properly integrated into the tree.
foo = soup.find(text="foo")
self.assertEqual(comment, foo.next_element)
baz = soup.find(text="baz")
self.assertEqual(comment, baz.previous_element)
def test_preserved_whitespace_in_pre_and_textarea(self):
"""Whitespace must be preserved in <pre> and <textarea> tags,
even if that would mean not prettifying the markup.
"""
pre_markup = "<pre> </pre>"
textarea_markup = "<textarea> woo\nwoo </textarea>"
self.assertSoupEquals(pre_markup)
self.assertSoupEquals(textarea_markup)
soup = self.soup(pre_markup)
self.assertEqual(soup.pre.prettify(), pre_markup)
soup = self.soup(textarea_markup)
self.assertEqual(soup.textarea.prettify(), textarea_markup)
soup = self.soup("<textarea></textarea>")
self.assertEqual(soup.textarea.prettify(), "<textarea></textarea>")
def test_nested_inline_elements(self):
"""Inline elements can be nested indefinitely."""
b_tag = "<b>Inside a B tag</b>"
self.assertSoupEquals(b_tag)
nested_b_tag = "<p>A <i>nested <b>tag</b></i></p>"
self.assertSoupEquals(nested_b_tag)
double_nested_b_tag = "<p>A <a>doubly <i>nested <b>tag</b></i></a></p>"
self.assertSoupEquals(nested_b_tag)
def test_nested_block_level_elements(self):
"""Block elements can be nested."""
soup = self.soup('<blockquote><p><b>Foo</b></p></blockquote>')
blockquote = soup.blockquote
self.assertEqual(blockquote.p.b.string, 'Foo')
self.assertEqual(blockquote.b.string, 'Foo')
def test_correctly_nested_tables(self):
"""One table can go inside another one."""
markup = ('<table id="1">'
'<tr>'
"<td>Here's another table:"
'<table id="2">'
'<tr><td>foo</td></tr>'
'</table></td>')
self.assertSoupEquals(
markup,
'<table id="1"><tr><td>Here\'s another table:'
'<table id="2"><tr><td>foo</td></tr></table>'
'</td></tr></table>')
self.assertSoupEquals(
"<table><thead><tr><td>Foo</td></tr></thead>"
"<tbody><tr><td>Bar</td></tr></tbody>"
"<tfoot><tr><td>Baz</td></tr></tfoot></table>")
def test_multivalued_attribute_with_whitespace(self):
# Whitespace separating the values of a multi-valued attribute
# should be ignored.
markup = '<div class=" foo bar "></a>'
soup = self.soup(markup)
self.assertEqual(['foo', 'bar'], soup.div['class'])
# If you search by the literal name of the class it's like the whitespace
# wasn't there.
self.assertEqual(soup.div, soup.find('div', class_="foo bar"))
def test_deeply_nested_multivalued_attribute(self):
# html5lib can set the attributes of the same tag many times
# as it rearranges the tree. This has caused problems with
# multivalued attributes.
markup = '<table><div><div class="css"></div></div></table>'
soup = self.soup(markup)
self.assertEqual(["css"], soup.div.div['class'])
def test_multivalued_attribute_on_html(self):
# html5lib uses a different API to set the attributes ot the
# <html> tag. This has caused problems with multivalued
# attributes.
markup = '<html class="a b"></html>'
soup = self.soup(markup)
self.assertEqual(["a", "b"], soup.html['class'])
def test_angle_brackets_in_attribute_values_are_escaped(self):
self.assertSoupEquals('<a b="<a>"></a>', '<a b="&lt;a&gt;"></a>')
def test_strings_resembling_character_entity_references(self):
# "&T" and "&p" look like incomplete character entities, but they are
# not.
self.assertSoupEquals(
"<p>&bull; AT&T is in the s&p 500</p>",
"<p>\u2022 AT&amp;T is in the s&amp;p 500</p>"
)
def test_apos_entity(self):
self.assertSoupEquals(
"<p>Bob&apos;s Bar</p>",
"<p>Bob's Bar</p>",
)
def test_entities_in_foreign_document_encoding(self):
# &#147; and &#148; are invalid numeric entities referencing
# Windows-1252 characters. &#45; references a character common
# to Windows-1252 and Unicode, and &#9731; references a
# character only found in Unicode.
#
# All of these entities should be converted to Unicode
# characters.
markup = "<p>&#147;Hello&#148; &#45;&#9731;</p>"
soup = self.soup(markup)
self.assertEqual("“Hello” -☃", soup.p.string)
def test_entities_in_attributes_converted_to_unicode(self):
expect = '<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
self.assertSoupEquals('<p id="pi&#241;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&#xf1;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&#Xf1;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&ntilde;ata"></p>', expect)
def test_entities_in_text_converted_to_unicode(self):
expect = '<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
self.assertSoupEquals("<p>pi&#241;ata</p>", expect)
self.assertSoupEquals("<p>pi&#xf1;ata</p>", expect)
self.assertSoupEquals("<p>pi&#Xf1;ata</p>", expect)
self.assertSoupEquals("<p>pi&ntilde;ata</p>", expect)
def test_quot_entity_converted_to_quotation_mark(self):
self.assertSoupEquals("<p>I said &quot;good day!&quot;</p>",
'<p>I said "good day!"</p>')
def test_out_of_range_entity(self):
expect = "\N{REPLACEMENT CHARACTER}"
self.assertSoupEquals("&#10000000000000;", expect)
self.assertSoupEquals("&#x10000000000000;", expect)
self.assertSoupEquals("&#1000000000;", expect)
def test_multipart_strings(self):
"Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
self.assertEqual("p", soup.h2.string.next_element.name)
self.assertEqual("p", soup.p.name)
self.assertConnectedness(soup)
def test_empty_element_tags(self):
"""Verify consistent handling of empty-element tags,
no matter how they come in through the markup.
"""
self.assertSoupEquals('<br/><br/><br/>', "<br/><br/><br/>")
self.assertSoupEquals('<br /><br /><br />', "<br/><br/><br/>")
def test_head_tag_between_head_and_body(self):
"Prevent recurrence of a bug in the html5lib treebuilder."
content = """<html><head></head>
<link></link>
<body>foo</body>
</html>
"""
soup = self.soup(content)
self.assertNotEqual(None, soup.html.body)
self.assertConnectedness(soup)
def test_multiple_copies_of_a_tag(self):
"Prevent recurrence of a bug in the html5lib treebuilder."
content = """<!DOCTYPE html>
<html>
<body>
<article id="a" >
<div><a href="1"></div>
<footer>
<a href="2"></a>
</footer>
</article>
</body>
</html>
"""
soup = self.soup(content)
self.assertConnectedness(soup.article)
def test_basic_namespaces(self):
"""Parsers don't need to *understand* namespaces, but at the
very least they should not choke on namespaces or lose
data."""
markup = b'<html xmlns="http://www.w3.org/1999/xhtml" xmlns:mathml="http://www.w3.org/1998/Math/MathML" xmlns:svg="http://www.w3.org/2000/svg"><head></head><body><mathml:msqrt>4</mathml:msqrt><b svg:fill="red"></b></body></html>'
soup = self.soup(markup)
self.assertEqual(markup, soup.encode())
html = soup.html
self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns'])
self.assertEqual(
'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml'])
self.assertEqual(
'http://www.w3.org/2000/svg', soup.html['xmlns:svg'])
def test_multivalued_attribute_value_becomes_list(self):
markup = b'<a class="foo bar">'
soup = self.soup(markup)
self.assertEqual(['foo', 'bar'], soup.a['class'])
#
# Generally speaking, tests below this point are more tests of
# Beautiful Soup than tests of the tree builders. But parsers are
# weird, so we run these tests separately for every tree builder
# to detect any differences between them.
#
def test_can_parse_unicode_document(self):
# A seemingly innocuous document... but it's in Unicode! And
# it contains characters that can't be represented in the
# encoding found in the declaration! The horror!
markup = '<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
soup = self.soup(markup)
self.assertEqual('Sacr\xe9 bleu!', soup.body.string)
def test_soupstrainer(self):
"""Parsers should be able to work with SoupStrainers."""
strainer = SoupStrainer("b")
soup = self.soup("A <b>bold</b> <meta/> <i>statement</i>",
parse_only=strainer)
self.assertEqual(soup.decode(), "<b>bold</b>")
def test_single_quote_attribute_values_become_double_quotes(self):
self.assertSoupEquals("<foo attr='bar'></foo>",
'<foo attr="bar"></foo>')
def test_attribute_values_with_nested_quotes_are_left_alone(self):
text = """<foo attr='bar "brawls" happen'>a</foo>"""
self.assertSoupEquals(text)
def test_attribute_values_with_double_nested_quotes_get_quoted(self):
text = """<foo attr='bar "brawls" happen'>a</foo>"""
soup = self.soup(text)
soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"'
self.assertSoupEquals(
soup.foo.decode(),
"""<foo attr="Brawls happen at &quot;Bob\'s Bar&quot;">a</foo>""")
def test_ampersand_in_attribute_value_gets_escaped(self):
self.assertSoupEquals('<this is="really messed up & stuff"></this>',
'<this is="really messed up &amp; stuff"></this>')
self.assertSoupEquals(
'<a href="http://example.org?a=1&b=2;3">foo</a>',
'<a href="http://example.org?a=1&amp;b=2;3">foo</a>')
def test_escaped_ampersand_in_attribute_value_is_left_alone(self):
self.assertSoupEquals('<a href="http://example.org?a=1&amp;b=2;3"></a>')
def test_entities_in_strings_converted_during_parsing(self):
# Both XML and HTML entities are converted to Unicode characters
# during parsing.
text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
expected = "<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>"
self.assertSoupEquals(text, expected)
def test_smart_quotes_converted_on_the_way_in(self):
# Microsoft smart quotes are converted to Unicode characters during
# parsing.
quote = b"<p>\x91Foo\x92</p>"
soup = self.soup(quote)
self.assertEqual(
soup.p.string,
"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
def test_non_breaking_spaces_converted_on_the_way_in(self):
soup = self.soup("<a>&nbsp;&nbsp;</a>")
self.assertEqual(soup.a.string, "\N{NO-BREAK SPACE}" * 2)
def test_entities_converted_on_the_way_out(self):
text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
expected = "<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>".encode("utf-8")
soup = self.soup(text)
self.assertEqual(soup.p.encode("utf-8"), expected)
def test_real_iso_latin_document(self):
# Smoke test of interrelated functionality, using an
# easy-to-understand document.
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
unicode_html = '<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
# That's because we're going to encode it into ISO-Latin-1, and use
# that to test.
iso_latin_html = unicode_html.encode("iso-8859-1")
# Parse the ISO-Latin-1 HTML.
soup = self.soup(iso_latin_html)
# Encode it to UTF-8.
result = soup.encode("utf-8")
# What do we expect the result to look like? Well, it would
# look like unicode_html, except that the META tag would say
# UTF-8 instead of ISO-Latin-1.
expected = unicode_html.replace("ISO-Latin-1", "utf-8")
# And, of course, it would be in UTF-8, not Unicode.
expected = expected.encode("utf-8")
# Ta-da!
self.assertEqual(result, expected)
def test_real_shift_jis_document(self):
# Smoke test to make sure the parser can handle a document in
# Shift-JIS encoding, without choking.
shift_jis_html = (
b'<html><head></head><body><pre>'
b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
b'</pre></body></html>')
unicode_html = shift_jis_html.decode("shift-jis")
soup = self.soup(unicode_html)
# Make sure the parse tree is correctly encoded to various
# encodings.
self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8"))
self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp"))
def test_real_hebrew_document(self):
# A real-world test to make sure we can convert ISO-8859-9 (a
# Hebrew encoding) to UTF-8.
hebrew_document = b'<html><head><title>Hebrew (ISO 8859-8) in Visual Directionality</title></head><body><h1>Hebrew (ISO 8859-8) in Visual Directionality</h1>\xed\xe5\xec\xf9</body></html>'
soup = self.soup(
hebrew_document, from_encoding="iso8859-8")
# Some tree builders call it iso8859-8, others call it iso-8859-9.
# That's not a difference we really care about.
assert soup.original_encoding in ('iso8859-8', 'iso-8859-8')
self.assertEqual(
soup.encode('utf-8'),
hebrew_document.decode("iso8859-8").encode("utf-8"))
def test_meta_tag_reflects_current_encoding(self):
# Here's the <meta> tag saying that a document is
# encoded in Shift-JIS.
meta_tag = ('<meta content="text/html; charset=x-sjis" '
'http-equiv="Content-type"/>')
# Here's a document incorporating that meta tag.
shift_jis_html = (
'<html><head>\n%s\n'
'<meta http-equiv="Content-language" content="ja"/>'
'</head><body>Shift-JIS markup goes here.') % meta_tag
soup = self.soup(shift_jis_html)
# Parse the document, and the charset is seemingly unaffected.
parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'})
content = parsed_meta['content']
self.assertEqual('text/html; charset=x-sjis', content)
# But that value is actually a ContentMetaAttributeValue object.
self.assertTrue(isinstance(content, ContentMetaAttributeValue))
# And it will take on a value that reflects its current
# encoding.
self.assertEqual('text/html; charset=utf8', content.encode("utf8"))
# For the rest of the story, see TestSubstitutions in
# test_tree.py.
def test_html5_style_meta_tag_reflects_current_encoding(self):
# Here's the <meta> tag saying that a document is
# encoded in Shift-JIS.
meta_tag = ('<meta id="encoding" charset="x-sjis" />')
# Here's a document incorporating that meta tag.
shift_jis_html = (
'<html><head>\n%s\n'
'<meta http-equiv="Content-language" content="ja"/>'
'</head><body>Shift-JIS markup goes here.') % meta_tag
soup = self.soup(shift_jis_html)
# Parse the document, and the charset is seemingly unaffected.
parsed_meta = soup.find('meta', id="encoding")
charset = parsed_meta['charset']
self.assertEqual('x-sjis', charset)
# But that value is actually a CharsetMetaAttributeValue object.
self.assertTrue(isinstance(charset, CharsetMetaAttributeValue))
# And it will take on a value that reflects its current
# encoding.
self.assertEqual('utf8', charset.encode("utf8"))
def test_tag_with_no_attributes_can_have_attributes_added(self):
data = self.soup("<a>text</a>")
data.a['foo'] = 'bar'
self.assertEqual('<a foo="bar">text</a>', data.a.decode())
def test_worst_case(self):
"""Test the worst case (currently) for linking issues."""
soup = self.soup(BAD_DOCUMENT)
self.linkage_validator(soup)
class XMLTreeBuilderSmokeTest(object):
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
# to the original.
tree = self.soup("<a><b>foo</a>")
dumped = pickle.dumps(tree, 2)
loaded = pickle.loads(dumped)
self.assertEqual(loaded.__class__, BeautifulSoup)
self.assertEqual(loaded.decode(), tree.decode())
def test_docstring_generated(self):
soup = self.soup("<root/>")
self.assertEqual(
soup.encode(), b'<?xml version="1.0" encoding="utf-8"?>\n<root/>')
def test_xml_declaration(self):
markup = b"""<?xml version="1.0" encoding="utf8"?>\n<foo/>"""
soup = self.soup(markup)
self.assertEqual(markup, soup.encode("utf8"))
def test_processing_instruction(self):
markup = b"""<?xml version="1.0" encoding="utf8"?>\n<?PITarget PIContent?>"""
soup = self.soup(markup)
self.assertEqual(markup, soup.encode("utf8"))
def test_real_xhtml_document(self):
"""A real XHTML document should come out *exactly* the same as it went in."""
markup = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
<html xmlns="http://www.w3.org/1999/xhtml">
<head><title>Hello.</title></head>
<body>Goodbye.</body>
</html>"""
soup = self.soup(markup)
self.assertEqual(
soup.encode("utf-8"), markup)
def test_nested_namespaces(self):
doc = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
<parent xmlns="http://ns1/">
<child xmlns="http://ns2/" xmlns:ns3="http://ns3/">
<grandchild ns3:attr="value" xmlns="http://ns4/"/>
</child>
</parent>"""
soup = self.soup(doc)
self.assertEqual(doc, soup.encode())
def test_formatter_processes_script_tag_for_xml_documents(self):
doc = """
<script type="text/javascript">
</script>
"""
soup = BeautifulSoup(doc, "lxml-xml")
# lxml would have stripped this while parsing, but we can add
# it later.
soup.script.string = 'console.log("< < hey > > ");'
encoded = soup.encode()
self.assertTrue(b"&lt; &lt; hey &gt; &gt;" in encoded)
def test_can_parse_unicode_document(self):
markup = '<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
soup = self.soup(markup)
self.assertEqual('Sacr\xe9 bleu!', soup.root.string)
def test_popping_namespaced_tag(self):
markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>'
soup = self.soup(markup)
self.assertEqual(
str(soup.rss), markup)
def test_docstring_includes_correct_encoding(self):
soup = self.soup("<root/>")
self.assertEqual(
soup.encode("latin1"),
b'<?xml version="1.0" encoding="latin1"?>\n<root/>')
def test_large_xml_document(self):
"""A large XML document should come out the same as it went in."""
markup = (b'<?xml version="1.0" encoding="utf-8"?>\n<root>'
+ b'0' * (2**12)
+ b'</root>')
soup = self.soup(markup)
self.assertEqual(soup.encode("utf-8"), markup)
def test_tags_are_empty_element_if_and_only_if_they_are_empty(self):
self.assertSoupEquals("<p>", "<p/>")
self.assertSoupEquals("<p>foo</p>")
def test_namespaces_are_preserved(self):
markup = '<root xmlns:a="http://example.com/" xmlns:b="http://example.net/"><a:foo>This tag is in the a namespace</a:foo><b:foo>This tag is in the b namespace</b:foo></root>'
soup = self.soup(markup)
root = soup.root
self.assertEqual("http://example.com/", root['xmlns:a'])
self.assertEqual("http://example.net/", root['xmlns:b'])
def test_closing_namespaced_tag(self):
markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>'
soup = self.soup(markup)
self.assertEqual(str(soup.p), markup)
def test_namespaced_attributes(self):
markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>'
soup = self.soup(markup)
self.assertEqual(str(soup.foo), markup)
def test_namespaced_attributes_xml_namespace(self):
markup = '<foo xml:lang="fr">bar</foo>'
soup = self.soup(markup)
self.assertEqual(str(soup.foo), markup)
def test_find_by_prefixed_name(self):
doc = """<?xml version="1.0" encoding="utf-8"?>
<Document xmlns="http://example.com/ns0"
xmlns:ns1="http://example.com/ns1"
xmlns:ns2="http://example.com/ns2"
<ns1:tag>foo</ns1:tag>
<ns1:tag>bar</ns1:tag>
<ns2:tag key="value">baz</ns2:tag>
</Document>
"""
soup = self.soup(doc)
# There are three <tag> tags.
self.assertEqual(3, len(soup.find_all('tag')))
# But two of them are ns1:tag and one of them is ns2:tag.
self.assertEqual(2, len(soup.find_all('ns1:tag')))
self.assertEqual(1, len(soup.find_all('ns2:tag')))
self.assertEqual(1, len(soup.find_all('ns2:tag', key='value')))
self.assertEqual(3, len(soup.find_all(['ns1:tag', 'ns2:tag'])))
def test_copy_tag_preserves_namespace(self):
xml = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<w:document xmlns:w="http://example.com/ns0"/>"""
soup = self.soup(xml)
tag = soup.document
duplicate = copy.copy(tag)
# The two tags have the same namespace prefix.
self.assertEqual(tag.prefix, duplicate.prefix)
def test_worst_case(self):
"""Test the worst case (currently) for linking issues."""
soup = self.soup(BAD_DOCUMENT)
self.linkage_validator(soup)
class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
"""Smoke test for a tree builder that supports HTML5."""
def test_real_xhtml_document(self):
# Since XHTML is not HTML5, HTML5 parsers are not tested to handle
# XHTML documents in any particular way.
pass
def test_html_tags_have_namespace(self):
markup = "<a>"
soup = self.soup(markup)
self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace)
def test_svg_tags_have_namespace(self):
markup = '<svg><circle/></svg>'
soup = self.soup(markup)
namespace = "http://www.w3.org/2000/svg"
self.assertEqual(namespace, soup.svg.namespace)
self.assertEqual(namespace, soup.circle.namespace)
def test_mathml_tags_have_namespace(self):
markup = '<math><msqrt>5</msqrt></math>'
soup = self.soup(markup)
namespace = 'http://www.w3.org/1998/Math/MathML'
self.assertEqual(namespace, soup.math.namespace)
self.assertEqual(namespace, soup.msqrt.namespace)
def test_xml_declaration_becomes_comment(self):
markup = '<?xml version="1.0" encoding="utf-8"?><html></html>'
soup = self.soup(markup)
self.assertTrue(isinstance(soup.contents[0], Comment))
self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?')
self.assertEqual("html", soup.contents[0].next_element.name)
def skipIf(condition, reason):
def nothing(test, *args, **kwargs):
return None
def decorator(test_item):
if condition:
return nothing
else:
return test_item
return decorator

View file

@ -0,0 +1 @@
"The beautifulsoup tests."

View file

@ -0,0 +1,147 @@
"""Tests of the builder registry."""
import unittest
import warnings
from bs4 import BeautifulSoup
from bs4.builder import (
builder_registry as registry,
HTMLParserTreeBuilder,
TreeBuilderRegistry,
)
try:
from bs4.builder import HTML5TreeBuilder
HTML5LIB_PRESENT = True
except ImportError:
HTML5LIB_PRESENT = False
try:
from bs4.builder import (
LXMLTreeBuilderForXML,
LXMLTreeBuilder,
)
LXML_PRESENT = True
except ImportError:
LXML_PRESENT = False
class BuiltInRegistryTest(unittest.TestCase):
"""Test the built-in registry with the default builders registered."""
def test_combination(self):
if LXML_PRESENT:
self.assertEqual(registry.lookup('fast', 'html'),
LXMLTreeBuilder)
if LXML_PRESENT:
self.assertEqual(registry.lookup('permissive', 'xml'),
LXMLTreeBuilderForXML)
self.assertEqual(registry.lookup('strict', 'html'),
HTMLParserTreeBuilder)
if HTML5LIB_PRESENT:
self.assertEqual(registry.lookup('html5lib', 'html'),
HTML5TreeBuilder)
def test_lookup_by_markup_type(self):
if LXML_PRESENT:
self.assertEqual(registry.lookup('html'), LXMLTreeBuilder)
self.assertEqual(registry.lookup('xml'), LXMLTreeBuilderForXML)
else:
self.assertEqual(registry.lookup('xml'), None)
if HTML5LIB_PRESENT:
self.assertEqual(registry.lookup('html'), HTML5TreeBuilder)
else:
self.assertEqual(registry.lookup('html'), HTMLParserTreeBuilder)
def test_named_library(self):
if LXML_PRESENT:
self.assertEqual(registry.lookup('lxml', 'xml'),
LXMLTreeBuilderForXML)
self.assertEqual(registry.lookup('lxml', 'html'),
LXMLTreeBuilder)
if HTML5LIB_PRESENT:
self.assertEqual(registry.lookup('html5lib'),
HTML5TreeBuilder)
self.assertEqual(registry.lookup('html.parser'),
HTMLParserTreeBuilder)
def test_beautifulsoup_constructor_does_lookup(self):
with warnings.catch_warnings(record=True) as w:
# This will create a warning about not explicitly
# specifying a parser, but we'll ignore it.
# You can pass in a string.
BeautifulSoup("", features="html")
# Or a list of strings.
BeautifulSoup("", features=["html", "fast"])
# You'll get an exception if BS can't find an appropriate
# builder.
self.assertRaises(ValueError, BeautifulSoup,
"", features="no-such-feature")
class RegistryTest(unittest.TestCase):
"""Test the TreeBuilderRegistry class in general."""
def setUp(self):
self.registry = TreeBuilderRegistry()
def builder_for_features(self, *feature_list):
cls = type('Builder_' + '_'.join(feature_list),
(object,), {'features' : feature_list})
self.registry.register(cls)
return cls
def test_register_with_no_features(self):
builder = self.builder_for_features()
# Since the builder advertises no features, you can't find it
# by looking up features.
self.assertEqual(self.registry.lookup('foo'), None)
# But you can find it by doing a lookup with no features, if
# this happens to be the only registered builder.
self.assertEqual(self.registry.lookup(), builder)
def test_register_with_features_makes_lookup_succeed(self):
builder = self.builder_for_features('foo', 'bar')
self.assertEqual(self.registry.lookup('foo'), builder)
self.assertEqual(self.registry.lookup('bar'), builder)
def test_lookup_fails_when_no_builder_implements_feature(self):
builder = self.builder_for_features('foo', 'bar')
self.assertEqual(self.registry.lookup('baz'), None)
def test_lookup_gets_most_recent_registration_when_no_feature_specified(self):
builder1 = self.builder_for_features('foo')
builder2 = self.builder_for_features('bar')
self.assertEqual(self.registry.lookup(), builder2)
def test_lookup_fails_when_no_tree_builders_registered(self):
self.assertEqual(self.registry.lookup(), None)
def test_lookup_gets_most_recent_builder_supporting_all_features(self):
has_one = self.builder_for_features('foo')
has_the_other = self.builder_for_features('bar')
has_both_early = self.builder_for_features('foo', 'bar', 'baz')
has_both_late = self.builder_for_features('foo', 'bar', 'quux')
lacks_one = self.builder_for_features('bar')
has_the_other = self.builder_for_features('foo')
# There are two builders featuring 'foo' and 'bar', but
# the one that also features 'quux' was registered later.
self.assertEqual(self.registry.lookup('foo', 'bar'),
has_both_late)
# There is only one builder featuring 'foo', 'bar', and 'baz'.
self.assertEqual(self.registry.lookup('foo', 'bar', 'baz'),
has_both_early)
def test_lookup_fails_when_cannot_reconcile_requested_features(self):
builder1 = self.builder_for_features('foo', 'bar')
builder2 = self.builder_for_features('foo', 'baz')
self.assertEqual(self.registry.lookup('bar', 'baz'), None)

View file

@ -0,0 +1,36 @@
"Test harness for doctests."
# pylint: disable-msg=E0611,W0142
__metaclass__ = type
__all__ = [
'additional_tests',
]
import atexit
import doctest
import os
#from pkg_resources import (
# resource_filename, resource_exists, resource_listdir, cleanup_resources)
import unittest
DOCTEST_FLAGS = (
doctest.ELLIPSIS |
doctest.NORMALIZE_WHITESPACE |
doctest.REPORT_NDIFF)
# def additional_tests():
# "Run the doc tests (README.txt and docs/*, if any exist)"
# doctest_files = [
# os.path.abspath(resource_filename('bs4', 'README.txt'))]
# if resource_exists('bs4', 'docs'):
# for name in resource_listdir('bs4', 'docs'):
# if name.endswith('.txt'):
# doctest_files.append(
# os.path.abspath(
# resource_filename('bs4', 'docs/%s' % name)))
# kwargs = dict(module_relative=False, optionflags=DOCTEST_FLAGS)
# atexit.register(cleanup_resources)
# return unittest.TestSuite((
# doctest.DocFileSuite(*doctest_files, **kwargs)))

View file

@ -0,0 +1,170 @@
"""Tests to ensure that the html5lib tree builder generates good trees."""
import warnings
try:
from bs4.builder import HTML5TreeBuilder
HTML5LIB_PRESENT = True
except ImportError as e:
HTML5LIB_PRESENT = False
from bs4.element import SoupStrainer
from bs4.testing import (
HTML5TreeBuilderSmokeTest,
SoupTest,
skipIf,
)
@skipIf(
not HTML5LIB_PRESENT,
"html5lib seems not to be present, not testing its tree builder.")
class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
"""See ``HTML5TreeBuilderSmokeTest``."""
@property
def default_builder(self):
return HTML5TreeBuilder
def test_soupstrainer(self):
# The html5lib tree builder does not support SoupStrainers.
strainer = SoupStrainer("b")
markup = "<p>A <b>bold</b> statement.</p>"
with warnings.catch_warnings(record=True) as w:
soup = self.soup(markup, parse_only=strainer)
self.assertEqual(
soup.decode(), self.document_for(markup))
self.assertTrue(
"the html5lib tree builder doesn't support parse_only" in
str(w[0].message))
def test_correctly_nested_tables(self):
"""html5lib inserts <tbody> tags where other parsers don't."""
markup = ('<table id="1">'
'<tr>'
"<td>Here's another table:"
'<table id="2">'
'<tr><td>foo</td></tr>'
'</table></td>')
self.assertSoupEquals(
markup,
'<table id="1"><tbody><tr><td>Here\'s another table:'
'<table id="2"><tbody><tr><td>foo</td></tr></tbody></table>'
'</td></tr></tbody></table>')
self.assertSoupEquals(
"<table><thead><tr><td>Foo</td></tr></thead>"
"<tbody><tr><td>Bar</td></tr></tbody>"
"<tfoot><tr><td>Baz</td></tr></tfoot></table>")
def test_xml_declaration_followed_by_doctype(self):
markup = '''<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html>
<html>
<head>
</head>
<body>
<p>foo</p>
</body>
</html>'''
soup = self.soup(markup)
# Verify that we can reach the <p> tag; this means the tree is connected.
self.assertEqual(b"<p>foo</p>", soup.p.encode())
def test_reparented_markup(self):
markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>'
soup = self.soup(markup)
self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p></body>", soup.body.decode())
self.assertEqual(2, len(soup.find_all('p')))
def test_reparented_markup_ends_with_whitespace(self):
markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>\n'
soup = self.soup(markup)
self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p>\n</body>", soup.body.decode())
self.assertEqual(2, len(soup.find_all('p')))
def test_reparented_markup_containing_identical_whitespace_nodes(self):
"""Verify that we keep the two whitespace nodes in this
document distinct when reparenting the adjacent <tbody> tags.
"""
markup = '<table> <tbody><tbody><ims></tbody> </table>'
soup = self.soup(markup)
space1, space2 = soup.find_all(string=' ')
tbody1, tbody2 = soup.find_all('tbody')
assert space1.next_element is tbody1
assert tbody2.next_element is space2
def test_reparented_markup_containing_children(self):
markup = '<div><a>aftermath<p><noscript>target</noscript>aftermath</a></p></div>'
soup = self.soup(markup)
noscript = soup.noscript
self.assertEqual("target", noscript.next_element)
target = soup.find(string='target')
# The 'aftermath' string was duplicated; we want the second one.
final_aftermath = soup.find_all(string='aftermath')[-1]
# The <noscript> tag was moved beneath a copy of the <a> tag,
# but the 'target' string within is still connected to the
# (second) 'aftermath' string.
self.assertEqual(final_aftermath, target.next_element)
self.assertEqual(target, final_aftermath.previous_element)
def test_processing_instruction(self):
"""Processing instructions become comments."""
markup = b"""<?PITarget PIContent?>"""
soup = self.soup(markup)
assert str(soup).startswith("<!--?PITarget PIContent?-->")
def test_cloned_multivalue_node(self):
markup = b"""<a class="my_class"><p></a>"""
soup = self.soup(markup)
a1, a2 = soup.find_all('a')
self.assertEqual(a1, a2)
assert a1 is not a2
def test_foster_parenting(self):
markup = b"""<table><td></tbody>A"""
soup = self.soup(markup)
self.assertEqual("<body>A<table><tbody><tr><td></td></tr></tbody></table></body>", soup.body.decode())
def test_extraction(self):
"""
Test that extraction does not destroy the tree.
https://bugs.launchpad.net/beautifulsoup/+bug/1782928
"""
markup = """
<html><head></head>
<style>
</style><script></script><body><p>hello</p></body></html>
"""
soup = self.soup(markup)
[s.extract() for s in soup('script')]
[s.extract() for s in soup('style')]
self.assertEqual(len(soup.find_all("p")), 1)
def test_empty_comment(self):
"""
Test that empty comment does not break structure.
https://bugs.launchpad.net/beautifulsoup/+bug/1806598
"""
markup = """
<html>
<body>
<form>
<!----><input type="text">
</form>
</body>
</html>
"""
soup = self.soup(markup)
inputs = []
for form in soup.find_all('form'):
inputs.extend(form.find_all('input'))
self.assertEqual(len(inputs), 1)

View file

@ -0,0 +1,47 @@
"""Tests to ensure that the html.parser tree builder generates good
trees."""
from pdb import set_trace
import pickle
from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest
from bs4.builder import HTMLParserTreeBuilder
from bs4.builder._htmlparser import BeautifulSoupHTMLParser
class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
default_builder = HTMLParserTreeBuilder
def test_namespaced_system_doctype(self):
# html.parser can't handle namespaced doctypes, so skip this one.
pass
def test_namespaced_public_doctype(self):
# html.parser can't handle namespaced doctypes, so skip this one.
pass
def test_builder_is_pickled(self):
"""Unlike most tree builders, HTMLParserTreeBuilder and will
be restored after pickling.
"""
tree = self.soup("<a><b>foo</a>")
dumped = pickle.dumps(tree, 2)
loaded = pickle.loads(dumped)
self.assertTrue(isinstance(loaded.builder, type(tree.builder)))
def test_redundant_empty_element_closing_tags(self):
self.assertSoupEquals('<br></br><br></br><br></br>', "<br/><br/><br/>")
self.assertSoupEquals('</br></br></br>', "")
def test_empty_element(self):
# This verifies that any buffered data present when the parser
# finishes working is handled.
self.assertSoupEquals("foo &# bar", "foo &amp;# bar")
class TestHTMLParserSubclass(SoupTest):
def test_error(self):
"""Verify that our HTMLParser subclass implements error() in a way
that doesn't cause a crash.
"""
parser = BeautifulSoupHTMLParser()
parser.error("don't crash")

100
libs/bs4/tests/test_lxml.py Normal file
View file

@ -0,0 +1,100 @@
"""Tests to ensure that the lxml tree builder generates good trees."""
import re
import warnings
try:
import lxml.etree
LXML_PRESENT = True
LXML_VERSION = lxml.etree.LXML_VERSION
except ImportError as e:
LXML_PRESENT = False
LXML_VERSION = (0,)
if LXML_PRESENT:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
from bs4 import (
BeautifulSoup,
BeautifulStoneSoup,
)
from bs4.element import Comment, Doctype, SoupStrainer
from bs4.testing import skipIf
from bs4.tests import test_htmlparser
from bs4.testing import (
HTMLTreeBuilderSmokeTest,
XMLTreeBuilderSmokeTest,
SoupTest,
skipIf,
)
@skipIf(
not LXML_PRESENT,
"lxml seems not to be present, not testing its tree builder.")
class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
"""See ``HTMLTreeBuilderSmokeTest``."""
@property
def default_builder(self):
return LXMLTreeBuilder
def test_out_of_range_entity(self):
self.assertSoupEquals(
"<p>foo&#10000000000000;bar</p>", "<p>foobar</p>")
self.assertSoupEquals(
"<p>foo&#x10000000000000;bar</p>", "<p>foobar</p>")
self.assertSoupEquals(
"<p>foo&#1000000000;bar</p>", "<p>foobar</p>")
def test_entities_in_foreign_document_encoding(self):
# We can't implement this case correctly because by the time we
# hear about markup like "&#147;", it's been (incorrectly) converted into
# a string like u'\x93'
pass
# In lxml < 2.3.5, an empty doctype causes a segfault. Skip this
# test if an old version of lxml is installed.
@skipIf(
not LXML_PRESENT or LXML_VERSION < (2,3,5,0),
"Skipping doctype test for old version of lxml to avoid segfault.")
def test_empty_doctype(self):
soup = self.soup("<!DOCTYPE>")
doctype = soup.contents[0]
self.assertEqual("", doctype.strip())
def test_beautifulstonesoup_is_xml_parser(self):
# Make sure that the deprecated BSS class uses an xml builder
# if one is installed.
with warnings.catch_warnings(record=True) as w:
soup = BeautifulStoneSoup("<b />")
self.assertEqual("<b/>", str(soup.b))
self.assertTrue("BeautifulStoneSoup class is deprecated" in str(w[0].message))
@skipIf(
not LXML_PRESENT,
"lxml seems not to be present, not testing its XML tree builder.")
class LXMLXMLTreeBuilderSmokeTest(SoupTest, XMLTreeBuilderSmokeTest):
"""See ``HTMLTreeBuilderSmokeTest``."""
@property
def default_builder(self):
return LXMLTreeBuilderForXML
def test_namespace_indexing(self):
# We should not track un-prefixed namespaces as we can only hold one
# and it will be recognized as the default namespace by soupsieve,
# which may be confusing in some situations. When no namespace is provided
# for a selector, the default namespace (if defined) is assumed.
soup = self.soup(
'<?xml version="1.1"?>\n'
'<root>'
'<tag xmlns="http://unprefixed-namespace.com">content</tag>'
'<prefix:tag xmlns:prefix="http://prefixed-namespace.com">content</tag>'
'</root>'
)
self.assertEqual(
soup._namespaces,
{'xml': 'http://www.w3.org/XML/1998/namespace', 'prefix': 'http://prefixed-namespace.com'}
)

567
libs/bs4/tests/test_soup.py Normal file
View file

@ -0,0 +1,567 @@
# -*- coding: utf-8 -*-
"""Tests of Beautiful Soup as a whole."""
from pdb import set_trace
import logging
import unittest
import sys
import tempfile
from bs4 import (
BeautifulSoup,
BeautifulStoneSoup,
)
from bs4.element import (
CharsetMetaAttributeValue,
ContentMetaAttributeValue,
SoupStrainer,
NamespacedAttribute,
)
import bs4.dammit
from bs4.dammit import (
EntitySubstitution,
UnicodeDammit,
EncodingDetector,
)
from bs4.testing import (
default_builder,
SoupTest,
skipIf,
)
import warnings
try:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
LXML_PRESENT = True
except ImportError as e:
LXML_PRESENT = False
PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2))
class TestConstructor(SoupTest):
def test_short_unicode_input(self):
data = "<h1>éé</h1>"
soup = self.soup(data)
self.assertEqual("éé", soup.h1.string)
def test_embedded_null(self):
data = "<h1>foo\0bar</h1>"
soup = self.soup(data)
self.assertEqual("foo\0bar", soup.h1.string)
def test_exclude_encodings(self):
utf8_data = "Räksmörgås".encode("utf-8")
soup = self.soup(utf8_data, exclude_encodings=["utf-8"])
self.assertEqual("windows-1252", soup.original_encoding)
def test_custom_builder_class(self):
# Verify that you can pass in a custom Builder class and
# it'll be instantiated with the appropriate keyword arguments.
class Mock(object):
def __init__(self, **kwargs):
self.called_with = kwargs
self.is_xml = True
def initialize_soup(self, soup):
pass
def prepare_markup(self, *args, **kwargs):
return ''
kwargs = dict(
var="value",
# This is a deprecated BS3-era keyword argument, which
# will be stripped out.
convertEntities=True,
)
with warnings.catch_warnings(record=True):
soup = BeautifulSoup('', builder=Mock, **kwargs)
assert isinstance(soup.builder, Mock)
self.assertEqual(dict(var="value"), soup.builder.called_with)
# You can also instantiate the TreeBuilder yourself. In this
# case, that specific object is used and any keyword arguments
# to the BeautifulSoup constructor are ignored.
builder = Mock(**kwargs)
with warnings.catch_warnings(record=True) as w:
soup = BeautifulSoup(
'', builder=builder, ignored_value=True,
)
msg = str(w[0].message)
assert msg.startswith("Keyword arguments to the BeautifulSoup constructor will be ignored.")
self.assertEqual(builder, soup.builder)
self.assertEqual(kwargs, builder.called_with)
def test_cdata_list_attributes(self):
# Most attribute values are represented as scalars, but the
# HTML standard says that some attributes, like 'class' have
# space-separated lists as values.
markup = '<a id=" an id " class=" a class "></a>'
soup = self.soup(markup)
# Note that the spaces are stripped for 'class' but not for 'id'.
a = soup.a
self.assertEqual(" an id ", a['id'])
self.assertEqual(["a", "class"], a['class'])
# TreeBuilder takes an argument called 'mutli_valued_attributes' which lets
# you customize or disable this. As always, you can customize the TreeBuilder
# by passing in a keyword argument to the BeautifulSoup constructor.
soup = self.soup(markup, builder=default_builder, multi_valued_attributes=None)
self.assertEqual(" a class ", soup.a['class'])
# Here are two ways of saying that `id` is a multi-valued
# attribute in this context, but 'class' is not.
for switcheroo in ({'*': 'id'}, {'a': 'id'}):
with warnings.catch_warnings(record=True) as w:
# This will create a warning about not explicitly
# specifying a parser, but we'll ignore it.
soup = self.soup(markup, builder=None, multi_valued_attributes=switcheroo)
a = soup.a
self.assertEqual(["an", "id"], a['id'])
self.assertEqual(" a class ", a['class'])
class TestWarnings(SoupTest):
def _no_parser_specified(self, s, is_there=True):
v = s.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:80])
self.assertTrue(v)
def test_warning_if_no_parser_specified(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>")
msg = str(w[0].message)
self._assert_no_parser_specified(msg)
def test_warning_if_parser_specified_too_vague(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", "html")
msg = str(w[0].message)
self._assert_no_parser_specified(msg)
def test_no_warning_if_explicit_parser_specified(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", "html.parser")
self.assertEqual([], w)
def test_parseOnlyThese_renamed_to_parse_only(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", parseOnlyThese=SoupStrainer("b"))
msg = str(w[0].message)
self.assertTrue("parseOnlyThese" in msg)
self.assertTrue("parse_only" in msg)
self.assertEqual(b"<b></b>", soup.encode())
def test_fromEncoding_renamed_to_from_encoding(self):
with warnings.catch_warnings(record=True) as w:
utf8 = b"\xc3\xa9"
soup = self.soup(utf8, fromEncoding="utf8")
msg = str(w[0].message)
self.assertTrue("fromEncoding" in msg)
self.assertTrue("from_encoding" in msg)
self.assertEqual("utf8", soup.original_encoding)
def test_unrecognized_keyword_argument(self):
self.assertRaises(
TypeError, self.soup, "<a>", no_such_argument=True)
class TestWarnings(SoupTest):
def test_disk_file_warning(self):
filehandle = tempfile.NamedTemporaryFile()
filename = filehandle.name
try:
with warnings.catch_warnings(record=True) as w:
soup = self.soup(filename)
msg = str(w[0].message)
self.assertTrue("looks like a filename" in msg)
finally:
filehandle.close()
# The file no longer exists, so Beautiful Soup will no longer issue the warning.
with warnings.catch_warnings(record=True) as w:
soup = self.soup(filename)
self.assertEqual(0, len(w))
def test_url_warning_with_bytes_url(self):
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(b"http://www.crummybytes.com/")
# Be aware this isn't the only warning that can be raised during
# execution..
self.assertTrue(any("looks like a URL" in str(w.message)
for w in warning_list))
def test_url_warning_with_unicode_url(self):
with warnings.catch_warnings(record=True) as warning_list:
# note - this url must differ from the bytes one otherwise
# python's warnings system swallows the second warning
soup = self.soup("http://www.crummyunicode.com/")
self.assertTrue(any("looks like a URL" in str(w.message)
for w in warning_list))
def test_url_warning_with_bytes_and_space(self):
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(b"http://www.crummybytes.com/ is great")
self.assertFalse(any("looks like a URL" in str(w.message)
for w in warning_list))
def test_url_warning_with_unicode_and_space(self):
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup("http://www.crummyuncode.com/ is great")
self.assertFalse(any("looks like a URL" in str(w.message)
for w in warning_list))
class TestSelectiveParsing(SoupTest):
def test_parse_with_soupstrainer(self):
markup = "No<b>Yes</b><a>No<b>Yes <c>Yes</c></b>"
strainer = SoupStrainer("b")
soup = self.soup(markup, parse_only=strainer)
self.assertEqual(soup.encode(), b"<b>Yes</b><b>Yes <c>Yes</c></b>")
class TestEntitySubstitution(unittest.TestCase):
"""Standalone tests of the EntitySubstitution class."""
def setUp(self):
self.sub = EntitySubstitution
def test_simple_html_substitution(self):
# Unicode characters corresponding to named HTML entites
# are substituted, and no others.
s = "foo\u2200\N{SNOWMAN}\u00f5bar"
self.assertEqual(self.sub.substitute_html(s),
"foo&forall;\N{SNOWMAN}&otilde;bar")
def test_smart_quote_substitution(self):
# MS smart quotes are a common source of frustration, so we
# give them a special test.
quotes = b"\x91\x92foo\x93\x94"
dammit = UnicodeDammit(quotes)
self.assertEqual(self.sub.substitute_html(dammit.markup),
"&lsquo;&rsquo;foo&ldquo;&rdquo;")
def test_xml_converstion_includes_no_quotes_if_make_quoted_attribute_is_false(self):
s = 'Welcome to "my bar"'
self.assertEqual(self.sub.substitute_xml(s, False), s)
def test_xml_attribute_quoting_normally_uses_double_quotes(self):
self.assertEqual(self.sub.substitute_xml("Welcome", True),
'"Welcome"')
self.assertEqual(self.sub.substitute_xml("Bob's Bar", True),
'"Bob\'s Bar"')
def test_xml_attribute_quoting_uses_single_quotes_when_value_contains_double_quotes(self):
s = 'Welcome to "my bar"'
self.assertEqual(self.sub.substitute_xml(s, True),
"'Welcome to \"my bar\"'")
def test_xml_attribute_quoting_escapes_single_quotes_when_value_contains_both_single_and_double_quotes(self):
s = 'Welcome to "Bob\'s Bar"'
self.assertEqual(
self.sub.substitute_xml(s, True),
'"Welcome to &quot;Bob\'s Bar&quot;"')
def test_xml_quotes_arent_escaped_when_value_is_not_being_quoted(self):
quoted = 'Welcome to "Bob\'s Bar"'
self.assertEqual(self.sub.substitute_xml(quoted), quoted)
def test_xml_quoting_handles_angle_brackets(self):
self.assertEqual(
self.sub.substitute_xml("foo<bar>"),
"foo&lt;bar&gt;")
def test_xml_quoting_handles_ampersands(self):
self.assertEqual(self.sub.substitute_xml("AT&T"), "AT&amp;T")
def test_xml_quoting_including_ampersands_when_they_are_part_of_an_entity(self):
self.assertEqual(
self.sub.substitute_xml("&Aacute;T&T"),
"&amp;Aacute;T&amp;T")
def test_xml_quoting_ignoring_ampersands_when_they_are_part_of_an_entity(self):
self.assertEqual(
self.sub.substitute_xml_containing_entities("&Aacute;T&T"),
"&Aacute;T&amp;T")
def test_quotes_not_html_substituted(self):
"""There's no need to do this except inside attribute values."""
text = 'Bob\'s "bar"'
self.assertEqual(self.sub.substitute_html(text), text)
class TestEncodingConversion(SoupTest):
# Test Beautiful Soup's ability to decode and encode from various
# encodings.
def setUp(self):
super(TestEncodingConversion, self).setUp()
self.unicode_data = '<html><head><meta charset="utf-8"/></head><body><foo>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</foo></body></html>'
self.utf8_data = self.unicode_data.encode("utf-8")
# Just so you know what it looks like.
self.assertEqual(
self.utf8_data,
b'<html><head><meta charset="utf-8"/></head><body><foo>Sacr\xc3\xa9 bleu!</foo></body></html>')
def test_ascii_in_unicode_out(self):
# ASCII input is converted to Unicode. The original_encoding
# attribute is set to 'utf-8', a superset of ASCII.
chardet = bs4.dammit.chardet_dammit
logging.disable(logging.WARNING)
try:
def noop(str):
return None
# Disable chardet, which will realize that the ASCII is ASCII.
bs4.dammit.chardet_dammit = noop
ascii = b"<foo>a</foo>"
soup_from_ascii = self.soup(ascii)
unicode_output = soup_from_ascii.decode()
self.assertTrue(isinstance(unicode_output, str))
self.assertEqual(unicode_output, self.document_for(ascii.decode()))
self.assertEqual(soup_from_ascii.original_encoding.lower(), "utf-8")
finally:
logging.disable(logging.NOTSET)
bs4.dammit.chardet_dammit = chardet
def test_unicode_in_unicode_out(self):
# Unicode input is left alone. The original_encoding attribute
# is not set.
soup_from_unicode = self.soup(self.unicode_data)
self.assertEqual(soup_from_unicode.decode(), self.unicode_data)
self.assertEqual(soup_from_unicode.foo.string, 'Sacr\xe9 bleu!')
self.assertEqual(soup_from_unicode.original_encoding, None)
def test_utf8_in_unicode_out(self):
# UTF-8 input is converted to Unicode. The original_encoding
# attribute is set.
soup_from_utf8 = self.soup(self.utf8_data)
self.assertEqual(soup_from_utf8.decode(), self.unicode_data)
self.assertEqual(soup_from_utf8.foo.string, 'Sacr\xe9 bleu!')
def test_utf8_out(self):
# The internal data structures can be encoded as UTF-8.
soup_from_unicode = self.soup(self.unicode_data)
self.assertEqual(soup_from_unicode.encode('utf-8'), self.utf8_data)
@skipIf(
PYTHON_3_PRE_3_2,
"Bad HTMLParser detected; skipping test of non-ASCII characters in attribute name.")
def test_attribute_name_containing_unicode_characters(self):
markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
self.assertEqual(self.soup(markup).div.encode("utf8"), markup.encode("utf8"))
class TestUnicodeDammit(unittest.TestCase):
"""Standalone tests of UnicodeDammit."""
def test_unicode_input(self):
markup = "I'm already Unicode! \N{SNOWMAN}"
dammit = UnicodeDammit(markup)
self.assertEqual(dammit.unicode_markup, markup)
def test_smart_quotes_to_unicode(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup)
self.assertEqual(
dammit.unicode_markup, "<foo>\u2018\u2019\u201c\u201d</foo>")
def test_smart_quotes_to_xml_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="xml")
self.assertEqual(
dammit.unicode_markup, "<foo>&#x2018;&#x2019;&#x201C;&#x201D;</foo>")
def test_smart_quotes_to_html_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="html")
self.assertEqual(
dammit.unicode_markup, "<foo>&lsquo;&rsquo;&ldquo;&rdquo;</foo>")
def test_smart_quotes_to_ascii(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="ascii")
self.assertEqual(
dammit.unicode_markup, """<foo>''""</foo>""")
def test_detect_utf8(self):
utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
dammit = UnicodeDammit(utf8)
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
self.assertEqual(dammit.unicode_markup, 'Sacr\xe9 bleu! \N{SNOWMAN}')
def test_convert_hebrew(self):
hebrew = b"\xed\xe5\xec\xf9"
dammit = UnicodeDammit(hebrew, ["iso-8859-8"])
self.assertEqual(dammit.original_encoding.lower(), 'iso-8859-8')
self.assertEqual(dammit.unicode_markup, '\u05dd\u05d5\u05dc\u05e9')
def test_dont_see_smart_quotes_where_there_are_none(self):
utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch"
dammit = UnicodeDammit(utf_8)
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8)
def test_ignore_inappropriate_codecs(self):
utf8_data = "Räksmörgås".encode("utf-8")
dammit = UnicodeDammit(utf8_data, ["iso-8859-8"])
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
def test_ignore_invalid_codecs(self):
utf8_data = "Räksmörgås".encode("utf-8")
for bad_encoding in ['.utf8', '...', 'utF---16.!']:
dammit = UnicodeDammit(utf8_data, [bad_encoding])
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
def test_exclude_encodings(self):
# This is UTF-8.
utf8_data = "Räksmörgås".encode("utf-8")
# But if we exclude UTF-8 from consideration, the guess is
# Windows-1252.
dammit = UnicodeDammit(utf8_data, exclude_encodings=["utf-8"])
self.assertEqual(dammit.original_encoding.lower(), 'windows-1252')
# And if we exclude that, there is no valid guess at all.
dammit = UnicodeDammit(
utf8_data, exclude_encodings=["utf-8", "windows-1252"])
self.assertEqual(dammit.original_encoding, None)
def test_encoding_detector_replaces_junk_in_encoding_name_with_replacement_character(self):
detected = EncodingDetector(
b'<?xml version="1.0" encoding="UTF-\xdb" ?>')
encodings = list(detected.encodings)
assert 'utf-\N{REPLACEMENT CHARACTER}' in encodings
def test_detect_html5_style_meta_tag(self):
for data in (
b'<html><meta charset="euc-jp" /></html>',
b"<html><meta charset='euc-jp' /></html>",
b"<html><meta charset=euc-jp /></html>",
b"<html><meta charset=euc-jp/></html>"):
dammit = UnicodeDammit(data, is_html=True)
self.assertEqual(
"euc-jp", dammit.original_encoding)
def test_last_ditch_entity_replacement(self):
# This is a UTF-8 document that contains bytestrings
# completely incompatible with UTF-8 (ie. encoded with some other
# encoding).
#
# Since there is no consistent encoding for the document,
# Unicode, Dammit will eventually encode the document as UTF-8
# and encode the incompatible characters as REPLACEMENT
# CHARACTER.
#
# If chardet is installed, it will detect that the document
# can be converted into ISO-8859-1 without errors. This happens
# to be the wrong encoding, but it is a consistent encoding, so the
# code we're testing here won't run.
#
# So we temporarily disable chardet if it's present.
doc = b"""\357\273\277<?xml version="1.0" encoding="UTF-8"?>
<html><b>\330\250\330\252\330\261</b>
<i>\310\322\321\220\312\321\355\344</i></html>"""
chardet = bs4.dammit.chardet_dammit
logging.disable(logging.WARNING)
try:
def noop(str):
return None
bs4.dammit.chardet_dammit = noop
dammit = UnicodeDammit(doc)
self.assertEqual(True, dammit.contains_replacement_characters)
self.assertTrue("\ufffd" in dammit.unicode_markup)
soup = BeautifulSoup(doc, "html.parser")
self.assertTrue(soup.contains_replacement_characters)
finally:
logging.disable(logging.NOTSET)
bs4.dammit.chardet_dammit = chardet
def test_byte_order_mark_removed(self):
# A document written in UTF-16LE will have its byte order marker stripped.
data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00'
dammit = UnicodeDammit(data)
self.assertEqual("<a>áé</a>", dammit.unicode_markup)
self.assertEqual("utf-16le", dammit.original_encoding)
def test_detwingle(self):
# Here's a UTF8 document.
utf8 = ("\N{SNOWMAN}" * 3).encode("utf8")
# Here's a Windows-1252 document.
windows_1252 = (
"\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!"
"\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252")
# Through some unholy alchemy, they've been stuck together.
doc = utf8 + windows_1252 + utf8
# The document can't be turned into UTF-8:
self.assertRaises(UnicodeDecodeError, doc.decode, "utf8")
# Unicode, Dammit thinks the whole document is Windows-1252,
# and decodes it into "☃☃☃“Hi, I like Windows!”☃☃☃"
# But if we run it through fix_embedded_windows_1252, it's fixed:
fixed = UnicodeDammit.detwingle(doc)
self.assertEqual(
"☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8"))
def test_detwingle_ignores_multibyte_characters(self):
# Each of these characters has a UTF-8 representation ending
# in \x93. \x93 is a smart quote if interpreted as
# Windows-1252. But our code knows to skip over multibyte
# UTF-8 characters, so they'll survive the process unscathed.
for tricky_unicode_char in (
"\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93'
"\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93'
"\xf0\x90\x90\x93", # This is a CJK character, not sure which one.
):
input = tricky_unicode_char.encode("utf8")
self.assertTrue(input.endswith(b'\x93'))
output = UnicodeDammit.detwingle(input)
self.assertEqual(output, input)
class TestNamedspacedAttribute(SoupTest):
def test_name_may_be_none(self):
a = NamespacedAttribute("xmlns", None)
self.assertEqual(a, "xmlns")
def test_attribute_is_equivalent_to_colon_separated_string(self):
a = NamespacedAttribute("a", "b")
self.assertEqual("a:b", a)
def test_attributes_are_equivalent_if_prefix_and_name_identical(self):
a = NamespacedAttribute("a", "b", "c")
b = NamespacedAttribute("a", "b", "c")
self.assertEqual(a, b)
# The actual namespace is not considered.
c = NamespacedAttribute("a", "b", None)
self.assertEqual(a, c)
# But name and prefix are important.
d = NamespacedAttribute("a", "z", "c")
self.assertNotEqual(a, d)
e = NamespacedAttribute("z", "b", "c")
self.assertNotEqual(a, e)
class TestAttributeValueWithCharsetSubstitution(unittest.TestCase):
def test_content_meta_attribute_value(self):
value = CharsetMetaAttributeValue("euc-jp")
self.assertEqual("euc-jp", value)
self.assertEqual("euc-jp", value.original_value)
self.assertEqual("utf8", value.encode("utf8"))
def test_content_meta_attribute_value(self):
value = ContentMetaAttributeValue("text/html; charset=euc-jp")
self.assertEqual("text/html; charset=euc-jp", value)
self.assertEqual("text/html; charset=euc-jp", value.original_value)
self.assertEqual("text/html; charset=utf8", value.encode("utf8"))

2205
libs/bs4/tests/test_tree.py Normal file

File diff suppressed because it is too large Load diff

25
libs/engineio/__init__.py Normal file
View file

@ -0,0 +1,25 @@
import sys
from .client import Client
from .middleware import WSGIApp, Middleware
from .server import Server
if sys.version_info >= (3, 5): # pragma: no cover
from .asyncio_server import AsyncServer
from .asyncio_client import AsyncClient
from .async_drivers.asgi import ASGIApp
try:
from .async_drivers.tornado import get_tornado_handler
except ImportError:
get_tornado_handler = None
else: # pragma: no cover
AsyncServer = None
AsyncClient = None
get_tornado_handler = None
ASGIApp = None
__version__ = '3.11.2'
__all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client']
if AsyncServer is not None: # pragma: no cover
__all__ += ['AsyncServer', 'ASGIApp', 'get_tornado_handler',
'AsyncClient'],

View file

View file

@ -0,0 +1,128 @@
import asyncio
import sys
from urllib.parse import urlsplit
from aiohttp.web import Response, WebSocketResponse
import six
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.router.add_get(engineio_endpoint, engineio_server.handle_request)
app.router.add_post(engineio_endpoint, engineio_server.handle_request)
app.router.add_route('OPTIONS', engineio_endpoint,
engineio_server.handle_request)
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
message = request._message
payload = request._payload
uri_parts = urlsplit(message.path)
environ = {
'wsgi.input': payload,
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': message.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': message.path,
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'aiohttp.request': request
}
for hdr_name, hdr_value in message.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
return Response(body=payload, status=int(status.split()[0]),
headers=headers)
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a aiohttp WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['aiohttp.request']
self._sock = WebSocketResponse()
await self._sock.prepare(request)
self.environ = environ
await self.handler(self)
return self._sock
async def close(self):
await self._sock.close()
async def send(self, message):
if isinstance(message, bytes):
f = self._sock.send_bytes
else:
f = self._sock.send_str
if asyncio.iscoroutinefunction(f):
await f(message)
else:
f(message)
async def wait(self):
msg = await self._sock.receive()
if not isinstance(msg.data, six.binary_type) and \
not isinstance(msg.data, six.text_type):
raise IOError()
return msg.data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View file

@ -0,0 +1,214 @@
import os
import sys
from engineio.static_files import get_static_file
class ASGIApp:
"""ASGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another ASGI application.
:param engineio_server: The Engine.IO server. Must be an instance of the
``engineio.AsyncServer`` class.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param other_asgi_app: A separate ASGI app that receives all other traffic.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import engineio
import uvicorn
eio = engineio.AsyncServer()
app = engineio.ASGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
uvicorn.run(app, '127.0.0.1', 5000)
"""
def __init__(self, engineio_server, other_asgi_app=None,
static_files=None, engineio_path='engine.io'):
self.engineio_server = engineio_server
self.other_asgi_app = other_asgi_app
self.engineio_path = engineio_path.strip('/')
self.static_files = static_files or {}
async def __call__(self, scope, receive, send):
if scope['type'] in ['http', 'websocket'] and \
scope['path'].startswith('/{0}/'.format(self.engineio_path)):
await self.engineio_server.handle_request(scope, receive, send)
else:
static_file = get_static_file(scope['path'], self.static_files) \
if scope['type'] == 'http' and self.static_files else None
if static_file:
await self.serve_static_file(static_file, receive, send)
elif self.other_asgi_app is not None:
await self.other_asgi_app(scope, receive, send)
elif scope['type'] == 'lifespan':
await self.lifespan(receive, send)
else:
await self.not_found(receive, send)
async def serve_static_file(self, static_file, receive,
send): # pragma: no cover
event = await receive()
if event['type'] == 'http.request':
if os.path.exists(static_file['filename']):
with open(static_file['filename'], 'rb') as f:
payload = f.read()
await send({'type': 'http.response.start',
'status': 200,
'headers': [(b'Content-Type', static_file[
'content_type'].encode('utf-8'))]})
await send({'type': 'http.response.body',
'body': payload})
else:
await self.not_found(receive, send)
async def lifespan(self, receive, send):
event = await receive()
if event['type'] == 'lifespan.startup':
await send({'type': 'lifespan.startup.complete'})
elif event['type'] == 'lifespan.shutdown':
await send({'type': 'lifespan.shutdown.complete'})
async def not_found(self, receive, send):
"""Return a 404 Not Found error to the client."""
await send({'type': 'http.response.start',
'status': 404,
'headers': [(b'Content-Type', b'text/plain')]})
await send({'type': 'http.response.body',
'body': b'Not Found'})
async def translate_request(scope, receive, send):
class AwaitablePayload(object): # pragma: no cover
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
event = await receive()
payload = b''
if event['type'] == 'http.request':
payload += event.get('body') or b''
while event.get('more_body'):
event = await receive()
if event['type'] == 'http.request':
payload += event.get('body') or b''
elif event['type'] == 'websocket.connect':
await send({'type': 'websocket.accept'})
else:
return {}
raw_uri = scope['path'].encode('utf-8')
if 'query_string' in scope and scope['query_string']:
raw_uri += b'?' + scope['query_string']
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'asgi',
'REQUEST_METHOD': scope.get('method', 'GET'),
'PATH_INFO': scope['path'],
'QUERY_STRING': scope.get('query_string', b'').decode('utf-8'),
'RAW_URI': raw_uri.decode('utf-8'),
'SCRIPT_NAME': '',
'SERVER_PROTOCOL': 'HTTP/1.1',
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'asgi',
'SERVER_PORT': '0',
'asgi.receive': receive,
'asgi.send': send,
}
for hdr_name, hdr_value in scope['headers']:
hdr_name = hdr_name.upper().decode('utf-8')
hdr_value = hdr_value.decode('utf-8')
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
return environ
async def make_response(status, headers, payload, environ):
headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers]
await environ['asgi.send']({'type': 'http.response.start',
'status': int(status.split(' ')[0]),
'headers': headers})
await environ['asgi.send']({'type': 'http.response.body',
'body': payload})
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides an asgi WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self.asgi_receive = None
self.asgi_send = None
async def __call__(self, environ):
self.asgi_receive = environ['asgi.receive']
self.asgi_send = environ['asgi.send']
await self.handler(self)
async def close(self):
await self.asgi_send({'type': 'websocket.close'})
async def send(self, message):
msg_bytes = None
msg_text = None
if isinstance(message, bytes):
msg_bytes = message
else:
msg_text = message
await self.asgi_send({'type': 'websocket.send',
'bytes': msg_bytes,
'text': msg_text})
async def wait(self):
event = await self.asgi_receive()
if event['type'] != 'websocket.receive':
raise IOError()
return event.get('bytes') or event.get('text')
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View file

@ -0,0 +1,30 @@
from __future__ import absolute_import
from eventlet.green.threading import Thread, Event
from eventlet import queue
from eventlet import sleep
from eventlet.websocket import WebSocketWSGI as _WebSocketWSGI
class WebSocketWSGI(_WebSocketWSGI):
def __init__(self, *args, **kwargs):
super(WebSocketWSGI, self).__init__(*args, **kwargs)
self._sock = None
def __call__(self, environ, start_response):
if 'eventlet.input' not in environ:
raise RuntimeError('You need to use the eventlet server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['eventlet.input'].get_socket()
return super(WebSocketWSGI, self).__call__(environ, start_response)
_async = {
'thread': Thread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': WebSocketWSGI,
'sleep': sleep,
}

View file

@ -0,0 +1,63 @@
from __future__ import absolute_import
import gevent
from gevent import queue
from gevent.event import Event
try:
import geventwebsocket # noqa
_websocket_available = True
except ImportError:
_websocket_available = False
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super(Thread, self).__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class WebSocketWSGI(object): # pragma: no cover
"""
This wrapper class provides a gevent WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
if 'wsgi.websocket' not in environ:
raise RuntimeError('You need to use the gevent-websocket server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['wsgi.websocket']
self.environ = environ
self.version = self._sock.version
self.path = self._sock.path
self.origin = self._sock.origin
self.protocol = self._sock.protocol
return self.app(self)
def close(self):
return self._sock.close()
def send(self, message):
return self._sock.send(message)
def wait(self):
return self._sock.receive()
_async = {
'thread': Thread,
'queue': queue.JoinableQueue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': WebSocketWSGI if _websocket_available else None,
'sleep': gevent.sleep,
}

View file

@ -0,0 +1,156 @@
from __future__ import absolute_import
import six
import gevent
from gevent import queue
from gevent.event import Event
import uwsgi
_websocket_available = hasattr(uwsgi, 'websocket_handshake')
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super(Thread, self).__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class uWSGIWebSocket(object): # pragma: no cover
"""
This wrapper class provides a uWSGI WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
self._sock = None
def __call__(self, environ, start_response):
self._sock = uwsgi.connection_fd()
self.environ = environ
uwsgi.websocket_handshake()
self._req_ctx = None
if hasattr(uwsgi, 'request_context'):
# uWSGI >= 2.1.x with support for api access across-greenlets
self._req_ctx = uwsgi.request_context()
else:
# use event and queue for sending messages
from gevent.event import Event
from gevent.queue import Queue
from gevent.select import select
self._event = Event()
self._send_queue = Queue()
# spawn a select greenlet
def select_greenlet_runner(fd, event):
"""Sets event when data becomes available to read on fd."""
while True:
event.set()
try:
select([fd], [], [])[0]
except ValueError:
break
self._select_greenlet = gevent.spawn(
select_greenlet_runner,
self._sock,
self._event)
self.app(self)
def close(self):
"""Disconnects uWSGI from the client."""
uwsgi.disconnect()
if self._req_ctx is None:
# better kill it here in case wait() is not called again
self._select_greenlet.kill()
self._event.set()
def _send(self, msg):
"""Transmits message either in binary or UTF-8 text mode,
depending on its type."""
if isinstance(msg, six.binary_type):
method = uwsgi.websocket_send_binary
else:
method = uwsgi.websocket_send
if self._req_ctx is not None:
method(msg, request_context=self._req_ctx)
else:
method(msg)
def _decode_received(self, msg):
"""Returns either bytes or str, depending on message type."""
if not isinstance(msg, six.binary_type):
# already decoded - do nothing
return msg
# only decode from utf-8 if message is not binary data
type = six.byte2int(msg[0:1])
if type >= 48: # no binary
return msg.decode('utf-8')
# binary message, don't try to decode
return msg
def send(self, msg):
"""Queues a message for sending. Real transmission is done in
wait method.
Sends directly if uWSGI version is new enough."""
if self._req_ctx is not None:
self._send(msg)
else:
self._send_queue.put(msg)
self._event.set()
def wait(self):
"""Waits and returns received messages.
If running in compatibility mode for older uWSGI versions,
it also sends messages that have been queued by send().
A return value of None means that connection was closed.
This must be called repeatedly. For uWSGI < 2.1.x it must
be called from the main greenlet."""
while True:
if self._req_ctx is not None:
try:
msg = uwsgi.websocket_recv(request_context=self._req_ctx)
except IOError: # connection closed
return None
return self._decode_received(msg)
else:
# we wake up at least every 3 seconds to let uWSGI
# do its ping/ponging
event_set = self._event.wait(timeout=3)
if event_set:
self._event.clear()
# maybe there is something to send
msgs = []
while True:
try:
msgs.append(self._send_queue.get(block=False))
except gevent.queue.Empty:
break
for msg in msgs:
self._send(msg)
# maybe there is something to receive, if not, at least
# ensure uWSGI does its ping/ponging
try:
msg = uwsgi.websocket_recv_nb()
except IOError: # connection closed
self._select_greenlet.kill()
return None
if msg: # message available
return self._decode_received(msg)
_async = {
'thread': Thread,
'queue': queue.JoinableQueue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': uWSGIWebSocket if _websocket_available else None,
'sleep': gevent.sleep,
}

View file

@ -0,0 +1,144 @@
import sys
from urllib.parse import urlsplit
from sanic.response import HTTPResponse
try:
from sanic.websocket import WebSocketProtocol
except ImportError:
# the installed version of sanic does not have websocket support
WebSocketProtocol = None
import six
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.add_route(engineio_server.handle_request, engineio_endpoint,
methods=['GET', 'POST', 'OPTIONS'])
try:
app.enable_websocket()
except AttributeError:
# ignore, this version does not support websocket
pass
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload(object):
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
uri_parts = urlsplit(request.url)
environ = {
'wsgi.input': AwaitablePayload(request.body),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'sanic',
'REQUEST_METHOD': request.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': request.url,
'SERVER_PROTOCOL': 'HTTP/' + request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'sanic',
'SERVER_PORT': '0',
'sanic.request': request
}
for hdr_name, hdr_value in request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
headers_dict = {}
content_type = None
for h in headers:
if h[0].lower() == 'content-type':
content_type = h[1]
else:
headers_dict[h[0]] = h[1]
return HTTPResponse(body_bytes=payload, content_type=content_type,
status=int(status.split()[0]), headers=headers_dict)
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a sanic WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['sanic.request']
protocol = request.transport.get_protocol()
self._sock = await protocol.websocket_handshake(request)
self.environ = environ
await self.handler(self)
async def close(self):
await self._sock.close()
async def send(self, message):
await self._sock.send(message)
async def wait(self):
data = await self._sock.recv()
if not isinstance(data, six.binary_type) and \
not isinstance(data, six.text_type):
raise IOError()
return data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket if WebSocketProtocol else None,
}

View file

@ -0,0 +1,17 @@
from __future__ import absolute_import
import threading
import time
try:
import queue
except ImportError: # pragma: no cover
import Queue as queue
_async = {
'thread': threading.Thread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': threading.Event,
'websocket': None,
'sleep': time.sleep,
}

View file

@ -0,0 +1,184 @@
import asyncio
import sys
from urllib.parse import urlsplit
from .. import exceptions
import tornado.web
import tornado.websocket
import six
def get_tornado_handler(engineio_server):
class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(engineio_server.cors_allowed_origins,
six.string_types):
if engineio_server.cors_allowed_origins == '*':
self.allowed_origins = None
else:
self.allowed_origins = [
engineio_server.cors_allowed_origins]
else:
self.allowed_origins = engineio_server.cors_allowed_origins
self.receive_queue = asyncio.Queue()
async def get(self, *args, **kwargs):
if self.request.headers.get('Upgrade', '').lower() == 'websocket':
ret = super().get(*args, **kwargs)
if asyncio.iscoroutine(ret):
await ret
else:
await engineio_server.handle_request(self)
async def open(self, *args, **kwargs):
# this is the handler for the websocket request
asyncio.ensure_future(engineio_server.handle_request(self))
async def post(self, *args, **kwargs):
await engineio_server.handle_request(self)
async def options(self, *args, **kwargs):
await engineio_server.handle_request(self)
async def on_message(self, message):
await self.receive_queue.put(message)
async def get_next_message(self):
return await self.receive_queue.get()
def on_close(self):
self.receive_queue.put_nowait(None)
def check_origin(self, origin):
if self.allowed_origins is None or origin in self.allowed_origins:
return True
return super().check_origin(origin)
def get_compression_options(self):
# enable compression
return {}
return Handler
def translate_request(handler):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload(object):
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
payload = handler.request.body
uri_parts = urlsplit(handler.request.path)
full_uri = handler.request.path
if handler.request.query: # pragma: no cover
full_uri += '?' + handler.request.query
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': handler.request.method,
'QUERY_STRING': handler.request.query or '',
'RAW_URI': full_uri,
'SERVER_PROTOCOL': 'HTTP/%s' % handler.request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'tornado.handler': handler
}
for hdr_name, hdr_value in handler.request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
tornado_handler = environ['tornado.handler']
try:
tornado_handler.set_status(int(status.split()[0]))
except RuntimeError: # pragma: no cover
# for websocket connections Tornado does not accept a response, since
# it already emitted the 101 status code
return
for header, value in headers:
tornado_handler.set_header(header, value)
tornado_handler.write(payload)
tornado_handler.finish()
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a tornado WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self.tornado_handler = None
async def __call__(self, environ):
self.tornado_handler = environ['tornado.handler']
self.environ = environ
await self.handler(self)
async def close(self):
self.tornado_handler.close()
async def send(self, message):
try:
self.tornado_handler.write_message(
message, binary=isinstance(message, bytes))
except tornado.websocket.WebSocketClosedError:
raise exceptions.EngineIOError()
async def wait(self):
msg = await self.tornado_handler.get_next_message()
if not isinstance(msg, six.binary_type) and \
not isinstance(msg, six.text_type):
raise IOError()
return msg
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View file

@ -0,0 +1,585 @@
import asyncio
import ssl
try:
import aiohttp
except ImportError: # pragma: no cover
aiohttp = None
import six
from . import client
from . import exceptions
from . import packet
from . import payload
class AsyncClient(client.Client):
"""An Engine.IO client for asyncio.
This class implements a fully compliant Engine.IO web client with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
"""
def is_asyncio_based(self):
return True
async def connect(self, url, headers={}, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
:param url: The URL of the Engine.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param engineio_path: The endpoint where the Engine.IO server is
installed. The default value is appropriate for
most cases.
Note: this method is a coroutine.
Example usage::
eio = engineio.Client()
await eio.connect('http://localhost:5000')
"""
if self.state != 'disconnected':
raise ValueError('Client is not in a disconnected state')
valid_transports = ['polling', 'websocket']
if transports is not None:
if isinstance(transports, six.text_type):
transports = [transports]
transports = [transport for transport in transports
if transport in valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or valid_transports
self.queue = self.create_queue()
return await getattr(self, '_connect_' + self.transports[0])(
url, headers, engineio_path)
async def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
Note: this method is a coroutine.
"""
if self.read_loop_task:
await self.read_loop_task
async def send(self, data, binary=None):
"""Send a message to a client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
Note: this method is a coroutine.
"""
await self._send_packet(packet.Packet(packet.MESSAGE, data=data,
binary=binary))
async def disconnect(self, abort=False):
"""Disconnect from the server.
:param abort: If set to ``True``, do not wait for background tasks
associated with the connection to end.
Note: this method is a coroutine.
"""
if self.state == 'connected':
await self._send_packet(packet.Packet(packet.CLOSE))
await self.queue.put(None)
self.state = 'disconnecting'
await self._trigger_event('disconnect', run_async=False)
if self.current_transport == 'websocket':
await self.ws.close()
if not abort:
await self.read_loop_task
self.state = 'disconnected'
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task.
This is a utility function that applications can use to start a
background task.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
Note: this method is a coroutine.
"""
return asyncio.ensure_future(target(*args, **kwargs))
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time.
Note: this method is a coroutine.
"""
return await asyncio.sleep(seconds)
def create_queue(self):
"""Create a queue object."""
q = asyncio.Queue()
q.Empty = asyncio.QueueEmpty
return q
def create_event(self):
"""Create an event object."""
return asyncio.Event()
def _reset(self):
if self.http: # pragma: no cover
asyncio.ensure_future(self.http.close())
super()._reset()
async def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
if aiohttp is None: # pragma: no cover
self.logger.error('aiohttp not installed -- cannot make HTTP '
'requests!')
return
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None:
self._reset()
raise exceptions.ConnectionError(
'Connection refused by the server')
if r.status < 200 or r.status >= 300:
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status))
try:
p = payload.Payload(encoded_payload=await r.read())
except ValueError:
six.raise_from(exceptions.ConnectionError(
'Unexpected response from server'), None)
open_packet = p.packets[0]
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError(
'OPEN packet not returned by server')
self.logger.info(
'Polling connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'polling'
self.base_url += '&sid=' + self.sid
self.state = 'connected'
client.connected_clients.append(self)
await self._trigger_event('connect', run_async=False)
for pkt in p.packets[1:]:
await self._receive_packet(pkt)
if 'websocket' in self.upgrades and 'websocket' in self.transports:
# attempt to upgrade to websocket
if await self._connect_websocket(url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
async def _connect_websocket(self, url, headers, engineio_path):
"""Establish or upgrade to a WebSocket connection with the server."""
if aiohttp is None: # pragma: no cover
self.logger.error('aiohttp package not installed')
return False
websocket_url = self._get_engineio_url(url, engineio_path,
'websocket')
if self.sid:
self.logger.info(
'Attempting WebSocket upgrade to ' + websocket_url)
upgrade = True
websocket_url += '&sid=' + self.sid
else:
upgrade = False
self.base_url = websocket_url
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)
if self.http is None or self.http.closed: # pragma: no cover
self.http = aiohttp.ClientSession()
try:
if not self.ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ws = await self.http.ws_connect(
websocket_url + self._get_url_timestamp(),
headers=headers, ssl=ssl_context)
else:
ws = await self.http.ws_connect(
websocket_url + self._get_url_timestamp(),
headers=headers)
except (aiohttp.client_exceptions.WSServerHandshakeError,
aiohttp.client_exceptions.ServerConnectionError):
if upgrade:
self.logger.warning(
'WebSocket upgrade failed: connection error')
return False
else:
raise exceptions.ConnectionError('Connection error')
if upgrade:
p = packet.Packet(packet.PING, data='probe').encode(
always_bytes=False)
try:
await ws.send_str(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
try:
p = (await ws.receive()).data
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected recv exception: %s',
str(e))
return False
pkt = packet.Packet(encoded_packet=p)
if pkt.packet_type != packet.PONG or pkt.data != 'probe':
self.logger.warning(
'WebSocket upgrade failed: no PONG packet')
return False
p = packet.Packet(packet.UPGRADE).encode(always_bytes=False)
try:
await ws.send_str(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
self.current_transport = 'websocket'
self.logger.info('WebSocket upgrade was successful')
else:
try:
p = (await ws.receive()).data
except Exception as e: # pragma: no cover
raise exceptions.ConnectionError(
'Unexpected recv exception: ' + str(e))
open_packet = packet.Packet(encoded_packet=p)
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError('no OPEN packet')
self.logger.info(
'WebSocket connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'websocket'
self.state = 'connected'
client.connected_clients.append(self)
await self._trigger_event('connect', run_async=False)
self.ws = ws
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
return True
async def _receive_packet(self, pkt):
"""Handle incoming packets from the server."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.logger.info(
'Received packet %s data %s', packet_name,
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
if pkt.packet_type == packet.MESSAGE:
await self._trigger_event('message', pkt.data, run_async=True)
elif pkt.packet_type == packet.PONG:
self.pong_received = True
elif pkt.packet_type == packet.CLOSE:
await self.disconnect(abort=True)
elif pkt.packet_type == packet.NOOP:
pass
else:
self.logger.error('Received unexpected packet of type %s',
pkt.packet_type)
async def _send_packet(self, pkt):
"""Queue a packet to be sent to the server."""
if self.state != 'connected':
return
await self.queue.put(pkt)
self.logger.info(
'Sending packet %s data %s',
packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
async def _send_request(
self, method, url, headers=None, body=None,
timeout=None): # pragma: no cover
if self.http is None or self.http.closed:
self.http = aiohttp.ClientSession()
http_method = getattr(self.http, method.lower())
try:
if not self.ssl_verify:
return await http_method(
url, headers=headers, data=body,
timeout=aiohttp.ClientTimeout(total=timeout), ssl=False)
else:
return await http_method(
url, headers=headers, data=body,
timeout=aiohttp.ClientTimeout(total=timeout))
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]) is True:
if run_async:
return self.start_background_task(self.handlers[event],
*args)
else:
try:
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
if run_async:
async def async_handler():
return self.handlers[event](*args)
return self.start_background_task(async_handler)
else:
try:
ret = self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
async def _ping_loop(self):
"""This background task sends a PING to the server at the requested
interval.
"""
self.pong_received = True
if self.ping_loop_event is None:
self.ping_loop_event = self.create_event()
else:
self.ping_loop_event.clear()
while self.state == 'connected':
if not self.pong_received:
self.logger.info(
'PONG response has not been received, aborting')
if self.ws:
await self.ws.close()
await self.queue.put(None)
break
self.pong_received = False
await self._send_packet(packet.Packet(packet.PING))
try:
await asyncio.wait_for(self.ping_loop_event.wait(),
self.ping_interval)
except (asyncio.TimeoutError,
asyncio.CancelledError): # pragma: no cover
pass
self.logger.info('Exiting ping task')
async def _read_loop_polling(self):
"""Read packets by polling the Engine.IO server."""
while self.state == 'connected':
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
await self.queue.put(None)
break
if r.status < 200 or r.status >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
await self.queue.put(None)
break
try:
p = payload.Payload(encoded_payload=await r.read())
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
await self.queue.put(None)
break
for pkt in p.packets:
await self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
self.logger.info('Waiting for ping loop task to end')
if self.ping_loop_event: # pragma: no cover
self.ping_loop_event.set()
await self.ping_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect', run_async=False)
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
async def _read_loop_websocket(self):
"""Read packets from the Engine.IO WebSocket connection."""
while self.state == 'connected':
p = None
try:
p = (await self.ws.receive()).data
if p is None: # pragma: no cover
raise RuntimeError('WebSocket read returned None')
except aiohttp.client_exceptions.ServerDisconnectedError:
self.logger.info(
'Read loop: WebSocket connection was closed, aborting')
await self.queue.put(None)
break
except Exception as e:
self.logger.info(
'Unexpected error "%s", aborting', str(e))
await self.queue.put(None)
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
await self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
self.logger.info('Waiting for ping loop task to end')
if self.ping_loop_event: # pragma: no cover
self.ping_loop_event.set()
await self.ping_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect', run_async=False)
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
async def _write_loop(self):
"""This background task sends packages to the server as they are
pushed to the send queue.
"""
while self.state == 'connected':
# to simplify the timeout handling, use the maximum of the
# ping interval and ping timeout as timeout, with an extra 5
# seconds grace period
timeout = max(self.ping_interval, self.ping_timeout) + 5
packets = None
try:
packets = [await asyncio.wait_for(self.queue.get(), timeout)]
except (self.queue.Empty, asyncio.TimeoutError,
asyncio.CancelledError):
self.logger.error('packet queue is empty, aborting')
break
if packets == [None]:
self.queue.task_done()
packets = []
else:
while True:
try:
packets.append(self.queue.get_nowait())
except self.queue.Empty:
break
if packets[-1] is None:
packets = packets[:-1]
self.queue.task_done()
break
if not packets:
# empty packet list returned -> connection closed
break
if self.current_transport == 'polling':
p = payload.Payload(packets=packets)
r = await self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
break
if r.status < 200 or r.status >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
self._reset()
break
else:
# websocket
try:
for pkt in packets:
if pkt.binary:
await self.ws.send_bytes(pkt.encode(
always_bytes=False))
else:
await self.ws.send_str(pkt.encode(
always_bytes=False))
self.queue.task_done()
except aiohttp.client_exceptions.ServerDisconnectedError:
self.logger.info(
'Write loop: WebSocket connection was closed, '
'aborting')
break
self.logger.info('Exiting write loop task')

View file

@ -0,0 +1,472 @@
import asyncio
import six
from six.moves import urllib
from . import exceptions
from . import packet
from . import server
from . import asyncio_socket
class AsyncServer(server.Server):
"""An Engine.IO server for asyncio.
This class implements a fully compliant Engine.IO web server with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "aiohttp",
"sanic", "tornado" and "asgi". If this argument is not
given, "aiohttp" is tried first, followed by "sanic",
"tornado", and finally "asgi". The first async mode that
has all its dependencies installed is the one that is
chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting.
:param ping_interval: The interval in seconds at which the client pings
the server. The default is 25 seconds. For advanced
control, a two element tuple can be given, where
the first number is the ping interval and the second
is a grace period added by the server. The default
grace period is 5 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport.
:param allow_upgrades: Whether to allow transport upgrades or not.
:param http_compression: Whether to compress packages when using the
polling transport.
:param compression_threshold: Only compress messages when their byte size
is greater than this value.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
def is_asyncio_based(self):
return True
def async_modes(self):
return ['aiohttp', 'sanic', 'tornado', 'asgi']
def attach(self, app, engineio_path='engine.io'):
"""Attach the Engine.IO server to an application."""
engineio_path = engineio_path.strip('/')
self._async['create_route'](app, self, '/{}/'.format(engineio_path))
async def send(self, sid, data, binary=None):
"""Send a message to a client.
:param sid: The session id of the recipient client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
Note: this method is a coroutine.
"""
try:
socket = self._get_socket(sid)
except KeyError:
# the socket is not available
self.logger.warning('Cannot send to sid %s', sid)
return
await socket.send(packet.Packet(packet.MESSAGE, data=data,
binary=binary))
async def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved. If you want to modify
the user session, use the ``session`` context manager instead.
"""
socket = self._get_socket(sid)
return socket.session
async def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session
def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
async with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None
async def __aenter__(self):
self.session = await self.server.get_session(sid)
return self.session
async def __aexit__(self, *args):
await self.server.save_session(sid, self.session)
return _session_context_manager(self, sid)
async def disconnect(self, sid=None):
"""Disconnect a client.
:param sid: The session id of the client to close. If this parameter
is not given, then all clients are closed.
Note: this method is a coroutine.
"""
if sid is not None:
try:
socket = self._get_socket(sid)
except KeyError: # pragma: no cover
# the socket was already closed or gone
pass
else:
await socket.close()
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
else:
await asyncio.wait([client.close()
for client in six.itervalues(self.sockets)])
self.sockets = {}
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
This is the entry point of the Engine.IO application. This function
returns the HTTP response to deliver to the client.
Note: this method is a coroutine.
"""
translate_request = self._async['translate_request']
if asyncio.iscoroutinefunction(translate_request):
environ = await translate_request(*args, **kwargs)
else:
environ = translate_request(*args, **kwargs)
if self.cors_allowed_origins != []:
# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since
# browsers only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in \
allowed_origins:
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
make_response = self._async['make_response']
if asyncio.iscoroutinefunction(make_response):
response = await make_response(
r['status'], r['headers'], r['response'], environ)
else:
response = make_response(r['status'], r['headers'],
r['response'], environ)
return response
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
sid = query['sid'][0] if 'sid' in query else None
b64 = False
jsonp = False
jsonp_index = None
if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if 'j' in query:
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass
if jsonp and jsonp_index is None:
self.logger.warning('Invalid JSONP index number')
r = self._bad_request()
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = await self._handle_connect(environ, transport,
b64, jsonp_index)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
packets = await socket.handle_get_request(environ)
if isinstance(packets, list):
r = self._ok(packets, b64=b64,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
await socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r
if self.http_compression and \
len(r['response']) >= self.compression_threshold:
encodings = [e.split(';')[0].strip() for e in
environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
for encoding in encodings:
if encoding in self.compression_methods:
r['response'] = \
getattr(self, '_' + encoding)(r['response'])
r['headers'] += [('Content-Encoding', encoding)]
break
cors_headers = self._cors_headers(environ)
make_response = self._async['make_response']
if asyncio.iscoroutinefunction(make_response):
response = await make_response(r['status'],
r['headers'] + cors_headers,
r['response'], environ)
else:
response = make_response(r['status'], r['headers'] + cors_headers,
r['response'], environ)
return response
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
"""
return asyncio.ensure_future(target(*args, **kwargs))
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await asyncio.sleep(seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object using the appropriate async model.
This is a utility function that applications can use to create a queue
without having to worry about using the correct call for the selected
async mode. For asyncio based async modes, this returns an instance of
``asyncio.Queue``.
"""
return asyncio.Queue(*args, **kwargs)
def get_queue_empty_exception(self):
"""Return the queue empty exception for the appropriate async model.
This is a utility function that applications can use to work with a
queue without having to worry about using the correct call for the
selected async mode. For asyncio based async modes, this returns an
instance of ``asyncio.QueueEmpty``.
"""
return asyncio.QueueEmpty
def create_event(self, *args, **kwargs):
"""Create an event object using the appropriate async model.
This is a utility function that applications can use to create an
event without having to worry about using the correct call for the
selected async mode. For asyncio based async modes, this returns
an instance of ``asyncio.Event``.
"""
return asyncio.Event(*args, **kwargs)
async def _handle_connect(self, environ, transport, b64=False,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)
sid = self._generate_id()
s = asyncio_socket.AsyncSocket(self, sid)
self.sockets[sid] = s
pkt = packet.Packet(
packet.OPEN, {'sid': sid,
'upgrades': self._upgrades(sid, transport),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(self.ping_interval * 1000)})
await s.send(pkt)
ret = await self._trigger_event('connect', sid, environ,
run_async=False)
if ret is False:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized()
if transport == 'websocket':
ret = await s.handle_get_request(environ)
if s.closed:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
try:
return self._ok(await s.poll(), headers=headers, b64=b64,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]) is True:
if run_async:
return self.start_background_task(self.handlers[event],
*args)
else:
try:
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
if run_async:
async def async_handler():
return self.handlers[event](*args)
return self.start_background_task(async_handler)
else:
try:
ret = self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
async def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
if len(self.sockets) == 0:
# nothing to do
await self.sleep(self.ping_timeout)
continue
# go through the entire client list in a ping interval cycle
sleep_interval = self.ping_timeout / len(self.sockets)
try:
# iterate over the current clients
for socket in self.sockets.copy().values():
if not socket.closing and not socket.closed:
await socket.check_ping_timeout()
await self.sleep(sleep_interval)
except (SystemExit, KeyboardInterrupt, asyncio.CancelledError):
self.logger.info('service task canceled')
break
except:
if asyncio.get_event_loop().is_closed():
self.logger.info('event loop is closed, exiting service '
'task')
break
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')

View file

@ -0,0 +1,236 @@
import asyncio
import six
import sys
import time
from . import exceptions
from . import packet
from . import payload
from . import socket
class AsyncSocket(socket.Socket):
async def poll(self):
"""Wait for packets to send to the client."""
try:
packets = [await asyncio.wait_for(self.queue.get(),
self.server.ping_timeout)]
self.queue.task_done()
except (asyncio.TimeoutError, asyncio.CancelledError):
raise exceptions.QueueEmpty()
if packets == [None]:
return []
try:
packets.append(self.queue.get_nowait())
self.queue.task_done()
except asyncio.QueueEmpty:
pass
return packets
async def receive(self, pkt):
"""Receive packet from the client."""
self.server.logger.info('%s: Received packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
if pkt.packet_type == packet.PING:
self.last_ping = time.time()
await self.send(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.MESSAGE:
await self.server._trigger_event(
'message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
await self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
await self.close(wait=False, abort=True)
else:
raise exceptions.UnknownPacketError()
async def check_ping_timeout(self):
"""Make sure the client is still sending pings.
This helps detect disconnections for long-polling clients.
"""
if self.closed:
raise exceptions.SocketIsClosedError()
if time.time() - self.last_ping > self.server.ping_interval + \
self.server.ping_interval_grace_period:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
# Passing abort=False here will cause close() to write a
# CLOSE packet. This has the effect of updating half-open sockets
# to their correct state of disconnected
await self.close(wait=False, abort=False)
return False
return True
async def send(self, pkt):
"""Send a packet to the client."""
if not await self.check_ping_timeout():
return
if self.upgrading:
self.packet_backlog.append(pkt)
else:
await self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
async def handle_get_request(self, environ):
"""Handle a long-polling GET request from the client."""
connections = [
s.strip()
for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
transport = environ.get('HTTP_UPGRADE', '').lower()
if 'upgrade' in connections and transport in self.upgrade_protocols:
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return await getattr(self, '_upgrade_' + transport)(environ)
try:
packets = await self.poll()
except exceptions.QueueEmpty:
exc = sys.exc_info()
await self.close(wait=False)
six.reraise(*exc)
return packets
async def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise exceptions.ContentTooLongError()
else:
body = await environ['wsgi.input'].read(length)
p = payload.Payload(encoded_payload=body)
for pkt in p.packets:
await self.receive(pkt)
async def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
await self.server._trigger_event('disconnect', self.sid)
if not abort:
await self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
await self.queue.join()
async def _upgrade_websocket(self, environ):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise IOError('Socket has been upgraded already')
if self.server._async['websocket'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
ws = self.server._async['websocket'](self._websocket_handler)
return await ws(environ)
async def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
return
await ws.send(packet.Packet(
packet.PONG,
data=six.text_type('probe')).encode(always_bytes=False))
await self.queue.put(packet.Packet(packet.NOOP)) # end poll
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
return
self.upgraded = True
# flush any packets that were sent during the upgrade
for pkt in self.packet_backlog:
await self.queue.put(pkt)
self.packet_backlog = []
self.upgrading = False
else:
self.connected = True
self.upgraded = True
# start separate writer thread
async def writer():
while True:
packets = None
try:
packets = await self.poll()
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
break
try:
for pkt in packets:
await ws.send(pkt.encode(always_bytes=False))
except:
break
writer_task = asyncio.ensure_future(writer())
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
while True:
p = None
wait_task = asyncio.ensure_future(ws.wait())
try:
p = await asyncio.wait_for(wait_task, self.server.ping_timeout)
except asyncio.CancelledError: # pragma: no cover
# there is a bug (https://bugs.python.org/issue30508) in
# asyncio that causes a "Task exception never retrieved" error
# to appear when wait_task raises an exception before it gets
# cancelled. Calling wait_task.exception() prevents the error
# from being issued in Python 3.6, but causes other errors in
# other versions, so we run it with all errors suppressed and
# hope for the best.
try:
wait_task.exception()
except:
pass
break
except:
break
if p is None:
# connection closed by client
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
try:
await self.receive(pkt)
except exceptions.UnknownPacketError: # pragma: no cover
pass
except exceptions.SocketIsClosedError: # pragma: no cover
self.server.logger.info('Receive error -- socket is closed')
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Unknown receive error')
await self.queue.put(None) # unlock the writer task so it can exit
await asyncio.wait_for(writer_task, timeout=None)
await self.close(wait=False, abort=True)

680
libs/engineio/client.py Normal file
View file

@ -0,0 +1,680 @@
import logging
try:
import queue
except ImportError: # pragma: no cover
import Queue as queue
import signal
import ssl
import threading
import time
import six
from six.moves import urllib
try:
import requests
except ImportError: # pragma: no cover
requests = None
try:
import websocket
except ImportError: # pragma: no cover
websocket = None
from . import exceptions
from . import packet
from . import payload
default_logger = logging.getLogger('engineio.client')
connected_clients = []
if six.PY2: # pragma: no cover
ConnectionError = OSError
def signal_handler(sig, frame):
"""SIGINT handler.
Disconnect all active clients and then invoke the original signal handler.
"""
for client in connected_clients[:]:
if client.is_asyncio_based():
client.start_background_task(client.disconnect, abort=True)
else:
client.disconnect(abort=True)
if callable(original_signal_handler):
return original_signal_handler(sig, frame)
else: # pragma: no cover
# Handle case where no original SIGINT handler was present.
return signal.default_int_handler(sig, frame)
original_signal_handler = None
class Client(object):
"""An Engine.IO client.
This class implements a fully compliant Engine.IO web client with support
for websocket and long-polling transports.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
"""
event_names = ['connect', 'disconnect', 'message']
def __init__(self,
logger=False,
json=None,
request_timeout=5,
ssl_verify=True):
global original_signal_handler
if original_signal_handler is None:
original_signal_handler = signal.signal(signal.SIGINT,
signal_handler)
self.handlers = {}
self.base_url = None
self.transports = None
self.current_transport = None
self.sid = None
self.upgrades = None
self.ping_interval = None
self.ping_timeout = None
self.pong_received = True
self.http = None
self.ws = None
self.read_loop_task = None
self.write_loop_task = None
self.ping_loop_task = None
self.ping_loop_event = None
self.queue = None
self.state = 'disconnected'
self.ssl_verify = ssl_verify
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
self.request_timeout = request_timeout
def is_asyncio_based(self):
return False
def on(self, event, handler=None):
"""Register an event handler.
:param event: The event name. Can be ``'connect'``, ``'message'`` or
``'disconnect'``.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
Example usage::
# as a decorator:
@eio.on('connect')
def connect_handler():
print('Connection request')
# as a method:
def message_handler(msg):
print('Received message: ', msg)
eio.send('response')
eio.on('message', message_handler)
"""
if event not in self.event_names:
raise ValueError('Invalid event')
def set_handler(handler):
self.handlers[event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def connect(self, url, headers={}, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
:param url: The URL of the Engine.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param engineio_path: The endpoint where the Engine.IO server is
installed. The default value is appropriate for
most cases.
Example usage::
eio = engineio.Client()
eio.connect('http://localhost:5000')
"""
if self.state != 'disconnected':
raise ValueError('Client is not in a disconnected state')
valid_transports = ['polling', 'websocket']
if transports is not None:
if isinstance(transports, six.string_types):
transports = [transports]
transports = [transport for transport in transports
if transport in valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or valid_transports
self.queue = self.create_queue()
return getattr(self, '_connect_' + self.transports[0])(
url, headers, engineio_path)
def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
"""
if self.read_loop_task:
self.read_loop_task.join()
def send(self, data, binary=None):
"""Send a message to a client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
"""
self._send_packet(packet.Packet(packet.MESSAGE, data=data,
binary=binary))
def disconnect(self, abort=False):
"""Disconnect from the server.
:param abort: If set to ``True``, do not wait for background tasks
associated with the connection to end.
"""
if self.state == 'connected':
self._send_packet(packet.Packet(packet.CLOSE))
self.queue.put(None)
self.state = 'disconnecting'
self._trigger_event('disconnect', run_async=False)
if self.current_transport == 'websocket':
self.ws.close()
if not abort:
self.read_loop_task.join()
self.state = 'disconnected'
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
def transport(self):
"""Return the name of the transport currently in use.
The possible values returned by this function are ``'polling'`` and
``'websocket'``.
"""
return self.current_transport
def start_background_task(self, target, *args, **kwargs):
"""Start a background task.
This is a utility function that applications can use to start a
background task.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
th = threading.Thread(target=target, args=args, kwargs=kwargs)
th.start()
return th
def sleep(self, seconds=0):
"""Sleep for the requested amount of time."""
return time.sleep(seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object."""
q = queue.Queue(*args, **kwargs)
q.Empty = queue.Empty
return q
def create_event(self, *args, **kwargs):
"""Create an event object."""
return threading.Event(*args, **kwargs)
def _reset(self):
self.state = 'disconnected'
self.sid = None
def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
if requests is None: # pragma: no cover
# not installed
self.logger.error('requests package is not installed -- cannot '
'send HTTP requests!')
return
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None:
self._reset()
raise exceptions.ConnectionError(
'Connection refused by the server')
if r.status_code < 200 or r.status_code >= 300:
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status_code))
try:
p = payload.Payload(encoded_payload=r.content)
except ValueError:
six.raise_from(exceptions.ConnectionError(
'Unexpected response from server'), None)
open_packet = p.packets[0]
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError(
'OPEN packet not returned by server')
self.logger.info(
'Polling connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'polling'
self.base_url += '&sid=' + self.sid
self.state = 'connected'
connected_clients.append(self)
self._trigger_event('connect', run_async=False)
for pkt in p.packets[1:]:
self._receive_packet(pkt)
if 'websocket' in self.upgrades and 'websocket' in self.transports:
# attempt to upgrade to websocket
if self._connect_websocket(url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return
# start background tasks associated with this client
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
def _connect_websocket(self, url, headers, engineio_path):
"""Establish or upgrade to a WebSocket connection with the server."""
if websocket is None: # pragma: no cover
# not installed
self.logger.warning('websocket-client package not installed, only '
'polling transport is available')
return False
websocket_url = self._get_engineio_url(url, engineio_path, 'websocket')
if self.sid:
self.logger.info(
'Attempting WebSocket upgrade to ' + websocket_url)
upgrade = True
websocket_url += '&sid=' + self.sid
else:
upgrade = False
self.base_url = websocket_url
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)
# get the cookies from the long-polling connection so that they can
# also be sent the the WebSocket route
cookies = None
if self.http:
cookies = '; '.join(["{}={}".format(cookie.name, cookie.value)
for cookie in self.http.cookies])
try:
if not self.ssl_verify:
ws = websocket.create_connection(
websocket_url + self._get_url_timestamp(), header=headers,
cookie=cookies, sslopt={"cert_reqs": ssl.CERT_NONE})
else:
ws = websocket.create_connection(
websocket_url + self._get_url_timestamp(), header=headers,
cookie=cookies)
except (ConnectionError, IOError, websocket.WebSocketException):
if upgrade:
self.logger.warning(
'WebSocket upgrade failed: connection error')
return False
else:
raise exceptions.ConnectionError('Connection error')
if upgrade:
p = packet.Packet(packet.PING,
data=six.text_type('probe')).encode()
try:
ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
try:
p = ws.recv()
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected recv exception: %s',
str(e))
return False
pkt = packet.Packet(encoded_packet=p)
if pkt.packet_type != packet.PONG or pkt.data != 'probe':
self.logger.warning(
'WebSocket upgrade failed: no PONG packet')
return False
p = packet.Packet(packet.UPGRADE).encode()
try:
ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
self.current_transport = 'websocket'
self.logger.info('WebSocket upgrade was successful')
else:
try:
p = ws.recv()
except Exception as e: # pragma: no cover
raise exceptions.ConnectionError(
'Unexpected recv exception: ' + str(e))
open_packet = packet.Packet(encoded_packet=p)
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError('no OPEN packet')
self.logger.info(
'WebSocket connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'websocket'
self.state = 'connected'
connected_clients.append(self)
self._trigger_event('connect', run_async=False)
self.ws = ws
# start background tasks associated with this client
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
return True
def _receive_packet(self, pkt):
"""Handle incoming packets from the server."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.logger.info(
'Received packet %s data %s', packet_name,
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
if pkt.packet_type == packet.MESSAGE:
self._trigger_event('message', pkt.data, run_async=True)
elif pkt.packet_type == packet.PONG:
self.pong_received = True
elif pkt.packet_type == packet.CLOSE:
self.disconnect(abort=True)
elif pkt.packet_type == packet.NOOP:
pass
else:
self.logger.error('Received unexpected packet of type %s',
pkt.packet_type)
def _send_packet(self, pkt):
"""Queue a packet to be sent to the server."""
if self.state != 'connected':
return
self.queue.put(pkt)
self.logger.info(
'Sending packet %s data %s',
packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
def _send_request(
self, method, url, headers=None, body=None,
timeout=None): # pragma: no cover
if self.http is None:
self.http = requests.Session()
try:
return self.http.request(method, url, headers=headers, data=body,
timeout=timeout, verify=self.ssl_verify)
except requests.exceptions.RequestException as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
if run_async:
return self.start_background_task(self.handlers[event], *args)
else:
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
def _get_engineio_url(self, url, engineio_path, transport):
"""Generate the Engine.IO connection URL."""
engineio_path = engineio_path.strip('/')
parsed_url = urllib.parse.urlparse(url)
if transport == 'polling':
scheme = 'http'
elif transport == 'websocket':
scheme = 'ws'
else: # pragma: no cover
raise ValueError('invalid transport')
if parsed_url.scheme in ['https', 'wss']:
scheme += 's'
return ('{scheme}://{netloc}/{path}/?{query}'
'{sep}transport={transport}&EIO=3').format(
scheme=scheme, netloc=parsed_url.netloc,
path=engineio_path, query=parsed_url.query,
sep='&' if parsed_url.query else '',
transport=transport)
def _get_url_timestamp(self):
"""Generate the Engine.IO query string timestamp."""
return '&t=' + str(time.time())
def _ping_loop(self):
"""This background task sends a PING to the server at the requested
interval.
"""
self.pong_received = True
if self.ping_loop_event is None:
self.ping_loop_event = self.create_event()
else:
self.ping_loop_event.clear()
while self.state == 'connected':
if not self.pong_received:
self.logger.info(
'PONG response has not been received, aborting')
if self.ws:
self.ws.close(timeout=0)
self.queue.put(None)
break
self.pong_received = False
self._send_packet(packet.Packet(packet.PING))
self.ping_loop_event.wait(timeout=self.ping_interval)
self.logger.info('Exiting ping task')
def _read_loop_polling(self):
"""Read packets by polling the Engine.IO server."""
while self.state == 'connected':
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
self.queue.put(None)
break
if r.status_code < 200 or r.status_code >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status_code)
self.queue.put(None)
break
try:
p = payload.Payload(encoded_payload=r.content)
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
self.queue.put(None)
break
for pkt in p.packets:
self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
self.logger.info('Waiting for ping loop task to end')
if self.ping_loop_event: # pragma: no cover
self.ping_loop_event.set()
self.ping_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect', run_async=False)
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
def _read_loop_websocket(self):
"""Read packets from the Engine.IO WebSocket connection."""
while self.state == 'connected':
p = None
try:
p = self.ws.recv()
except websocket.WebSocketConnectionClosedException:
self.logger.warning(
'WebSocket connection was closed, aborting')
self.queue.put(None)
break
except Exception as e:
self.logger.info(
'Unexpected error "%s", aborting', str(e))
self.queue.put(None)
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
self.logger.info('Waiting for ping loop task to end')
if self.ping_loop_event: # pragma: no cover
self.ping_loop_event.set()
self.ping_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect', run_async=False)
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
def _write_loop(self):
"""This background task sends packages to the server as they are
pushed to the send queue.
"""
while self.state == 'connected':
# to simplify the timeout handling, use the maximum of the
# ping interval and ping timeout as timeout, with an extra 5
# seconds grace period
timeout = max(self.ping_interval, self.ping_timeout) + 5
packets = None
try:
packets = [self.queue.get(timeout=timeout)]
except self.queue.Empty:
self.logger.error('packet queue is empty, aborting')
break
if packets == [None]:
self.queue.task_done()
packets = []
else:
while True:
try:
packets.append(self.queue.get(block=False))
except self.queue.Empty:
break
if packets[-1] is None:
packets = packets[:-1]
self.queue.task_done()
break
if not packets:
# empty packet list returned -> connection closed
break
if self.current_transport == 'polling':
p = payload.Payload(packets=packets)
r = self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
break
if r.status_code < 200 or r.status_code >= 300:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status_code)
self._reset()
break
else:
# websocket
try:
for pkt in packets:
encoded_packet = pkt.encode(always_bytes=False)
if pkt.binary:
self.ws.send_binary(encoded_packet)
else:
self.ws.send(encoded_packet)
self.queue.task_done()
except websocket.WebSocketConnectionClosedException:
self.logger.warning(
'WebSocket connection was closed, aborting')
break
self.logger.info('Exiting write loop task')

View file

@ -0,0 +1,22 @@
class EngineIOError(Exception):
pass
class ContentTooLongError(EngineIOError):
pass
class UnknownPacketError(EngineIOError):
pass
class QueueEmpty(EngineIOError):
pass
class SocketIsClosedError(EngineIOError):
pass
class ConnectionError(EngineIOError):
pass

View file

@ -0,0 +1,87 @@
import os
from engineio.static_files import get_static_file
class WSGIApp(object):
"""WSGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another WSGI application.
:param engineio_app: The Engine.IO server. Must be an instance of the
``engineio.Server`` class.
:param wsgi_app: The WSGI app that receives all other traffic.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import engineio
import eventlet
eio = engineio.Server()
app = engineio.WSGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
"""
def __init__(self, engineio_app, wsgi_app=None, static_files=None,
engineio_path='engine.io'):
self.engineio_app = engineio_app
self.wsgi_app = wsgi_app
self.engineio_path = engineio_path.strip('/')
self.static_files = static_files or {}
def __call__(self, environ, start_response):
if 'gunicorn.socket' in environ:
# gunicorn saves the socket under environ['gunicorn.socket'], while
# eventlet saves it under environ['eventlet.input']. Eventlet also
# stores the socket inside a wrapper class, while gunicon writes it
# directly into the environment. To give eventlet's WebSocket
# module access to this socket when running under gunicorn, here we
# copy the socket to the eventlet format.
class Input(object):
def __init__(self, socket):
self.socket = socket
def get_socket(self):
return self.socket
environ['eventlet.input'] = Input(environ['gunicorn.socket'])
path = environ['PATH_INFO']
if path is not None and \
path.startswith('/{0}/'.format(self.engineio_path)):
return self.engineio_app.handle_request(environ, start_response)
else:
static_file = get_static_file(path, self.static_files) \
if self.static_files else None
if static_file:
if os.path.exists(static_file['filename']):
start_response(
'200 OK',
[('Content-Type', static_file['content_type'])])
with open(static_file['filename'], 'rb') as f:
return [f.read()]
else:
return self.not_found(start_response)
elif self.wsgi_app is not None:
return self.wsgi_app(environ, start_response)
return self.not_found(start_response)
def not_found(self, start_response):
start_response("404 Not Found", [('Content-Type', 'text/plain')])
return [b'Not Found']
class Middleware(WSGIApp):
"""This class has been renamed to ``WSGIApp`` and is now deprecated."""
def __init__(self, engineio_app, wsgi_app=None,
engineio_path='engine.io'):
super(Middleware, self).__init__(engineio_app, wsgi_app,
engineio_path=engineio_path)

92
libs/engineio/packet.py Normal file
View file

@ -0,0 +1,92 @@
import base64
import json as _json
import six
(OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6)
packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP']
binary_types = (six.binary_type, bytearray)
class Packet(object):
"""Engine.IO packet."""
json = _json
def __init__(self, packet_type=NOOP, data=None, binary=None,
encoded_packet=None):
self.packet_type = packet_type
self.data = data
if binary is not None:
self.binary = binary
elif isinstance(data, six.text_type):
self.binary = False
elif isinstance(data, binary_types):
self.binary = True
else:
self.binary = False
if encoded_packet:
self.decode(encoded_packet)
def encode(self, b64=False, always_bytes=True):
"""Encode the packet for transmission."""
if self.binary and not b64:
encoded_packet = six.int2byte(self.packet_type)
else:
encoded_packet = six.text_type(self.packet_type)
if self.binary and b64:
encoded_packet = 'b' + encoded_packet
if self.binary:
if b64:
encoded_packet += base64.b64encode(self.data).decode('utf-8')
else:
encoded_packet += self.data
elif isinstance(self.data, six.string_types):
encoded_packet += self.data
elif isinstance(self.data, dict) or isinstance(self.data, list):
encoded_packet += self.json.dumps(self.data,
separators=(',', ':'))
elif self.data is not None:
encoded_packet += str(self.data)
if always_bytes and not isinstance(encoded_packet, binary_types):
encoded_packet = encoded_packet.encode('utf-8')
return encoded_packet
def decode(self, encoded_packet):
"""Decode a transmitted package."""
b64 = False
if not isinstance(encoded_packet, binary_types):
encoded_packet = encoded_packet.encode('utf-8')
elif not isinstance(encoded_packet, bytes):
encoded_packet = bytes(encoded_packet)
self.packet_type = six.byte2int(encoded_packet[0:1])
if self.packet_type == 98: # 'b' --> binary base64 encoded packet
self.binary = True
encoded_packet = encoded_packet[1:]
self.packet_type = six.byte2int(encoded_packet[0:1])
self.packet_type -= 48
b64 = True
elif self.packet_type >= 48:
self.packet_type -= 48
self.binary = False
else:
self.binary = True
self.data = None
if len(encoded_packet) > 1:
if self.binary:
if b64:
self.data = base64.b64decode(encoded_packet[1:])
else:
self.data = encoded_packet[1:]
else:
try:
self.data = self.json.loads(
encoded_packet[1:].decode('utf-8'))
if isinstance(self.data, int):
# do not allow integer payloads, see
# github.com/miguelgrinberg/python-engineio/issues/75
# for background on this decision
raise ValueError
except ValueError:
self.data = encoded_packet[1:].decode('utf-8')

81
libs/engineio/payload.py Normal file
View file

@ -0,0 +1,81 @@
import six
from . import packet
from six.moves import urllib
class Payload(object):
"""Engine.IO payload."""
max_decode_packets = 16
def __init__(self, packets=None, encoded_payload=None):
self.packets = packets or []
if encoded_payload is not None:
self.decode(encoded_payload)
def encode(self, b64=False, jsonp_index=None):
"""Encode the payload for transmission."""
encoded_payload = b''
for pkt in self.packets:
encoded_packet = pkt.encode(b64=b64)
packet_len = len(encoded_packet)
if b64:
encoded_payload += str(packet_len).encode('utf-8') + b':' + \
encoded_packet
else:
binary_len = b''
while packet_len != 0:
binary_len = six.int2byte(packet_len % 10) + binary_len
packet_len = int(packet_len / 10)
if not pkt.binary:
encoded_payload += b'\0'
else:
encoded_payload += b'\1'
encoded_payload += binary_len + b'\xff' + encoded_packet
if jsonp_index is not None:
encoded_payload = b'___eio[' + \
str(jsonp_index).encode() + \
b']("' + \
encoded_payload.replace(b'"', b'\\"') + \
b'");'
return encoded_payload
def decode(self, encoded_payload):
"""Decode a transmitted payload."""
self.packets = []
if len(encoded_payload) == 0:
return
# JSONP POST payload starts with 'd='
if encoded_payload.startswith(b'd='):
encoded_payload = urllib.parse.parse_qs(
encoded_payload)[b'd'][0]
i = 0
if six.byte2int(encoded_payload[0:1]) <= 1:
# binary encoding
while i < len(encoded_payload):
if len(self.packets) >= self.max_decode_packets:
raise ValueError('Too many packets in payload')
packet_len = 0
i += 1
while six.byte2int(encoded_payload[i:i + 1]) != 255:
packet_len = packet_len * 10 + six.byte2int(
encoded_payload[i:i + 1])
i += 1
self.packets.append(packet.Packet(
encoded_packet=encoded_payload[i + 1:i + 1 + packet_len]))
i += packet_len + 1
else:
# assume text encoding
encoded_payload = encoded_payload.decode('utf-8')
while i < len(encoded_payload):
if len(self.packets) >= self.max_decode_packets:
raise ValueError('Too many packets in payload')
j = encoded_payload.find(':', i)
packet_len = int(encoded_payload[i:j])
pkt = encoded_payload[j + 1:j + 1 + packet_len]
self.packets.append(packet.Packet(encoded_packet=pkt))
i = j + 1 + packet_len

675
libs/engineio/server.py Normal file
View file

@ -0,0 +1,675 @@
import gzip
import importlib
import logging
import uuid
import zlib
import six
from six.moves import urllib
from . import exceptions
from . import packet
from . import payload
from . import socket
default_logger = logging.getLogger('engineio.server')
class Server(object):
"""An Engine.IO server.
This class implements a fully compliant Engine.IO web server with support
for websocket and long-polling transports.
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
The first async mode that has all its dependencies
installed is the one that is chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 60 seconds.
:param ping_interval: The interval in seconds at which the client pings
the server. The default is 25 seconds. For advanced
control, a two element tuple can be given, where
the first number is the ping interval and the second
is a grace period added by the server. The default
grace period is 5 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 100,000,000
bytes.
:param allow_upgrades: Whether to allow transport upgrades or not. The
default is ``True``.
:param http_compression: Whether to compress packages when using the
polling transport. The default is ``True``.
:param compression_threshold: Only compress messages when their byte size
is greater than this value. The default is
1024 bytes.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default
is ``True``.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
compression_methods = ['gzip', 'deflate']
event_names = ['connect', 'disconnect', 'message']
_default_monitor_clients = True
def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25,
max_http_buffer_size=100000000, allow_upgrades=True,
http_compression=True, compression_threshold=1024,
cookie='io', cors_allowed_origins=None,
cors_credentials=True, logger=False, json=None,
async_handlers=True, monitor_clients=None, **kwargs):
self.ping_timeout = ping_timeout
if isinstance(ping_interval, tuple):
self.ping_interval = ping_interval[0]
self.ping_interval_grace_period = ping_interval[1]
else:
self.ping_interval = ping_interval
self.ping_interval_grace_period = 5
self.max_http_buffer_size = max_http_buffer_size
self.allow_upgrades = allow_upgrades
self.http_compression = http_compression
self.compression_threshold = compression_threshold
self.cookie = cookie
self.cors_allowed_origins = cors_allowed_origins
self.cors_credentials = cors_credentials
self.async_handlers = async_handlers
self.sockets = {}
self.handlers = {}
self.start_service_task = monitor_clients \
if monitor_clients is not None else self._default_monitor_clients
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
modes = self.async_modes()
if async_mode is not None:
modes = [async_mode] if async_mode in modes else []
self._async = None
self.async_mode = None
for mode in modes:
try:
self._async = importlib.import_module(
'engineio.async_drivers.' + mode)._async
asyncio_based = self._async['asyncio'] \
if 'asyncio' in self._async else False
if asyncio_based != self.is_asyncio_based():
continue # pragma: no cover
self.async_mode = mode
break
except ImportError:
pass
if self.async_mode is None:
raise ValueError('Invalid async_mode specified')
if self.is_asyncio_based() and \
('asyncio' not in self._async or not
self._async['asyncio']): # pragma: no cover
raise ValueError('The selected async_mode is not asyncio '
'compatible')
if not self.is_asyncio_based() and 'asyncio' in self._async and \
self._async['asyncio']: # pragma: no cover
raise ValueError('The selected async_mode requires asyncio and '
'must use the AsyncServer class')
self.logger.info('Server initialized for %s.', self.async_mode)
def is_asyncio_based(self):
return False
def async_modes(self):
return ['eventlet', 'gevent_uwsgi', 'gevent', 'threading']
def on(self, event, handler=None):
"""Register an event handler.
:param event: The event name. Can be ``'connect'``, ``'message'`` or
``'disconnect'``.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
Example usage::
# as a decorator:
@eio.on('connect')
def connect_handler(sid, environ):
print('Connection request')
if environ['REMOTE_ADDR'] in blacklisted:
return False # reject
# as a method:
def message_handler(sid, msg):
print('Received message: ', msg)
eio.send(sid, 'response')
eio.on('message', message_handler)
The handler function receives the ``sid`` (session ID) for the
client as first argument. The ``'connect'`` event handler receives the
WSGI environment as a second argument, and can return ``False`` to
reject the connection. The ``'message'`` handler receives the message
payload as a second argument. The ``'disconnect'`` handler does not
take a second argument.
"""
if event not in self.event_names:
raise ValueError('Invalid event')
def set_handler(handler):
self.handlers[event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def send(self, sid, data, binary=None):
"""Send a message to a client.
:param sid: The session id of the recipient client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
"""
try:
socket = self._get_socket(sid)
except KeyError:
# the socket is not available
self.logger.warning('Cannot send to sid %s', sid)
return
socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary))
def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved unless
``save_session()`` is called, or when the ``session`` context manager
is used.
"""
socket = self._get_socket(sid)
return socket.session
def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session
def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None
def __enter__(self):
self.session = self.server.get_session(sid)
return self.session
def __exit__(self, *args):
self.server.save_session(sid, self.session)
return _session_context_manager(self, sid)
def disconnect(self, sid=None):
"""Disconnect a client.
:param sid: The session id of the client to close. If this parameter
is not given, then all clients are closed.
"""
if sid is not None:
try:
socket = self._get_socket(sid)
except KeyError: # pragma: no cover
# the socket was already closed or gone
pass
else:
socket.close()
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
else:
for client in six.itervalues(self.sockets):
client.close()
self.sockets = {}
def transport(self, sid):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
:param sid: The session of the client.
"""
return 'websocket' if self._get_socket(sid).upgraded else 'polling'
def handle_request(self, environ, start_response):
"""Handle an HTTP request from the client.
This is the entry point of the Engine.IO application, using the same
interface as a WSGI application. For the typical usage, this function
is invoked by the :class:`Middleware` instance, but it can be invoked
directly when the middleware is not used.
:param environ: The WSGI environment.
:param start_response: The WSGI ``start_response`` function.
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
if self.cors_allowed_origins != []:
# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since
# browsers only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in \
allowed_origins:
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
start_response(r['status'], r['headers'])
return [r['response']]
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
sid = query['sid'][0] if 'sid' in query else None
b64 = False
jsonp = False
jsonp_index = None
if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if 'j' in query:
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass
if jsonp and jsonp_index is None:
self.logger.warning('Invalid JSONP index number')
r = self._bad_request()
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = self._handle_connect(environ, start_response,
transport, b64, jsonp_index)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets, b64=b64,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r or []
if self.http_compression and \
len(r['response']) >= self.compression_threshold:
encodings = [e.split(';')[0].strip() for e in
environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
for encoding in encodings:
if encoding in self.compression_methods:
r['response'] = \
getattr(self, '_' + encoding)(r['response'])
r['headers'] += [('Content-Encoding', encoding)]
break
cors_headers = self._cors_headers(environ)
start_response(r['status'], r['headers'] + cors_headers)
return [r['response']]
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
th = self._async['thread'](target=target, args=args, kwargs=kwargs)
th.start()
return th # pragma: no cover
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self._async['sleep'](seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object using the appropriate async model.
This is a utility function that applications can use to create a queue
without having to worry about using the correct call for the selected
async mode.
"""
return self._async['queue'](*args, **kwargs)
def get_queue_empty_exception(self):
"""Return the queue empty exception for the appropriate async model.
This is a utility function that applications can use to work with a
queue without having to worry about using the correct call for the
selected async mode.
"""
return self._async['queue_empty']
def create_event(self, *args, **kwargs):
"""Create an event object using the appropriate async model.
This is a utility function that applications can use to create an
event without having to worry about using the correct call for the
selected async mode.
"""
return self._async['event'](*args, **kwargs)
def _generate_id(self):
"""Generate a unique session id."""
return uuid.uuid4().hex
def _handle_connect(self, environ, start_response, transport, b64=False,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)
sid = self._generate_id()
s = socket.Socket(self, sid)
self.sockets[sid] = s
pkt = packet.Packet(
packet.OPEN, {'sid': sid,
'upgrades': self._upgrades(sid, transport),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(self.ping_interval * 1000)})
s.send(pkt)
ret = self._trigger_event('connect', sid, environ, run_async=False)
if ret is False:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized()
if transport == 'websocket':
ret = s.handle_get_request(environ, start_response)
if s.closed:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
try:
return self._ok(s.poll(), headers=headers, b64=b64,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()
def _upgrades(self, sid, transport):
"""Return the list of possible upgrades for a client connection."""
if not self.allow_upgrades or self._get_socket(sid).upgraded or \
self._async['websocket'] is None or transport == 'websocket':
return []
return ['websocket']
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
if run_async:
return self.start_background_task(self.handlers[event], *args)
else:
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
def _get_socket(self, sid):
"""Return the socket object for a given session."""
try:
s = self.sockets[sid]
except KeyError:
raise KeyError('Session not found')
if s.closed:
del self.sockets[sid]
raise KeyError('Session is disconnected')
return s
def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None):
"""Generate a successful HTTP response."""
if packets is not None:
if headers is None:
headers = []
if b64:
headers += [('Content-Type', 'text/plain; charset=UTF-8')]
else:
headers += [('Content-Type', 'application/octet-stream')]
return {'status': '200 OK',
'headers': headers,
'response': payload.Payload(packets=packets).encode(
b64=b64, jsonp_index=jsonp_index)}
else:
return {'status': '200 OK',
'headers': [('Content-Type', 'text/plain')],
'response': b'OK'}
def _bad_request(self):
"""Generate a bad request HTTP error response."""
return {'status': '400 BAD REQUEST',
'headers': [('Content-Type', 'text/plain')],
'response': b'Bad Request'}
def _method_not_found(self):
"""Generate a method not found HTTP error response."""
return {'status': '405 METHOD NOT FOUND',
'headers': [('Content-Type', 'text/plain')],
'response': b'Method Not Found'}
def _unauthorized(self):
"""Generate a unauthorized HTTP error response."""
return {'status': '401 UNAUTHORIZED',
'headers': [('Content-Type', 'text/plain')],
'response': b'Unauthorized'}
def _cors_allowed_origins(self, environ):
default_origins = []
if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ:
default_origins.append('{scheme}://{host}'.format(
scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST']))
if 'HTTP_X_FORWARDED_HOST' in environ:
scheme = environ.get(
'HTTP_X_FORWARDED_PROTO',
environ['wsgi.url_scheme']).split(',')[0].strip()
default_origins.append('{scheme}://{host}'.format(
scheme=scheme, host=environ['HTTP_X_FORWARDED_HOST'].split(
',')[0].strip()))
if self.cors_allowed_origins is None:
allowed_origins = default_origins
elif self.cors_allowed_origins == '*':
allowed_origins = None
elif isinstance(self.cors_allowed_origins, six.string_types):
allowed_origins = [self.cors_allowed_origins]
else:
allowed_origins = self.cors_allowed_origins
return allowed_origins
def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if self.cors_allowed_origins == []:
# special case, CORS handling is completely disabled
return []
headers = []
allowed_origins = self._cors_allowed_origins(environ)
if 'HTTP_ORIGIN' in environ and \
(allowed_origins is None or environ['HTTP_ORIGIN'] in
allowed_origins):
headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
if environ['REQUEST_METHOD'] == 'OPTIONS':
headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')]
if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ:
headers += [('Access-Control-Allow-Headers',
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
if self.cors_credentials:
headers += [('Access-Control-Allow-Credentials', 'true')]
return headers
def _gzip(self, response):
"""Apply gzip compression to a response."""
bytesio = six.BytesIO()
with gzip.GzipFile(fileobj=bytesio, mode='w') as gz:
gz.write(response)
return bytesio.getvalue()
def _deflate(self, response):
"""Apply deflate compression to a response."""
return zlib.compress(response)
def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
if len(self.sockets) == 0:
# nothing to do
self.sleep(self.ping_timeout)
continue
# go through the entire client list in a ping interval cycle
sleep_interval = float(self.ping_timeout) / len(self.sockets)
try:
# iterate over the current clients
for s in self.sockets.copy().values():
if not s.closing and not s.closed:
s.check_ping_timeout()
self.sleep(sleep_interval)
except (SystemExit, KeyboardInterrupt):
self.logger.info('service task canceled')
break
except:
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')

248
libs/engineio/socket.py Normal file
View file

@ -0,0 +1,248 @@
import six
import sys
import time
from . import exceptions
from . import packet
from . import payload
class Socket(object):
"""An Engine.IO socket."""
upgrade_protocols = ['websocket']
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.queue = self.server.create_queue()
self.last_ping = time.time()
self.connected = False
self.upgrading = False
self.upgraded = False
self.packet_backlog = []
self.closing = False
self.closed = False
self.session = {}
def poll(self):
"""Wait for packets to send to the client."""
queue_empty = self.server.get_queue_empty_exception()
try:
packets = [self.queue.get(timeout=self.server.ping_timeout)]
self.queue.task_done()
except queue_empty:
raise exceptions.QueueEmpty()
if packets == [None]:
return []
while True:
try:
packets.append(self.queue.get(block=False))
self.queue.task_done()
except queue_empty:
break
return packets
def receive(self, pkt):
"""Receive packet from the client."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.server.logger.info('%s: Received packet %s data %s',
self.sid, packet_name,
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
if pkt.packet_type == packet.PING:
self.last_ping = time.time()
self.send(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.MESSAGE:
self.server._trigger_event('message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
self.close(wait=False, abort=True)
else:
raise exceptions.UnknownPacketError()
def check_ping_timeout(self):
"""Make sure the client is still sending pings.
This helps detect disconnections for long-polling clients.
"""
if self.closed:
raise exceptions.SocketIsClosedError()
if time.time() - self.last_ping > self.server.ping_interval + \
self.server.ping_interval_grace_period:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
# Passing abort=False here will cause close() to write a
# CLOSE packet. This has the effect of updating half-open sockets
# to their correct state of disconnected
self.close(wait=False, abort=False)
return False
return True
def send(self, pkt):
"""Send a packet to the client."""
if not self.check_ping_timeout():
return
if self.upgrading:
self.packet_backlog.append(pkt)
else:
self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
def handle_get_request(self, environ, start_response):
"""Handle a long-polling GET request from the client."""
connections = [
s.strip()
for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
transport = environ.get('HTTP_UPGRADE', '').lower()
if 'upgrade' in connections and transport in self.upgrade_protocols:
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return getattr(self, '_upgrade_' + transport)(environ,
start_response)
try:
packets = self.poll()
except exceptions.QueueEmpty:
exc = sys.exc_info()
self.close(wait=False)
six.reraise(*exc)
return packets
def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise exceptions.ContentTooLongError()
else:
body = environ['wsgi.input'].read(length)
p = payload.Payload(encoded_payload=body)
for pkt in p.packets:
self.receive(pkt)
def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
self.server._trigger_event('disconnect', self.sid, run_async=False)
if not abort:
self.send(packet.Packet(packet.CLOSE))
self.closed = True
self.queue.put(None)
if wait:
self.queue.join()
def _upgrade_websocket(self, environ, start_response):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise IOError('Socket has been upgraded already')
if self.server._async['websocket'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
ws = self.server._async['websocket'](self._websocket_handler)
return ws(environ, start_response)
def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
# try to set a socket timeout matching the configured ping interval
for attr in ['_sock', 'socket']: # pragma: no cover
if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
getattr(ws, attr).settimeout(self.server.ping_timeout)
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
pkt = ws.wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
return []
ws.send(packet.Packet(
packet.PONG,
data=six.text_type('probe')).encode(always_bytes=False))
self.queue.put(packet.Packet(packet.NOOP)) # end poll
pkt = ws.wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
return []
self.upgraded = True
# flush any packets that were sent during the upgrade
for pkt in self.packet_backlog:
self.queue.put(pkt)
self.packet_backlog = []
self.upgrading = False
else:
self.connected = True
self.upgraded = True
# start separate writer thread
def writer():
while True:
packets = None
try:
packets = self.poll()
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
break
try:
for pkt in packets:
ws.send(pkt.encode(always_bytes=False))
except:
break
writer_task = self.server.start_background_task(writer)
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
while True:
p = None
try:
p = ws.wait()
except Exception as e:
# if the socket is already closed, we can assume this is a
# downstream error of that
if not self.closed: # pragma: no cover
self.server.logger.info(
'%s: Unexpected error "%s", closing connection',
self.sid, str(e))
break
if p is None:
# connection closed by client
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
try:
self.receive(pkt)
except exceptions.UnknownPacketError: # pragma: no cover
pass
except exceptions.SocketIsClosedError: # pragma: no cover
self.server.logger.info('Receive error -- socket is closed')
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Unknown receive error')
break
self.queue.put(None) # unlock the writer task so that it can exit
writer_task.join()
self.close(wait=False, abort=True)
return []

View file

@ -0,0 +1,55 @@
content_types = {
'css': 'text/css',
'gif': 'image/gif',
'html': 'text/html',
'jpg': 'image/jpeg',
'js': 'application/javascript',
'json': 'application/json',
'png': 'image/png',
'txt': 'text/plain',
}
def get_static_file(path, static_files):
"""Return the local filename and content type for the requested static
file URL.
:param path: the path portion of the requested URL.
:param static_files: a static file configuration dictionary.
This function returns a dictionary with two keys, "filename" and
"content_type". If the requested URL does not match any static file, the
return value is None.
"""
if path in static_files:
f = static_files[path]
else:
f = None
rest = ''
while path != '':
path, last = path.rsplit('/', 1)
rest = '/' + last + rest
if path in static_files:
f = static_files[path] + rest
break
elif path + '/' in static_files:
f = static_files[path + '/'] + rest[1:]
break
if f:
if isinstance(f, str):
f = {'filename': f}
if f['filename'].endswith('/'):
if '' in static_files:
if isinstance(static_files[''], str):
f['filename'] += static_files['']
else:
f['filename'] += static_files['']['filename']
if 'content_type' in static_files['']:
f['content_type'] = static_files['']['content_type']
else:
f['filename'] += 'index.html'
if 'content_type' not in f:
ext = f['filename'].rsplit('.')[-1]
f['content_type'] = content_types.get(
ext, 'application/octet-stream')
return f

View file

@ -0,0 +1,922 @@
from functools import wraps
import os
import sys
# make sure gevent-socketio is not installed, as it conflicts with
# python-socketio
gevent_socketio_found = True
try:
from socketio import socketio_manage
except ImportError:
gevent_socketio_found = False
if gevent_socketio_found:
print('The gevent-socketio package is incompatible with this version of '
'the Flask-SocketIO extension. Please uninstall it, and then '
'install the latest version of python-socketio in its place.')
sys.exit(1)
import flask
from flask import _request_ctx_stack, json as flask_json
from flask.sessions import SessionMixin
import socketio
from socketio.exceptions import ConnectionRefusedError
from werkzeug.debug import DebuggedApplication
from werkzeug.serving import run_with_reloader
from .namespace import Namespace
from .test_client import SocketIOTestClient
__version__ = '4.2.1'
class _SocketIOMiddleware(socketio.WSGIApp):
"""This WSGI middleware simply exposes the Flask application in the WSGI
environment before executing the request.
"""
def __init__(self, socketio_app, flask_app, socketio_path='socket.io'):
self.flask_app = flask_app
super(_SocketIOMiddleware, self).__init__(socketio_app,
flask_app.wsgi_app,
socketio_path=socketio_path)
def __call__(self, environ, start_response):
environ = environ.copy()
environ['flask.app'] = self.flask_app
return super(_SocketIOMiddleware, self).__call__(environ,
start_response)
class _ManagedSession(dict, SessionMixin):
"""This class is used for user sessions that are managed by
Flask-SocketIO. It is simple dict, expanded with the Flask session
attributes."""
pass
class SocketIO(object):
"""Create a Flask-SocketIO server.
:param app: The flask application instance. If the application instance
isn't known at the time this class is instantiated, then call
``socketio.init_app(app)`` once the application instance is
available.
:param manage_session: If set to ``True``, this extension manages the user
session for Socket.IO events. If set to ``False``,
Flask's own session management is used. When using
Flask's cookie based sessions it is recommended that
you leave this set to the default of ``True``. When
using server-side sessions, a ``False`` setting
enables sharing the user session between HTTP routes
and Socket.IO events.
:param message_queue: A connection URL for a message queue service the
server can use for multi-process communication. A
message queue is not required when using a single
server process.
:param channel: The channel name, when using a message queue. If a channel
isn't specified, a default channel will be used. If
multiple clusters of SocketIO processes need to use the
same message queue without interfering with each other, then
each cluster should use a different channel.
:param path: The path where the Socket.IO server is exposed. Defaults to
``'socket.io'``. Leave this as is unless you know what you are
doing.
:param resource: Alias to ``path``.
:param kwargs: Socket.IO and Engine.IO server options.
The Socket.IO server options are detailed below:
:param client_manager: The client manager instance that will manage the
client list. When this is omitted, the client list
is stored in an in-memory structure, so the use of
multiple connected servers is not possible. In most
cases, this argument does not need to be set
explicitly.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions. To use the same json encoder and decoder as a Flask
application, use ``flask.json``.
:param async_handlers: If set to ``True``, event handlers for a client are
executed in separate threads. To run handlers for a
client synchronously, set to ``False``. The default
is ``True``.
:param always_connect: When set to ``False``, new connections are
provisory until the connect handler returns
something other than ``False``, at which point they
are accepted. When set to ``True``, connections are
immediately accepted, and then if the connect
handler returns ``False`` a disconnect is issued.
Set to ``True`` if you need to emit events from the
connect handler and your client is confused when it
receives events before the connection acceptance.
In any other case use the default of ``False``.
The Engine.IO server configuration supports the following settings:
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are
``threading``, ``eventlet``, ``gevent`` and
``gevent_uwsgi``. If this argument is not given,
``eventlet`` is tried first, then ``gevent_uwsgi``,
then ``gevent``, and finally ``threading``. The
first async mode that has all its dependencies installed
is then one that is chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default is
60 seconds.
:param ping_interval: The interval in seconds at which the client pings
the server. The default is 25 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 100,000,000
bytes.
:param allow_upgrades: Whether to allow transport upgrades or not. The
default is ``True``.
:param http_compression: Whether to compress packages when using the
polling transport. The default is ``True``.
:param compression_threshold: Only compress messages when their byte size
is greater than this value. The default is
1024 bytes.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default is
``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def __init__(self, app=None, **kwargs):
self.server = None
self.server_options = {}
self.wsgi_server = None
self.handlers = []
self.namespace_handlers = []
self.exception_handlers = {}
self.default_exception_handler = None
self.manage_session = True
# We can call init_app when:
# - we were given the Flask app instance (standard initialization)
# - we were not given the app, but we were given a message_queue
# (standard initialization for auxiliary process)
# In all other cases we collect the arguments and assume the client
# will call init_app from an app factory function.
if app is not None or 'message_queue' in kwargs:
self.init_app(app, **kwargs)
else:
self.server_options.update(kwargs)
def init_app(self, app, **kwargs):
if app is not None:
if not hasattr(app, 'extensions'):
app.extensions = {} # pragma: no cover
app.extensions['socketio'] = self
self.server_options.update(kwargs)
self.manage_session = self.server_options.pop('manage_session',
self.manage_session)
if 'client_manager' not in self.server_options:
url = self.server_options.pop('message_queue', None)
channel = self.server_options.pop('channel', 'flask-socketio')
write_only = app is None
if url:
if url.startswith(('redis://', "rediss://")):
queue_class = socketio.RedisManager
elif url.startswith(('kafka://')):
queue_class = socketio.KafkaManager
elif url.startswith('zmq'):
queue_class = socketio.ZmqManager
else:
queue_class = socketio.KombuManager
queue = queue_class(url, channel=channel,
write_only=write_only)
self.server_options['client_manager'] = queue
if 'json' in self.server_options and \
self.server_options['json'] == flask_json:
# flask's json module is tricky to use because its output
# changes when it is invoked inside or outside the app context
# so here to prevent any ambiguities we replace it with wrappers
# that ensure that the app context is always present
class FlaskSafeJSON(object):
@staticmethod
def dumps(*args, **kwargs):
with app.app_context():
return flask_json.dumps(*args, **kwargs)
@staticmethod
def loads(*args, **kwargs):
with app.app_context():
return flask_json.loads(*args, **kwargs)
self.server_options['json'] = FlaskSafeJSON
resource = self.server_options.pop('path', None) or \
self.server_options.pop('resource', None) or 'socket.io'
if resource.startswith('/'):
resource = resource[1:]
if os.environ.get('FLASK_RUN_FROM_CLI'):
if self.server_options.get('async_mode') is None:
if app is not None:
app.logger.warning(
'Flask-SocketIO is Running under Werkzeug, WebSocket '
'is not available.')
self.server_options['async_mode'] = 'threading'
self.server = socketio.Server(**self.server_options)
self.async_mode = self.server.async_mode
for handler in self.handlers:
self.server.on(handler[0], handler[1], namespace=handler[2])
for namespace_handler in self.namespace_handlers:
self.server.register_namespace(namespace_handler)
if app is not None:
# here we attach the SocketIO middlware to the SocketIO object so it
# can be referenced later if debug middleware needs to be inserted
self.sockio_mw = _SocketIOMiddleware(self.server, app,
socketio_path=resource)
app.wsgi_app = self.sockio_mw
def on(self, message, namespace=None):
"""Decorator to register a SocketIO event handler.
This decorator must be applied to SocketIO event handlers. Example::
@socketio.on('my event', namespace='/chat')
def handle_my_custom_event(json):
print('received json: ' + str(json))
:param message: The name of the event. This is normally a user defined
string, but a few event names are already defined. Use
``'message'`` to define a handler that takes a string
payload, ``'json'`` to define a handler that takes a
JSON blob payload, ``'connect'`` or ``'disconnect'``
to create handlers for connection and disconnection
events.
:param namespace: The namespace on which the handler is to be
registered. Defaults to the global namespace.
"""
namespace = namespace or '/'
def decorator(handler):
@wraps(handler)
def _handler(sid, *args):
return self._handle_event(handler, message, namespace, sid,
*args)
if self.server:
self.server.on(message, _handler, namespace=namespace)
else:
self.handlers.append((message, _handler, namespace))
return handler
return decorator
def on_error(self, namespace=None):
"""Decorator to define a custom error handler for SocketIO events.
This decorator can be applied to a function that acts as an error
handler for a namespace. This handler will be invoked when a SocketIO
event handler raises an exception. The handler function must accept one
argument, which is the exception raised. Example::
@socketio.on_error(namespace='/chat')
def chat_error_handler(e):
print('An error has occurred: ' + str(e))
:param namespace: The namespace for which to register the error
handler. Defaults to the global namespace.
"""
namespace = namespace or '/'
def decorator(exception_handler):
if not callable(exception_handler):
raise ValueError('exception_handler must be callable')
self.exception_handlers[namespace] = exception_handler
return exception_handler
return decorator
def on_error_default(self, exception_handler):
"""Decorator to define a default error handler for SocketIO events.
This decorator can be applied to a function that acts as a default
error handler for any namespaces that do not have a specific handler.
Example::
@socketio.on_error_default
def error_handler(e):
print('An error has occurred: ' + str(e))
"""
if not callable(exception_handler):
raise ValueError('exception_handler must be callable')
self.default_exception_handler = exception_handler
return exception_handler
def on_event(self, message, handler, namespace=None):
"""Register a SocketIO event handler.
``on_event`` is the non-decorator version of ``'on'``.
Example::
def on_foo_event(json):
print('received json: ' + str(json))
socketio.on_event('my event', on_foo_event, namespace='/chat')
:param message: The name of the event. This is normally a user defined
string, but a few event names are already defined. Use
``'message'`` to define a handler that takes a string
payload, ``'json'`` to define a handler that takes a
JSON blob payload, ``'connect'`` or ``'disconnect'``
to create handlers for connection and disconnection
events.
:param handler: The function that handles the event.
:param namespace: The namespace on which the handler is to be
registered. Defaults to the global namespace.
"""
self.on(message, namespace=namespace)(handler)
def on_namespace(self, namespace_handler):
if not isinstance(namespace_handler, Namespace):
raise ValueError('Not a namespace instance.')
namespace_handler._set_socketio(self)
if self.server:
self.server.register_namespace(namespace_handler)
else:
self.namespace_handlers.append(namespace_handler)
def emit(self, event, *args, **kwargs):
"""Emit a server generated SocketIO event.
This function emits a SocketIO event to one or more connected clients.
A JSON blob can be attached to the event as payload. This function can
be used outside of a SocketIO event context, so it is appropriate to
use when the server is the originator of an event, outside of any
client context, such as in a regular HTTP request handler or a
background task. Example::
@app.route('/ping')
def ping():
socketio.emit('ping event', {'data': 42}, namespace='/chat')
:param event: The name of the user event to emit.
:param args: A dictionary with the JSON data to send as payload.
:param namespace: The namespace under which the message is to be sent.
Defaults to the global namespace.
:param room: Send the message to all the users in the given room. If
this parameter is not included, the event is sent to
all connected users.
:param skip_sid: The session id of a client to ignore when broadcasting
or addressing a room. This is typically set to the
originator of the message, so that everyone except
that client receive the message. To skip multiple sids
pass a list.
:param callback: If given, this function will be called to acknowledge
that the client has received the message. The
arguments that will be passed to the function are
those provided by the client. Callback functions can
only be used when addressing an individual client.
"""
namespace = kwargs.pop('namespace', '/')
room = kwargs.pop('room', None)
include_self = kwargs.pop('include_self', True)
skip_sid = kwargs.pop('skip_sid', None)
if not include_self and not skip_sid:
skip_sid = flask.request.sid
callback = kwargs.pop('callback', None)
if callback:
# wrap the callback so that it sets app app and request contexts
sid = flask.request.sid
original_callback = callback
def _callback_wrapper(*args):
return self._handle_event(original_callback, None, namespace,
sid, *args)
callback = _callback_wrapper
self.server.emit(event, *args, namespace=namespace, room=room,
skip_sid=skip_sid, callback=callback, **kwargs)
def send(self, data, json=False, namespace=None, room=None,
callback=None, include_self=True, skip_sid=None, **kwargs):
"""Send a server-generated SocketIO message.
This function sends a simple SocketIO message to one or more connected
clients. The message can be a string or a JSON blob. This is a simpler
version of ``emit()``, which should be preferred. This function can be
used outside of a SocketIO event context, so it is appropriate to use
when the server is the originator of an event.
:param data: The message to send, either a string or a JSON blob.
:param json: ``True`` if ``message`` is a JSON blob, ``False``
otherwise.
:param namespace: The namespace under which the message is to be sent.
Defaults to the global namespace.
:param room: Send the message only to the users in the given room. If
this parameter is not included, the message is sent to
all connected users.
:param skip_sid: The session id of a client to ignore when broadcasting
or addressing a room. This is typically set to the
originator of the message, so that everyone except
that client receive the message. To skip multiple sids
pass a list.
:param callback: If given, this function will be called to acknowledge
that the client has received the message. The
arguments that will be passed to the function are
those provided by the client. Callback functions can
only be used when addressing an individual client.
"""
skip_sid = flask.request.sid if not include_self else skip_sid
if json:
self.emit('json', data, namespace=namespace, room=room,
skip_sid=skip_sid, callback=callback, **kwargs)
else:
self.emit('message', data, namespace=namespace, room=room,
skip_sid=skip_sid, callback=callback, **kwargs)
def close_room(self, room, namespace=None):
"""Close a room.
This function removes any users that are in the given room and then
deletes the room from the server. This function can be used outside
of a SocketIO event context.
:param room: The name of the room to close.
:param namespace: The namespace under which the room exists. Defaults
to the global namespace.
"""
self.server.close_room(room, namespace)
def run(self, app, host=None, port=None, **kwargs):
"""Run the SocketIO web server.
:param app: The Flask application instance.
:param host: The hostname or IP address for the server to listen on.
Defaults to 127.0.0.1.
:param port: The port number for the server to listen on. Defaults to
5000.
:param debug: ``True`` to start the server in debug mode, ``False`` to
start in normal mode.
:param use_reloader: ``True`` to enable the Flask reloader, ``False``
to disable it.
:param extra_files: A list of additional files that the Flask
reloader should watch. Defaults to ``None``
:param log_output: If ``True``, the server logs all incomming
connections. If ``False`` logging is disabled.
Defaults to ``True`` in debug mode, ``False``
in normal mode. Unused when the threading async
mode is used.
:param kwargs: Additional web server options. The web server options
are specific to the server used in each of the supported
async modes. Note that options provided here will
not be seen when using an external web server such
as gunicorn, since this method is not called in that
case.
"""
if host is None:
host = '127.0.0.1'
if port is None:
server_name = app.config['SERVER_NAME']
if server_name and ':' in server_name:
port = int(server_name.rsplit(':', 1)[1])
else:
port = 5000
debug = kwargs.pop('debug', app.debug)
log_output = kwargs.pop('log_output', debug)
use_reloader = kwargs.pop('use_reloader', debug)
extra_files = kwargs.pop('extra_files', None)
app.debug = debug
if app.debug and self.server.eio.async_mode != 'threading':
# put the debug middleware between the SocketIO middleware
# and the Flask application instance
#
# mw1 mw2 mw3 Flask app
# o ---- o ---- o ---- o
# /
# o Flask-SocketIO
# \ middleware
# o
# Flask-SocketIO WebSocket handler
#
# BECOMES
#
# dbg-mw mw1 mw2 mw3 Flask app
# o ---- o ---- o ---- o ---- o
# /
# o Flask-SocketIO
# \ middleware
# o
# Flask-SocketIO WebSocket handler
#
self.sockio_mw.wsgi_app = DebuggedApplication(self.sockio_mw.wsgi_app,
evalex=True)
if self.server.eio.async_mode == 'threading':
from werkzeug._internal import _log
_log('warning', 'WebSocket transport not available. Install '
'eventlet or gevent and gevent-websocket for '
'improved performance.')
app.run(host=host, port=port, threaded=True,
use_reloader=use_reloader, **kwargs)
elif self.server.eio.async_mode == 'eventlet':
def run_server():
import eventlet
import eventlet.wsgi
import eventlet.green
addresses = eventlet.green.socket.getaddrinfo(host, port)
if not addresses:
raise RuntimeError('Could not resolve host to a valid address')
eventlet_socket = eventlet.listen(addresses[0][4], addresses[0][0])
# If provided an SSL argument, use an SSL socket
ssl_args = ['keyfile', 'certfile', 'server_side', 'cert_reqs',
'ssl_version', 'ca_certs',
'do_handshake_on_connect', 'suppress_ragged_eofs',
'ciphers']
ssl_params = {k: kwargs[k] for k in kwargs if k in ssl_args}
if len(ssl_params) > 0:
for k in ssl_params:
kwargs.pop(k)
ssl_params['server_side'] = True # Listening requires true
eventlet_socket = eventlet.wrap_ssl(eventlet_socket,
**ssl_params)
eventlet.wsgi.server(eventlet_socket, app,
log_output=log_output, **kwargs)
if use_reloader:
run_with_reloader(run_server, extra_files=extra_files)
else:
run_server()
elif self.server.eio.async_mode == 'gevent':
from gevent import pywsgi
try:
from geventwebsocket.handler import WebSocketHandler
websocket = True
except ImportError:
websocket = False
log = 'default'
if not log_output:
log = None
if websocket:
self.wsgi_server = pywsgi.WSGIServer(
(host, port), app, handler_class=WebSocketHandler,
log=log, **kwargs)
else:
self.wsgi_server = pywsgi.WSGIServer((host, port), app,
log=log, **kwargs)
if use_reloader:
# monkey patching is required by the reloader
from gevent import monkey
monkey.patch_thread()
monkey.patch_time()
def run_server():
self.wsgi_server.serve_forever()
run_with_reloader(run_server, extra_files=extra_files)
else:
self.wsgi_server.serve_forever()
def stop(self):
"""Stop a running SocketIO web server.
This method must be called from a HTTP or SocketIO handler function.
"""
if self.server.eio.async_mode == 'threading':
func = flask.request.environ.get('werkzeug.server.shutdown')
if func:
func()
else:
raise RuntimeError('Cannot stop unknown web server')
elif self.server.eio.async_mode == 'eventlet':
raise SystemExit
elif self.server.eio.async_mode == 'gevent':
self.wsgi_server.stop()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.server.start_background_task(target, *args, **kwargs)
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self.server.sleep(seconds)
def test_client(self, app, namespace=None, query_string=None,
headers=None, flask_test_client=None):
"""The Socket.IO test client is useful for testing a Flask-SocketIO
server. It works in a similar way to the Flask Test Client, but
adapted to the Socket.IO server.
:param app: The Flask application instance.
:param namespace: The namespace for the client. If not provided, the
client connects to the server on the global
namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
:param flask_test_client: The instance of the Flask test client
currently in use. Passing the Flask test
client is optional, but is necessary if you
want the Flask user session and any other
cookies set in HTTP routes accessible from
Socket.IO events.
"""
return SocketIOTestClient(app, self, namespace=namespace,
query_string=query_string, headers=headers,
flask_test_client=flask_test_client)
def _handle_event(self, handler, message, namespace, sid, *args):
if sid not in self.server.environ:
# we don't have record of this client, ignore this event
return '', 400
app = self.server.environ[sid]['flask.app']
with app.request_context(self.server.environ[sid]):
if self.manage_session:
# manage a separate session for this client's Socket.IO events
# created as a copy of the regular user session
if 'saved_session' not in self.server.environ[sid]:
self.server.environ[sid]['saved_session'] = \
_ManagedSession(flask.session)
session_obj = self.server.environ[sid]['saved_session']
else:
# let Flask handle the user session
# for cookie based sessions, this effectively freezes the
# session to its state at connection time
# for server-side sessions, this allows HTTP and Socket.IO to
# share the session, with both having read/write access to it
session_obj = flask.session._get_current_object()
_request_ctx_stack.top.session = session_obj
flask.request.sid = sid
flask.request.namespace = namespace
flask.request.event = {'message': message, 'args': args}
try:
if message == 'connect':
ret = handler()
else:
ret = handler(*args)
except:
err_handler = self.exception_handlers.get(
namespace, self.default_exception_handler)
if err_handler is None:
raise
type, value, traceback = sys.exc_info()
return err_handler(value)
if not self.manage_session:
# when Flask is managing the user session, it needs to save it
if not hasattr(session_obj, 'modified') or session_obj.modified:
resp = app.response_class()
app.session_interface.save_session(app, session_obj, resp)
return ret
def emit(event, *args, **kwargs):
"""Emit a SocketIO event.
This function emits a SocketIO event to one or more connected clients. A
JSON blob can be attached to the event as payload. This is a function that
can only be called from a SocketIO event handler, as in obtains some
information from the current client context. Example::
@socketio.on('my event')
def handle_my_custom_event(json):
emit('my response', {'data': 42})
:param event: The name of the user event to emit.
:param args: A dictionary with the JSON data to send as payload.
:param namespace: The namespace under which the message is to be sent.
Defaults to the namespace used by the originating event.
A ``'/'`` can be used to explicitly specify the global
namespace.
:param callback: Callback function to invoke with the client's
acknowledgement.
:param broadcast: ``True`` to send the message to all clients, or ``False``
to only reply to the sender of the originating event.
:param room: Send the message to all the users in the given room. If this
argument is set, then broadcast is implied to be ``True``.
:param include_self: ``True`` to include the sender when broadcasting or
addressing a room, or ``False`` to send to everyone
but the sender.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used, or when there is a
single addresee. It is recommended to always leave
this parameter with its default value of ``False``.
"""
if 'namespace' in kwargs:
namespace = kwargs['namespace']
else:
namespace = flask.request.namespace
callback = kwargs.get('callback')
broadcast = kwargs.get('broadcast')
room = kwargs.get('room')
if room is None and not broadcast:
room = flask.request.sid
include_self = kwargs.get('include_self', True)
ignore_queue = kwargs.get('ignore_queue', False)
socketio = flask.current_app.extensions['socketio']
return socketio.emit(event, *args, namespace=namespace, room=room,
include_self=include_self, callback=callback,
ignore_queue=ignore_queue)
def send(message, **kwargs):
"""Send a SocketIO message.
This function sends a simple SocketIO message to one or more connected
clients. The message can be a string or a JSON blob. This is a simpler
version of ``emit()``, which should be preferred. This is a function that
can only be called from a SocketIO event handler.
:param message: The message to send, either a string or a JSON blob.
:param json: ``True`` if ``message`` is a JSON blob, ``False``
otherwise.
:param namespace: The namespace under which the message is to be sent.
Defaults to the namespace used by the originating event.
An empty string can be used to use the global namespace.
:param callback: Callback function to invoke with the client's
acknowledgement.
:param broadcast: ``True`` to send the message to all connected clients, or
``False`` to only reply to the sender of the originating
event.
:param room: Send the message to all the users in the given room.
:param include_self: ``True`` to include the sender when broadcasting or
addressing a room, or ``False`` to send to everyone
but the sender.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used, or when there is a
single addresee. It is recommended to always leave
this parameter with its default value of ``False``.
"""
json = kwargs.get('json', False)
if 'namespace' in kwargs:
namespace = kwargs['namespace']
else:
namespace = flask.request.namespace
callback = kwargs.get('callback')
broadcast = kwargs.get('broadcast')
room = kwargs.get('room')
if room is None and not broadcast:
room = flask.request.sid
include_self = kwargs.get('include_self', True)
ignore_queue = kwargs.get('ignore_queue', False)
socketio = flask.current_app.extensions['socketio']
return socketio.send(message, json=json, namespace=namespace, room=room,
include_self=include_self, callback=callback,
ignore_queue=ignore_queue)
def join_room(room, sid=None, namespace=None):
"""Join a room.
This function puts the user in a room, under the current namespace. The
user and the namespace are obtained from the event context. This is a
function that can only be called from a SocketIO event handler. Example::
@socketio.on('join')
def on_join(data):
username = session['username']
room = data['room']
join_room(room)
send(username + ' has entered the room.', room=room)
:param room: The name of the room to join.
:param sid: The session id of the client. If not provided, the client is
obtained from the request context.
:param namespace: The namespace for the room. If not provided, the
namespace is obtained from the request context.
"""
socketio = flask.current_app.extensions['socketio']
sid = sid or flask.request.sid
namespace = namespace or flask.request.namespace
socketio.server.enter_room(sid, room, namespace=namespace)
def leave_room(room, sid=None, namespace=None):
"""Leave a room.
This function removes the user from a room, under the current namespace.
The user and the namespace are obtained from the event context. Example::
@socketio.on('leave')
def on_leave(data):
username = session['username']
room = data['room']
leave_room(room)
send(username + ' has left the room.', room=room)
:param room: The name of the room to leave.
:param sid: The session id of the client. If not provided, the client is
obtained from the request context.
:param namespace: The namespace for the room. If not provided, the
namespace is obtained from the request context.
"""
socketio = flask.current_app.extensions['socketio']
sid = sid or flask.request.sid
namespace = namespace or flask.request.namespace
socketio.server.leave_room(sid, room, namespace=namespace)
def close_room(room, namespace=None):
"""Close a room.
This function removes any users that are in the given room and then deletes
the room from the server.
:param room: The name of the room to close.
:param namespace: The namespace for the room. If not provided, the
namespace is obtained from the request context.
"""
socketio = flask.current_app.extensions['socketio']
namespace = namespace or flask.request.namespace
socketio.server.close_room(room, namespace=namespace)
def rooms(sid=None, namespace=None):
"""Return a list of the rooms the client is in.
This function returns all the rooms the client has entered, including its
own room, assigned by the Socket.IO server.
:param sid: The session id of the client. If not provided, the client is
obtained from the request context.
:param namespace: The namespace for the room. If not provided, the
namespace is obtained from the request context.
"""
socketio = flask.current_app.extensions['socketio']
sid = sid or flask.request.sid
namespace = namespace or flask.request.namespace
return socketio.server.rooms(sid, namespace=namespace)
def disconnect(sid=None, namespace=None, silent=False):
"""Disconnect the client.
This function terminates the connection with the client. As a result of
this call the client will receive a disconnect event. Example::
@socketio.on('message')
def receive_message(msg):
if is_banned(session['username']):
disconnect()
else:
# ...
:param sid: The session id of the client. If not provided, the client is
obtained from the request context.
:param namespace: The namespace for the room. If not provided, the
namespace is obtained from the request context.
:param silent: this option is deprecated.
"""
socketio = flask.current_app.extensions['socketio']
sid = sid or flask.request.sid
namespace = namespace or flask.request.namespace
return socketio.server.disconnect(sid, namespace=namespace)

View file

@ -0,0 +1,47 @@
from socketio import Namespace as _Namespace
class Namespace(_Namespace):
def __init__(self, namespace=None):
super(Namespace, self).__init__(namespace)
self.socketio = None
def _set_socketio(self, socketio):
self.socketio = socketio
def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
"""
handler_name = 'on_' + event
if not hasattr(self, handler_name):
# there is no handler for this event, so we ignore it
return
handler = getattr(self, handler_name)
return self.socketio._handle_event(handler, event, self.namespace,
*args)
def emit(self, event, data=None, room=None, include_self=True,
namespace=None, callback=None):
"""Emit a custom event to one or more connected clients."""
return self.socketio.emit(event, data, room=room,
include_self=include_self,
namespace=namespace or self.namespace,
callback=callback)
def send(self, data, room=None, include_self=True, namespace=None,
callback=None):
"""Send a message to one or more connected clients."""
return self.socketio.send(data, room=room, include_self=include_self,
namespace=namespace or self.namespace,
callback=callback)
def close_room(self, room, namespace=None):
"""Close a room."""
return self.socketio.close_room(room=room,
namespace=namespace or self.namespace)

View file

@ -0,0 +1,205 @@
import uuid
from socketio import packet
from socketio.pubsub_manager import PubSubManager
from werkzeug.test import EnvironBuilder
class SocketIOTestClient(object):
"""
This class is useful for testing a Flask-SocketIO server. It works in a
similar way to the Flask Test Client, but adapted to the Socket.IO server.
:param app: The Flask application instance.
:param socketio: The application's ``SocketIO`` instance.
:param namespace: The namespace for the client. If not provided, the client
connects to the server on the global namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
:param flask_test_client: The instance of the Flask test client
currently in use. Passing the Flask test
client is optional, but is necessary if you
want the Flask user session and any other
cookies set in HTTP routes accessible from
Socket.IO events.
"""
queue = {}
acks = {}
def __init__(self, app, socketio, namespace=None, query_string=None,
headers=None, flask_test_client=None):
def _mock_send_packet(sid, pkt):
if pkt.packet_type == packet.EVENT or \
pkt.packet_type == packet.BINARY_EVENT:
if sid not in self.queue:
self.queue[sid] = []
if pkt.data[0] == 'message' or pkt.data[0] == 'json':
self.queue[sid].append({'name': pkt.data[0],
'args': pkt.data[1],
'namespace': pkt.namespace or '/'})
else:
self.queue[sid].append({'name': pkt.data[0],
'args': pkt.data[1:],
'namespace': pkt.namespace or '/'})
elif pkt.packet_type == packet.ACK or \
pkt.packet_type == packet.BINARY_ACK:
self.acks[sid] = {'args': pkt.data,
'namespace': pkt.namespace or '/'}
elif pkt.packet_type == packet.DISCONNECT:
self.connected[pkt.namespace or '/'] = False
self.app = app
self.flask_test_client = flask_test_client
self.sid = uuid.uuid4().hex
self.queue[self.sid] = []
self.acks[self.sid] = None
self.callback_counter = 0
self.socketio = socketio
self.connected = {}
socketio.server._send_packet = _mock_send_packet
socketio.server.environ[self.sid] = {}
socketio.server.async_handlers = False # easier to test when
socketio.server.eio.async_handlers = False # events are sync
if isinstance(socketio.server.manager, PubSubManager):
raise RuntimeError('Test client cannot be used with a message '
'queue. Disable the queue on your test '
'configuration.')
socketio.server.manager.initialize()
self.connect(namespace=namespace, query_string=query_string,
headers=headers)
def is_connected(self, namespace=None):
"""Check if a namespace is connected.
:param namespace: The namespace to check. The global namespace is
assumed if this argument is not provided.
"""
return self.connected.get(namespace or '/', False)
def connect(self, namespace=None, query_string=None, headers=None):
"""Connect the client.
:param namespace: The namespace for the client. If not provided, the
client connects to the server on the global
namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
Note that it is usually not necessary to explicitly call this method,
since a connection is automatically established when an instance of
this class is created. An example where it this method would be useful
is when the application accepts multiple namespace connections.
"""
url = '/socket.io'
if query_string:
if query_string[0] != '?':
query_string = '?' + query_string
url += query_string
environ = EnvironBuilder(url, headers=headers).get_environ()
environ['flask.app'] = self.app
if self.flask_test_client:
# inject cookies from Flask
self.flask_test_client.cookie_jar.inject_wsgi(environ)
self.connected['/'] = True
if self.socketio.server._handle_eio_connect(
self.sid, environ) is False:
del self.connected['/']
if namespace is not None and namespace != '/':
self.connected[namespace] = True
pkt = packet.Packet(packet.CONNECT, namespace=namespace)
with self.app.app_context():
if self.socketio.server._handle_eio_message(
self.sid, pkt.encode()) is False:
del self.connected[namespace]
def disconnect(self, namespace=None):
"""Disconnect the client.
:param namespace: The namespace to disconnect. The global namespace is
assumed if this argument is not provided.
"""
if not self.is_connected(namespace):
raise RuntimeError('not connected')
pkt = packet.Packet(packet.DISCONNECT, namespace=namespace)
with self.app.app_context():
self.socketio.server._handle_eio_message(self.sid, pkt.encode())
del self.connected[namespace or '/']
def emit(self, event, *args, **kwargs):
"""Emit an event to the server.
:param event: The event name.
:param *args: The event arguments.
:param callback: ``True`` if the client requests a callback, ``False``
if not. Note that client-side callbacks are not
implemented, a callback request will just tell the
server to provide the arguments to invoke the
callback, but no callback is invoked. Instead, the
arguments that the server provided for the callback
are returned by this function.
:param namespace: The namespace of the event. The global namespace is
assumed if this argument is not provided.
"""
namespace = kwargs.pop('namespace', None)
if not self.is_connected(namespace):
raise RuntimeError('not connected')
callback = kwargs.pop('callback', False)
id = None
if callback:
self.callback_counter += 1
id = self.callback_counter
pkt = packet.Packet(packet.EVENT, data=[event] + list(args),
namespace=namespace, id=id)
with self.app.app_context():
encoded_pkt = pkt.encode()
if isinstance(encoded_pkt, list):
for epkt in encoded_pkt:
self.socketio.server._handle_eio_message(self.sid, epkt)
else:
self.socketio.server._handle_eio_message(self.sid, encoded_pkt)
ack = self.acks.pop(self.sid, None)
if ack is not None:
return ack['args'][0] if len(ack['args']) == 1 \
else ack['args']
def send(self, data, json=False, callback=False, namespace=None):
"""Send a text or JSON message to the server.
:param data: A string, dictionary or list to send to the server.
:param json: ``True`` to send a JSON message, ``False`` to send a text
message.
:param callback: ``True`` if the client requests a callback, ``False``
if not. Note that client-side callbacks are not
implemented, a callback request will just tell the
server to provide the arguments to invoke the
callback, but no callback is invoked. Instead, the
arguments that the server provided for the callback
are returned by this function.
:param namespace: The namespace of the event. The global namespace is
assumed if this argument is not provided.
"""
if json:
msg = 'json'
else:
msg = 'message'
return self.emit(msg, data, callback=callback, namespace=namespace)
def get_received(self, namespace=None):
"""Return the list of messages received from the server.
Since this is not a real client, any time the server emits an event,
the event is simply stored. The test code can invoke this method to
obtain the list of events that were received since the last call.
:param namespace: The namespace to get events from. The global
namespace is assumed if this argument is not
provided.
"""
if not self.is_connected(namespace):
raise RuntimeError('not connected')
namespace = namespace or '/'
r = [pkt for pkt in self.queue[self.sid]
if pkt['namespace'] == namespace]
self.queue[self.sid] = [pkt for pkt in self.queue[self.sid]
if pkt not in r]
return r

38
libs/socketio/__init__.py Normal file
View file

@ -0,0 +1,38 @@
import sys
from .client import Client
from .base_manager import BaseManager
from .pubsub_manager import PubSubManager
from .kombu_manager import KombuManager
from .redis_manager import RedisManager
from .kafka_manager import KafkaManager
from .zmq_manager import ZmqManager
from .server import Server
from .namespace import Namespace, ClientNamespace
from .middleware import WSGIApp, Middleware
from .tornado import get_tornado_handler
if sys.version_info >= (3, 5): # pragma: no cover
from .asyncio_client import AsyncClient
from .asyncio_server import AsyncServer
from .asyncio_manager import AsyncManager
from .asyncio_namespace import AsyncNamespace, AsyncClientNamespace
from .asyncio_redis_manager import AsyncRedisManager
from .asyncio_aiopika_manager import AsyncAioPikaManager
from .asgi import ASGIApp
else: # pragma: no cover
AsyncClient = None
AsyncServer = None
AsyncManager = None
AsyncNamespace = None
AsyncRedisManager = None
AsyncAioPikaManager = None
__version__ = '4.4.0'
__all__ = ['__version__', 'Client', 'Server', 'BaseManager', 'PubSubManager',
'KombuManager', 'RedisManager', 'ZmqManager', 'KafkaManager',
'Namespace', 'ClientNamespace', 'WSGIApp', 'Middleware']
if AsyncServer is not None: # pragma: no cover
__all__ += ['AsyncClient', 'AsyncServer', 'AsyncNamespace',
'AsyncClientNamespace', 'AsyncManager', 'AsyncRedisManager',
'ASGIApp', 'get_tornado_handler', 'AsyncAioPikaManager']

36
libs/socketio/asgi.py Normal file
View file

@ -0,0 +1,36 @@
import engineio
class ASGIApp(engineio.ASGIApp): # pragma: no cover
"""ASGI application middleware for Socket.IO.
This middleware dispatches traffic to an Socket.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another ASGI application.
:param socketio_server: The Socket.IO server. Must be an instance of the
``socketio.AsyncServer`` class.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param other_asgi_app: A separate ASGI app that receives all other traffic.
:param socketio_path: The endpoint where the Socket.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import socketio
import uvicorn
sio = socketio.AsyncServer()
app = engineio.ASGIApp(sio, static_files={
'/': 'index.html',
'/static': './public',
})
uvicorn.run(app, host='127.0.0.1', port=5000)
"""
def __init__(self, socketio_server, other_asgi_app=None,
static_files=None, socketio_path='socket.io'):
super().__init__(socketio_server, other_asgi_app,
static_files=static_files,
engineio_path=socketio_path)

View file

@ -0,0 +1,105 @@
import asyncio
import pickle
from socketio.asyncio_pubsub_manager import AsyncPubSubManager
try:
import aio_pika
except ImportError:
aio_pika = None
class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover
"""Client manager that uses aio_pika for inter-process messaging under
asyncio.
This class implements a client manager backend for event sharing across
multiple processes, using RabbitMQ
To use a aio_pika backend, initialize the :class:`Server` instance as
follows::
url = 'amqp://user:password@hostname:port//'
server = socketio.Server(client_manager=socketio.AsyncAioPikaManager(
url))
:param url: The connection URL for the backend messaging queue. Example
connection URLs are ``'amqp://guest:guest@localhost:5672//'``
for RabbitMQ.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
With this manager, the channel name is the exchange name
in rabbitmq
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
"""
name = 'asyncaiopika'
def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None):
if aio_pika is None:
raise RuntimeError('aio_pika package is not installed '
'(Run "pip install aio_pika" in your '
'virtualenv).')
self.url = url
self.listener_connection = None
self.listener_channel = None
self.listener_queue = None
super().__init__(channel=channel, write_only=write_only, logger=logger)
async def _connection(self):
return await aio_pika.connect_robust(self.url)
async def _channel(self, connection):
return await connection.channel()
async def _exchange(self, channel):
return await channel.declare_exchange(self.channel,
aio_pika.ExchangeType.FANOUT)
async def _queue(self, channel, exchange):
queue = await channel.declare_queue(durable=False,
arguments={'x-expires': 300000})
await queue.bind(exchange)
return queue
async def _publish(self, data):
connection = await self._connection()
channel = await self._channel(connection)
exchange = await self._exchange(channel)
await exchange.publish(
aio_pika.Message(body=pickle.dumps(data),
delivery_mode=aio_pika.DeliveryMode.PERSISTENT),
routing_key='*'
)
async def _listen(self):
retry_sleep = 1
while True:
try:
if self.listener_connection is None:
self.listener_connection = await self._connection()
self.listener_channel = await self._channel(
self.listener_connection
)
await self.listener_channel.set_qos(prefetch_count=1)
exchange = await self._exchange(self.listener_channel)
self.listener_queue = await self._queue(
self.listener_channel, exchange
)
async with self.listener_queue.iterator() as queue_iter:
async for message in queue_iter:
with message.process():
return pickle.loads(message.body)
except Exception:
self._get_logger().error('Cannot receive from rabbitmq... '
'retrying in '
'{} secs'.format(retry_sleep))
self.listener_connection = None
await asyncio.sleep(retry_sleep)
retry_sleep *= 2
if retry_sleep > 60:
retry_sleep = 60

View file

@ -0,0 +1,475 @@
import asyncio
import logging
import random
import engineio
import six
from . import client
from . import exceptions
from . import packet
default_logger = logging.getLogger('socketio.client')
class AsyncClient(client.Client):
"""A Socket.IO client for asyncio.
This class implements a fully compliant Socket.IO web client with support
for websocket and long-polling transports.
:param reconnection: ``True`` if the client should automatically attempt to
reconnect to the server after an interruption, or
``False`` to not reconnect. The default is ``True``.
:param reconnection_attempts: How many reconnection attempts to issue
before giving up, or 0 for infinity attempts.
The default is 0.
:param reconnection_delay: How long to wait in seconds before the first
reconnection attempt. Each successive attempt
doubles this delay.
:param reconnection_delay_max: The maximum delay between reconnection
attempts.
:param randomization_factor: Randomization amount for each delay between
reconnection attempts. The default is 0.5,
which means that each delay is randomly
adjusted by +/- 50%.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
The Engine.IO configuration supports the following settings:
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def is_asyncio_based(self):
return True
async def connect(self, url, headers={}, transports=None,
namespaces=None, socketio_path='socket.io'):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param namespaces: The list of custom namespaces to connect, in
addition to the default namespace. If not given,
the namespace list is obtained from the registered
event handlers.
:param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for
most cases.
Note: this method is a coroutine.
Example usage::
sio = socketio.Client()
sio.connect('http://localhost:5000')
"""
self.connection_url = url
self.connection_headers = headers
self.connection_transports = transports
self.connection_namespaces = namespaces
self.socketio_path = socketio_path
if namespaces is None:
namespaces = set(self.handlers.keys()).union(
set(self.namespace_handlers.keys()))
elif isinstance(namespaces, six.string_types):
namespaces = [namespaces]
self.connection_namespaces = namespaces
self.namespaces = [n for n in namespaces if n != '/']
try:
await self.eio.connect(url, headers=headers,
transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
self.connected = True
async def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
Note: this method is a coroutine.
"""
while True:
await self.eio.wait()
await self.sleep(1) # give the reconnect task time to start up
if not self._reconnect_task:
break
await self._reconnect_task
if self.eio.state != 'connected':
break
async def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
if namespace != '/' and namespace not in self.namespaces:
raise exceptions.BadNamespaceError(
namespace + ' is not a connected namespace.')
self.logger.info('Emitting event "%s" [%s]', event, namespace)
if callback is not None:
id = self._generate_ack_id(namespace, callback)
else:
id = None
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
elif data is not None:
data = [data]
else:
data = []
await self._send_packet(packet.Packet(
packet.EVENT, namespace=namespace, data=[event] + data, id=id,
binary=binary))
async def send(self, data, namespace=None, callback=None):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
Note: this method is a coroutine.
"""
await self.emit('message', data=data, namespace=namespace,
callback=callback)
async def call(self, event, data=None, namespace=None, timeout=60):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
Note: this method is a coroutine.
"""
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
await self.emit(event, data=data, namespace=namespace,
callback=event_callback)
try:
await asyncio.wait_for(callback_event.wait(), timeout)
except asyncio.TimeoutError:
six.raise_from(exceptions.TimeoutError(), None)
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
async def disconnect(self):
"""Disconnect from the server.
Note: this method is a coroutine.
"""
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
await self._send_packet(packet.Packet(packet.DISCONNECT,
namespace=n))
await self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.connected = False
await self.eio.disconnect(abort=True)
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.eio.start_background_task(target, *args, **kwargs)
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await self.eio.sleep(seconds)
async def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
await self.eio.send(ep, binary=binary)
binary = True
else:
await self.eio.send(encoded_packet, binary=False)
async def _handle_connect(self, namespace):
namespace = namespace or '/'
self.logger.info('Namespace {} is connected'.format(namespace))
await self._trigger_event('connect', namespace=namespace)
if namespace == '/':
for n in self.namespaces:
await self._send_packet(packet.Packet(packet.CONNECT,
namespace=n))
elif namespace not in self.namespaces:
self.namespaces.append(namespace)
async def _handle_disconnect(self, namespace):
if not self.connected:
return
namespace = namespace or '/'
if namespace == '/':
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
self.namespaces = []
await self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.connected = False
async def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received event "%s" [%s]', data[0], namespace)
r = await self._trigger_event(data[0], namespace, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
await self._send_packet(packet.Packet(
packet.ACK, namespace=namespace, id=id, data=data,
binary=binary))
async def _handle_ack(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received ack [%s]', namespace)
callback = None
try:
callback = self.callbacks[namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self.logger.warning('Unknown callback received, ignoring.')
else:
del self.callbacks[namespace][id]
if callback is not None:
if asyncio.iscoroutinefunction(callback):
await callback(*data)
else:
callback(*data)
async def _handle_error(self, namespace, data):
namespace = namespace or '/'
self.logger.info('Connection to namespace {} was rejected'.format(
namespace))
if data is None:
data = tuple()
elif not isinstance(data, (tuple, list)):
data = (data,)
await self._trigger_event('connect_error', namespace, *data)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.connected = False
async def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
if asyncio.iscoroutinefunction(self.handlers[namespace][event]):
try:
ret = await self.handlers[namespace][event](*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = self.handlers[namespace][event](*args)
return ret
# or else, forward the event to a namepsace handler if one exists
elif namespace in self.namespace_handlers:
return await self.namespace_handlers[namespace].trigger_event(
event, *args)
async def _handle_reconnect(self):
self._reconnect_abort.clear()
client.reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
delay = current_delay
current_delay *= 2
if delay > self.reconnection_delay_max:
delay = self.reconnection_delay_max
delay += self.randomization_factor * (2 * random.random() - 1)
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
try:
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
self.logger.info('Reconnect task aborted')
break
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
attempt_count += 1
try:
await self.connect(self.connection_url,
headers=self.connection_headers,
transports=self.connection_transports,
namespaces=self.connection_namespaces,
socketio_path=self.socketio_path)
except (exceptions.ConnectionError, ValueError):
pass
else:
self.logger.info('Reconnection successful')
self._reconnect_task = None
break
if self.reconnection_attempts and \
attempt_count >= self.reconnection_attempts:
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
client.reconnecting_clients.remove(self)
def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
async def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""
if self._binary_packet:
pkt = self._binary_packet
if pkt.add_attachment(data):
self._binary_packet = None
if pkt.packet_type == packet.BINARY_EVENT:
await self._handle_event(pkt.namespace, pkt.id, pkt.data)
else:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(pkt.namespace)
elif pkt.packet_type == packet.EVENT:
await self._handle_event(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet = pkt
elif pkt.packet_type == packet.ERROR:
await self._handle_error(pkt.namespace, pkt.data)
else:
raise ValueError('Unknown packet type.')
async def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
self._reconnect_abort.set()
if self.connected:
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
self.namespaces = []
self.connected = False
self.callbacks = {}
self._binary_packet = None
self.sid = None
if self.eio.state == 'connected' and self.reconnection:
self._reconnect_task = self.start_background_task(
self._handle_reconnect)
def _engineio_client_class(self):
return engineio.AsyncClient

View file

@ -0,0 +1,58 @@
import asyncio
from .base_manager import BaseManager
class AsyncManager(BaseManager):
"""Manage a client list for an asyncio server."""
async def emit(self, event, data, namespace, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace.
Note: this method is a coroutine.
"""
if namespace not in self.rooms or room not in self.rooms[namespace]:
return
tasks = []
if not isinstance(skip_sid, list):
skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room):
if sid not in skip_sid:
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
tasks.append(self.server._emit_internal(sid, event, data,
namespace, id))
if tasks == []: # pragma: no cover
return
await asyncio.wait(tasks)
async def close_room(self, room, namespace):
"""Remove all participants from a room.
Note: this method is a coroutine.
"""
return super().close_room(room, namespace)
async def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback.
Note: this method is a coroutine.
"""
callback = None
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self._get_logger().warning('Unknown callback received, ignoring.')
else:
del self.callbacks[sid][namespace][id]
if callback is not None:
ret = callback(*data)
if asyncio.iscoroutine(ret):
try:
await ret
except asyncio.CancelledError: # pragma: no cover
pass

View file

@ -0,0 +1,204 @@
import asyncio
from socketio import namespace
class AsyncNamespace(namespace.Namespace):
"""Base class for asyncio server-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on. These can be regular functions or
coroutines.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def is_asyncio_based(self):
return True
async def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
Note: this method is a coroutine.
"""
handler_name = 'on_' + event
if hasattr(self, handler_name):
handler = getattr(self, handler_name)
if asyncio.iscoroutinefunction(handler) is True:
try:
ret = await handler(*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = handler(*args)
return ret
async def emit(self, event, data=None, room=None, skip_sid=None,
namespace=None, callback=None):
"""Emit a custom event to one or more connected clients.
The only difference with the :func:`socketio.Server.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.emit(event, data=data, room=room,
skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
async def send(self, data, room=None, skip_sid=None, namespace=None,
callback=None):
"""Send a message to one or more connected clients.
The only difference with the :func:`socketio.Server.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.send(data, room=room, skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
async def close_room(self, room, namespace=None):
"""Close a room.
The only difference with the :func:`socketio.Server.close_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.close_room(
room, namespace=namespace or self.namespace)
async def get_session(self, sid, namespace=None):
"""Return the user session for a client.
The only difference with the :func:`socketio.Server.get_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.get_session(
sid, namespace=namespace or self.namespace)
async def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
The only difference with the :func:`socketio.Server.save_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.save_session(
sid, session, namespace=namespace or self.namespace)
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
The only difference with the :func:`socketio.Server.session` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.session(sid, namespace=namespace or self.namespace)
async def disconnect(self, sid, namespace=None):
"""Disconnect a client.
The only difference with the :func:`socketio.Server.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.disconnect(
sid, namespace=namespace or self.namespace)
class AsyncClientNamespace(namespace.ClientNamespace):
"""Base class for asyncio client-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on. These can be regular functions or
coroutines.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def is_asyncio_based(self):
return True
async def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
Note: this method is a coroutine.
"""
handler_name = 'on_' + event
if hasattr(self, handler_name):
handler = getattr(self, handler_name)
if asyncio.iscoroutinefunction(handler) is True:
try:
ret = await handler(*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = handler(*args)
return ret
async def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to the server.
The only difference with the :func:`socketio.Client.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.client.emit(event, data=data,
namespace=namespace or self.namespace,
callback=callback)
async def send(self, data, namespace=None, callback=None):
"""Send a message to the server.
The only difference with the :func:`socketio.Client.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.client.send(data,
namespace=namespace or self.namespace,
callback=callback)
async def disconnect(self):
"""Disconnect a client.
The only difference with the :func:`socketio.Client.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.client.disconnect()

View file

@ -0,0 +1,163 @@
from functools import partial
import uuid
import json
import pickle
import six
from .asyncio_manager import AsyncManager
class AsyncPubSubManager(AsyncManager):
"""Manage a client list attached to a pub/sub backend under asyncio.
This is a base class that enables multiple servers to share the list of
clients, with the servers communicating events through a pub/sub backend.
The use of a pub/sub backend also allows any client connected to the
backend to emit events addressed to Socket.IO clients.
The actual backends must be implemented by subclasses, this class only
provides a pub/sub generic framework for asyncio applications.
:param channel: The channel name on which the server sends and receives
notifications.
"""
name = 'asyncpubsub'
def __init__(self, channel='socketio', write_only=False, logger=None):
super().__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
def initialize(self):
super().initialize()
if not self.write_only:
self.thread = self.server.start_background_task(self._thread)
self._get_logger().info(self.name + ' backend initialized.')
async def emit(self, event, data, namespace=None, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace.
This method takes care or propagating the message to all the servers
that are connected through the message queue.
The parameters are the same as in :meth:`.Server.emit`.
Note: this method is a coroutine.
"""
if kwargs.get('ignore_queue'):
return await super().emit(
event, data, namespace=namespace, room=room, skip_sid=skip_sid,
callback=callback)
namespace = namespace or '/'
if callback is not None:
if self.server is None:
raise RuntimeError('Callbacks can only be issued from the '
'context of a server.')
if room is None:
raise ValueError('Cannot use callback without a room set.')
id = self._generate_ack_id(room, namespace, callback)
callback = (room, namespace, id)
else:
callback = None
await self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace, 'room': room,
'skip_sid': skip_sid, 'callback': callback,
'host_id': self.host_id})
async def close_room(self, room, namespace=None):
await self._publish({'method': 'close_room', 'room': room,
'namespace': namespace or '/'})
async def _publish(self, data):
"""Publish a message on the Socket.IO channel.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
async def _listen(self):
"""Return the next message published on the Socket.IO channel,
blocking until a message is available.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
async def _handle_emit(self, message):
# Events with callbacks are very tricky to handle across hosts
# Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender
remote_callback = message.get('callback')
remote_host_id = message.get('host_id')
if remote_callback is not None and len(remote_callback) == 3:
callback = partial(self._return_callback, remote_host_id,
*remote_callback)
else:
callback = None
await super().emit(message['event'], message['data'],
namespace=message.get('namespace'),
room=message.get('room'),
skip_sid=message.get('skip_sid'),
callback=callback)
async def _handle_callback(self, message):
if self.host_id == message.get('host_id'):
try:
sid = message['sid']
namespace = message['namespace']
id = message['id']
args = message['args']
except KeyError:
return
await self.trigger_callback(sid, namespace, id, args)
async def _return_callback(self, host_id, sid, namespace, callback_id,
*args):
# When an event callback is received, the callback is returned back
# the sender, which is identified by the host_id
await self._publish({'method': 'callback', 'host_id': host_id,
'sid': sid, 'namespace': namespace,
'id': callback_id, 'args': args})
async def _handle_close_room(self, message):
await super().close_room(
room=message.get('room'), namespace=message.get('namespace'))
async def _thread(self):
while True:
try:
message = await self._listen()
except:
import traceback
traceback.print_exc()
break
data = None
if isinstance(message, dict):
data = message
else:
if isinstance(message, six.binary_type): # pragma: no cover
try:
data = pickle.loads(message)
except:
pass
if data is None:
try:
data = json.loads(message)
except:
pass
if data and 'method' in data:
if data['method'] == 'emit':
await self._handle_emit(data)
elif data['method'] == 'callback':
await self._handle_callback(data)
elif data['method'] == 'close_room':
await self._handle_close_room(data)

View file

@ -0,0 +1,107 @@
import asyncio
import pickle
from urllib.parse import urlparse
try:
import aioredis
except ImportError:
aioredis = None
from .asyncio_pubsub_manager import AsyncPubSubManager
def _parse_redis_url(url):
p = urlparse(url)
if p.scheme not in {'redis', 'rediss'}:
raise ValueError('Invalid redis url')
ssl = p.scheme == 'rediss'
host = p.hostname or 'localhost'
port = p.port or 6379
password = p.password
if p.path:
db = int(p.path[1:])
else:
db = 0
return host, port, password, db, ssl
class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover
"""Redis based client manager for asyncio servers.
This class implements a Redis backend for event sharing across multiple
processes. Only kept here as one more example of how to build a custom
backend, since the kombu backend is perfectly adequate to support a Redis
message queue.
To use a Redis backend, initialize the :class:`Server` instance as
follows::
server = socketio.Server(client_manager=socketio.AsyncRedisManager(
'redis://hostname:port/0'))
:param url: The connection URL for the Redis server. For a default Redis
store running on the same host, use ``redis://``. To use an
SSL connection, use ``rediss://``.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
"""
name = 'aioredis'
def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None):
if aioredis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install aioredis" in your '
'virtualenv).')
(
self.host, self.port, self.password, self.db, self.ssl
) = _parse_redis_url(url)
self.pub = None
self.sub = None
super().__init__(channel=channel, write_only=write_only, logger=logger)
async def _publish(self, data):
retry = True
while True:
try:
if self.pub is None:
self.pub = await aioredis.create_redis(
(self.host, self.port), db=self.db,
password=self.password, ssl=self.ssl
)
return await self.pub.publish(self.channel,
pickle.dumps(data))
except (aioredis.RedisError, OSError):
if retry:
self._get_logger().error('Cannot publish to redis... '
'retrying')
self.pub = None
retry = False
else:
self._get_logger().error('Cannot publish to redis... '
'giving up')
break
async def _listen(self):
retry_sleep = 1
while True:
try:
if self.sub is None:
self.sub = await aioredis.create_redis(
(self.host, self.port), db=self.db,
password=self.password, ssl=self.ssl
)
self.ch = (await self.sub.subscribe(self.channel))[0]
return await self.ch.get()
except (aioredis.RedisError, OSError):
self._get_logger().error('Cannot receive from redis... '
'retrying in '
'{} secs'.format(retry_sleep))
self.sub = None
await asyncio.sleep(retry_sleep)
retry_sleep *= 2
if retry_sleep > 60:
retry_sleep = 60

View file

@ -0,0 +1,526 @@
import asyncio
import engineio
import six
from . import asyncio_manager
from . import exceptions
from . import packet
from . import server
class AsyncServer(server.Server):
"""A Socket.IO server for asyncio.
This class implements a fully compliant Socket.IO web server with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param client_manager: The client manager instance that will manage the
client list. When this is omitted, the client list
is stored in an in-memory structure, so the use of
multiple connected servers is not possible.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, event handlers are executed in
separate threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings:
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "aiohttp". If
this argument is not given, an async mode is chosen
based on the installed packages.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting.
:param ping_interval: The interval in seconds at which the client pings
the server.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport.
:param allow_upgrades: Whether to allow transport upgrades or not.
:param http_compression: Whether to compress packages when using the
polling transport.
:param compression_threshold: Only compress messages when their byte size
is greater than this value.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``.
"""
def __init__(self, client_manager=None, logger=False, json=None,
async_handlers=True, **kwargs):
if client_manager is None:
client_manager = asyncio_manager.AsyncManager()
super().__init__(client_manager=client_manager, logger=logger,
binary=False, json=json,
async_handlers=async_handlers, **kwargs)
def is_asyncio_based(self):
return True
def attach(self, app, socketio_path='socket.io'):
"""Attach the Socket.IO server to an application."""
self.eio.attach(app, socketio_path)
async def emit(self, event, data=None, to=None, room=None, skip_sid=None,
namespace=None, callback=None, **kwargs):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
room = to or room
self.logger.info('emitting event "%s" to %s [%s]', event,
room or 'all', namespace)
await self.manager.emit(event, data, namespace, room=room,
skip_sid=skip_sid, callback=callback,
**kwargs)
async def send(self, data, to=None, room=None, skip_sid=None,
namespace=None, callback=None, **kwargs):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
Note: this method is a coroutine.
"""
await self.emit('message', data=data, to=to, room=room,
skip_sid=skip_sid, namespace=namespace,
callback=callback, **kwargs)
async def call(self, event, data=None, to=None, sid=None, namespace=None,
timeout=60, **kwargs):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The session ID of the recipient client.
:param sid: Alias for the ``to`` parameter.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
client directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
if not self.async_handlers:
raise RuntimeError(
'Cannot use call() when async_handlers is False.')
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
await self.emit(event, data=data, room=to or sid, namespace=namespace,
callback=event_callback, **kwargs)
try:
await asyncio.wait_for(callback_event.wait(), timeout)
except asyncio.TimeoutError:
six.raise_from(exceptions.TimeoutError(), None)
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
async def close_room(self, room, namespace=None):
"""Close a room.
This function removes all the clients from the given room.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
self.logger.info('room %s is closing [%s]', room, namespace)
await self.manager.close_room(room, namespace)
async def get_session(self, sid, namespace=None):
"""Return the user session for a client.
:param sid: The session id of the client.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved. If you want to modify
the user session, use the ``session`` context manager instead.
"""
namespace = namespace or '/'
eio_session = await self.eio.get_session(sid)
return eio_session.setdefault(namespace, {})
async def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
"""
namespace = namespace or '/'
eio_session = await self.eio.get_session(sid)
eio_session[namespace] = session
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
async with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid, namespace):
self.server = server
self.sid = sid
self.namespace = namespace
self.session = None
async def __aenter__(self):
self.session = await self.server.get_session(
sid, namespace=self.namespace)
return self.session
async def __aexit__(self, *args):
await self.server.save_session(sid, self.session,
namespace=self.namespace)
return _session_context_manager(self, sid, namespace)
async def disconnect(self, sid, namespace=None):
"""Disconnect a client.
:param sid: Session ID of the client.
:param namespace: The Socket.IO namespace to disconnect. If this
argument is omitted the default namespace is used.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
if self.manager.is_connected(sid, namespace=namespace):
self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
await self.eio.disconnect(sid)
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
This is the entry point of the Socket.IO application. This function
returns the HTTP response body to deliver to the client.
Note: this method is a coroutine.
"""
return await self.eio.handle_request(*args, **kwargs)
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute. Must be a coroutine.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
Note: this method is a coroutine.
"""
return self.eio.start_background_task(target, *args, **kwargs)
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await self.eio.sleep(seconds)
async def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client."""
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
else:
data = [data]
await self._send_packet(sid, packet.Packet(
packet.EVENT, namespace=namespace, data=[event] + data, id=id,
binary=None))
async def _send_packet(self, sid, pkt):
"""Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
await self.eio.send(sid, ep, binary=binary)
binary = True
else:
await self.eio.send(sid, encoded_packet, binary=False)
async def _handle_connect(self, sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
if self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
fail_reason = None
try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace)
if not self.always_connect:
await self._send_packet(sid, packet.Packet(
packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
elif not self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
async def _handle_disconnect(self, sid, namespace):
"""Handle a client disconnect."""
namespace = namespace or '/'
if namespace == '/':
namespace_list = list(self.manager.get_namespaces())
else:
namespace_list = [namespace]
for n in namespace_list:
if n != '/' and self.manager.is_connected(sid, n):
await self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
await self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
async def _handle_event(self, sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
if not self.manager.is_connected(sid, namespace):
self.logger.warning('%s is not connected to namespace %s',
sid, namespace)
return
if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id)
else:
await self._handle_event_internal(self, sid, data, namespace, id)
async def _handle_event_internal(self, server, sid, data, namespace, id):
r = await server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
await server._send_packet(sid, packet.Packet(packet.ACK,
namespace=namespace,
id=id, data=data,
binary=None))
async def _handle_ack(self, sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace)
await self.manager.trigger_callback(sid, namespace, id, data)
async def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
if asyncio.iscoroutinefunction(self.handlers[namespace][event]) \
is True:
try:
ret = await self.handlers[namespace][event](*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = self.handlers[namespace][event](*args)
return ret
# or else, forward the event to a namepsace handler if one exists
elif namespace in self.namespace_handlers:
return await self.namespace_handlers[namespace].trigger_event(
event, *args)
async def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ
return await self._handle_connect(sid, '/')
async def _handle_eio_message(self, sid, data):
"""Dispatch Engine.IO messages."""
if sid in self._binary_packet:
pkt = self._binary_packet[sid]
if pkt.add_attachment(data):
del self._binary_packet[sid]
if pkt.packet_type == packet.BINARY_EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id,
pkt.data)
else:
await self._handle_ack(sid, pkt.namespace, pkt.id,
pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt
elif pkt.packet_type == packet.ERROR:
raise ValueError('Unexpected ERROR packet.')
else:
raise ValueError('Unknown packet type.')
async def _handle_eio_disconnect(self, sid):
"""Handle Engine.IO disconnect event."""
await self._handle_disconnect(sid, '/')
if sid in self.environ:
del self.environ[sid]
def _engineio_server_class(self):
return engineio.AsyncServer

View file

@ -0,0 +1,178 @@
import itertools
import logging
import six
default_logger = logging.getLogger('socketio')
class BaseManager(object):
"""Manage client connections.
This class keeps track of all the clients and the rooms they are in, to
support the broadcasting of messages. The data used by this class is
stored in a memory structure, making it appropriate only for single process
services. More sophisticated storage backends can be implemented by
subclasses.
"""
def __init__(self):
self.logger = None
self.server = None
self.rooms = {}
self.callbacks = {}
self.pending_disconnect = {}
def set_server(self, server):
self.server = server
def initialize(self):
"""Invoked before the first request is received. Subclasses can add
their initialization code here.
"""
pass
def get_namespaces(self):
"""Return an iterable with the active namespace names."""
return six.iterkeys(self.rooms)
def get_participants(self, namespace, room):
"""Return an iterable with the active participants in a room."""
for sid, active in six.iteritems(self.rooms[namespace][room].copy()):
yield sid
def connect(self, sid, namespace):
"""Register a client connection to a namespace."""
self.enter_room(sid, namespace, None)
self.enter_room(sid, namespace, sid)
def is_connected(self, sid, namespace):
if namespace in self.pending_disconnect and \
sid in self.pending_disconnect[namespace]:
# the client is in the process of being disconnected
return False
try:
return self.rooms[namespace][None][sid]
except KeyError:
pass
def pre_disconnect(self, sid, namespace):
"""Put the client in the to-be-disconnected list.
This allows the client data structures to be present while the
disconnect handler is invoked, but still recognize the fact that the
client is soon going away.
"""
if namespace not in self.pending_disconnect:
self.pending_disconnect[namespace] = []
self.pending_disconnect[namespace].append(sid)
def disconnect(self, sid, namespace):
"""Register a client disconnect from a namespace."""
if namespace not in self.rooms:
return
rooms = []
for room_name, room in six.iteritems(self.rooms[namespace].copy()):
if sid in room:
rooms.append(room_name)
for room in rooms:
self.leave_room(sid, namespace, room)
if sid in self.callbacks and namespace in self.callbacks[sid]:
del self.callbacks[sid][namespace]
if len(self.callbacks[sid]) == 0:
del self.callbacks[sid]
if namespace in self.pending_disconnect and \
sid in self.pending_disconnect[namespace]:
self.pending_disconnect[namespace].remove(sid)
if len(self.pending_disconnect[namespace]) == 0:
del self.pending_disconnect[namespace]
def enter_room(self, sid, namespace, room):
"""Add a client to a room."""
if namespace not in self.rooms:
self.rooms[namespace] = {}
if room not in self.rooms[namespace]:
self.rooms[namespace][room] = {}
self.rooms[namespace][room][sid] = True
def leave_room(self, sid, namespace, room):
"""Remove a client from a room."""
try:
del self.rooms[namespace][room][sid]
if len(self.rooms[namespace][room]) == 0:
del self.rooms[namespace][room]
if len(self.rooms[namespace]) == 0:
del self.rooms[namespace]
except KeyError:
pass
def close_room(self, room, namespace):
"""Remove all participants from a room."""
try:
for sid in self.get_participants(namespace, room):
self.leave_room(sid, namespace, room)
except KeyError:
pass
def get_rooms(self, sid, namespace):
"""Return the rooms a client is in."""
r = []
try:
for room_name, room in six.iteritems(self.rooms[namespace]):
if room_name is not None and sid in room and room[sid]:
r.append(room_name)
except KeyError:
pass
return r
def emit(self, event, data, namespace, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace."""
if namespace not in self.rooms or room not in self.rooms[namespace]:
return
if not isinstance(skip_sid, list):
skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room):
if sid not in skip_sid:
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
self.server._emit_internal(sid, event, data, namespace, id)
def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
callback = None
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self._get_logger().warning('Unknown callback received, ignoring.')
else:
del self.callbacks[sid][namespace][id]
if callback is not None:
callback(*data)
def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id
def _get_logger(self):
"""Get the appropriate logger
Prevents uninitialized servers in write-only mode from failing.
"""
if self.logger:
return self.logger
elif self.server:
return self.server.logger
else:
return default_logger

620
libs/socketio/client.py Normal file
View file

@ -0,0 +1,620 @@
import itertools
import logging
import random
import signal
import engineio
import six
from . import exceptions
from . import namespace
from . import packet
default_logger = logging.getLogger('socketio.client')
reconnecting_clients = []
def signal_handler(sig, frame): # pragma: no cover
"""SIGINT handler.
Notify any clients that are in a reconnect loop to abort. Other
disconnection tasks are handled at the engine.io level.
"""
for client in reconnecting_clients[:]:
client._reconnect_abort.set()
return original_signal_handler(sig, frame)
original_signal_handler = signal.signal(signal.SIGINT, signal_handler)
class Client(object):
"""A Socket.IO client.
This class implements a fully compliant Socket.IO web client with support
for websocket and long-polling transports.
:param reconnection: ``True`` if the client should automatically attempt to
reconnect to the server after an interruption, or
``False`` to not reconnect. The default is ``True``.
:param reconnection_attempts: How many reconnection attempts to issue
before giving up, or 0 for infinity attempts.
The default is 0.
:param reconnection_delay: How long to wait in seconds before the first
reconnection attempt. Each successive attempt
doubles this delay.
:param reconnection_delay_max: The maximum delay between reconnection
attempts.
:param randomization_factor: Randomization amount for each delay between
reconnection attempts. The default is 0.5,
which means that each delay is randomly
adjusted by +/- 50%.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
The Engine.IO configuration supports the following settings:
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def __init__(self, reconnection=True, reconnection_attempts=0,
reconnection_delay=1, reconnection_delay_max=5,
randomization_factor=0.5, logger=False, binary=False,
json=None, **kwargs):
self.reconnection = reconnection
self.reconnection_attempts = reconnection_attempts
self.reconnection_delay = reconnection_delay
self.reconnection_delay_max = reconnection_delay_max
self.randomization_factor = randomization_factor
self.binary = binary
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if json is not None:
packet.Packet.json = json
engineio_options['json'] = json
self.eio = self._engineio_client_class()(**engineio_options)
self.eio.on('connect', self._handle_eio_connect)
self.eio.on('message', self._handle_eio_message)
self.eio.on('disconnect', self._handle_eio_disconnect)
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
self.connection_url = None
self.connection_headers = None
self.connection_transports = None
self.connection_namespaces = None
self.socketio_path = None
self.sid = None
self.connected = False
self.namespaces = []
self.handlers = {}
self.namespace_handlers = {}
self.callbacks = {}
self._binary_packet = None
self._reconnect_task = None
self._reconnect_abort = self.eio.create_event()
def is_asyncio_based(self):
return False
def on(self, event, handler=None, namespace=None):
"""Register an event handler.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the handler is associated with
the default namespace.
Example usage::
# as a decorator:
@sio.on('connect')
def connect_handler():
print('Connected!')
# as a method:
def message_handler(msg):
print('Received message: ', msg)
sio.send( 'response')
sio.on('message', message_handler)
The ``'connect'`` event handler receives no arguments. The
``'message'`` handler and handlers for custom event names receive the
message payload as only argument. Any values returned from a message
handler will be passed to the client's acknowledgement callback
function if it exists. The ``'disconnect'`` handler does not take
arguments.
"""
namespace = namespace or '/'
def set_handler(handler):
if namespace not in self.handlers:
self.handlers[namespace] = {}
self.handlers[namespace][event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def event(self, *args, **kwargs):
"""Decorator to register an event handler.
This is a simplified version of the ``on()`` method that takes the
event name from the decorated function.
Example usage::
@sio.event
def my_event(data):
print('Received data: ', data)
The above example is equivalent to::
@sio.on('my_event')
def my_event(data):
print('Received data: ', data)
A custom namespace can be given as an argument to the decorator::
@sio.event(namespace='/test')
def my_event(data):
print('Received data: ', data)
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# the decorator was invoked without arguments
# args[0] is the decorated function
return self.on(args[0].__name__)(args[0])
else:
# the decorator was invoked with arguments
def set_handler(handler):
return self.on(handler.__name__, *args, **kwargs)(handler)
return set_handler
def register_namespace(self, namespace_handler):
"""Register a namespace handler object.
:param namespace_handler: An instance of a :class:`Namespace`
subclass that handles all the event traffic
for a namespace.
"""
if not isinstance(namespace_handler, namespace.ClientNamespace):
raise ValueError('Not a namespace instance')
if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
raise ValueError('Not a valid namespace class for this client')
namespace_handler._set_client(self)
self.namespace_handlers[namespace_handler.namespace] = \
namespace_handler
def connect(self, url, headers={}, transports=None,
namespaces=None, socketio_path='socket.io'):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param namespaces: The list of custom namespaces to connect, in
addition to the default namespace. If not given,
the namespace list is obtained from the registered
event handlers.
:param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for
most cases.
Example usage::
sio = socketio.Client()
sio.connect('http://localhost:5000')
"""
self.connection_url = url
self.connection_headers = headers
self.connection_transports = transports
self.connection_namespaces = namespaces
self.socketio_path = socketio_path
if namespaces is None:
namespaces = set(self.handlers.keys()).union(
set(self.namespace_handlers.keys()))
elif isinstance(namespaces, six.string_types):
namespaces = [namespaces]
self.connection_namespaces = namespaces
self.namespaces = [n for n in namespaces if n != '/']
try:
self.eio.connect(url, headers=headers, transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
self.connected = True
def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
"""
while True:
self.eio.wait()
self.sleep(1) # give the reconnect task time to start up
if not self._reconnect_task:
break
self._reconnect_task.join()
if self.eio.state != 'connected':
break
def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
"""
namespace = namespace or '/'
if namespace != '/' and namespace not in self.namespaces:
raise exceptions.BadNamespaceError(
namespace + ' is not a connected namespace.')
self.logger.info('Emitting event "%s" [%s]', event, namespace)
if callback is not None:
id = self._generate_ack_id(namespace, callback)
else:
id = None
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
elif data is not None:
data = [data]
else:
data = []
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id,
binary=binary))
def send(self, data, namespace=None, callback=None):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
"""
self.emit('message', data=data, namespace=namespace,
callback=callback)
def call(self, event, data=None, namespace=None, timeout=60):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
"""
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
self.emit(event, data=data, namespace=namespace,
callback=event_callback)
if not callback_event.wait(timeout=timeout):
raise exceptions.TimeoutError()
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
def disconnect(self):
"""Disconnect from the server."""
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.connected = False
self.eio.disconnect(abort=True)
def transport(self):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
"""
return self.eio.transport()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.eio.start_background_task(target, *args, **kwargs)
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self.eio.sleep(seconds)
def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
self.eio.send(ep, binary=binary)
binary = True
else:
self.eio.send(encoded_packet, binary=False)
def _generate_ack_id(self, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if namespace not in self.callbacks:
self.callbacks[namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[namespace][0])
self.callbacks[namespace][id] = callback
return id
def _handle_connect(self, namespace):
namespace = namespace or '/'
self.logger.info('Namespace {} is connected'.format(namespace))
self._trigger_event('connect', namespace=namespace)
if namespace == '/':
for n in self.namespaces:
self._send_packet(packet.Packet(packet.CONNECT, namespace=n))
elif namespace not in self.namespaces:
self.namespaces.append(namespace)
def _handle_disconnect(self, namespace):
if not self.connected:
return
namespace = namespace or '/'
if namespace == '/':
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self.namespaces = []
self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.connected = False
def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received event "%s" [%s]', data[0], namespace)
r = self._trigger_event(data[0], namespace, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
self._send_packet(packet.Packet(packet.ACK, namespace=namespace,
id=id, data=data, binary=binary))
def _handle_ack(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received ack [%s]', namespace)
callback = None
try:
callback = self.callbacks[namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self.logger.warning('Unknown callback received, ignoring.')
else:
del self.callbacks[namespace][id]
if callback is not None:
callback(*data)
def _handle_error(self, namespace, data):
namespace = namespace or '/'
self.logger.info('Connection to namespace {} was rejected'.format(
namespace))
if data is None:
data = tuple()
elif not isinstance(data, (tuple, list)):
data = (data,)
self._trigger_event('connect_error', namespace, *data)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.connected = False
def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args)
# or else, forward the event to a namespace handler if one exists
elif namespace in self.namespace_handlers:
return self.namespace_handlers[namespace].trigger_event(
event, *args)
def _handle_reconnect(self):
self._reconnect_abort.clear()
reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
delay = current_delay
current_delay *= 2
if delay > self.reconnection_delay_max:
delay = self.reconnection_delay_max
delay += self.randomization_factor * (2 * random.random() - 1)
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
if self._reconnect_abort.wait(delay):
self.logger.info('Reconnect task aborted')
break
attempt_count += 1
try:
self.connect(self.connection_url,
headers=self.connection_headers,
transports=self.connection_transports,
namespaces=self.connection_namespaces,
socketio_path=self.socketio_path)
except (exceptions.ConnectionError, ValueError):
pass
else:
self.logger.info('Reconnection successful')
self._reconnect_task = None
break
if self.reconnection_attempts and \
attempt_count >= self.reconnection_attempts:
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
reconnecting_clients.remove(self)
def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""
if self._binary_packet:
pkt = self._binary_packet
if pkt.add_attachment(data):
self._binary_packet = None
if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(pkt.namespace, pkt.id, pkt.data)
else:
self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(pkt.namespace)
elif pkt.packet_type == packet.EVENT:
self._handle_event(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
self._handle_ack(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet = pkt
elif pkt.packet_type == packet.ERROR:
self._handle_error(pkt.namespace, pkt.data)
else:
raise ValueError('Unknown packet type.')
def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
if self.connected:
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self._trigger_event('disconnect', namespace='/')
self.namespaces = []
self.connected = False
self.callbacks = {}
self._binary_packet = None
self.sid = None
if self.eio.state == 'connected' and self.reconnection:
self._reconnect_task = self.start_background_task(
self._handle_reconnect)
def _engineio_client_class(self):
return engineio.Client

View file

@ -0,0 +1,30 @@
class SocketIOError(Exception):
pass
class ConnectionError(SocketIOError):
pass
class ConnectionRefusedError(ConnectionError):
"""Connection refused exception.
This exception can be raised from a connect handler when the connection
is not accepted. The positional arguments provided with the exception are
returned with the error packet to the client.
"""
def __init__(self, *args):
if len(args) == 0:
self.error_args = None
elif len(args) == 1 and not isinstance(args[0], list):
self.error_args = args[0]
else:
self.error_args = args
class TimeoutError(SocketIOError):
pass
class BadNamespaceError(SocketIOError):
pass

View file

@ -0,0 +1,63 @@
import logging
import pickle
try:
import kafka
except ImportError:
kafka = None
from .pubsub_manager import PubSubManager
logger = logging.getLogger('socketio')
class KafkaManager(PubSubManager): # pragma: no cover
"""Kafka based client manager.
This class implements a Kafka backend for event sharing across multiple
processes.
To use a Kafka backend, initialize the :class:`Server` instance as
follows::
url = 'kafka://hostname:port'
server = socketio.Server(client_manager=socketio.KafkaManager(url))
:param url: The connection URL for the Kafka server. For a default Kafka
store running on the same host, use ``kafka://``.
:param channel: The channel name (topic) on which the server sends and
receives notifications. Must be the same in all the
servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
"""
name = 'kafka'
def __init__(self, url='kafka://localhost:9092', channel='socketio',
write_only=False):
if kafka is None:
raise RuntimeError('kafka-python package is not installed '
'(Run "pip install kafka-python" in your '
'virtualenv).')
super(KafkaManager, self).__init__(channel=channel,
write_only=write_only)
self.kafka_url = url[8:] if url != 'kafka://' else 'localhost:9092'
self.producer = kafka.KafkaProducer(bootstrap_servers=self.kafka_url)
self.consumer = kafka.KafkaConsumer(self.channel,
bootstrap_servers=self.kafka_url)
def _publish(self, data):
self.producer.send(self.channel, value=pickle.dumps(data))
self.producer.flush()
def _kafka_listen(self):
for message in self.consumer:
yield message
def _listen(self):
for message in self._kafka_listen():
if message.topic == self.channel:
yield pickle.loads(message.value)

View file

@ -0,0 +1,122 @@
import pickle
import uuid
try:
import kombu
except ImportError:
kombu = None
from .pubsub_manager import PubSubManager
class KombuManager(PubSubManager): # pragma: no cover
"""Client manager that uses kombu for inter-process messaging.
This class implements a client manager backend for event sharing across
multiple processes, using RabbitMQ, Redis or any other messaging mechanism
supported by `kombu <http://kombu.readthedocs.org/en/latest/>`_.
To use a kombu backend, initialize the :class:`Server` instance as
follows::
url = 'amqp://user:password@hostname:port//'
server = socketio.Server(client_manager=socketio.KombuManager(url))
:param url: The connection URL for the backend messaging queue. Example
connection URLs are ``'amqp://guest:guest@localhost:5672//'``
and ``'redis://localhost:6379/'`` for RabbitMQ and Redis
respectively. Consult the `kombu documentation
<http://kombu.readthedocs.org/en/latest/userguide\
/connections.html#urls>`_ for more on how to construct
connection URLs.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param connection_options: additional keyword arguments to be passed to
``kombu.Connection()``.
:param exchange_options: additional keyword arguments to be passed to
``kombu.Exchange()``.
:param queue_options: additional keyword arguments to be passed to
``kombu.Queue()``.
:param producer_options: additional keyword arguments to be passed to
``kombu.Producer()``.
"""
name = 'kombu'
def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None,
connection_options=None, exchange_options=None,
queue_options=None, producer_options=None):
if kombu is None:
raise RuntimeError('Kombu package is not installed '
'(Run "pip install kombu" in your '
'virtualenv).')
super(KombuManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
self.url = url
self.connection_options = connection_options or {}
self.exchange_options = exchange_options or {}
self.queue_options = queue_options or {}
self.producer_options = producer_options or {}
self.producer = self._producer()
def initialize(self):
super(KombuManager, self).initialize()
monkey_patched = True
if self.server.async_mode == 'eventlet':
from eventlet.patcher import is_monkey_patched
monkey_patched = is_monkey_patched('socket')
elif 'gevent' in self.server.async_mode:
from gevent.monkey import is_module_patched
monkey_patched = is_module_patched('socket')
if not monkey_patched:
raise RuntimeError(
'Kombu requires a monkey patched socket library to work '
'with ' + self.server.async_mode)
def _connection(self):
return kombu.Connection(self.url, **self.connection_options)
def _exchange(self):
options = {'type': 'fanout', 'durable': False}
options.update(self.exchange_options)
return kombu.Exchange(self.channel, **options)
def _queue(self):
queue_name = 'flask-socketio.' + str(uuid.uuid4())
options = {'durable': False, 'queue_arguments': {'x-expires': 300000}}
options.update(self.queue_options)
return kombu.Queue(queue_name, self._exchange(), **options)
def _producer(self):
return self._connection().Producer(exchange=self._exchange(),
**self.producer_options)
def __error_callback(self, exception, interval):
self._get_logger().exception('Sleeping {}s'.format(interval))
def _publish(self, data):
connection = self._connection()
publish = connection.ensure(self.producer, self.producer.publish,
errback=self.__error_callback)
publish(pickle.dumps(data))
def _listen(self):
reader_queue = self._queue()
while True:
connection = self._connection().ensure_connection(
errback=self.__error_callback)
try:
with connection.SimpleQueue(reader_queue) as queue:
while True:
message = queue.get(block=True)
message.ack()
yield message.payload
except connection.connection_errors:
self._get_logger().exception("Connection error "
"while reading from queue")

View file

@ -0,0 +1,42 @@
import engineio
class WSGIApp(engineio.WSGIApp):
"""WSGI middleware for Socket.IO.
This middleware dispatches traffic to a Socket.IO application. It can also
serve a list of static files to the client, or forward unrelated HTTP
traffic to another WSGI application.
:param socketio_app: The Socket.IO server. Must be an instance of the
``socketio.Server`` class.
:param wsgi_app: The WSGI app that receives all other traffic.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param socketio_path: The endpoint where the Socket.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import socketio
import eventlet
from . import wsgi_app
sio = socketio.Server()
app = socketio.WSGIApp(sio, wsgi_app)
eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
"""
def __init__(self, socketio_app, wsgi_app=None, static_files=None,
socketio_path='socket.io'):
super(WSGIApp, self).__init__(socketio_app, wsgi_app,
static_files=static_files,
engineio_path=socketio_path)
class Middleware(WSGIApp):
"""This class has been renamed to WSGIApp and is now deprecated."""
def __init__(self, socketio_app, wsgi_app=None,
socketio_path='socket.io'):
super(Middleware, self).__init__(socketio_app, wsgi_app,
socketio_path=socketio_path)

191
libs/socketio/namespace.py Normal file
View file

@ -0,0 +1,191 @@
class BaseNamespace(object):
def __init__(self, namespace=None):
self.namespace = namespace or '/'
def is_asyncio_based(self):
return False
def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
"""
handler_name = 'on_' + event
if hasattr(self, handler_name):
return getattr(self, handler_name)(*args)
class Namespace(BaseNamespace):
"""Base class for server-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def __init__(self, namespace=None):
super(Namespace, self).__init__(namespace=namespace)
self.server = None
def _set_server(self, server):
self.server = server
def emit(self, event, data=None, room=None, skip_sid=None, namespace=None,
callback=None):
"""Emit a custom event to one or more connected clients.
The only difference with the :func:`socketio.Server.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.emit(event, data=data, room=room, skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
def send(self, data, room=None, skip_sid=None, namespace=None,
callback=None):
"""Send a message to one or more connected clients.
The only difference with the :func:`socketio.Server.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.send(data, room=room, skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
def enter_room(self, sid, room, namespace=None):
"""Enter a room.
The only difference with the :func:`socketio.Server.enter_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.enter_room(sid, room,
namespace=namespace or self.namespace)
def leave_room(self, sid, room, namespace=None):
"""Leave a room.
The only difference with the :func:`socketio.Server.leave_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.leave_room(sid, room,
namespace=namespace or self.namespace)
def close_room(self, room, namespace=None):
"""Close a room.
The only difference with the :func:`socketio.Server.close_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.close_room(room,
namespace=namespace or self.namespace)
def rooms(self, sid, namespace=None):
"""Return the rooms a client is in.
The only difference with the :func:`socketio.Server.rooms` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.rooms(sid, namespace=namespace or self.namespace)
def get_session(self, sid, namespace=None):
"""Return the user session for a client.
The only difference with the :func:`socketio.Server.get_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
"""
return self.server.get_session(
sid, namespace=namespace or self.namespace)
def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
The only difference with the :func:`socketio.Server.save_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
"""
return self.server.save_session(
sid, session, namespace=namespace or self.namespace)
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
The only difference with the :func:`socketio.Server.session` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.session(sid, namespace=namespace or self.namespace)
def disconnect(self, sid, namespace=None):
"""Disconnect a client.
The only difference with the :func:`socketio.Server.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.disconnect(sid,
namespace=namespace or self.namespace)
class ClientNamespace(BaseNamespace):
"""Base class for client-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def __init__(self, namespace=None):
super(ClientNamespace, self).__init__(namespace=namespace)
self.client = None
def _set_client(self, client):
self.client = client
def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to the server.
The only difference with the :func:`socketio.Client.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.client.emit(event, data=data,
namespace=namespace or self.namespace,
callback=callback)
def send(self, data, room=None, skip_sid=None, namespace=None,
callback=None):
"""Send a message to the server.
The only difference with the :func:`socketio.Client.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.client.send(data, namespace=namespace or self.namespace,
callback=callback)
def disconnect(self):
"""Disconnect from the server.
The only difference with the :func:`socketio.Client.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.client.disconnect()

179
libs/socketio/packet.py Normal file
View file

@ -0,0 +1,179 @@
import functools
import json as _json
import six
(CONNECT, DISCONNECT, EVENT, ACK, ERROR, BINARY_EVENT, BINARY_ACK) = \
(0, 1, 2, 3, 4, 5, 6)
packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'ERROR',
'BINARY_EVENT', 'BINARY_ACK']
class Packet(object):
"""Socket.IO packet."""
# the format of the Socket.IO packet is as follows:
#
# packet type: 1 byte, values 0-6
# num_attachments: ASCII encoded, only if num_attachments != 0
# '-': only if num_attachments != 0
# namespace: only if namespace != '/'
# ',': only if namespace and one of id and data are defined in this packet
# id: ASCII encoded, only if id is not None
# data: JSON dump of data payload
json = _json
def __init__(self, packet_type=EVENT, data=None, namespace=None, id=None,
binary=None, encoded_packet=None):
self.packet_type = packet_type
self.data = data
self.namespace = namespace
self.id = id
if binary or (binary is None and self._data_is_binary(self.data)):
if self.packet_type == EVENT:
self.packet_type = BINARY_EVENT
elif self.packet_type == ACK:
self.packet_type = BINARY_ACK
else:
raise ValueError('Packet does not support binary payload.')
self.attachment_count = 0
self.attachments = []
if encoded_packet:
self.attachment_count = self.decode(encoded_packet)
def encode(self):
"""Encode the packet for transmission.
If the packet contains binary elements, this function returns a list
of packets where the first is the original packet with placeholders for
the binary components and the remaining ones the binary attachments.
"""
encoded_packet = six.text_type(self.packet_type)
if self.packet_type == BINARY_EVENT or self.packet_type == BINARY_ACK:
data, attachments = self._deconstruct_binary(self.data)
encoded_packet += six.text_type(len(attachments)) + '-'
else:
data = self.data
attachments = None
needs_comma = False
if self.namespace is not None and self.namespace != '/':
encoded_packet += self.namespace
needs_comma = True
if self.id is not None:
if needs_comma:
encoded_packet += ','
needs_comma = False
encoded_packet += six.text_type(self.id)
if data is not None:
if needs_comma:
encoded_packet += ','
encoded_packet += self.json.dumps(data, separators=(',', ':'))
if attachments is not None:
encoded_packet = [encoded_packet] + attachments
return encoded_packet
def decode(self, encoded_packet):
"""Decode a transmitted package.
The return value indicates how many binary attachment packets are
necessary to fully decode the packet.
"""
ep = encoded_packet
try:
self.packet_type = int(ep[0:1])
except TypeError:
self.packet_type = ep
ep = ''
self.namespace = None
self.data = None
ep = ep[1:]
dash = ep.find('-')
attachment_count = 0
if dash > 0 and ep[0:dash].isdigit():
attachment_count = int(ep[0:dash])
ep = ep[dash + 1:]
if ep and ep[0:1] == '/':
sep = ep.find(',')
if sep == -1:
self.namespace = ep
ep = ''
else:
self.namespace = ep[0:sep]
ep = ep[sep + 1:]
q = self.namespace.find('?')
if q != -1:
self.namespace = self.namespace[0:q]
if ep and ep[0].isdigit():
self.id = 0
while ep and ep[0].isdigit():
self.id = self.id * 10 + int(ep[0])
ep = ep[1:]
if ep:
self.data = self.json.loads(ep)
return attachment_count
def add_attachment(self, attachment):
if self.attachment_count <= len(self.attachments):
raise ValueError('Unexpected binary attachment')
self.attachments.append(attachment)
if self.attachment_count == len(self.attachments):
self.reconstruct_binary(self.attachments)
return True
return False
def reconstruct_binary(self, attachments):
"""Reconstruct a decoded packet using the given list of binary
attachments.
"""
self.data = self._reconstruct_binary_internal(self.data,
self.attachments)
def _reconstruct_binary_internal(self, data, attachments):
if isinstance(data, list):
return [self._reconstruct_binary_internal(item, attachments)
for item in data]
elif isinstance(data, dict):
if data.get('_placeholder') and 'num' in data:
return attachments[data['num']]
else:
return {key: self._reconstruct_binary_internal(value,
attachments)
for key, value in six.iteritems(data)}
else:
return data
def _deconstruct_binary(self, data):
"""Extract binary components in the packet."""
attachments = []
data = self._deconstruct_binary_internal(data, attachments)
return data, attachments
def _deconstruct_binary_internal(self, data, attachments):
if isinstance(data, six.binary_type):
attachments.append(data)
return {'_placeholder': True, 'num': len(attachments) - 1}
elif isinstance(data, list):
return [self._deconstruct_binary_internal(item, attachments)
for item in data]
elif isinstance(data, dict):
return {key: self._deconstruct_binary_internal(value, attachments)
for key, value in six.iteritems(data)}
else:
return data
def _data_is_binary(self, data):
"""Check if the data contains binary components."""
if isinstance(data, six.binary_type):
return True
elif isinstance(data, list):
return functools.reduce(
lambda a, b: a or b, [self._data_is_binary(item)
for item in data], False)
elif isinstance(data, dict):
return functools.reduce(
lambda a, b: a or b, [self._data_is_binary(item)
for item in six.itervalues(data)],
False)
else:
return False

View file

@ -0,0 +1,154 @@
from functools import partial
import uuid
import json
import pickle
import six
from .base_manager import BaseManager
class PubSubManager(BaseManager):
"""Manage a client list attached to a pub/sub backend.
This is a base class that enables multiple servers to share the list of
clients, with the servers communicating events through a pub/sub backend.
The use of a pub/sub backend also allows any client connected to the
backend to emit events addressed to Socket.IO clients.
The actual backends must be implemented by subclasses, this class only
provides a pub/sub generic framework.
:param channel: The channel name on which the server sends and receives
notifications.
"""
name = 'pubsub'
def __init__(self, channel='socketio', write_only=False, logger=None):
super(PubSubManager, self).__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
def initialize(self):
super(PubSubManager, self).initialize()
if not self.write_only:
self.thread = self.server.start_background_task(self._thread)
self._get_logger().info(self.name + ' backend initialized.')
def emit(self, event, data, namespace=None, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace.
This method takes care or propagating the message to all the servers
that are connected through the message queue.
The parameters are the same as in :meth:`.Server.emit`.
"""
if kwargs.get('ignore_queue'):
return super(PubSubManager, self).emit(
event, data, namespace=namespace, room=room, skip_sid=skip_sid,
callback=callback)
namespace = namespace or '/'
if callback is not None:
if self.server is None:
raise RuntimeError('Callbacks can only be issued from the '
'context of a server.')
if room is None:
raise ValueError('Cannot use callback without a room set.')
id = self._generate_ack_id(room, namespace, callback)
callback = (room, namespace, id)
else:
callback = None
self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace, 'room': room,
'skip_sid': skip_sid, 'callback': callback,
'host_id': self.host_id})
def close_room(self, room, namespace=None):
self._publish({'method': 'close_room', 'room': room,
'namespace': namespace or '/'})
def _publish(self, data):
"""Publish a message on the Socket.IO channel.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
def _listen(self):
"""Return the next message published on the Socket.IO channel,
blocking until a message is available.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
def _handle_emit(self, message):
# Events with callbacks are very tricky to handle across hosts
# Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender
remote_callback = message.get('callback')
remote_host_id = message.get('host_id')
if remote_callback is not None and len(remote_callback) == 3:
callback = partial(self._return_callback, remote_host_id,
*remote_callback)
else:
callback = None
super(PubSubManager, self).emit(message['event'], message['data'],
namespace=message.get('namespace'),
room=message.get('room'),
skip_sid=message.get('skip_sid'),
callback=callback)
def _handle_callback(self, message):
if self.host_id == message.get('host_id'):
try:
sid = message['sid']
namespace = message['namespace']
id = message['id']
args = message['args']
except KeyError:
return
self.trigger_callback(sid, namespace, id, args)
def _return_callback(self, host_id, sid, namespace, callback_id, *args):
# When an event callback is received, the callback is returned back
# the sender, which is identified by the host_id
self._publish({'method': 'callback', 'host_id': host_id,
'sid': sid, 'namespace': namespace, 'id': callback_id,
'args': args})
def _handle_close_room(self, message):
super(PubSubManager, self).close_room(
room=message.get('room'), namespace=message.get('namespace'))
def _thread(self):
for message in self._listen():
data = None
if isinstance(message, dict):
data = message
else:
if isinstance(message, six.binary_type): # pragma: no cover
try:
data = pickle.loads(message)
except:
pass
if data is None:
try:
data = json.loads(message)
except:
pass
if data and 'method' in data:
if data['method'] == 'emit':
self._handle_emit(data)
elif data['method'] == 'callback':
self._handle_callback(data)
elif data['method'] == 'close_room':
self._handle_close_room(data)

View file

@ -0,0 +1,115 @@
import logging
import pickle
import time
try:
import redis
except ImportError:
redis = None
from .pubsub_manager import PubSubManager
logger = logging.getLogger('socketio')
class RedisManager(PubSubManager): # pragma: no cover
"""Redis based client manager.
This class implements a Redis backend for event sharing across multiple
processes. Only kept here as one more example of how to build a custom
backend, since the kombu backend is perfectly adequate to support a Redis
message queue.
To use a Redis backend, initialize the :class:`Server` instance as
follows::
url = 'redis://hostname:port/0'
server = socketio.Server(client_manager=socketio.RedisManager(url))
:param url: The connection URL for the Redis server. For a default Redis
store running on the same host, use ``redis://``.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param redis_options: additional keyword arguments to be passed to
``Redis.from_url()``.
"""
name = 'redis'
def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None, redis_options=None):
if redis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install redis" in your '
'virtualenv).')
self.redis_url = url
self.redis_options = redis_options or {}
self._redis_connect()
super(RedisManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
def initialize(self):
super(RedisManager, self).initialize()
monkey_patched = True
if self.server.async_mode == 'eventlet':
from eventlet.patcher import is_monkey_patched
monkey_patched = is_monkey_patched('socket')
elif 'gevent' in self.server.async_mode:
from gevent.monkey import is_module_patched
monkey_patched = is_module_patched('socket')
if not monkey_patched:
raise RuntimeError(
'Redis requires a monkey patched socket library to work '
'with ' + self.server.async_mode)
def _redis_connect(self):
self.redis = redis.Redis.from_url(self.redis_url,
**self.redis_options)
self.pubsub = self.redis.pubsub()
def _publish(self, data):
retry = True
while True:
try:
if not retry:
self._redis_connect()
return self.redis.publish(self.channel, pickle.dumps(data))
except redis.exceptions.ConnectionError:
if retry:
logger.error('Cannot publish to redis... retrying')
retry = False
else:
logger.error('Cannot publish to redis... giving up')
break
def _redis_listen_with_retries(self):
retry_sleep = 1
connect = False
while True:
try:
if connect:
self._redis_connect()
self.pubsub.subscribe(self.channel)
for message in self.pubsub.listen():
yield message
except redis.exceptions.ConnectionError:
logger.error('Cannot receive from redis... '
'retrying in {} secs'.format(retry_sleep))
connect = True
time.sleep(retry_sleep)
retry_sleep *= 2
if retry_sleep > 60:
retry_sleep = 60
def _listen(self):
channel = self.channel.encode('utf-8')
self.pubsub.subscribe(self.channel)
for message in self._redis_listen_with_retries():
if message['channel'] == channel and \
message['type'] == 'message' and 'data' in message:
yield message['data']
self.pubsub.unsubscribe(self.channel)

730
libs/socketio/server.py Normal file
View file

@ -0,0 +1,730 @@
import logging
import engineio
import six
from . import base_manager
from . import exceptions
from . import namespace
from . import packet
default_logger = logging.getLogger('socketio.server')
class Server(object):
"""A Socket.IO server.
This class implements a fully compliant Socket.IO web server with support
for websocket and long-polling transports.
:param client_manager: The client manager instance that will manage the
client list. When this is omitted, the client list
is stored in an in-memory structure, so the use of
multiple connected servers is not possible.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, event handlers for a client are
executed in separate threads. To run handlers for a
client synchronously, set to ``False``. The default
is ``True``.
:param always_connect: When set to ``False``, new connections are
provisory until the connect handler returns
something other than ``False``, at which point they
are accepted. When set to ``True``, connections are
immediately accepted, and then if the connect
handler returns ``False`` a disconnect is issued.
Set to ``True`` if you need to emit events from the
connect handler and your client is confused when it
receives events before the connection acceptance.
In any other case use the default of ``False``.
:param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings:
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
The first async mode that has all its dependencies
installed is then one that is chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 60 seconds.
:param ping_interval: The interval in seconds at which the client pings
the server. The default is 25 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 100,000,000
bytes.
:param allow_upgrades: Whether to allow transport upgrades or not. The
default is ``True``.
:param http_compression: Whether to compress packages when using the
polling transport. The default is ``True``.
:param compression_threshold: Only compress messages when their byte size
is greater than this value. The default is
1024 bytes.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default is
``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def __init__(self, client_manager=None, logger=False, binary=False,
json=None, async_handlers=True, always_connect=False,
**kwargs):
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if json is not None:
packet.Packet.json = json
engineio_options['json'] = json
engineio_options['async_handlers'] = False
self.eio = self._engineio_server_class()(**engineio_options)
self.eio.on('connect', self._handle_eio_connect)
self.eio.on('message', self._handle_eio_message)
self.eio.on('disconnect', self._handle_eio_disconnect)
self.binary = binary
self.environ = {}
self.handlers = {}
self.namespace_handlers = {}
self._binary_packet = {}
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
if client_manager is None:
client_manager = base_manager.BaseManager()
self.manager = client_manager
self.manager.set_server(self)
self.manager_initialized = False
self.async_handlers = async_handlers
self.always_connect = always_connect
self.async_mode = self.eio.async_mode
def is_asyncio_based(self):
return False
def on(self, event, handler=None, namespace=None):
"""Register an event handler.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the handler is associated with
the default namespace.
Example usage::
# as a decorator:
@socket_io.on('connect', namespace='/chat')
def connect_handler(sid, environ):
print('Connection request')
if environ['REMOTE_ADDR'] in blacklisted:
return False # reject
# as a method:
def message_handler(sid, msg):
print('Received message: ', msg)
eio.send(sid, 'response')
socket_io.on('message', namespace='/chat', message_handler)
The handler function receives the ``sid`` (session ID) for the
client as first argument. The ``'connect'`` event handler receives the
WSGI environment as a second argument, and can return ``False`` to
reject the connection. The ``'message'`` handler and handlers for
custom event names receive the message payload as a second argument.
Any values returned from a message handler will be passed to the
client's acknowledgement callback function if it exists. The
``'disconnect'`` handler does not take a second argument.
"""
namespace = namespace or '/'
def set_handler(handler):
if namespace not in self.handlers:
self.handlers[namespace] = {}
self.handlers[namespace][event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def event(self, *args, **kwargs):
"""Decorator to register an event handler.
This is a simplified version of the ``on()`` method that takes the
event name from the decorated function.
Example usage::
@sio.event
def my_event(data):
print('Received data: ', data)
The above example is equivalent to::
@sio.on('my_event')
def my_event(data):
print('Received data: ', data)
A custom namespace can be given as an argument to the decorator::
@sio.event(namespace='/test')
def my_event(data):
print('Received data: ', data)
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# the decorator was invoked without arguments
# args[0] is the decorated function
return self.on(args[0].__name__)(args[0])
else:
# the decorator was invoked with arguments
def set_handler(handler):
return self.on(handler.__name__, *args, **kwargs)(handler)
return set_handler
def register_namespace(self, namespace_handler):
"""Register a namespace handler object.
:param namespace_handler: An instance of a :class:`Namespace`
subclass that handles all the event traffic
for a namespace.
"""
if not isinstance(namespace_handler, namespace.Namespace):
raise ValueError('Not a namespace instance')
if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
raise ValueError('Not a valid namespace class for this server')
namespace_handler._set_server(self)
self.namespace_handlers[namespace_handler.namespace] = \
namespace_handler
def emit(self, event, data=None, to=None, room=None, skip_sid=None,
namespace=None, callback=None, **kwargs):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender. To
skip multiple sids, pass a list.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
namespace = namespace or '/'
room = to or room
self.logger.info('emitting event "%s" to %s [%s]', event,
room or 'all', namespace)
self.manager.emit(event, data, namespace, room=room,
skip_sid=skip_sid, callback=callback, **kwargs)
def send(self, data, to=None, room=None, skip_sid=None, namespace=None,
callback=None, **kwargs):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender. To
skip multiple sids, pass a list.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
self.emit('message', data=data, to=to, room=room, skip_sid=skip_sid,
namespace=namespace, callback=callback, **kwargs)
def call(self, event, data=None, to=None, sid=None, namespace=None,
timeout=60, **kwargs):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The session ID of the recipient client.
:param sid: Alias for the ``to`` parameter.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
client directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
if not self.async_handlers:
raise RuntimeError(
'Cannot use call() when async_handlers is False.')
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
self.emit(event, data=data, room=to or sid, namespace=namespace,
callback=event_callback, **kwargs)
if not callback_event.wait(timeout=timeout):
raise exceptions.TimeoutError()
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
def enter_room(self, sid, room, namespace=None):
"""Enter a room.
This function adds the client to a room. The :func:`emit` and
:func:`send` functions can optionally broadcast events to all the
clients in a room.
:param sid: Session ID of the client.
:param room: Room name. If the room does not exist it is created.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
self.logger.info('%s is entering room %s [%s]', sid, room, namespace)
self.manager.enter_room(sid, namespace, room)
def leave_room(self, sid, room, namespace=None):
"""Leave a room.
This function removes the client from a room.
:param sid: Session ID of the client.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
self.logger.info('%s is leaving room %s [%s]', sid, room, namespace)
self.manager.leave_room(sid, namespace, room)
def close_room(self, room, namespace=None):
"""Close a room.
This function removes all the clients from the given room.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
self.logger.info('room %s is closing [%s]', room, namespace)
self.manager.close_room(room, namespace)
def rooms(self, sid, namespace=None):
"""Return the rooms a client is in.
:param sid: Session ID of the client.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
return self.manager.get_rooms(sid, namespace)
def get_session(self, sid, namespace=None):
"""Return the user session for a client.
:param sid: The session id of the client.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved unless
``save_session()`` is called, or when the ``session`` context manager
is used.
"""
namespace = namespace or '/'
eio_session = self.eio.get_session(sid)
return eio_session.setdefault(namespace, {})
def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
"""
namespace = namespace or '/'
eio_session = self.eio.get_session(sid)
eio_session[namespace] = session
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@sio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with sio.session(sid) as session:
session['username'] = username
@sio.on('message')
def on_message(sid, msg):
with sio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid, namespace):
self.server = server
self.sid = sid
self.namespace = namespace
self.session = None
def __enter__(self):
self.session = self.server.get_session(sid,
namespace=namespace)
return self.session
def __exit__(self, *args):
self.server.save_session(sid, self.session,
namespace=namespace)
return _session_context_manager(self, sid, namespace)
def disconnect(self, sid, namespace=None):
"""Disconnect a client.
:param sid: Session ID of the client.
:param namespace: The Socket.IO namespace to disconnect. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
if self.manager.is_connected(sid, namespace=namespace):
self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
self.eio.disconnect(sid)
def transport(self, sid):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
:param sid: The session of the client.
"""
return self.eio.transport(sid)
def handle_request(self, environ, start_response):
"""Handle an HTTP request from the client.
This is the entry point of the Socket.IO application, using the same
interface as a WSGI application. For the typical usage, this function
is invoked by the :class:`Middleware` instance, but it can be invoked
directly when the middleware is not used.
:param environ: The WSGI environment.
:param start_response: The WSGI ``start_response`` function.
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
return self.eio.handle_request(environ, start_response)
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.eio.start_background_task(target, *args, **kwargs)
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self.eio.sleep(seconds)
def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client."""
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
else:
data = [data]
self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id,
binary=binary))
def _send_packet(self, sid, pkt):
"""Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
self.eio.send(sid, ep, binary=binary)
binary = True
else:
self.eio.send(sid, encoded_packet, binary=False)
def _handle_connect(self, sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
if self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
fail_reason = None
try:
success = self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
self._send_packet(sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace)
if not self.always_connect:
self._send_packet(sid, packet.Packet(
packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
elif not self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
def _handle_disconnect(self, sid, namespace):
"""Handle a client disconnect."""
namespace = namespace or '/'
if namespace == '/':
namespace_list = list(self.manager.get_namespaces())
else:
namespace_list = [namespace]
for n in namespace_list:
if n != '/' and self.manager.is_connected(sid, n):
self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
def _handle_event(self, sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
if not self.manager.is_connected(sid, namespace):
self.logger.warning('%s is not connected to namespace %s',
sid, namespace)
return
if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id)
else:
self._handle_event_internal(self, sid, data, namespace, id)
def _handle_event_internal(self, server, sid, data, namespace, id):
r = server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
server._send_packet(sid, packet.Packet(packet.ACK,
namespace=namespace,
id=id, data=data,
binary=binary))
def _handle_ack(self, sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace)
self.manager.trigger_callback(sid, namespace, id, data)
def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args)
# or else, forward the event to a namespace handler if one exists
elif namespace in self.namespace_handlers:
return self.namespace_handlers[namespace].trigger_event(
event, *args)
def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ
return self._handle_connect(sid, '/')
def _handle_eio_message(self, sid, data):
"""Dispatch Engine.IO messages."""
if sid in self._binary_packet:
pkt = self._binary_packet[sid]
if pkt.add_attachment(data):
del self._binary_packet[sid]
if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
else:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt
elif pkt.packet_type == packet.ERROR:
raise ValueError('Unexpected ERROR packet.')
else:
raise ValueError('Unknown packet type.')
def _handle_eio_disconnect(self, sid):
"""Handle Engine.IO disconnect event."""
self._handle_disconnect(sid, '/')
if sid in self.environ:
del self.environ[sid]
def _engineio_server_class(self):
return engineio.Server

11
libs/socketio/tornado.py Normal file
View file

@ -0,0 +1,11 @@
import sys
if sys.version_info >= (3, 5):
try:
from engineio.async_drivers.tornado import get_tornado_handler as \
get_engineio_handler
except ImportError: # pragma: no cover
get_engineio_handler = None
def get_tornado_handler(socketio_server): # pragma: no cover
return get_engineio_handler(socketio_server.eio)

View file

@ -0,0 +1,111 @@
import pickle
import re
try:
import eventlet.green.zmq as zmq
except ImportError:
zmq = None
import six
from .pubsub_manager import PubSubManager
class ZmqManager(PubSubManager): # pragma: no cover
"""zmq based client manager.
NOTE: this zmq implementation should be considered experimental at this
time. At this time, eventlet is required to use zmq.
This class implements a zmq backend for event sharing across multiple
processes. To use a zmq backend, initialize the :class:`Server` instance as
follows::
url = 'zmq+tcp://hostname:port1+port2'
server = socketio.Server(client_manager=socketio.ZmqManager(url))
:param url: The connection URL for the zmq message broker,
which will need to be provided and running.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
A zmq message broker must be running for the zmq_manager to work.
you can write your own or adapt one from the following simple broker
below::
import zmq
receiver = zmq.Context().socket(zmq.PULL)
receiver.bind("tcp://*:5555")
publisher = zmq.Context().socket(zmq.PUB)
publisher.bind("tcp://*:5556")
while True:
publisher.send(receiver.recv())
"""
name = 'zmq'
def __init__(self, url='zmq+tcp://localhost:5555+5556',
channel='socketio',
write_only=False,
logger=None):
if zmq is None:
raise RuntimeError('zmq package is not installed '
'(Run "pip install pyzmq" in your '
'virtualenv).')
r = re.compile(r':\d+\+\d+$')
if not (url.startswith('zmq+tcp://') and r.search(url)):
raise RuntimeError('unexpected connection string: ' + url)
url = url.replace('zmq+', '')
(sink_url, sub_port) = url.split('+')
sink_port = sink_url.split(':')[-1]
sub_url = sink_url.replace(sink_port, sub_port)
sink = zmq.Context().socket(zmq.PUSH)
sink.connect(sink_url)
sub = zmq.Context().socket(zmq.SUB)
sub.setsockopt_string(zmq.SUBSCRIBE, u'')
sub.connect(sub_url)
self.sink = sink
self.sub = sub
self.channel = channel
super(ZmqManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
def _publish(self, data):
pickled_data = pickle.dumps(
{
'type': 'message',
'channel': self.channel,
'data': data
}
)
return self.sink.send(pickled_data)
def zmq_listen(self):
while True:
response = self.sub.recv()
if response is not None:
yield response
def _listen(self):
for message in self.zmq_listen():
if isinstance(message, six.binary_type):
try:
message = pickle.loads(message)
except Exception:
pass
if isinstance(message, dict) and \
message['type'] == 'message' and \
message['channel'] == self.channel and \
'data' in message:
yield message['data']
return

402
libs/yaml/__init__.py Normal file
View file

@ -0,0 +1,402 @@
from .error import *
from .tokens import *
from .events import *
from .nodes import *
from .loader import *
from .dumper import *
__version__ = '5.1'
try:
from .cyaml import *
__with_libyaml__ = True
except ImportError:
__with_libyaml__ = False
import io
#------------------------------------------------------------------------------
# Warnings control
#------------------------------------------------------------------------------
# 'Global' warnings state:
_warnings_enabled = {
'YAMLLoadWarning': True,
}
# Get or set global warnings' state
def warnings(settings=None):
if settings is None:
return _warnings_enabled
if type(settings) is dict:
for key in settings:
if key in _warnings_enabled:
_warnings_enabled[key] = settings[key]
# Warn when load() is called without Loader=...
class YAMLLoadWarning(RuntimeWarning):
pass
def load_warning(method):
if _warnings_enabled['YAMLLoadWarning'] is False:
return
import warnings
message = (
"calling yaml.%s() without Loader=... is deprecated, as the "
"default Loader is unsafe. Please read "
"https://msg.pyyaml.org/load for full details."
) % method
warnings.warn(message, YAMLLoadWarning, stacklevel=3)
#------------------------------------------------------------------------------
def scan(stream, Loader=Loader):
"""
Scan a YAML stream and produce scanning tokens.
"""
loader = Loader(stream)
try:
while loader.check_token():
yield loader.get_token()
finally:
loader.dispose()
def parse(stream, Loader=Loader):
"""
Parse a YAML stream and produce parsing events.
"""
loader = Loader(stream)
try:
while loader.check_event():
yield loader.get_event()
finally:
loader.dispose()
def compose(stream, Loader=Loader):
"""
Parse the first YAML document in a stream
and produce the corresponding representation tree.
"""
loader = Loader(stream)
try:
return loader.get_single_node()
finally:
loader.dispose()
def compose_all(stream, Loader=Loader):
"""
Parse all YAML documents in a stream
and produce corresponding representation trees.
"""
loader = Loader(stream)
try:
while loader.check_node():
yield loader.get_node()
finally:
loader.dispose()
def load(stream, Loader=None):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
"""
if Loader is None:
load_warning('load')
Loader = FullLoader
loader = Loader(stream)
try:
return loader.get_single_data()
finally:
loader.dispose()
def load_all(stream, Loader=None):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
"""
if Loader is None:
load_warning('load_all')
Loader = FullLoader
loader = Loader(stream)
try:
while loader.check_data():
yield loader.get_data()
finally:
loader.dispose()
def full_load(stream):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve all tags except those known to be
unsafe on untrusted input.
"""
return load(stream, FullLoader)
def full_load_all(stream):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve all tags except those known to be
unsafe on untrusted input.
"""
return load_all(stream, FullLoader)
def safe_load(stream):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve only basic YAML tags. This is known
to be safe for untrusted input.
"""
return load(stream, SafeLoader)
def safe_load_all(stream):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve only basic YAML tags. This is known
to be safe for untrusted input.
"""
return load_all(stream, SafeLoader)
def unsafe_load(stream):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve all tags, even those known to be
unsafe on untrusted input.
"""
return load(stream, UnsafeLoader)
def unsafe_load_all(stream):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve all tags, even those known to be
unsafe on untrusted input.
"""
return load_all(stream, UnsafeLoader)
def emit(events, stream=None, Dumper=Dumper,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None):
"""
Emit YAML parsing events into a stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
stream = io.StringIO()
getvalue = stream.getvalue
dumper = Dumper(stream, canonical=canonical, indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
try:
for event in events:
dumper.emit(event)
finally:
dumper.dispose()
if getvalue:
return getvalue()
def serialize_all(nodes, stream=None, Dumper=Dumper,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None):
"""
Serialize a sequence of representation trees into a YAML stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
if encoding is None:
stream = io.StringIO()
else:
stream = io.BytesIO()
getvalue = stream.getvalue
dumper = Dumper(stream, canonical=canonical, indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break,
encoding=encoding, version=version, tags=tags,
explicit_start=explicit_start, explicit_end=explicit_end)
try:
dumper.open()
for node in nodes:
dumper.serialize(node)
dumper.close()
finally:
dumper.dispose()
if getvalue:
return getvalue()
def serialize(node, stream=None, Dumper=Dumper, **kwds):
"""
Serialize a representation tree into a YAML stream.
If stream is None, return the produced string instead.
"""
return serialize_all([node], stream, Dumper=Dumper, **kwds)
def dump_all(documents, stream=None, Dumper=Dumper,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
"""
Serialize a sequence of Python objects into a YAML stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
if encoding is None:
stream = io.StringIO()
else:
stream = io.BytesIO()
getvalue = stream.getvalue
dumper = Dumper(stream, default_style=default_style,
default_flow_style=default_flow_style,
canonical=canonical, indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break,
encoding=encoding, version=version, tags=tags,
explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys)
try:
dumper.open()
for data in documents:
dumper.represent(data)
dumper.close()
finally:
dumper.dispose()
if getvalue:
return getvalue()
def dump(data, stream=None, Dumper=Dumper, **kwds):
"""
Serialize a Python object into a YAML stream.
If stream is None, return the produced string instead.
"""
return dump_all([data], stream, Dumper=Dumper, **kwds)
def safe_dump_all(documents, stream=None, **kwds):
"""
Serialize a sequence of Python objects into a YAML stream.
Produce only basic YAML tags.
If stream is None, return the produced string instead.
"""
return dump_all(documents, stream, Dumper=SafeDumper, **kwds)
def safe_dump(data, stream=None, **kwds):
"""
Serialize a Python object into a YAML stream.
Produce only basic YAML tags.
If stream is None, return the produced string instead.
"""
return dump_all([data], stream, Dumper=SafeDumper, **kwds)
def add_implicit_resolver(tag, regexp, first=None,
Loader=Loader, Dumper=Dumper):
"""
Add an implicit scalar detector.
If an implicit scalar value matches the given regexp,
the corresponding tag is assigned to the scalar.
first is a sequence of possible initial characters or None.
"""
Loader.add_implicit_resolver(tag, regexp, first)
Dumper.add_implicit_resolver(tag, regexp, first)
def add_path_resolver(tag, path, kind=None, Loader=Loader, Dumper=Dumper):
"""
Add a path based resolver for the given tag.
A path is a list of keys that forms a path
to a node in the representation tree.
Keys can be string values, integers, or None.
"""
Loader.add_path_resolver(tag, path, kind)
Dumper.add_path_resolver(tag, path, kind)
def add_constructor(tag, constructor, Loader=Loader):
"""
Add a constructor for the given tag.
Constructor is a function that accepts a Loader instance
and a node object and produces the corresponding Python object.
"""
Loader.add_constructor(tag, constructor)
def add_multi_constructor(tag_prefix, multi_constructor, Loader=Loader):
"""
Add a multi-constructor for the given tag prefix.
Multi-constructor is called for a node if its tag starts with tag_prefix.
Multi-constructor accepts a Loader instance, a tag suffix,
and a node object and produces the corresponding Python object.
"""
Loader.add_multi_constructor(tag_prefix, multi_constructor)
def add_representer(data_type, representer, Dumper=Dumper):
"""
Add a representer for the given type.
Representer is a function accepting a Dumper instance
and an instance of the given data type
and producing the corresponding representation node.
"""
Dumper.add_representer(data_type, representer)
def add_multi_representer(data_type, multi_representer, Dumper=Dumper):
"""
Add a representer for the given type.
Multi-representer is a function accepting a Dumper instance
and an instance of the given data type or subtype
and producing the corresponding representation node.
"""
Dumper.add_multi_representer(data_type, multi_representer)
class YAMLObjectMetaclass(type):
"""
The metaclass for YAMLObject.
"""
def __init__(cls, name, bases, kwds):
super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds)
if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None:
cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml)
cls.yaml_dumper.add_representer(cls, cls.to_yaml)
class YAMLObject(metaclass=YAMLObjectMetaclass):
"""
An object that can dump itself to a YAML stream
and load itself from a YAML stream.
"""
__slots__ = () # no direct instantiation, so allow immutable subclasses
yaml_loader = Loader
yaml_dumper = Dumper
yaml_tag = None
yaml_flow_style = None
@classmethod
def from_yaml(cls, loader, node):
"""
Convert a representation node to a Python object.
"""
return loader.construct_yaml_object(node, cls)
@classmethod
def to_yaml(cls, dumper, data):
"""
Convert a Python object to a representation node.
"""
return dumper.represent_yaml_object(cls.yaml_tag, data, cls,
flow_style=cls.yaml_flow_style)

139
libs/yaml/composer.py Normal file
View file

@ -0,0 +1,139 @@
__all__ = ['Composer', 'ComposerError']
from .error import MarkedYAMLError
from .events import *
from .nodes import *
class ComposerError(MarkedYAMLError):
pass
class Composer:
def __init__(self):
self.anchors = {}
def check_node(self):
# Drop the STREAM-START event.
if self.check_event(StreamStartEvent):
self.get_event()
# If there are more documents available?
return not self.check_event(StreamEndEvent)
def get_node(self):
# Get the root node of the next document.
if not self.check_event(StreamEndEvent):
return self.compose_document()
def get_single_node(self):
# Drop the STREAM-START event.
self.get_event()
# Compose a document if the stream is not empty.
document = None
if not self.check_event(StreamEndEvent):
document = self.compose_document()
# Ensure that the stream contains no more documents.
if not self.check_event(StreamEndEvent):
event = self.get_event()
raise ComposerError("expected a single document in the stream",
document.start_mark, "but found another document",
event.start_mark)
# Drop the STREAM-END event.
self.get_event()
return document
def compose_document(self):
# Drop the DOCUMENT-START event.
self.get_event()
# Compose the root node.
node = self.compose_node(None, None)
# Drop the DOCUMENT-END event.
self.get_event()
self.anchors = {}
return node
def compose_node(self, parent, index):
if self.check_event(AliasEvent):
event = self.get_event()
anchor = event.anchor
if anchor not in self.anchors:
raise ComposerError(None, None, "found undefined alias %r"
% anchor, event.start_mark)
return self.anchors[anchor]
event = self.peek_event()
anchor = event.anchor
if anchor is not None:
if anchor in self.anchors:
raise ComposerError("found duplicate anchor %r; first occurrence"
% anchor, self.anchors[anchor].start_mark,
"second occurrence", event.start_mark)
self.descend_resolver(parent, index)
if self.check_event(ScalarEvent):
node = self.compose_scalar_node(anchor)
elif self.check_event(SequenceStartEvent):
node = self.compose_sequence_node(anchor)
elif self.check_event(MappingStartEvent):
node = self.compose_mapping_node(anchor)
self.ascend_resolver()
return node
def compose_scalar_node(self, anchor):
event = self.get_event()
tag = event.tag
if tag is None or tag == '!':
tag = self.resolve(ScalarNode, event.value, event.implicit)
node = ScalarNode(tag, event.value,
event.start_mark, event.end_mark, style=event.style)
if anchor is not None:
self.anchors[anchor] = node
return node
def compose_sequence_node(self, anchor):
start_event = self.get_event()
tag = start_event.tag
if tag is None or tag == '!':
tag = self.resolve(SequenceNode, None, start_event.implicit)
node = SequenceNode(tag, [],
start_event.start_mark, None,
flow_style=start_event.flow_style)
if anchor is not None:
self.anchors[anchor] = node
index = 0
while not self.check_event(SequenceEndEvent):
node.value.append(self.compose_node(node, index))
index += 1
end_event = self.get_event()
node.end_mark = end_event.end_mark
return node
def compose_mapping_node(self, anchor):
start_event = self.get_event()
tag = start_event.tag
if tag is None or tag == '!':
tag = self.resolve(MappingNode, None, start_event.implicit)
node = MappingNode(tag, [],
start_event.start_mark, None,
flow_style=start_event.flow_style)
if anchor is not None:
self.anchors[anchor] = node
while not self.check_event(MappingEndEvent):
#key_event = self.peek_event()
item_key = self.compose_node(node, None)
#if item_key in node.value:
# raise ComposerError("while composing a mapping", start_event.start_mark,
# "found duplicate key", key_event.start_mark)
item_value = self.compose_node(node, item_key)
#node.value[item_key] = item_value
node.value.append((item_key, item_value))
end_event = self.get_event()
node.end_mark = end_event.end_mark
return node

720
libs/yaml/constructor.py Normal file
View file

@ -0,0 +1,720 @@
__all__ = [
'BaseConstructor',
'SafeConstructor',
'FullConstructor',
'UnsafeConstructor',
'Constructor',
'ConstructorError'
]
from .error import *
from .nodes import *
import collections.abc, datetime, base64, binascii, re, sys, types
class ConstructorError(MarkedYAMLError):
pass
class BaseConstructor:
yaml_constructors = {}
yaml_multi_constructors = {}
def __init__(self):
self.constructed_objects = {}
self.recursive_objects = {}
self.state_generators = []
self.deep_construct = False
def check_data(self):
# If there are more documents available?
return self.check_node()
def get_data(self):
# Construct and return the next document.
if self.check_node():
return self.construct_document(self.get_node())
def get_single_data(self):
# Ensure that the stream contains a single document and construct it.
node = self.get_single_node()
if node is not None:
return self.construct_document(node)
return None
def construct_document(self, node):
data = self.construct_object(node)
while self.state_generators:
state_generators = self.state_generators
self.state_generators = []
for generator in state_generators:
for dummy in generator:
pass
self.constructed_objects = {}
self.recursive_objects = {}
self.deep_construct = False
return data
def construct_object(self, node, deep=False):
if node in self.constructed_objects:
return self.constructed_objects[node]
if deep:
old_deep = self.deep_construct
self.deep_construct = True
if node in self.recursive_objects:
raise ConstructorError(None, None,
"found unconstructable recursive node", node.start_mark)
self.recursive_objects[node] = None
constructor = None
tag_suffix = None
if node.tag in self.yaml_constructors:
constructor = self.yaml_constructors[node.tag]
else:
for tag_prefix in self.yaml_multi_constructors:
if node.tag.startswith(tag_prefix):
tag_suffix = node.tag[len(tag_prefix):]
constructor = self.yaml_multi_constructors[tag_prefix]
break
else:
if None in self.yaml_multi_constructors:
tag_suffix = node.tag
constructor = self.yaml_multi_constructors[None]
elif None in self.yaml_constructors:
constructor = self.yaml_constructors[None]
elif isinstance(node, ScalarNode):
constructor = self.__class__.construct_scalar
elif isinstance(node, SequenceNode):
constructor = self.__class__.construct_sequence
elif isinstance(node, MappingNode):
constructor = self.__class__.construct_mapping
if tag_suffix is None:
data = constructor(self, node)
else:
data = constructor(self, tag_suffix, node)
if isinstance(data, types.GeneratorType):
generator = data
data = next(generator)
if self.deep_construct:
for dummy in generator:
pass
else:
self.state_generators.append(generator)
self.constructed_objects[node] = data
del self.recursive_objects[node]
if deep:
self.deep_construct = old_deep
return data
def construct_scalar(self, node):
if not isinstance(node, ScalarNode):
raise ConstructorError(None, None,
"expected a scalar node, but found %s" % node.id,
node.start_mark)
return node.value
def construct_sequence(self, node, deep=False):
if not isinstance(node, SequenceNode):
raise ConstructorError(None, None,
"expected a sequence node, but found %s" % node.id,
node.start_mark)
return [self.construct_object(child, deep=deep)
for child in node.value]
def construct_mapping(self, node, deep=False):
if not isinstance(node, MappingNode):
raise ConstructorError(None, None,
"expected a mapping node, but found %s" % node.id,
node.start_mark)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
if not isinstance(key, collections.abc.Hashable):
raise ConstructorError("while constructing a mapping", node.start_mark,
"found unhashable key", key_node.start_mark)
value = self.construct_object(value_node, deep=deep)
mapping[key] = value
return mapping
def construct_pairs(self, node, deep=False):
if not isinstance(node, MappingNode):
raise ConstructorError(None, None,
"expected a mapping node, but found %s" % node.id,
node.start_mark)
pairs = []
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
value = self.construct_object(value_node, deep=deep)
pairs.append((key, value))
return pairs
@classmethod
def add_constructor(cls, tag, constructor):
if not 'yaml_constructors' in cls.__dict__:
cls.yaml_constructors = cls.yaml_constructors.copy()
cls.yaml_constructors[tag] = constructor
@classmethod
def add_multi_constructor(cls, tag_prefix, multi_constructor):
if not 'yaml_multi_constructors' in cls.__dict__:
cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
cls.yaml_multi_constructors[tag_prefix] = multi_constructor
class SafeConstructor(BaseConstructor):
def construct_scalar(self, node):
if isinstance(node, MappingNode):
for key_node, value_node in node.value:
if key_node.tag == 'tag:yaml.org,2002:value':
return self.construct_scalar(value_node)
return super().construct_scalar(node)
def flatten_mapping(self, node):
merge = []
index = 0
while index < len(node.value):
key_node, value_node = node.value[index]
if key_node.tag == 'tag:yaml.org,2002:merge':
del node.value[index]
if isinstance(value_node, MappingNode):
self.flatten_mapping(value_node)
merge.extend(value_node.value)
elif isinstance(value_node, SequenceNode):
submerge = []
for subnode in value_node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s"
% subnode.id, subnode.start_mark)
self.flatten_mapping(subnode)
submerge.append(subnode.value)
submerge.reverse()
for value in submerge:
merge.extend(value)
else:
raise ConstructorError("while constructing a mapping", node.start_mark,
"expected a mapping or list of mappings for merging, but found %s"
% value_node.id, value_node.start_mark)
elif key_node.tag == 'tag:yaml.org,2002:value':
key_node.tag = 'tag:yaml.org,2002:str'
index += 1
else:
index += 1
if merge:
node.value = merge + node.value
def construct_mapping(self, node, deep=False):
if isinstance(node, MappingNode):
self.flatten_mapping(node)
return super().construct_mapping(node, deep=deep)
def construct_yaml_null(self, node):
self.construct_scalar(node)
return None
bool_values = {
'yes': True,
'no': False,
'true': True,
'false': False,
'on': True,
'off': False,
}
def construct_yaml_bool(self, node):
value = self.construct_scalar(node)
return self.bool_values[value.lower()]
def construct_yaml_int(self, node):
value = self.construct_scalar(node)
value = value.replace('_', '')
sign = +1
if value[0] == '-':
sign = -1
if value[0] in '+-':
value = value[1:]
if value == '0':
return 0
elif value.startswith('0b'):
return sign*int(value[2:], 2)
elif value.startswith('0x'):
return sign*int(value[2:], 16)
elif value[0] == '0':
return sign*int(value, 8)
elif ':' in value:
digits = [int(part) for part in value.split(':')]
digits.reverse()
base = 1
value = 0
for digit in digits:
value += digit*base
base *= 60
return sign*value
else:
return sign*int(value)
inf_value = 1e300
while inf_value != inf_value*inf_value:
inf_value *= inf_value
nan_value = -inf_value/inf_value # Trying to make a quiet NaN (like C99).
def construct_yaml_float(self, node):
value = self.construct_scalar(node)
value = value.replace('_', '').lower()
sign = +1
if value[0] == '-':
sign = -1
if value[0] in '+-':
value = value[1:]
if value == '.inf':
return sign*self.inf_value
elif value == '.nan':
return self.nan_value
elif ':' in value:
digits = [float(part) for part in value.split(':')]
digits.reverse()
base = 1
value = 0.0
for digit in digits:
value += digit*base
base *= 60
return sign*value
else:
return sign*float(value)
def construct_yaml_binary(self, node):
try:
value = self.construct_scalar(node).encode('ascii')
except UnicodeEncodeError as exc:
raise ConstructorError(None, None,
"failed to convert base64 data into ascii: %s" % exc,
node.start_mark)
try:
if hasattr(base64, 'decodebytes'):
return base64.decodebytes(value)
else:
return base64.decodestring(value)
except binascii.Error as exc:
raise ConstructorError(None, None,
"failed to decode base64 data: %s" % exc, node.start_mark)
timestamp_regexp = re.compile(
r'''^(?P<year>[0-9][0-9][0-9][0-9])
-(?P<month>[0-9][0-9]?)
-(?P<day>[0-9][0-9]?)
(?:(?:[Tt]|[ \t]+)
(?P<hour>[0-9][0-9]?)
:(?P<minute>[0-9][0-9])
:(?P<second>[0-9][0-9])
(?:\.(?P<fraction>[0-9]*))?
(?:[ \t]*(?P<tz>Z|(?P<tz_sign>[-+])(?P<tz_hour>[0-9][0-9]?)
(?::(?P<tz_minute>[0-9][0-9]))?))?)?$''', re.X)
def construct_yaml_timestamp(self, node):
value = self.construct_scalar(node)
match = self.timestamp_regexp.match(node.value)
values = match.groupdict()
year = int(values['year'])
month = int(values['month'])
day = int(values['day'])
if not values['hour']:
return datetime.date(year, month, day)
hour = int(values['hour'])
minute = int(values['minute'])
second = int(values['second'])
fraction = 0
if values['fraction']:
fraction = values['fraction'][:6]
while len(fraction) < 6:
fraction += '0'
fraction = int(fraction)
delta = None
if values['tz_sign']:
tz_hour = int(values['tz_hour'])
tz_minute = int(values['tz_minute'] or 0)
delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
if values['tz_sign'] == '-':
delta = -delta
data = datetime.datetime(year, month, day, hour, minute, second, fraction)
if delta:
data -= delta
return data
def construct_yaml_omap(self, node):
# Note: we do not check for duplicate keys, because it's too
# CPU-expensive.
omap = []
yield omap
if not isinstance(node, SequenceNode):
raise ConstructorError("while constructing an ordered map", node.start_mark,
"expected a sequence, but found %s" % node.id, node.start_mark)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing an ordered map", node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark)
if len(subnode.value) != 1:
raise ConstructorError("while constructing an ordered map", node.start_mark,
"expected a single mapping item, but found %d items" % len(subnode.value),
subnode.start_mark)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
value = self.construct_object(value_node)
omap.append((key, value))
def construct_yaml_pairs(self, node):
# Note: the same code as `construct_yaml_omap`.
pairs = []
yield pairs
if not isinstance(node, SequenceNode):
raise ConstructorError("while constructing pairs", node.start_mark,
"expected a sequence, but found %s" % node.id, node.start_mark)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing pairs", node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark)
if len(subnode.value) != 1:
raise ConstructorError("while constructing pairs", node.start_mark,
"expected a single mapping item, but found %d items" % len(subnode.value),
subnode.start_mark)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
value = self.construct_object(value_node)
pairs.append((key, value))
def construct_yaml_set(self, node):
data = set()
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_yaml_str(self, node):
return self.construct_scalar(node)
def construct_yaml_seq(self, node):
data = []
yield data
data.extend(self.construct_sequence(node))
def construct_yaml_map(self, node):
data = {}
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_yaml_object(self, node, cls):
data = cls.__new__(cls)
yield data
if hasattr(data, '__setstate__'):
state = self.construct_mapping(node, deep=True)
data.__setstate__(state)
else:
state = self.construct_mapping(node)
data.__dict__.update(state)
def construct_undefined(self, node):
raise ConstructorError(None, None,
"could not determine a constructor for the tag %r" % node.tag,
node.start_mark)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:null',
SafeConstructor.construct_yaml_null)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:bool',
SafeConstructor.construct_yaml_bool)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:int',
SafeConstructor.construct_yaml_int)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:float',
SafeConstructor.construct_yaml_float)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:binary',
SafeConstructor.construct_yaml_binary)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:timestamp',
SafeConstructor.construct_yaml_timestamp)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:omap',
SafeConstructor.construct_yaml_omap)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:pairs',
SafeConstructor.construct_yaml_pairs)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:set',
SafeConstructor.construct_yaml_set)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:str',
SafeConstructor.construct_yaml_str)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:seq',
SafeConstructor.construct_yaml_seq)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:map',
SafeConstructor.construct_yaml_map)
SafeConstructor.add_constructor(None,
SafeConstructor.construct_undefined)
class FullConstructor(SafeConstructor):
def construct_python_str(self, node):
return self.construct_scalar(node)
def construct_python_unicode(self, node):
return self.construct_scalar(node)
def construct_python_bytes(self, node):
try:
value = self.construct_scalar(node).encode('ascii')
except UnicodeEncodeError as exc:
raise ConstructorError(None, None,
"failed to convert base64 data into ascii: %s" % exc,
node.start_mark)
try:
if hasattr(base64, 'decodebytes'):
return base64.decodebytes(value)
else:
return base64.decodestring(value)
except binascii.Error as exc:
raise ConstructorError(None, None,
"failed to decode base64 data: %s" % exc, node.start_mark)
def construct_python_long(self, node):
return self.construct_yaml_int(node)
def construct_python_complex(self, node):
return complex(self.construct_scalar(node))
def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))
def find_python_module(self, name, mark, unsafe=False):
if not name:
raise ConstructorError("while constructing a Python module", mark,
"expected non-empty name appended to the tag", mark)
if unsafe:
try:
__import__(name)
except ImportError as exc:
raise ConstructorError("while constructing a Python module", mark,
"cannot find module %r (%s)" % (name, exc), mark)
if not name in sys.modules:
raise ConstructorError("while constructing a Python module", mark,
"module %r is not imported" % name, mark)
return sys.modules[name]
def find_python_name(self, name, mark, unsafe=False):
if not name:
raise ConstructorError("while constructing a Python object", mark,
"expected non-empty name appended to the tag", mark)
if '.' in name:
module_name, object_name = name.rsplit('.', 1)
else:
module_name = 'builtins'
object_name = name
if unsafe:
try:
__import__(module_name)
except ImportError as exc:
raise ConstructorError("while constructing a Python object", mark,
"cannot find module %r (%s)" % (module_name, exc), mark)
if not module_name in sys.modules:
raise ConstructorError("while constructing a Python object", mark,
"module %r is not imported" % module_name, mark)
module = sys.modules[module_name]
if not hasattr(module, object_name):
raise ConstructorError("while constructing a Python object", mark,
"cannot find %r in the module %r"
% (object_name, module.__name__), mark)
return getattr(module, object_name)
def construct_python_name(self, suffix, node):
value = self.construct_scalar(node)
if value:
raise ConstructorError("while constructing a Python name", node.start_mark,
"expected the empty value, but found %r" % value, node.start_mark)
return self.find_python_name(suffix, node.start_mark)
def construct_python_module(self, suffix, node):
value = self.construct_scalar(node)
if value:
raise ConstructorError("while constructing a Python module", node.start_mark,
"expected the empty value, but found %r" % value, node.start_mark)
return self.find_python_module(suffix, node.start_mark)
def make_python_instance(self, suffix, node,
args=None, kwds=None, newobj=False, unsafe=False):
if not args:
args = []
if not kwds:
kwds = {}
cls = self.find_python_name(suffix, node.start_mark)
if not (unsafe or isinstance(cls, type)):
raise ConstructorError("while constructing a Python instance", node.start_mark,
"expected a class, but found %r" % type(cls),
node.start_mark)
if newobj and isinstance(cls, type):
return cls.__new__(cls, *args, **kwds)
else:
return cls(*args, **kwds)
def set_python_instance_state(self, instance, state):
if hasattr(instance, '__setstate__'):
instance.__setstate__(state)
else:
slotstate = {}
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
if hasattr(instance, '__dict__'):
instance.__dict__.update(state)
elif state:
slotstate.update(state)
for key, value in slotstate.items():
setattr(object, key, value)
def construct_python_object(self, suffix, node):
# Format:
# !!python/object:module.name { ... state ... }
instance = self.make_python_instance(suffix, node, newobj=True)
yield instance
deep = hasattr(instance, '__setstate__')
state = self.construct_mapping(node, deep=deep)
self.set_python_instance_state(instance, state)
def construct_python_object_apply(self, suffix, node, newobj=False):
# Format:
# !!python/object/apply # (or !!python/object/new)
# args: [ ... arguments ... ]
# kwds: { ... keywords ... }
# state: ... state ...
# listitems: [ ... listitems ... ]
# dictitems: { ... dictitems ... }
# or short format:
# !!python/object/apply [ ... arguments ... ]
# The difference between !!python/object/apply and !!python/object/new
# is how an object is created, check make_python_instance for details.
if isinstance(node, SequenceNode):
args = self.construct_sequence(node, deep=True)
kwds = {}
state = {}
listitems = []
dictitems = {}
else:
value = self.construct_mapping(node, deep=True)
args = value.get('args', [])
kwds = value.get('kwds', {})
state = value.get('state', {})
listitems = value.get('listitems', [])
dictitems = value.get('dictitems', {})
instance = self.make_python_instance(suffix, node, args, kwds, newobj)
if state:
self.set_python_instance_state(instance, state)
if listitems:
instance.extend(listitems)
if dictitems:
for key in dictitems:
instance[key] = dictitems[key]
return instance
def construct_python_object_new(self, suffix, node):
return self.construct_python_object_apply(suffix, node, newobj=True)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/none',
FullConstructor.construct_yaml_null)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/bool',
FullConstructor.construct_yaml_bool)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/str',
FullConstructor.construct_python_str)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/unicode',
FullConstructor.construct_python_unicode)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/bytes',
FullConstructor.construct_python_bytes)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/int',
FullConstructor.construct_yaml_int)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/long',
FullConstructor.construct_python_long)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/float',
FullConstructor.construct_yaml_float)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/complex',
FullConstructor.construct_python_complex)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/list',
FullConstructor.construct_yaml_seq)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/tuple',
FullConstructor.construct_python_tuple)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/dict',
FullConstructor.construct_yaml_map)
FullConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/name:',
FullConstructor.construct_python_name)
FullConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/module:',
FullConstructor.construct_python_module)
FullConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/object:',
FullConstructor.construct_python_object)
FullConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/object/apply:',
FullConstructor.construct_python_object_apply)
FullConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/object/new:',
FullConstructor.construct_python_object_new)
class UnsafeConstructor(FullConstructor):
def find_python_module(self, name, mark):
return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True)
def find_python_name(self, name, mark):
return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True)
def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False):
return super(UnsafeConstructor, self).make_python_instance(
suffix, node, args, kwds, newobj, unsafe=True)
# Constructor is same as UnsafeConstructor. Need to leave this in place in case
# people have extended it directly.
class Constructor(UnsafeConstructor):
pass

101
libs/yaml/cyaml.py Normal file
View file

@ -0,0 +1,101 @@
__all__ = [
'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader',
'CBaseDumper', 'CSafeDumper', 'CDumper'
]
from _yaml import CParser, CEmitter
from .constructor import *
from .serializer import *
from .representer import *
from .resolver import *
class CBaseLoader(CParser, BaseConstructor, BaseResolver):
def __init__(self, stream):
CParser.__init__(self, stream)
BaseConstructor.__init__(self)
BaseResolver.__init__(self)
class CSafeLoader(CParser, SafeConstructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
SafeConstructor.__init__(self)
Resolver.__init__(self)
class CFullLoader(CParser, FullConstructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
FullConstructor.__init__(self)
Resolver.__init__(self)
class CUnsafeLoader(CParser, UnsafeConstructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
UnsafeConstructor.__init__(self)
Resolver.__init__(self)
class CLoader(CParser, Constructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
Constructor.__init__(self)
Resolver.__init__(self)
class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
CEmitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width, encoding=encoding,
allow_unicode=allow_unicode, line_break=line_break,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class CSafeDumper(CEmitter, SafeRepresenter, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
CEmitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width, encoding=encoding,
allow_unicode=allow_unicode, line_break=line_break,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
SafeRepresenter.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class CDumper(CEmitter, Serializer, Representer, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
CEmitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width, encoding=encoding,
allow_unicode=allow_unicode, line_break=line_break,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)

62
libs/yaml/dumper.py Normal file
View file

@ -0,0 +1,62 @@
__all__ = ['BaseDumper', 'SafeDumper', 'Dumper']
from .emitter import *
from .serializer import *
from .representer import *
from .resolver import *
class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
Emitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
Serializer.__init__(self, encoding=encoding,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
Emitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
Serializer.__init__(self, encoding=encoding,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
SafeRepresenter.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class Dumper(Emitter, Serializer, Representer, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
Emitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
Serializer.__init__(self, encoding=encoding,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)

1137
libs/yaml/emitter.py Normal file

File diff suppressed because it is too large Load diff

75
libs/yaml/error.py Normal file
View file

@ -0,0 +1,75 @@
__all__ = ['Mark', 'YAMLError', 'MarkedYAMLError']
class Mark:
def __init__(self, name, index, line, column, buffer, pointer):
self.name = name
self.index = index
self.line = line
self.column = column
self.buffer = buffer
self.pointer = pointer
def get_snippet(self, indent=4, max_length=75):
if self.buffer is None:
return None
head = ''
start = self.pointer
while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029':
start -= 1
if self.pointer-start > max_length/2-1:
head = ' ... '
start += 5
break
tail = ''
end = self.pointer
while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029':
end += 1
if end-self.pointer > max_length/2-1:
tail = ' ... '
end -= 5
break
snippet = self.buffer[start:end]
return ' '*indent + head + snippet + tail + '\n' \
+ ' '*(indent+self.pointer-start+len(head)) + '^'
def __str__(self):
snippet = self.get_snippet()
where = " in \"%s\", line %d, column %d" \
% (self.name, self.line+1, self.column+1)
if snippet is not None:
where += ":\n"+snippet
return where
class YAMLError(Exception):
pass
class MarkedYAMLError(YAMLError):
def __init__(self, context=None, context_mark=None,
problem=None, problem_mark=None, note=None):
self.context = context
self.context_mark = context_mark
self.problem = problem
self.problem_mark = problem_mark
self.note = note
def __str__(self):
lines = []
if self.context is not None:
lines.append(self.context)
if self.context_mark is not None \
and (self.problem is None or self.problem_mark is None
or self.context_mark.name != self.problem_mark.name
or self.context_mark.line != self.problem_mark.line
or self.context_mark.column != self.problem_mark.column):
lines.append(str(self.context_mark))
if self.problem is not None:
lines.append(self.problem)
if self.problem_mark is not None:
lines.append(str(self.problem_mark))
if self.note is not None:
lines.append(self.note)
return '\n'.join(lines)

86
libs/yaml/events.py Normal file
View file

@ -0,0 +1,86 @@
# Abstract classes.
class Event(object):
def __init__(self, start_mark=None, end_mark=None):
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
attributes = [key for key in ['anchor', 'tag', 'implicit', 'value']
if hasattr(self, key)]
arguments = ', '.join(['%s=%r' % (key, getattr(self, key))
for key in attributes])
return '%s(%s)' % (self.__class__.__name__, arguments)
class NodeEvent(Event):
def __init__(self, anchor, start_mark=None, end_mark=None):
self.anchor = anchor
self.start_mark = start_mark
self.end_mark = end_mark
class CollectionStartEvent(NodeEvent):
def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None,
flow_style=None):
self.anchor = anchor
self.tag = tag
self.implicit = implicit
self.start_mark = start_mark
self.end_mark = end_mark
self.flow_style = flow_style
class CollectionEndEvent(Event):
pass
# Implementations.
class StreamStartEvent(Event):
def __init__(self, start_mark=None, end_mark=None, encoding=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.encoding = encoding
class StreamEndEvent(Event):
pass
class DocumentStartEvent(Event):
def __init__(self, start_mark=None, end_mark=None,
explicit=None, version=None, tags=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.explicit = explicit
self.version = version
self.tags = tags
class DocumentEndEvent(Event):
def __init__(self, start_mark=None, end_mark=None,
explicit=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.explicit = explicit
class AliasEvent(NodeEvent):
pass
class ScalarEvent(NodeEvent):
def __init__(self, anchor, tag, implicit, value,
start_mark=None, end_mark=None, style=None):
self.anchor = anchor
self.tag = tag
self.implicit = implicit
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.style = style
class SequenceStartEvent(CollectionStartEvent):
pass
class SequenceEndEvent(CollectionEndEvent):
pass
class MappingStartEvent(CollectionStartEvent):
pass
class MappingEndEvent(CollectionEndEvent):
pass

63
libs/yaml/loader.py Normal file
View file

@ -0,0 +1,63 @@
__all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader']
from .reader import *
from .scanner import *
from .parser import *
from .composer import *
from .constructor import *
from .resolver import *
class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
BaseConstructor.__init__(self)
BaseResolver.__init__(self)
class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
FullConstructor.__init__(self)
Resolver.__init__(self)
class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
SafeConstructor.__init__(self)
Resolver.__init__(self)
class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
Constructor.__init__(self)
Resolver.__init__(self)
# UnsafeLoader is the same as Loader (which is and was always unsafe on
# untrusted input). Use of either Loader or UnsafeLoader should be rare, since
# FullLoad should be able to load almost all YAML safely. Loader is left intact
# to ensure backwards compatability.
class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
Constructor.__init__(self)
Resolver.__init__(self)

49
libs/yaml/nodes.py Normal file
View file

@ -0,0 +1,49 @@
class Node(object):
def __init__(self, tag, value, start_mark, end_mark):
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
value = self.value
#if isinstance(value, list):
# if len(value) == 0:
# value = '<empty>'
# elif len(value) == 1:
# value = '<1 item>'
# else:
# value = '<%d items>' % len(value)
#else:
# if len(value) > 75:
# value = repr(value[:70]+u' ... ')
# else:
# value = repr(value)
value = repr(value)
return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value)
class ScalarNode(Node):
id = 'scalar'
def __init__(self, tag, value,
start_mark=None, end_mark=None, style=None):
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.style = style
class CollectionNode(Node):
def __init__(self, tag, value,
start_mark=None, end_mark=None, flow_style=None):
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.flow_style = flow_style
class SequenceNode(CollectionNode):
id = 'sequence'
class MappingNode(CollectionNode):
id = 'mapping'

589
libs/yaml/parser.py Normal file
View file

@ -0,0 +1,589 @@
# The following YAML grammar is LL(1) and is parsed by a recursive descent
# parser.
#
# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END
# implicit_document ::= block_node DOCUMENT-END*
# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
# block_node_or_indentless_sequence ::=
# ALIAS
# | properties (block_content | indentless_block_sequence)?
# | block_content
# | indentless_block_sequence
# block_node ::= ALIAS
# | properties block_content?
# | block_content
# flow_node ::= ALIAS
# | properties flow_content?
# | flow_content
# properties ::= TAG ANCHOR? | ANCHOR TAG?
# block_content ::= block_collection | flow_collection | SCALAR
# flow_content ::= flow_collection | SCALAR
# block_collection ::= block_sequence | block_mapping
# flow_collection ::= flow_sequence | flow_mapping
# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END
# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
# block_mapping ::= BLOCK-MAPPING_START
# ((KEY block_node_or_indentless_sequence?)?
# (VALUE block_node_or_indentless_sequence?)?)*
# BLOCK-END
# flow_sequence ::= FLOW-SEQUENCE-START
# (flow_sequence_entry FLOW-ENTRY)*
# flow_sequence_entry?
# FLOW-SEQUENCE-END
# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
# flow_mapping ::= FLOW-MAPPING-START
# (flow_mapping_entry FLOW-ENTRY)*
# flow_mapping_entry?
# FLOW-MAPPING-END
# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
#
# FIRST sets:
#
# stream: { STREAM-START }
# explicit_document: { DIRECTIVE DOCUMENT-START }
# implicit_document: FIRST(block_node)
# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START }
# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START }
# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START }
# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
# block_sequence: { BLOCK-SEQUENCE-START }
# block_mapping: { BLOCK-MAPPING-START }
# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START BLOCK-ENTRY }
# indentless_sequence: { ENTRY }
# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
# flow_sequence: { FLOW-SEQUENCE-START }
# flow_mapping: { FLOW-MAPPING-START }
# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY }
# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY }
__all__ = ['Parser', 'ParserError']
from .error import MarkedYAMLError
from .tokens import *
from .events import *
from .scanner import *
class ParserError(MarkedYAMLError):
pass
class Parser:
# Since writing a recursive-descendant parser is a straightforward task, we
# do not give many comments here.
DEFAULT_TAGS = {
'!': '!',
'!!': 'tag:yaml.org,2002:',
}
def __init__(self):
self.current_event = None
self.yaml_version = None
self.tag_handles = {}
self.states = []
self.marks = []
self.state = self.parse_stream_start
def dispose(self):
# Reset the state attributes (to clear self-references)
self.states = []
self.state = None
def check_event(self, *choices):
# Check the type of the next event.
if self.current_event is None:
if self.state:
self.current_event = self.state()
if self.current_event is not None:
if not choices:
return True
for choice in choices:
if isinstance(self.current_event, choice):
return True
return False
def peek_event(self):
# Get the next event.
if self.current_event is None:
if self.state:
self.current_event = self.state()
return self.current_event
def get_event(self):
# Get the next event and proceed further.
if self.current_event is None:
if self.state:
self.current_event = self.state()
value = self.current_event
self.current_event = None
return value
# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END
# implicit_document ::= block_node DOCUMENT-END*
# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
def parse_stream_start(self):
# Parse the stream start.
token = self.get_token()
event = StreamStartEvent(token.start_mark, token.end_mark,
encoding=token.encoding)
# Prepare the next state.
self.state = self.parse_implicit_document_start
return event
def parse_implicit_document_start(self):
# Parse an implicit document.
if not self.check_token(DirectiveToken, DocumentStartToken,
StreamEndToken):
self.tag_handles = self.DEFAULT_TAGS
token = self.peek_token()
start_mark = end_mark = token.start_mark
event = DocumentStartEvent(start_mark, end_mark,
explicit=False)
# Prepare the next state.
self.states.append(self.parse_document_end)
self.state = self.parse_block_node
return event
else:
return self.parse_document_start()
def parse_document_start(self):
# Parse any extra document end indicators.
while self.check_token(DocumentEndToken):
self.get_token()
# Parse an explicit document.
if not self.check_token(StreamEndToken):
token = self.peek_token()
start_mark = token.start_mark
version, tags = self.process_directives()
if not self.check_token(DocumentStartToken):
raise ParserError(None, None,
"expected '<document start>', but found %r"
% self.peek_token().id,
self.peek_token().start_mark)
token = self.get_token()
end_mark = token.end_mark
event = DocumentStartEvent(start_mark, end_mark,
explicit=True, version=version, tags=tags)
self.states.append(self.parse_document_end)
self.state = self.parse_document_content
else:
# Parse the end of the stream.
token = self.get_token()
event = StreamEndEvent(token.start_mark, token.end_mark)
assert not self.states
assert not self.marks
self.state = None
return event
def parse_document_end(self):
# Parse the document end.
token = self.peek_token()
start_mark = end_mark = token.start_mark
explicit = False
if self.check_token(DocumentEndToken):
token = self.get_token()
end_mark = token.end_mark
explicit = True
event = DocumentEndEvent(start_mark, end_mark,
explicit=explicit)
# Prepare the next state.
self.state = self.parse_document_start
return event
def parse_document_content(self):
if self.check_token(DirectiveToken,
DocumentStartToken, DocumentEndToken, StreamEndToken):
event = self.process_empty_scalar(self.peek_token().start_mark)
self.state = self.states.pop()
return event
else:
return self.parse_block_node()
def process_directives(self):
self.yaml_version = None
self.tag_handles = {}
while self.check_token(DirectiveToken):
token = self.get_token()
if token.name == 'YAML':
if self.yaml_version is not None:
raise ParserError(None, None,
"found duplicate YAML directive", token.start_mark)
major, minor = token.value
if major != 1:
raise ParserError(None, None,
"found incompatible YAML document (version 1.* is required)",
token.start_mark)
self.yaml_version = token.value
elif token.name == 'TAG':
handle, prefix = token.value
if handle in self.tag_handles:
raise ParserError(None, None,
"duplicate tag handle %r" % handle,
token.start_mark)
self.tag_handles[handle] = prefix
if self.tag_handles:
value = self.yaml_version, self.tag_handles.copy()
else:
value = self.yaml_version, None
for key in self.DEFAULT_TAGS:
if key not in self.tag_handles:
self.tag_handles[key] = self.DEFAULT_TAGS[key]
return value
# block_node_or_indentless_sequence ::= ALIAS
# | properties (block_content | indentless_block_sequence)?
# | block_content
# | indentless_block_sequence
# block_node ::= ALIAS
# | properties block_content?
# | block_content
# flow_node ::= ALIAS
# | properties flow_content?
# | flow_content
# properties ::= TAG ANCHOR? | ANCHOR TAG?
# block_content ::= block_collection | flow_collection | SCALAR
# flow_content ::= flow_collection | SCALAR
# block_collection ::= block_sequence | block_mapping
# flow_collection ::= flow_sequence | flow_mapping
def parse_block_node(self):
return self.parse_node(block=True)
def parse_flow_node(self):
return self.parse_node()
def parse_block_node_or_indentless_sequence(self):
return self.parse_node(block=True, indentless_sequence=True)
def parse_node(self, block=False, indentless_sequence=False):
if self.check_token(AliasToken):
token = self.get_token()
event = AliasEvent(token.value, token.start_mark, token.end_mark)
self.state = self.states.pop()
else:
anchor = None
tag = None
start_mark = end_mark = tag_mark = None
if self.check_token(AnchorToken):
token = self.get_token()
start_mark = token.start_mark
end_mark = token.end_mark
anchor = token.value
if self.check_token(TagToken):
token = self.get_token()
tag_mark = token.start_mark
end_mark = token.end_mark
tag = token.value
elif self.check_token(TagToken):
token = self.get_token()
start_mark = tag_mark = token.start_mark
end_mark = token.end_mark
tag = token.value
if self.check_token(AnchorToken):
token = self.get_token()
end_mark = token.end_mark
anchor = token.value
if tag is not None:
handle, suffix = tag
if handle is not None:
if handle not in self.tag_handles:
raise ParserError("while parsing a node", start_mark,
"found undefined tag handle %r" % handle,
tag_mark)
tag = self.tag_handles[handle]+suffix
else:
tag = suffix
#if tag == '!':
# raise ParserError("while parsing a node", start_mark,
# "found non-specific tag '!'", tag_mark,
# "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.")
if start_mark is None:
start_mark = end_mark = self.peek_token().start_mark
event = None
implicit = (tag is None or tag == '!')
if indentless_sequence and self.check_token(BlockEntryToken):
end_mark = self.peek_token().end_mark
event = SequenceStartEvent(anchor, tag, implicit,
start_mark, end_mark)
self.state = self.parse_indentless_sequence_entry
else:
if self.check_token(ScalarToken):
token = self.get_token()
end_mark = token.end_mark
if (token.plain and tag is None) or tag == '!':
implicit = (True, False)
elif tag is None:
implicit = (False, True)
else:
implicit = (False, False)
event = ScalarEvent(anchor, tag, implicit, token.value,
start_mark, end_mark, style=token.style)
self.state = self.states.pop()
elif self.check_token(FlowSequenceStartToken):
end_mark = self.peek_token().end_mark
event = SequenceStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=True)
self.state = self.parse_flow_sequence_first_entry
elif self.check_token(FlowMappingStartToken):
end_mark = self.peek_token().end_mark
event = MappingStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=True)
self.state = self.parse_flow_mapping_first_key
elif block and self.check_token(BlockSequenceStartToken):
end_mark = self.peek_token().start_mark
event = SequenceStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=False)
self.state = self.parse_block_sequence_first_entry
elif block and self.check_token(BlockMappingStartToken):
end_mark = self.peek_token().start_mark
event = MappingStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=False)
self.state = self.parse_block_mapping_first_key
elif anchor is not None or tag is not None:
# Empty scalars are allowed even if a tag or an anchor is
# specified.
event = ScalarEvent(anchor, tag, (implicit, False), '',
start_mark, end_mark)
self.state = self.states.pop()
else:
if block:
node = 'block'
else:
node = 'flow'
token = self.peek_token()
raise ParserError("while parsing a %s node" % node, start_mark,
"expected the node content, but found %r" % token.id,
token.start_mark)
return event
# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END
def parse_block_sequence_first_entry(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_block_sequence_entry()
def parse_block_sequence_entry(self):
if self.check_token(BlockEntryToken):
token = self.get_token()
if not self.check_token(BlockEntryToken, BlockEndToken):
self.states.append(self.parse_block_sequence_entry)
return self.parse_block_node()
else:
self.state = self.parse_block_sequence_entry
return self.process_empty_scalar(token.end_mark)
if not self.check_token(BlockEndToken):
token = self.peek_token()
raise ParserError("while parsing a block collection", self.marks[-1],
"expected <block end>, but found %r" % token.id, token.start_mark)
token = self.get_token()
event = SequenceEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
def parse_indentless_sequence_entry(self):
if self.check_token(BlockEntryToken):
token = self.get_token()
if not self.check_token(BlockEntryToken,
KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_indentless_sequence_entry)
return self.parse_block_node()
else:
self.state = self.parse_indentless_sequence_entry
return self.process_empty_scalar(token.end_mark)
token = self.peek_token()
event = SequenceEndEvent(token.start_mark, token.start_mark)
self.state = self.states.pop()
return event
# block_mapping ::= BLOCK-MAPPING_START
# ((KEY block_node_or_indentless_sequence?)?
# (VALUE block_node_or_indentless_sequence?)?)*
# BLOCK-END
def parse_block_mapping_first_key(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_block_mapping_key()
def parse_block_mapping_key(self):
if self.check_token(KeyToken):
token = self.get_token()
if not self.check_token(KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_block_mapping_value)
return self.parse_block_node_or_indentless_sequence()
else:
self.state = self.parse_block_mapping_value
return self.process_empty_scalar(token.end_mark)
if not self.check_token(BlockEndToken):
token = self.peek_token()
raise ParserError("while parsing a block mapping", self.marks[-1],
"expected <block end>, but found %r" % token.id, token.start_mark)
token = self.get_token()
event = MappingEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_block_mapping_value(self):
if self.check_token(ValueToken):
token = self.get_token()
if not self.check_token(KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_block_mapping_key)
return self.parse_block_node_or_indentless_sequence()
else:
self.state = self.parse_block_mapping_key
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_block_mapping_key
token = self.peek_token()
return self.process_empty_scalar(token.start_mark)
# flow_sequence ::= FLOW-SEQUENCE-START
# (flow_sequence_entry FLOW-ENTRY)*
# flow_sequence_entry?
# FLOW-SEQUENCE-END
# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
#
# Note that while production rules for both flow_sequence_entry and
# flow_mapping_entry are equal, their interpretations are different.
# For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?`
# generate an inline mapping (set syntax).
def parse_flow_sequence_first_entry(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_flow_sequence_entry(first=True)
def parse_flow_sequence_entry(self, first=False):
if not self.check_token(FlowSequenceEndToken):
if not first:
if self.check_token(FlowEntryToken):
self.get_token()
else:
token = self.peek_token()
raise ParserError("while parsing a flow sequence", self.marks[-1],
"expected ',' or ']', but got %r" % token.id, token.start_mark)
if self.check_token(KeyToken):
token = self.peek_token()
event = MappingStartEvent(None, None, True,
token.start_mark, token.end_mark,
flow_style=True)
self.state = self.parse_flow_sequence_entry_mapping_key
return event
elif not self.check_token(FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry)
return self.parse_flow_node()
token = self.get_token()
event = SequenceEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_flow_sequence_entry_mapping_key(self):
token = self.get_token()
if not self.check_token(ValueToken,
FlowEntryToken, FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry_mapping_value)
return self.parse_flow_node()
else:
self.state = self.parse_flow_sequence_entry_mapping_value
return self.process_empty_scalar(token.end_mark)
def parse_flow_sequence_entry_mapping_value(self):
if self.check_token(ValueToken):
token = self.get_token()
if not self.check_token(FlowEntryToken, FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry_mapping_end)
return self.parse_flow_node()
else:
self.state = self.parse_flow_sequence_entry_mapping_end
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_flow_sequence_entry_mapping_end
token = self.peek_token()
return self.process_empty_scalar(token.start_mark)
def parse_flow_sequence_entry_mapping_end(self):
self.state = self.parse_flow_sequence_entry
token = self.peek_token()
return MappingEndEvent(token.start_mark, token.start_mark)
# flow_mapping ::= FLOW-MAPPING-START
# (flow_mapping_entry FLOW-ENTRY)*
# flow_mapping_entry?
# FLOW-MAPPING-END
# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
def parse_flow_mapping_first_key(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_flow_mapping_key(first=True)
def parse_flow_mapping_key(self, first=False):
if not self.check_token(FlowMappingEndToken):
if not first:
if self.check_token(FlowEntryToken):
self.get_token()
else:
token = self.peek_token()
raise ParserError("while parsing a flow mapping", self.marks[-1],
"expected ',' or '}', but got %r" % token.id, token.start_mark)
if self.check_token(KeyToken):
token = self.get_token()
if not self.check_token(ValueToken,
FlowEntryToken, FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_value)
return self.parse_flow_node()
else:
self.state = self.parse_flow_mapping_value
return self.process_empty_scalar(token.end_mark)
elif not self.check_token(FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_empty_value)
return self.parse_flow_node()
token = self.get_token()
event = MappingEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_flow_mapping_value(self):
if self.check_token(ValueToken):
token = self.get_token()
if not self.check_token(FlowEntryToken, FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_key)
return self.parse_flow_node()
else:
self.state = self.parse_flow_mapping_key
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_flow_mapping_key
token = self.peek_token()
return self.process_empty_scalar(token.start_mark)
def parse_flow_mapping_empty_value(self):
self.state = self.parse_flow_mapping_key
return self.process_empty_scalar(self.peek_token().start_mark)
def process_empty_scalar(self, mark):
return ScalarEvent(None, None, (True, False), '', mark, mark)

185
libs/yaml/reader.py Normal file
View file

@ -0,0 +1,185 @@
# This module contains abstractions for the input stream. You don't have to
# looks further, there are no pretty code.
#
# We define two classes here.
#
# Mark(source, line, column)
# It's just a record and its only use is producing nice error messages.
# Parser does not use it for any other purposes.
#
# Reader(source, data)
# Reader determines the encoding of `data` and converts it to unicode.
# Reader provides the following methods and attributes:
# reader.peek(length=1) - return the next `length` characters
# reader.forward(length=1) - move the current position to `length` characters.
# reader.index - the number of the current character.
# reader.line, stream.column - the line and the column of the current character.
__all__ = ['Reader', 'ReaderError']
from .error import YAMLError, Mark
import codecs, re
class ReaderError(YAMLError):
def __init__(self, name, position, character, encoding, reason):
self.name = name
self.character = character
self.position = position
self.encoding = encoding
self.reason = reason
def __str__(self):
if isinstance(self.character, bytes):
return "'%s' codec can't decode byte #x%02x: %s\n" \
" in \"%s\", position %d" \
% (self.encoding, ord(self.character), self.reason,
self.name, self.position)
else:
return "unacceptable character #x%04x: %s\n" \
" in \"%s\", position %d" \
% (self.character, self.reason,
self.name, self.position)
class Reader(object):
# Reader:
# - determines the data encoding and converts it to a unicode string,
# - checks if characters are in allowed range,
# - adds '\0' to the end.
# Reader accepts
# - a `bytes` object,
# - a `str` object,
# - a file-like object with its `read` method returning `str`,
# - a file-like object with its `read` method returning `unicode`.
# Yeah, it's ugly and slow.
def __init__(self, stream):
self.name = None
self.stream = None
self.stream_pointer = 0
self.eof = True
self.buffer = ''
self.pointer = 0
self.raw_buffer = None
self.raw_decode = None
self.encoding = None
self.index = 0
self.line = 0
self.column = 0
if isinstance(stream, str):
self.name = "<unicode string>"
self.check_printable(stream)
self.buffer = stream+'\0'
elif isinstance(stream, bytes):
self.name = "<byte string>"
self.raw_buffer = stream
self.determine_encoding()
else:
self.stream = stream
self.name = getattr(stream, 'name', "<file>")
self.eof = False
self.raw_buffer = None
self.determine_encoding()
def peek(self, index=0):
try:
return self.buffer[self.pointer+index]
except IndexError:
self.update(index+1)
return self.buffer[self.pointer+index]
def prefix(self, length=1):
if self.pointer+length >= len(self.buffer):
self.update(length)
return self.buffer[self.pointer:self.pointer+length]
def forward(self, length=1):
if self.pointer+length+1 >= len(self.buffer):
self.update(length+1)
while length:
ch = self.buffer[self.pointer]
self.pointer += 1
self.index += 1
if ch in '\n\x85\u2028\u2029' \
or (ch == '\r' and self.buffer[self.pointer] != '\n'):
self.line += 1
self.column = 0
elif ch != '\uFEFF':
self.column += 1
length -= 1
def get_mark(self):
if self.stream is None:
return Mark(self.name, self.index, self.line, self.column,
self.buffer, self.pointer)
else:
return Mark(self.name, self.index, self.line, self.column,
None, None)
def determine_encoding(self):
while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2):
self.update_raw()
if isinstance(self.raw_buffer, bytes):
if self.raw_buffer.startswith(codecs.BOM_UTF16_LE):
self.raw_decode = codecs.utf_16_le_decode
self.encoding = 'utf-16-le'
elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE):
self.raw_decode = codecs.utf_16_be_decode
self.encoding = 'utf-16-be'
else:
self.raw_decode = codecs.utf_8_decode
self.encoding = 'utf-8'
self.update(1)
NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]')
def check_printable(self, data):
match = self.NON_PRINTABLE.search(data)
if match:
character = match.group()
position = self.index+(len(self.buffer)-self.pointer)+match.start()
raise ReaderError(self.name, position, ord(character),
'unicode', "special characters are not allowed")
def update(self, length):
if self.raw_buffer is None:
return
self.buffer = self.buffer[self.pointer:]
self.pointer = 0
while len(self.buffer) < length:
if not self.eof:
self.update_raw()
if self.raw_decode is not None:
try:
data, converted = self.raw_decode(self.raw_buffer,
'strict', self.eof)
except UnicodeDecodeError as exc:
character = self.raw_buffer[exc.start]
if self.stream is not None:
position = self.stream_pointer-len(self.raw_buffer)+exc.start
else:
position = exc.start
raise ReaderError(self.name, position, character,
exc.encoding, exc.reason)
else:
data = self.raw_buffer
converted = len(data)
self.check_printable(data)
self.buffer += data
self.raw_buffer = self.raw_buffer[converted:]
if self.eof:
self.buffer += '\0'
self.raw_buffer = None
break
def update_raw(self, size=4096):
data = self.stream.read(size)
if self.raw_buffer is None:
self.raw_buffer = data
else:
self.raw_buffer += data
self.stream_pointer += len(data)
if not data:
self.eof = True

389
libs/yaml/representer.py Normal file
View file

@ -0,0 +1,389 @@
__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer',
'RepresenterError']
from .error import *
from .nodes import *
import datetime, sys, copyreg, types, base64, collections
class RepresenterError(YAMLError):
pass
class BaseRepresenter:
yaml_representers = {}
yaml_multi_representers = {}
def __init__(self, default_style=None, default_flow_style=False, sort_keys=True):
self.default_style = default_style
self.sort_keys = sort_keys
self.default_flow_style = default_flow_style
self.represented_objects = {}
self.object_keeper = []
self.alias_key = None
def represent(self, data):
node = self.represent_data(data)
self.serialize(node)
self.represented_objects = {}
self.object_keeper = []
self.alias_key = None
def represent_data(self, data):
if self.ignore_aliases(data):
self.alias_key = None
else:
self.alias_key = id(data)
if self.alias_key is not None:
if self.alias_key in self.represented_objects:
node = self.represented_objects[self.alias_key]
#if node is None:
# raise RepresenterError("recursive objects are not allowed: %r" % data)
return node
#self.represented_objects[alias_key] = None
self.object_keeper.append(data)
data_types = type(data).__mro__
if data_types[0] in self.yaml_representers:
node = self.yaml_representers[data_types[0]](self, data)
else:
for data_type in data_types:
if data_type in self.yaml_multi_representers:
node = self.yaml_multi_representers[data_type](self, data)
break
else:
if None in self.yaml_multi_representers:
node = self.yaml_multi_representers[None](self, data)
elif None in self.yaml_representers:
node = self.yaml_representers[None](self, data)
else:
node = ScalarNode(None, str(data))
#if alias_key is not None:
# self.represented_objects[alias_key] = node
return node
@classmethod
def add_representer(cls, data_type, representer):
if not 'yaml_representers' in cls.__dict__:
cls.yaml_representers = cls.yaml_representers.copy()
cls.yaml_representers[data_type] = representer
@classmethod
def add_multi_representer(cls, data_type, representer):
if not 'yaml_multi_representers' in cls.__dict__:
cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
cls.yaml_multi_representers[data_type] = representer
def represent_scalar(self, tag, value, style=None):
if style is None:
style = self.default_style
node = ScalarNode(tag, value, style=style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
return node
def represent_sequence(self, tag, sequence, flow_style=None):
value = []
node = SequenceNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
for item in sequence:
node_item = self.represent_data(item)
if not (isinstance(node_item, ScalarNode) and not node_item.style):
best_style = False
value.append(node_item)
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
node = MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, 'items'):
mapping = list(mapping.items())
if self.sort_keys:
try:
mapping = sorted(mapping)
except TypeError:
pass
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def ignore_aliases(self, data):
return False
class SafeRepresenter(BaseRepresenter):
def ignore_aliases(self, data):
if data is None:
return True
if isinstance(data, tuple) and data == ():
return True
if isinstance(data, (str, bytes, bool, int, float)):
return True
def represent_none(self, data):
return self.represent_scalar('tag:yaml.org,2002:null', 'null')
def represent_str(self, data):
return self.represent_scalar('tag:yaml.org,2002:str', data)
def represent_binary(self, data):
if hasattr(base64, 'encodebytes'):
data = base64.encodebytes(data).decode('ascii')
else:
data = base64.encodestring(data).decode('ascii')
return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|')
def represent_bool(self, data):
if data:
value = 'true'
else:
value = 'false'
return self.represent_scalar('tag:yaml.org,2002:bool', value)
def represent_int(self, data):
return self.represent_scalar('tag:yaml.org,2002:int', str(data))
inf_value = 1e300
while repr(inf_value) != repr(inf_value*inf_value):
inf_value *= inf_value
def represent_float(self, data):
if data != data or (data == 0.0 and data == 1.0):
value = '.nan'
elif data == self.inf_value:
value = '.inf'
elif data == -self.inf_value:
value = '-.inf'
else:
value = repr(data).lower()
# Note that in some cases `repr(data)` represents a float number
# without the decimal parts. For instance:
# >>> repr(1e17)
# '1e17'
# Unfortunately, this is not a valid float representation according
# to the definition of the `!!float` tag. We fix this by adding
# '.0' before the 'e' symbol.
if '.' not in value and 'e' in value:
value = value.replace('e', '.0e', 1)
return self.represent_scalar('tag:yaml.org,2002:float', value)
def represent_list(self, data):
#pairs = (len(data) > 0 and isinstance(data, list))
#if pairs:
# for item in data:
# if not isinstance(item, tuple) or len(item) != 2:
# pairs = False
# break
#if not pairs:
return self.represent_sequence('tag:yaml.org,2002:seq', data)
#value = []
#for item_key, item_value in data:
# value.append(self.represent_mapping(u'tag:yaml.org,2002:map',
# [(item_key, item_value)]))
#return SequenceNode(u'tag:yaml.org,2002:pairs', value)
def represent_dict(self, data):
return self.represent_mapping('tag:yaml.org,2002:map', data)
def represent_set(self, data):
value = {}
for key in data:
value[key] = None
return self.represent_mapping('tag:yaml.org,2002:set', value)
def represent_date(self, data):
value = data.isoformat()
return self.represent_scalar('tag:yaml.org,2002:timestamp', value)
def represent_datetime(self, data):
value = data.isoformat(' ')
return self.represent_scalar('tag:yaml.org,2002:timestamp', value)
def represent_yaml_object(self, tag, data, cls, flow_style=None):
if hasattr(data, '__getstate__'):
state = data.__getstate__()
else:
state = data.__dict__.copy()
return self.represent_mapping(tag, state, flow_style=flow_style)
def represent_undefined(self, data):
raise RepresenterError("cannot represent an object", data)
SafeRepresenter.add_representer(type(None),
SafeRepresenter.represent_none)
SafeRepresenter.add_representer(str,
SafeRepresenter.represent_str)
SafeRepresenter.add_representer(bytes,
SafeRepresenter.represent_binary)
SafeRepresenter.add_representer(bool,
SafeRepresenter.represent_bool)
SafeRepresenter.add_representer(int,
SafeRepresenter.represent_int)
SafeRepresenter.add_representer(float,
SafeRepresenter.represent_float)
SafeRepresenter.add_representer(list,
SafeRepresenter.represent_list)
SafeRepresenter.add_representer(tuple,
SafeRepresenter.represent_list)
SafeRepresenter.add_representer(dict,
SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(set,
SafeRepresenter.represent_set)
SafeRepresenter.add_representer(datetime.date,
SafeRepresenter.represent_date)
SafeRepresenter.add_representer(datetime.datetime,
SafeRepresenter.represent_datetime)
SafeRepresenter.add_representer(None,
SafeRepresenter.represent_undefined)
class Representer(SafeRepresenter):
def represent_complex(self, data):
if data.imag == 0.0:
data = '%r' % data.real
elif data.real == 0.0:
data = '%rj' % data.imag
elif data.imag > 0:
data = '%r+%rj' % (data.real, data.imag)
else:
data = '%r%rj' % (data.real, data.imag)
return self.represent_scalar('tag:yaml.org,2002:python/complex', data)
def represent_tuple(self, data):
return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
def represent_name(self, data):
name = '%s.%s' % (data.__module__, data.__name__)
return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '')
def represent_module(self, data):
return self.represent_scalar(
'tag:yaml.org,2002:python/module:'+data.__name__, '')
def represent_object(self, data):
# We use __reduce__ API to save the data. data.__reduce__ returns
# a tuple of length 2-5:
# (function, args, state, listitems, dictitems)
# For reconstructing, we calls function(*args), then set its state,
# listitems, and dictitems if they are not None.
# A special case is when function.__name__ == '__newobj__'. In this
# case we create the object with args[0].__new__(*args).
# Another special case is when __reduce__ returns a string - we don't
# support it.
# We produce a !!python/object, !!python/object/new or
# !!python/object/apply node.
cls = type(data)
if cls in copyreg.dispatch_table:
reduce = copyreg.dispatch_table[cls](data)
elif hasattr(data, '__reduce_ex__'):
reduce = data.__reduce_ex__(2)
elif hasattr(data, '__reduce__'):
reduce = data.__reduce__()
else:
raise RepresenterError("cannot represent an object", data)
reduce = (list(reduce)+[None]*5)[:5]
function, args, state, listitems, dictitems = reduce
args = list(args)
if state is None:
state = {}
if listitems is not None:
listitems = list(listitems)
if dictitems is not None:
dictitems = dict(dictitems)
if function.__name__ == '__newobj__':
function = args[0]
args = args[1:]
tag = 'tag:yaml.org,2002:python/object/new:'
newobj = True
else:
tag = 'tag:yaml.org,2002:python/object/apply:'
newobj = False
function_name = '%s.%s' % (function.__module__, function.__name__)
if not args and not listitems and not dictitems \
and isinstance(state, dict) and newobj:
return self.represent_mapping(
'tag:yaml.org,2002:python/object:'+function_name, state)
if not listitems and not dictitems \
and isinstance(state, dict) and not state:
return self.represent_sequence(tag+function_name, args)
value = {}
if args:
value['args'] = args
if state or not isinstance(state, dict):
value['state'] = state
if listitems:
value['listitems'] = listitems
if dictitems:
value['dictitems'] = dictitems
return self.represent_mapping(tag+function_name, value)
def represent_ordered_dict(self, data):
# Provide uniform representation across different Python versions.
data_type = type(data)
tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \
% (data_type.__module__, data_type.__name__)
items = [[key, value] for key, value in data.items()]
return self.represent_sequence(tag, [items])
Representer.add_representer(complex,
Representer.represent_complex)
Representer.add_representer(tuple,
Representer.represent_tuple)
Representer.add_representer(type,
Representer.represent_name)
Representer.add_representer(collections.OrderedDict,
Representer.represent_ordered_dict)
Representer.add_representer(types.FunctionType,
Representer.represent_name)
Representer.add_representer(types.BuiltinFunctionType,
Representer.represent_name)
Representer.add_representer(types.ModuleType,
Representer.represent_module)
Representer.add_multi_representer(object,
Representer.represent_object)

227
libs/yaml/resolver.py Normal file
View file

@ -0,0 +1,227 @@
__all__ = ['BaseResolver', 'Resolver']
from .error import *
from .nodes import *
import re
class ResolverError(YAMLError):
pass
class BaseResolver:
DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str'
DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq'
DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map'
yaml_implicit_resolvers = {}
yaml_path_resolvers = {}
def __init__(self):
self.resolver_exact_paths = []
self.resolver_prefix_paths = []
@classmethod
def add_implicit_resolver(cls, tag, regexp, first):
if not 'yaml_implicit_resolvers' in cls.__dict__:
implicit_resolvers = {}
for key in cls.yaml_implicit_resolvers:
implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:]
cls.yaml_implicit_resolvers = implicit_resolvers
if first is None:
first = [None]
for ch in first:
cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
@classmethod
def add_path_resolver(cls, tag, path, kind=None):
# Note: `add_path_resolver` is experimental. The API could be changed.
# `new_path` is a pattern that is matched against the path from the
# root to the node that is being considered. `node_path` elements are
# tuples `(node_check, index_check)`. `node_check` is a node class:
# `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None`
# matches any kind of a node. `index_check` could be `None`, a boolean
# value, a string value, or a number. `None` and `False` match against
# any _value_ of sequence and mapping nodes. `True` matches against
# any _key_ of a mapping node. A string `index_check` matches against
# a mapping value that corresponds to a scalar key which content is
# equal to the `index_check` value. An integer `index_check` matches
# against a sequence value with the index equal to `index_check`.
if not 'yaml_path_resolvers' in cls.__dict__:
cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy()
new_path = []
for element in path:
if isinstance(element, (list, tuple)):
if len(element) == 2:
node_check, index_check = element
elif len(element) == 1:
node_check = element[0]
index_check = True
else:
raise ResolverError("Invalid path element: %s" % element)
else:
node_check = None
index_check = element
if node_check is str:
node_check = ScalarNode
elif node_check is list:
node_check = SequenceNode
elif node_check is dict:
node_check = MappingNode
elif node_check not in [ScalarNode, SequenceNode, MappingNode] \
and not isinstance(node_check, str) \
and node_check is not None:
raise ResolverError("Invalid node checker: %s" % node_check)
if not isinstance(index_check, (str, int)) \
and index_check is not None:
raise ResolverError("Invalid index checker: %s" % index_check)
new_path.append((node_check, index_check))
if kind is str:
kind = ScalarNode
elif kind is list:
kind = SequenceNode
elif kind is dict:
kind = MappingNode
elif kind not in [ScalarNode, SequenceNode, MappingNode] \
and kind is not None:
raise ResolverError("Invalid node kind: %s" % kind)
cls.yaml_path_resolvers[tuple(new_path), kind] = tag
def descend_resolver(self, current_node, current_index):
if not self.yaml_path_resolvers:
return
exact_paths = {}
prefix_paths = []
if current_node:
depth = len(self.resolver_prefix_paths)
for path, kind in self.resolver_prefix_paths[-1]:
if self.check_resolver_prefix(depth, path, kind,
current_node, current_index):
if len(path) > depth:
prefix_paths.append((path, kind))
else:
exact_paths[kind] = self.yaml_path_resolvers[path, kind]
else:
for path, kind in self.yaml_path_resolvers:
if not path:
exact_paths[kind] = self.yaml_path_resolvers[path, kind]
else:
prefix_paths.append((path, kind))
self.resolver_exact_paths.append(exact_paths)
self.resolver_prefix_paths.append(prefix_paths)
def ascend_resolver(self):
if not self.yaml_path_resolvers:
return
self.resolver_exact_paths.pop()
self.resolver_prefix_paths.pop()
def check_resolver_prefix(self, depth, path, kind,
current_node, current_index):
node_check, index_check = path[depth-1]
if isinstance(node_check, str):
if current_node.tag != node_check:
return
elif node_check is not None:
if not isinstance(current_node, node_check):
return
if index_check is True and current_index is not None:
return
if (index_check is False or index_check is None) \
and current_index is None:
return
if isinstance(index_check, str):
if not (isinstance(current_index, ScalarNode)
and index_check == current_index.value):
return
elif isinstance(index_check, int) and not isinstance(index_check, bool):
if index_check != current_index:
return
return True
def resolve(self, kind, value, implicit):
if kind is ScalarNode and implicit[0]:
if value == '':
resolvers = self.yaml_implicit_resolvers.get('', [])
else:
resolvers = self.yaml_implicit_resolvers.get(value[0], [])
resolvers += self.yaml_implicit_resolvers.get(None, [])
for tag, regexp in resolvers:
if regexp.match(value):
return tag
implicit = implicit[1]
if self.yaml_path_resolvers:
exact_paths = self.resolver_exact_paths[-1]
if kind in exact_paths:
return exact_paths[kind]
if None in exact_paths:
return exact_paths[None]
if kind is ScalarNode:
return self.DEFAULT_SCALAR_TAG
elif kind is SequenceNode:
return self.DEFAULT_SEQUENCE_TAG
elif kind is MappingNode:
return self.DEFAULT_MAPPING_TAG
class Resolver(BaseResolver):
pass
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:bool',
re.compile(r'''^(?:yes|Yes|YES|no|No|NO
|true|True|TRUE|false|False|FALSE
|on|On|ON|off|Off|OFF)$''', re.X),
list('yYnNtTfFoO'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)?
|\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]*
|[-+]?\.(?:inf|Inf|INF)
|\.(?:nan|NaN|NAN))$''', re.X),
list('-+0123456789.'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:int',
re.compile(r'''^(?:[-+]?0b[0-1_]+
|[-+]?0[0-7_]+
|[-+]?(?:0|[1-9][0-9_]*)
|[-+]?0x[0-9a-fA-F_]+
|[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X),
list('-+0123456789'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:merge',
re.compile(r'^(?:<<)$'),
['<'])
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:null',
re.compile(r'''^(?: ~
|null|Null|NULL
| )$''', re.X),
['~', 'n', 'N', ''])
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:timestamp',
re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]
|[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]?
(?:[Tt]|[ \t]+)[0-9][0-9]?
:[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)?
(?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X),
list('0123456789'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:value',
re.compile(r'^(?:=)$'),
['='])
# The following resolver is only for documentation purposes. It cannot work
# because plain scalars cannot start with '!', '&', or '*'.
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:yaml',
re.compile(r'^(?:!|&|\*)$'),
list('!&*'))

1435
libs/yaml/scanner.py Normal file

File diff suppressed because it is too large Load diff

111
libs/yaml/serializer.py Normal file
View file

@ -0,0 +1,111 @@
__all__ = ['Serializer', 'SerializerError']
from .error import YAMLError
from .events import *
from .nodes import *
class SerializerError(YAMLError):
pass
class Serializer:
ANCHOR_TEMPLATE = 'id%03d'
def __init__(self, encoding=None,
explicit_start=None, explicit_end=None, version=None, tags=None):
self.use_encoding = encoding
self.use_explicit_start = explicit_start
self.use_explicit_end = explicit_end
self.use_version = version
self.use_tags = tags
self.serialized_nodes = {}
self.anchors = {}
self.last_anchor_id = 0
self.closed = None
def open(self):
if self.closed is None:
self.emit(StreamStartEvent(encoding=self.use_encoding))
self.closed = False
elif self.closed:
raise SerializerError("serializer is closed")
else:
raise SerializerError("serializer is already opened")
def close(self):
if self.closed is None:
raise SerializerError("serializer is not opened")
elif not self.closed:
self.emit(StreamEndEvent())
self.closed = True
#def __del__(self):
# self.close()
def serialize(self, node):
if self.closed is None:
raise SerializerError("serializer is not opened")
elif self.closed:
raise SerializerError("serializer is closed")
self.emit(DocumentStartEvent(explicit=self.use_explicit_start,
version=self.use_version, tags=self.use_tags))
self.anchor_node(node)
self.serialize_node(node, None, None)
self.emit(DocumentEndEvent(explicit=self.use_explicit_end))
self.serialized_nodes = {}
self.anchors = {}
self.last_anchor_id = 0
def anchor_node(self, node):
if node in self.anchors:
if self.anchors[node] is None:
self.anchors[node] = self.generate_anchor(node)
else:
self.anchors[node] = None
if isinstance(node, SequenceNode):
for item in node.value:
self.anchor_node(item)
elif isinstance(node, MappingNode):
for key, value in node.value:
self.anchor_node(key)
self.anchor_node(value)
def generate_anchor(self, node):
self.last_anchor_id += 1
return self.ANCHOR_TEMPLATE % self.last_anchor_id
def serialize_node(self, node, parent, index):
alias = self.anchors[node]
if node in self.serialized_nodes:
self.emit(AliasEvent(alias))
else:
self.serialized_nodes[node] = True
self.descend_resolver(parent, index)
if isinstance(node, ScalarNode):
detected_tag = self.resolve(ScalarNode, node.value, (True, False))
default_tag = self.resolve(ScalarNode, node.value, (False, True))
implicit = (node.tag == detected_tag), (node.tag == default_tag)
self.emit(ScalarEvent(alias, node.tag, implicit, node.value,
style=node.style))
elif isinstance(node, SequenceNode):
implicit = (node.tag
== self.resolve(SequenceNode, node.value, True))
self.emit(SequenceStartEvent(alias, node.tag, implicit,
flow_style=node.flow_style))
index = 0
for item in node.value:
self.serialize_node(item, node, index)
index += 1
self.emit(SequenceEndEvent())
elif isinstance(node, MappingNode):
implicit = (node.tag
== self.resolve(MappingNode, node.value, True))
self.emit(MappingStartEvent(alias, node.tag, implicit,
flow_style=node.flow_style))
for key, value in node.value:
self.serialize_node(key, node, None)
self.serialize_node(value, node, key)
self.emit(MappingEndEvent())
self.ascend_resolver()

104
libs/yaml/tokens.py Normal file
View file

@ -0,0 +1,104 @@
class Token(object):
def __init__(self, start_mark, end_mark):
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
attributes = [key for key in self.__dict__
if not key.endswith('_mark')]
attributes.sort()
arguments = ', '.join(['%s=%r' % (key, getattr(self, key))
for key in attributes])
return '%s(%s)' % (self.__class__.__name__, arguments)
#class BOMToken(Token):
# id = '<byte order mark>'
class DirectiveToken(Token):
id = '<directive>'
def __init__(self, name, value, start_mark, end_mark):
self.name = name
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class DocumentStartToken(Token):
id = '<document start>'
class DocumentEndToken(Token):
id = '<document end>'
class StreamStartToken(Token):
id = '<stream start>'
def __init__(self, start_mark=None, end_mark=None,
encoding=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.encoding = encoding
class StreamEndToken(Token):
id = '<stream end>'
class BlockSequenceStartToken(Token):
id = '<block sequence start>'
class BlockMappingStartToken(Token):
id = '<block mapping start>'
class BlockEndToken(Token):
id = '<block end>'
class FlowSequenceStartToken(Token):
id = '['
class FlowMappingStartToken(Token):
id = '{'
class FlowSequenceEndToken(Token):
id = ']'
class FlowMappingEndToken(Token):
id = '}'
class KeyToken(Token):
id = '?'
class ValueToken(Token):
id = ':'
class BlockEntryToken(Token):
id = '-'
class FlowEntryToken(Token):
id = ','
class AliasToken(Token):
id = '<alias>'
def __init__(self, value, start_mark, end_mark):
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class AnchorToken(Token):
id = '<anchor>'
def __init__(self, value, start_mark, end_mark):
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class TagToken(Token):
id = '<tag>'
def __init__(self, value, start_mark, end_mark):
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class ScalarToken(Token):
id = '<scalar>'
def __init__(self, value, plain, start_mark, end_mark, style=None):
self.value = value
self.plain = plain
self.start_mark = start_mark
self.end_mark = end_mark
self.style = style