Feature proposal: `RetryContextManager` (for retrying `__enter__` and NOT the contents of the `with` block)
rudolfbyker opened this issue · comments
Rudolf Byker commented
Goal
A context manager that retries if the __enter__
method of a context manager raises an exception, but does NOT retry when the contents of the with
block raises an exception.
Example use case
The context we are trying to set up is very unreliable (COM server black magic in our case), but once it is set up, it works well. We need to retry setting up the context, without squelching errors raised in the with block.
Proposed new RetryContextManager
class
from types import TracebackType
from typing import (
TypedDict,
Type,
ContextManager,
TypeVar,
Optional,
Callable,
Generic,
Union,
Any,
)
from tenacity import retry, RetryCallState, RetryError
from tenacity.stop import StopBaseT
from tenacity.wait import WaitBaseT
from tenacity.retry import RetryBaseT
class RetryKwargs(TypedDict, total=False):
"""
Copied from the arguments of `BaseRetrying.__init__` in the `tenacity` library.
"""
sleep: Callable[[Union[int, float]], None]
stop: StopBaseT
wait: WaitBaseT
retry: RetryBaseT
before: Callable[[RetryCallState], None]
after: Callable[[RetryCallState], None]
before_sleep: Optional[Callable[[RetryCallState], None]]
reraise: bool
retry_error_cls: Type[RetryError]
retry_error_callback: Optional[Callable[[RetryCallState], Any]]
T = TypeVar("T")
class RetryContextManager(Generic[T]):
"""
A context manager that retries if the `__enter__` method of a context manager raises an exception, but does NOT
retry when the contents of the `with` block raises an exception.
"""
def __init__(
self,
cm: Callable[[], ContextManager[T]],
retry_kwargs: RetryKwargs,
) -> None:
"""
Create the RetryContextManager.
Args:
cm: A callable that returns a context manager.
retry_kwargs: The arguments to pass to the `retry` decorator.
"""
self._cm: Callable[[], ContextManager[T]] = cm
self._cm_instance: Optional[ContextManager[T]] = None
self._retry_kwargs: RetryKwargs = retry_kwargs
def __enter__(self) -> T:
@retry(**self._retry_kwargs)
def _enter() -> T:
# Create a new instance of the context manager.
cm_instance = self._cm()
try:
# Enter the context manager.
managed_resource = cm_instance.__enter__()
except BaseException as e:
# Clean up the failed context.
cm_instance.__exit__(type(e), e, e.__traceback__)
# Re-raise so that we can retry.
raise
# Success.
self._cm_instance = cm_instance
return managed_resource
return _enter()
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._cm_instance is not None:
self._cm_instance.__exit__(exc_type, exc_val, exc_tb)
Unit tests
import unittest
from contextlib import contextmanager
from logging import getLogger, DEBUG
from queue import Queue
from typing import Generator
from tenacity import stop_after_attempt, before_log, after_log, before_sleep_log, RetryContextManager, RetryKwargs
def create_test_context_manager(
*,
n_failures: int,
thing_to_yield: int,
history: Queue[str],
retry_kwargs: RetryKwargs,
) -> RetryContextManager[int]:
"""
Create an (possibly unreliable) context manager, wrapped in a `RetryContextManager`, for testing.
Args:
n_failures: The number of times to fail before succeeding.
thing_to_yield: The thing to yield when the context manager succeeds.
history: A queue to put the history of the context manager into. This is used for test assertions.
retry_kwargs: The arguments to pass to the `retry` decorator.
Returns:
A callable that returns a context manager.
"""
n_tries = 0
@contextmanager
def cm() -> Generator[int, None, None]:
nonlocal n_tries
history.put(f"{n_tries} entering")
try:
if n_tries < n_failures:
history.put(f"{n_tries} raising")
raise RuntimeError(f"{n_tries} failed")
history.put(f"{n_tries} yielding")
yield thing_to_yield
finally:
history.put(f"{n_tries} exiting")
n_tries += 1
return RetryContextManager(
cm=cm,
retry_kwargs=retry_kwargs,
)
class TestRetryContextManager(unittest.TestCase):
def test_no_retries_necessary(self) -> None:
history: Queue[str] = Queue()
logger = getLogger("test")
with self.assertLogs(logger=logger, level=DEBUG) as logs:
logger.info("Before the context")
with create_test_context_manager(
n_failures=0,
thing_to_yield=1,
history=history,
retry_kwargs=RetryKwargs(
stop=stop_after_attempt(3),
reraise=True,
before=before_log(logger=logger, log_level=DEBUG),
after=after_log(logger=logger, log_level=DEBUG),
before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
),
) as value:
logger.info("Inside the context")
self.assertEqual(value, 1)
logger.info("After the context")
self.assertEqual(["0 entering", "0 yielding", "0 exiting"], list(history.queue))
self.assertEqual(
[
"INFO:test:Before the context",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
"INFO:test:Inside the context",
"INFO:test:After the context",
],
logs.output,
)
def test_retry_then_succeed(self) -> None:
history: Queue[str] = Queue()
logger = getLogger("test")
with self.assertLogs(logger=logger, level=DEBUG) as logs:
logger.info("Before the context")
with create_test_context_manager(
n_failures=2,
thing_to_yield=1,
history=history,
retry_kwargs=RetryKwargs(
stop=stop_after_attempt(3),
reraise=True,
before=before_log(logger=logger, log_level=DEBUG),
after=after_log(logger=logger, log_level=DEBUG),
before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
),
) as value:
logger.info("Inside the context")
self.assertEqual(value, 1)
logger.info("After the context")
self.assertEqual(
[
"0 entering",
"0 raising",
"0 exiting",
"1 entering",
"1 raising",
"1 exiting",
"2 entering",
"2 yielding",
"2 exiting",
],
list(history.queue),
)
self.assertEqual(
[
"INFO:test:Before the context",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
"DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 1st time calling it.",
"DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 0 failed.",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 2nd time calling it.",
"DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 2nd time calling it.",
"DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 1 failed.",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 3rd time calling it.",
"INFO:test:Inside the context",
"INFO:test:After the context",
],
logs.output,
)
def test_retry_then_give_up(self) -> None:
history: Queue[str] = Queue()
logger = getLogger("test")
with self.assertLogs(logger=logger, level=DEBUG) as logs:
logger.info("Before the context")
with self.assertRaisesRegex(RuntimeError, "2 failed"):
with create_test_context_manager(
n_failures=5,
thing_to_yield=1,
history=history,
retry_kwargs=RetryKwargs(
stop=stop_after_attempt(3),
reraise=True,
before=before_log(logger=logger, log_level=DEBUG),
after=after_log(logger=logger, log_level=DEBUG),
before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
),
):
logger.info("Inside the context")
logger.info("After the context")
self.assertEqual(
[
"0 entering",
"0 raising",
"0 exiting",
"1 entering",
"1 raising",
"1 exiting",
"2 entering",
"2 raising",
"2 exiting",
],
list(history.queue),
)
self.assertEqual(
[
"INFO:test:Before the context",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
"DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 1st time calling it.",
"DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 0 failed.",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 2nd time calling it.",
"DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 2nd time calling it.",
"DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 1 failed.",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 3rd time calling it.",
"DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 3rd time calling it.",
"INFO:test:After the context",
],
logs.output,
)
def test_with_body_does_not_cause_retry(self) -> None:
history: Queue[str] = Queue()
logger = getLogger("test")
with self.assertLogs(logger=logger, level=DEBUG) as logs:
logger.info("Before the context")
with self.assertRaises(RuntimeError):
with create_test_context_manager(
n_failures=0,
thing_to_yield=1,
history=history,
retry_kwargs=RetryKwargs(
stop=stop_after_attempt(3),
reraise=True,
before=before_log(logger=logger, log_level=DEBUG),
after=after_log(logger=logger, log_level=DEBUG),
before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
),
):
logger.info("Inside the context")
raise RuntimeError("This should not cause a retry")
logger.info("After the context")
self.assertEqual(["0 entering", "0 yielding", "0 exiting"], list(history.queue))
self.assertEqual(
[
"INFO:test:Before the context",
"DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
"INFO:test:Inside the context",
],
logs.output,
)