Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Doc/library/asyncio-task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -847,9 +847,10 @@ Task Object
APIs except :meth:`Future.set_result` and
:meth:`Future.set_exception`.

Tasks support the :mod:`contextvars` module. When a Task
is created it copies the current context and later runs its
coroutine in the copied context.
Tasks support the :mod:`contextvars` module. Tasks can be run under
any context, defaulting to a copy of the context that created them. This
context will later be used to run its coroutines. The context associated
with a task can be modified using `:meth:`asyncio.run_in_context`.

.. versionchanged:: 3.7
Added support for the :mod:`contextvars` module.
Expand Down
4 changes: 2 additions & 2 deletions Lib/asyncio/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import tasks


def run(main, *, debug=None):
def run(main, *, debug=None, **task_kwargs):
"""Execute the coroutine and return the result.

This function runs the passed coroutine, taking care of
Expand Down Expand Up @@ -41,7 +41,7 @@ async def main():
events.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
return loop.run_until_complete(tasks.Task(main, loop=loop, **task_kwargs))
finally:
try:
_cancel_all_tasks(loop)
Expand Down
33 changes: 30 additions & 3 deletions Lib/asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED',
'wait', 'wait_for', 'as_completed', 'sleep',
'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe',
'current_task', 'all_tasks',
'current_task', 'all_tasks', 'run_in_context',
'_register_task', '_unregister_task', '_enter_task', '_leave_task',
)

Expand Down Expand Up @@ -71,6 +71,24 @@ def _set_task_name(task, name):
set_name(name)


async def run_in_context(context, coro):
"""Run the coroutine coro in the passed context.

This method can be used to run coro in an alternate context within the
calling Task. This is the asyncio analog of contextvars.Context.run.
"""
task = current_task()
if task is None:
raise RuntimeError("No running task")
prev_context = task._set_context(context)
await __sleep0()
try:
return await coro
finally:
task._set_context(prev_context)
await __sleep0()


class Task(futures._PyFuture): # Inherit Python Task implementation
# from a Python Future implementation.

Expand All @@ -89,7 +107,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
# status is still pending
_log_destroy_pending = True

def __init__(self, coro, *, loop=None, name=None):
def __init__(self, coro, *, loop=None, name=None, context=None):
super().__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
Expand All @@ -107,7 +125,11 @@ def __init__(self, coro, *, loop=None, name=None):
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
self._context = contextvars.copy_context()

if context is None:
self._context = contextvars.copy_context()
else:
self._context = context

self._loop.call_soon(self.__step, context=self._context)
_register_task(self)
Expand All @@ -129,6 +151,11 @@ def __class_getitem__(cls, type):
def _repr_info(self):
return base_tasks._task_repr_info(self)

def _set_context(self, context):
prev_context = self._context
self._context = context
return prev_context

def get_coro(self):
return self._coro

Expand Down
26 changes: 26 additions & 0 deletions Lib/test/test_asyncio/test_runners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextvars
import unittest

from unittest import mock
Expand Down Expand Up @@ -180,3 +181,28 @@ async def main():

self.assertIsNone(spinner.ag_frame)
self.assertFalse(spinner.ag_running)

def test_asyncio_run_task_creation(self):
cvar = contextvars.ContextVar('cvar', default='nope')

context = contextvars.Context()
context.run(cvar.set, 'maybe')

async def check_explicit(name_expect, var_expect, var_update):
self.assertEqual(name_expect, asyncio.current_task().get_name())
self.assertEqual(var_expect, cvar.get())
cvar.set(var_update)

# Verify name and context passed to task
asyncio.run(check_explicit('my-task', 'maybe', 'sometimes'), name='my-task', context=context)
self.assertEqual(context.run(cvar.get), 'sometimes')

async def check_default(var_expect, var_update):
self.assertTrue(asyncio.current_task().get_name().startswith("Task-"))
self.assertEqual(var_expect, cvar.get())
cvar.set(var_update)

# Verify default name and context (copy of current context) used otherwise
cvar.set('seldom')
asyncio.run(check_default('seldom', 'often'))
self.assertEqual(cvar.get(), 'seldom')
70 changes: 68 additions & 2 deletions Lib/test/test_asyncio/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class BaseTaskTests:
Task = None
Future = None

def new_task(self, loop, coro, name='TestTask'):
return self.__class__.Task(coro, loop=loop, name=name)
def new_task(self, loop, coro, name='TestTask', context=None):
return self.__class__.Task(coro, loop=loop, name=name, context=context)

def new_future(self, loop):
return self.__class__.Future(loop=loop)
Expand Down Expand Up @@ -2860,6 +2860,72 @@ async def main():

self.assertEqual(cvar.get(), -1)

def test_context_4(self):
# Test specifying context
cvar = contextvars.ContextVar('cvar', default='nope')

context = contextvars.Context()
context.run(cvar.set, 'maybe')

async def sub(expect, update):
self.assertEqual(cvar.get(), expect)
cvar.set(update)

async def main():
self.assertEqual(cvar.get(), 'maybe')

await sub('maybe', 'always')
self.assertEqual(cvar.get(), 'always')

await self.new_task(loop, sub('always', 'never'))
self.assertEqual(cvar.get(), 'always')

await self.new_task(loop, sub('always', 'never'), context=context)
self.assertEqual(cvar.get(), 'never')

loop = asyncio.new_event_loop()
try:
task = self.new_task(loop, main(), context=context)
loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual(cvar.get(), 'nope')
self.assertEqual(context.run(cvar.get), 'never')

def test_run_in_context(self):
# Test run_in_context behavior
cvar = contextvars.ContextVar('cvar', default='nope')

context = contextvars.Context()
context.run(cvar.set, 'maybe')

async def sub(update, parent_task):
self.assertIs(parent_task, asyncio.current_task())
value = cvar.get()
cvar.set(update)
return value


async def main():
self.assertEqual(cvar.get(), 'maybe')
sub_context = context.copy()

cvar.set('never')
self.assertEqual(await asyncio.run_in_context(sub_context, sub('always', asyncio.current_task())), 'maybe')
self.assertEqual(cvar.get(), 'never')
self.assertEqual(sub_context.run(cvar.get), 'always')

loop = asyncio.new_event_loop()
try:
task = self.new_task(loop, main(), context=context)
loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual(cvar.get(), 'nope')
self.assertEqual(context.run(cvar.get), 'never')

def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
Expand Down
44 changes: 39 additions & 5 deletions Modules/_asynciomodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2009,14 +2009,15 @@ _asyncio.Task.__init__
*
loop: object = None
name: object = None
context: object = None

A coroutine wrapped in a Future.
[clinic start generated code]*/

static int
_asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
PyObject *name)
/*[clinic end generated code: output=88b12b83d570df50 input=352a3137fe60091d]*/
PyObject *name, PyObject *context)
/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/
{
if (future_init((FutureObj*)self, loop)) {
return -1;
Expand All @@ -2034,9 +2035,14 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
return -1;
}

Py_XSETREF(self->task_context, PyContext_CopyCurrent());
if (self->task_context == NULL) {
return -1;
if (context == Py_None) {
Py_XSETREF(self->task_context, PyContext_CopyCurrent());
if (self->task_context == NULL) {
return -1;
}
} else {
Py_INCREF(context);
Py_XSETREF(self->task_context, context);
}

Py_CLEAR(self->task_fut_waiter);
Expand Down Expand Up @@ -2379,6 +2385,33 @@ _asyncio_Task_set_name(TaskObj *self, PyObject *value)
Py_RETURN_NONE;
}

/*[clinic input]
_asyncio.Task._set_context

context: object

Set the context associated with the task.

This does not change the current thread context and only affects the thread
context of later callbacks. Returns the previously context attached to the task.
[clinic start generated code]*/

static PyObject *
_asyncio_Task__set_context_impl(TaskObj *self, PyObject *context)
/*[clinic end generated code: output=46ac5ea28ccfac3b input=36a0652ac2b5f671]*/
{
if (context == Py_None) {
PyErr_SetString(
PyExc_RuntimeError, "expected valid context");
return NULL;
}

PyObject *prev_context = self->task_context;
Py_INCREF(context);
self->task_context = context;
return prev_context;
}

static void
TaskObj_finalize(TaskObj *task)
{
Expand Down Expand Up @@ -2471,6 +2504,7 @@ static PyMethodDef TaskType_methods[] = {
_ASYNCIO_TASK__REPR_INFO_METHODDEF
_ASYNCIO_TASK_GET_NAME_METHODDEF
_ASYNCIO_TASK_SET_NAME_METHODDEF
_ASYNCIO_TASK__SET_CONTEXT_METHODDEF
_ASYNCIO_TASK_GET_CORO_METHODDEF
{"__class_getitem__", task_cls_getitem, METH_O|METH_CLASS, NULL},
{NULL, NULL} /* Sentinel */
Expand Down
56 changes: 49 additions & 7 deletions Modules/clinic/_asynciomodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.