Annotated[schema(description=...)] erased when using Generic types
thomascobb opened this issue · comments
#280 went back to using get_type_hints
, but appears to have lost the description annotations on Generic classes. Here's a snippet that works on 0.16.4:
from dataclasses import dataclass
from types import new_class
from typing import Any, Dict, Generic, TypeVar
from apischema import deserializer, schema, serializer
from apischema.conversions import Conversion
from apischema.graphql import graphql_schema
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged
from apischema.type_names import type_name
from graphql import print_schema
from typing_extensions import Annotated
Cls = TypeVar("Cls", bound=type)
T = TypeVar("T")
generic_name = type_name(lambda cls, *args: cls.__name__)
def as_tagged_union(cls: Cls) -> Cls:
def serialization() -> Conversion:
annotations = {sub.__name__: Tagged[sub] for sub in cls.__subclasses__()}
namespace = {"__annotations__": annotations}
tagged_union = new_class(
cls.__name__,
(TaggedUnion, Generic[T]),
exec_body=lambda ns: ns.update(namespace),
)
return Conversion(
lambda obj: tagged_union(**{obj.__class__.__name__: obj}),
source=cls[T],
target=tagged_union[T],
# Conversion must not be inherited because it would lead to
# infinite recursion otherwise
inherited=False,
)
def deserialization() -> Conversion:
generic_name(cls)
for sub in cls.__subclasses__():
generic_name(sub)
namespace: Dict[str, Any] = {
"__annotations__": {
sub.__name__: Tagged[sub[T]] for sub in cls.__subclasses__()
}
}
# Create the deserialization tagged union class
tagged_union = new_class(
cls.__name__,
(TaggedUnion, Generic[T]),
exec_body=lambda ns: ns.update(namespace),
)
return Conversion(
lambda obj: get_tagged(obj)[1], source=tagged_union[T], target=cls[T]
)
deserializer(lazy=deserialization, target=cls)
serializer(lazy=serialization, source=cls)
return cls
@as_tagged_union
class Base(Generic[T]):
bar: T
@dataclass
class Foo(Base[T]):
bar: Annotated[T, schema(description="A Foo bar")]
@dataclass
class Bat(Base[T]):
bar: Annotated[T, schema(description="A Foo bar")]
def get_bar(base: Base[str]) -> str:
return base.bar
def test_fails():
expected = '''\
type Query {
getBar(base: BaseInput!): String!
}
input BaseInput {
Foo: FooInput
Bat: BatInput
}
input FooInput {
"""A Foo bar"""
bar: String!
}
input BatInput {
"""A Foo bar"""
bar: String!
}
'''
assert print_schema(graphql_schema(query=[get_bar])) == expected
On 0.16.5 it fails with:
Traceback (most recent call last):
File "/dls/science/users/tmc43/common/python/scanspec/tests/test_fails.py", line 103, in test_fails
assert print_schema(graphql_schema(query=[get_bar])) == expected
AssertionError: assert ('type Query {\n'\n ' getBar(base: BaseInput!): String!\n'\n '}\n'\n '\n'\n 'input BaseInput {\n'\n ' Foo: FooInput\n'\n ' Bat: BatInput\n'\n '}\n'\n '\n'\n 'input FooInput {\n'\n ' bar: String!\n'\n '}\n'\n '\n'\n 'input BatInput {\n'\n ' bar: String!\n'\n '}\n') == ('type Query {\n'\n ' getBar(base: BaseInput!): String!\n'\n '}\n'\n '\n'\n 'input BaseInput {\n'\n ' Foo: FooInput\n'\n ' Bat: BatInput\n'\n '}\n'\n '\n'\n 'input FooInput {\n'\n ' """A Foo bar"""\n'\n ' bar: String!\n'\n '}\n'\n '\n'\n 'input BatInput {\n'\n ' """A Foo bar"""\n'\n ' bar: String!\n'\n '}\n')
type Query {
getBar(base: BaseInput!): String!
}
input BaseInput {
Foo: FooInput
Bat: BatInput
}
input FooInput {
- """A Foo bar"""
bar: String!
}
input BatInput {
- """A Foo bar"""
bar: String!
}
Shame on me … I've forgotten one include_extras=True
in #280.
I will fix it tonight and release a new patch (version 0.16.4 should be deleted because it's bugged).
Sorry for the inconvenience, I hope you can still work with v0.16.4 for now.
Thanks for the quick response, 0.16.4 will be fine for me, I only noticed because I have a weekly job on one of my projects that tests against the latest version of all its dependencies...
Thanks, all working now