vinissimus / async-asgi-testclient

A framework-agnostic library for testing ASGI web applications

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Allow setting the full request URL including a hostname / port

shevron opened this issue · comments

I am testing a FastAPI application that actually cares about the request hostname and port. I find async-asgi-testclient much better for our needs than Starlette's TestClient or httpx, however, right now simulating requests that have different hostnames and ports is quite hard.

It would be really cool if instead of just a path, I could pass in a full URL including hostname and port, and these will be taken into account when constructing the request (e.g. with the host header).

I could work around this in my tests by subclassing TestClient and overriding open and websocket_connect, but it would be nice if this would have been a built-in option (potentially, also allowing to set a default base_url, which is a common option in other tests clients).

For reference, here is my overriding code:

class TestClient(BaseTestClient):

    base_url: Optional[str] = None

    async def open(self, path: str, **kwargs: Any):
        path, kwargs = self._fix_args(path, kwargs)
        return await super().open(path, **kwargs)

    def websocket_connect(self, path: str, *args, **kwargs):
        path, kwargs = self._fix_args(path, kwargs)
        if "scheme" in kwargs:
            del kwargs["scheme"]  # TODO: deal with `wss://` connections somehow? - this is a separate issue...
        return super().websocket_connect(path, *args, **kwargs)

    def _fix_args(
        self, path: str, kwargs: Dict[str, Any]
    ) -> Tuple[str, Dict[str, Any]]:
        path, scheme, hostname = self._parse_path_or_url(path)
        headers = kwargs.get("headers", {})
        if hostname and not headers.get("host"):
            headers.update({"host": hostname})
            kwargs["headers"] = headers
        if scheme:
            kwargs["scheme"] = scheme

        return path, kwargs

    def _parse_path_or_url(self, path_or_url: str) -> Tuple[str, str, Optional[str]]:
        if self.base_url and "://" not in path_or_url:
            path_or_url = urljoin(self.base_url, path_or_url)

        if "://" not in path_or_url:
            return path_or_url, "https", None

        parts = urlsplit(path_or_url)
        scheme = parts.scheme
        hostname = parts.hostname
        if parts.port:
            hostname += ":" + str(parts.port)
        path = urlunsplit(("", "", parts.path, parts.query, parts.fragment))
        return path, scheme, hostname