From e8fd670087ab7c87f3fc25710351c24a3f93607f Mon Sep 17 00:00:00 2001 From: Paul Martin Date: Fri, 14 Oct 2022 16:09:35 +0100 Subject: [PATCH 1/2] Run tasks consumer tasks concurrently --- channels/utils.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/channels/utils.py b/channels/utils.py index 72cd9ca30..df8058c1a 100644 --- a/channels/utils.py +++ b/channels/utils.py @@ -33,12 +33,25 @@ async def await_many_dispatch(consumer_callables, dispatch): """ Given a set of consumer callables, awaits on them all and passes results from them to the dispatch awaitable as they come in. + If a dispatch awaitable raises an exception, + this coroutine will fail with that exception. """ # Call all callables, and ensure all return types are Futures tasks = [ asyncio.ensure_future(consumer_callable()) for consumer_callable in consumer_callables ] + + dispatch_tasks = [] + fut = asyncio.Future() # For tasks to report an exception + tasks.append(fut) + + def on_dispatch_task_complete(task): + dispatch_tasks.remove(task) + exc = task.exception() + if exc: + fut.set_exception(exc) + try: while True: # Wait for any of them to complete @@ -46,9 +59,16 @@ async def await_many_dispatch(consumer_callables, dispatch): # Find the completed one(s), yield results, and replace them for i, task in enumerate(tasks): if task.done(): - result = task.result() - await dispatch(result) - tasks[i] = asyncio.ensure_future(consumer_callables[i]()) + if task == fut: + exc = fut.exception() + if exc: + raise exc + else: + result = task.result() + task = asyncio.create_task(dispatch(result)) + dispatch_tasks.append(task) + task.add_done_callback(on_dispatch_task_complete) + tasks[i] = asyncio.ensure_future(consumer_callables[i]()) finally: # Make sure we clean up tasks on exit for task in tasks: @@ -57,3 +77,15 @@ async def await_many_dispatch(consumer_callables, dispatch): await task except asyncio.CancelledError: pass + if dispatch_tasks: + """ + This may be needed if the consumer task running this coroutine + is cancelled and one of the subtasks raises an exception after cancellation. + """ + done, pending = await asyncio.wait(dispatch_tasks) + for task in done: + exc = task.exception() + if exc: + raise exc + if not fut.done(): + fut.set_result(None) From d65a5d3a37caf495ae5133b05bcc71ebe0eca32d Mon Sep 17 00:00:00 2001 From: Paul Martin Date: Fri, 14 Oct 2022 19:13:41 +0100 Subject: [PATCH 2/2] Don't raise child task exception on CancelledError --- channels/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/channels/utils.py b/channels/utils.py index df8058c1a..dcd8e15bc 100644 --- a/channels/utils.py +++ b/channels/utils.py @@ -43,13 +43,13 @@ async def await_many_dispatch(consumer_callables, dispatch): ] dispatch_tasks = [] - fut = asyncio.Future() # For tasks to report an exception + fut = asyncio.Future() # For child task to report an exception tasks.append(fut) def on_dispatch_task_complete(task): dispatch_tasks.remove(task) exc = task.exception() - if exc: + if exc and not isinstance(exc, asyncio.CancelledError) and not fut.done(): fut.set_exception(exc) try: @@ -60,7 +60,7 @@ def on_dispatch_task_complete(task): for i, task in enumerate(tasks): if task.done(): if task == fut: - exc = fut.exception() + exc = fut.exception() # Child task has reported an exception if exc: raise exc else: @@ -85,7 +85,7 @@ def on_dispatch_task_complete(task): done, pending = await asyncio.wait(dispatch_tasks) for task in done: exc = task.exception() - if exc: + if exc and not isinstance(exc, asyncio.CancelledError): raise exc if not fut.done(): fut.set_result(None)