diff --git a/channels/testing/live.py b/channels/testing/live.py index aa1a7880..3e91092e 100644 --- a/channels/testing/live.py +++ b/channels/testing/live.py @@ -1,4 +1,5 @@ from functools import partial +import multiprocessing from daphne.testing import DaphneProcess from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler @@ -11,9 +12,15 @@ from channels.routing import get_default_application -def make_application(*, static_wrapper): +# Global queue for commands from test process to server process +_server_command_queue = None + + +def make_application(*, static_wrapper, commands={}): # Module-level function for pickle-ability application = get_default_application() + # Wrap the application with our command processing middleware + application = ServerCommandMiddleware(application, commands) if static_wrapper is not None: application = static_wrapper(application) return application @@ -28,6 +35,34 @@ def set_database_connection(): settings.DATABASES["default"]["NAME"] = test_db_name +class ServerCommandMiddleware: + """ + Middleware that processes commands from the test process. + This is automatically added to the ASGI application in test mode. + """ + def __init__(self, app, commands): + self.app = app + self.commands = commands + + async def __call__(self, scope, receive, send): + # Process any pending server commands before handling the request + self.process_server_commands() + return await self.app(scope, receive, send) + + def process_server_commands(self): + global _server_command_queue + if _server_command_queue is None: + return + + while not _server_command_queue.empty(): + try: + command = _server_command_queue.get_nowait() + if command in self.commands: + self.commands[command]() + except: + break + + class ChannelsLiveServerTestCase(TransactionTestCase): """ Does basically the same as TransactionTestCase but also launches a @@ -40,6 +75,7 @@ class ChannelsLiveServerTestCase(TransactionTestCase): ProtocolServerProcess = DaphneProcess static_wrapper = ASGIStaticFilesHandler serve_static = True + commands = {} @property def live_server_url(self): @@ -51,6 +87,8 @@ def live_server_ws_url(self): @classmethod def setUpClass(cls): + global _server_command_queue + for connection in connections.all(): if cls._is_in_memory_db(connection): raise ImproperlyConfigured( @@ -64,9 +102,14 @@ def setUpClass(cls): ) cls._live_server_modified_settings.enable() + # Create a command queue for communication with the server process + _server_command_queue = multiprocessing.Queue() + cls._server_command_queue = _server_command_queue + get_application = partial( make_application, static_wrapper=cls.static_wrapper if cls.serve_static else None, + commands=cls.commands, ) cls._server_process = cls.ProtocolServerProcess( cls.host, @@ -89,6 +132,13 @@ def tearDownClass(cls): cls._live_server_modified_settings.disable() super().tearDownClass() + def run_server_command(self, command): + """ + Add command to server command queue. + """ + if hasattr(self.__class__, '_server_command_queue'): + self._server_command_queue.put(command) + @classmethod def _is_in_memory_db(cls, connection): """