python-trio / trio-asyncio

a re-implementation of the asyncio mainloop on top of Trio

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

wrapping async contextmanager functions in a "yield safe" way

graingert opened this issue · comments

currently @aio_as_trio runs __aenter__ and __aexit__ in two different tasks, whereas some asyncio context managers are sensitive to being run in the same task

so I've got a little util function like this:

def aio_acmgr_fn_as_trio_acmgr_fn(acmgr_fn):
    @functools.wraps(acmgr_fn)
    @contextlib.asynccontextmanager
    async def wrapper(*args, **kwargs):
        event = trio.Event()
        oc = outcome.Value(None)

        @trio_asyncio.aio_as_trio
        async def run(*, task_status):
            async with acmgr_fn(*args, **kwargs) as result:
                task_status.started(result)
                await trio_asyncio.trio_as_aio(event.wait)()
                oc.unwrap()

        async with trio.open_nursery() as nursery:
            result = await nursery.start(run)
            try:
                yield result
            except BaseException as e:
                oc = outcome.Error(e)
            finally:
                event.set()

    return wrapper

I think it would be worth including it here also

I think Cancelled needs special handling:

def aio_acmgr_fn_as_trio_acmgr_fn(acmgr_fn):
    @functools.wraps(acmgr_fn)
    @contextlib.asynccontextmanager
    async def wrapper(*args, **kwargs):
        
        @trio_asyncio.aio_as_trio
        async def run(*, task_status):
            fut = asyncio.get_running_loop().create_future()
            async with acmgr_fn(*args, **kwargs) as result:
                task_status.started((fut, result))
                await fut

        async with trio.open_nursery() as nursery:
            fut, result = await nursery.start(run)
            try:
                yield result
            except trio.Cancelled:
                raise
            except BaseException as e:
                fut.set_exception(e)
            finally:
                if not fut.done():
                    fut.set_result(None)
    return wrapper