bazarr/libs/flask_socketio/test_client.py
Louis Vézina 83c95cc77d WIP
2020-01-29 20:07:26 -05:00

205 lines
9.5 KiB
Python

import uuid
from socketio import packet
from socketio.pubsub_manager import PubSubManager
from werkzeug.test import EnvironBuilder
class SocketIOTestClient(object):
"""
This class is useful for testing a Flask-SocketIO server. It works in a
similar way to the Flask Test Client, but adapted to the Socket.IO server.
:param app: The Flask application instance.
:param socketio: The application's ``SocketIO`` instance.
:param namespace: The namespace for the client. If not provided, the client
connects to the server on the global namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
:param flask_test_client: The instance of the Flask test client
currently in use. Passing the Flask test
client is optional, but is necessary if you
want the Flask user session and any other
cookies set in HTTP routes accessible from
Socket.IO events.
"""
queue = {}
acks = {}
def __init__(self, app, socketio, namespace=None, query_string=None,
headers=None, flask_test_client=None):
def _mock_send_packet(sid, pkt):
if pkt.packet_type == packet.EVENT or \
pkt.packet_type == packet.BINARY_EVENT:
if sid not in self.queue:
self.queue[sid] = []
if pkt.data[0] == 'message' or pkt.data[0] == 'json':
self.queue[sid].append({'name': pkt.data[0],
'args': pkt.data[1],
'namespace': pkt.namespace or '/'})
else:
self.queue[sid].append({'name': pkt.data[0],
'args': pkt.data[1:],
'namespace': pkt.namespace or '/'})
elif pkt.packet_type == packet.ACK or \
pkt.packet_type == packet.BINARY_ACK:
self.acks[sid] = {'args': pkt.data,
'namespace': pkt.namespace or '/'}
elif pkt.packet_type == packet.DISCONNECT:
self.connected[pkt.namespace or '/'] = False
self.app = app
self.flask_test_client = flask_test_client
self.sid = uuid.uuid4().hex
self.queue[self.sid] = []
self.acks[self.sid] = None
self.callback_counter = 0
self.socketio = socketio
self.connected = {}
socketio.server._send_packet = _mock_send_packet
socketio.server.environ[self.sid] = {}
socketio.server.async_handlers = False # easier to test when
socketio.server.eio.async_handlers = False # events are sync
if isinstance(socketio.server.manager, PubSubManager):
raise RuntimeError('Test client cannot be used with a message '
'queue. Disable the queue on your test '
'configuration.')
socketio.server.manager.initialize()
self.connect(namespace=namespace, query_string=query_string,
headers=headers)
def is_connected(self, namespace=None):
"""Check if a namespace is connected.
:param namespace: The namespace to check. The global namespace is
assumed if this argument is not provided.
"""
return self.connected.get(namespace or '/', False)
def connect(self, namespace=None, query_string=None, headers=None):
"""Connect the client.
:param namespace: The namespace for the client. If not provided, the
client connects to the server on the global
namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
Note that it is usually not necessary to explicitly call this method,
since a connection is automatically established when an instance of
this class is created. An example where it this method would be useful
is when the application accepts multiple namespace connections.
"""
url = '/socket.io'
if query_string:
if query_string[0] != '?':
query_string = '?' + query_string
url += query_string
environ = EnvironBuilder(url, headers=headers).get_environ()
environ['flask.app'] = self.app
if self.flask_test_client:
# inject cookies from Flask
self.flask_test_client.cookie_jar.inject_wsgi(environ)
self.connected['/'] = True
if self.socketio.server._handle_eio_connect(
self.sid, environ) is False:
del self.connected['/']
if namespace is not None and namespace != '/':
self.connected[namespace] = True
pkt = packet.Packet(packet.CONNECT, namespace=namespace)
with self.app.app_context():
if self.socketio.server._handle_eio_message(
self.sid, pkt.encode()) is False:
del self.connected[namespace]
def disconnect(self, namespace=None):
"""Disconnect the client.
:param namespace: The namespace to disconnect. The global namespace is
assumed if this argument is not provided.
"""
if not self.is_connected(namespace):
raise RuntimeError('not connected')
pkt = packet.Packet(packet.DISCONNECT, namespace=namespace)
with self.app.app_context():
self.socketio.server._handle_eio_message(self.sid, pkt.encode())
del self.connected[namespace or '/']
def emit(self, event, *args, **kwargs):
"""Emit an event to the server.
:param event: The event name.
:param *args: The event arguments.
:param callback: ``True`` if the client requests a callback, ``False``
if not. Note that client-side callbacks are not
implemented, a callback request will just tell the
server to provide the arguments to invoke the
callback, but no callback is invoked. Instead, the
arguments that the server provided for the callback
are returned by this function.
:param namespace: The namespace of the event. The global namespace is
assumed if this argument is not provided.
"""
namespace = kwargs.pop('namespace', None)
if not self.is_connected(namespace):
raise RuntimeError('not connected')
callback = kwargs.pop('callback', False)
id = None
if callback:
self.callback_counter += 1
id = self.callback_counter
pkt = packet.Packet(packet.EVENT, data=[event] + list(args),
namespace=namespace, id=id)
with self.app.app_context():
encoded_pkt = pkt.encode()
if isinstance(encoded_pkt, list):
for epkt in encoded_pkt:
self.socketio.server._handle_eio_message(self.sid, epkt)
else:
self.socketio.server._handle_eio_message(self.sid, encoded_pkt)
ack = self.acks.pop(self.sid, None)
if ack is not None:
return ack['args'][0] if len(ack['args']) == 1 \
else ack['args']
def send(self, data, json=False, callback=False, namespace=None):
"""Send a text or JSON message to the server.
:param data: A string, dictionary or list to send to the server.
:param json: ``True`` to send a JSON message, ``False`` to send a text
message.
:param callback: ``True`` if the client requests a callback, ``False``
if not. Note that client-side callbacks are not
implemented, a callback request will just tell the
server to provide the arguments to invoke the
callback, but no callback is invoked. Instead, the
arguments that the server provided for the callback
are returned by this function.
:param namespace: The namespace of the event. The global namespace is
assumed if this argument is not provided.
"""
if json:
msg = 'json'
else:
msg = 'message'
return self.emit(msg, data, callback=callback, namespace=namespace)
def get_received(self, namespace=None):
"""Return the list of messages received from the server.
Since this is not a real client, any time the server emits an event,
the event is simply stored. The test code can invoke this method to
obtain the list of events that were received since the last call.
:param namespace: The namespace to get events from. The global
namespace is assumed if this argument is not
provided.
"""
if not self.is_connected(namespace):
raise RuntimeError('not connected')
namespace = namespace or '/'
r = [pkt for pkt in self.queue[self.sid]
if pkt['namespace'] == namespace]
self.queue[self.sid] = [pkt for pkt in self.queue[self.sid]
if pkt not in r]
return r