Fern: regenerated Python SDK (#3829)

This commit is contained in:
Stanislav Novosad
2025-10-27 16:26:37 -06:00
committed by GitHub
parent c12c047768
commit ba0b25cb4b
150 changed files with 8305 additions and 3701 deletions

View File

@@ -1,33 +1,91 @@
# This file was auto-generated by Fern from our API Definition.
from .api_error import ApiError
from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
from .datetime_utils import serialize_datetime
from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
from .http_client import AsyncHttpClient, HttpClient
from .jsonable_encoder import jsonable_encoder
from .pydantic_utilities import (
IS_PYDANTIC_V2,
UniversalBaseModel,
UniversalRootModel,
parse_obj_as,
universal_field_validator,
universal_root_validator,
update_forward_refs,
)
from .query_encoder import encode_query
from .remove_none_from_dict import remove_none_from_dict
from .request_options import RequestOptions
from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from .api_error import ApiError
from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
from .datetime_utils import serialize_datetime
from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
from .http_client import AsyncHttpClient, HttpClient
from .http_response import AsyncHttpResponse, HttpResponse
from .jsonable_encoder import jsonable_encoder
from .pydantic_utilities import (
IS_PYDANTIC_V2,
UniversalBaseModel,
UniversalRootModel,
parse_obj_as,
universal_field_validator,
universal_root_validator,
update_forward_refs,
)
from .query_encoder import encode_query
from .remove_none_from_dict import remove_none_from_dict
from .request_options import RequestOptions
from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
_dynamic_imports: typing.Dict[str, str] = {
"ApiError": ".api_error",
"AsyncClientWrapper": ".client_wrapper",
"AsyncHttpClient": ".http_client",
"AsyncHttpResponse": ".http_response",
"BaseClientWrapper": ".client_wrapper",
"FieldMetadata": ".serialization",
"File": ".file",
"HttpClient": ".http_client",
"HttpResponse": ".http_response",
"IS_PYDANTIC_V2": ".pydantic_utilities",
"RequestOptions": ".request_options",
"SyncClientWrapper": ".client_wrapper",
"UniversalBaseModel": ".pydantic_utilities",
"UniversalRootModel": ".pydantic_utilities",
"convert_and_respect_annotation_metadata": ".serialization",
"convert_file_dict_to_httpx_tuples": ".file",
"encode_query": ".query_encoder",
"jsonable_encoder": ".jsonable_encoder",
"parse_obj_as": ".pydantic_utilities",
"remove_none_from_dict": ".remove_none_from_dict",
"serialize_datetime": ".datetime_utils",
"universal_field_validator": ".pydantic_utilities",
"universal_root_validator": ".pydantic_utilities",
"update_forward_refs": ".pydantic_utilities",
"with_content_type": ".file",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = [
"ApiError",
"AsyncClientWrapper",
"AsyncHttpClient",
"AsyncHttpResponse",
"BaseClientWrapper",
"FieldMetadata",
"File",
"HttpClient",
"HttpResponse",
"IS_PYDANTIC_V2",
"RequestOptions",
"SyncClientWrapper",

View File

@@ -1,15 +1,23 @@
# This file was auto-generated by Fern from our API Definition.
import typing
from typing import Any, Dict, Optional
class ApiError(Exception):
status_code: typing.Optional[int]
body: typing.Any
headers: Optional[Dict[str, str]]
status_code: Optional[int]
body: Any
def __init__(self, *, status_code: typing.Optional[int] = None, body: typing.Any = None):
def __init__(
self,
*,
headers: Optional[Dict[str, str]] = None,
status_code: Optional[int] = None,
body: Any = None,
) -> None:
self.headers = headers
self.status_code = status_code
self.body = body
def __str__(self) -> str:
return f"status_code: {self.status_code}, body: {self.body}"
return f"headers: {self.headers}, status_code: {self.status_code}, body: {self.body}"

View File

@@ -1,9 +1,9 @@
# This file was auto-generated by Fern from our API Definition.
import typing
import httpx
from .http_client import HttpClient
from .http_client import AsyncHttpClient
from .http_client import AsyncHttpClient, HttpClient
class BaseClientWrapper:
@@ -11,26 +11,30 @@ class BaseClientWrapper:
self,
*,
api_key: typing.Optional[str] = None,
x_api_key: str,
headers: typing.Optional[typing.Dict[str, str]] = None,
base_url: str,
timeout: typing.Optional[float] = None,
):
self._api_key = api_key
self.x_api_key = x_api_key
self._headers = headers
self._base_url = base_url
self._timeout = timeout
def get_headers(self) -> typing.Dict[str, str]:
headers: typing.Dict[str, str] = {
"User-Agent": "skyvern/0.2.20",
"X-Fern-Language": "Python",
"X-Fern-SDK-Name": "skyvern",
"X-Fern-SDK-Version": "0.2.18",
"X-Fern-SDK-Version": "0.2.20",
**(self.get_custom_headers() or {}),
}
if self._api_key is not None:
headers["x-api-key"] = self._api_key
headers["x-api-key"] = self.x_api_key
return headers
def get_custom_headers(self) -> typing.Optional[typing.Dict[str, str]]:
return self._headers
def get_base_url(self) -> str:
return self._base_url
@@ -43,12 +47,12 @@ class SyncClientWrapper(BaseClientWrapper):
self,
*,
api_key: typing.Optional[str] = None,
x_api_key: str,
headers: typing.Optional[typing.Dict[str, str]] = None,
base_url: str,
timeout: typing.Optional[float] = None,
httpx_client: httpx.Client,
):
super().__init__(api_key=api_key, x_api_key=x_api_key, base_url=base_url, timeout=timeout)
super().__init__(api_key=api_key, headers=headers, base_url=base_url, timeout=timeout)
self.httpx_client = HttpClient(
httpx_client=httpx_client,
base_headers=self.get_headers,
@@ -62,12 +66,12 @@ class AsyncClientWrapper(BaseClientWrapper):
self,
*,
api_key: typing.Optional[str] = None,
x_api_key: str,
headers: typing.Optional[typing.Dict[str, str]] = None,
base_url: str,
timeout: typing.Optional[float] = None,
httpx_client: httpx.AsyncClient,
):
super().__init__(api_key=api_key, x_api_key=x_api_key, base_url=base_url, timeout=timeout)
super().__init__(api_key=api_key, headers=headers, base_url=base_url, timeout=timeout)
self.httpx_client = AsyncHttpClient(
httpx_client=httpx_client,
base_headers=self.get_headers,

View File

@@ -0,0 +1,18 @@
# This file was auto-generated by Fern from our API Definition.
from typing import Any, Dict
class ForceMultipartDict(Dict[str, Any]):
"""
A dictionary subclass that always evaluates to True in boolean contexts.
This is used to force multipart/form-data encoding in HTTP requests even when
the dictionary is empty, which would normally evaluate to False.
"""
def __bool__(self) -> bool:
return True
FORCE_MULTIPART = ForceMultipartDict()

View File

@@ -2,7 +2,6 @@
import asyncio
import email.utils
import json
import re
import time
import typing
@@ -11,12 +10,13 @@ from contextlib import asynccontextmanager, contextmanager
from random import random
import httpx
from .file import File, convert_file_dict_to_httpx_tuples
from .force_multipart import FORCE_MULTIPART
from .jsonable_encoder import jsonable_encoder
from .query_encoder import encode_query
from .remove_none_from_dict import remove_none_from_dict
from .request_options import RequestOptions
from httpx._types import RequestFiles
INITIAL_RETRY_DELAY_SECONDS = 0.5
MAX_RETRY_DELAY_SECONDS = 10
@@ -180,11 +180,17 @@ class HttpClient:
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 2,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> httpx.Response:
base_url = self.get_base_url(base_url)
timeout = (
@@ -195,6 +201,15 @@ class HttpClient:
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
response = self.httpx_client.request(
method=method,
url=urllib.parse.urljoin(f"{base_url}/", path),
@@ -227,11 +242,7 @@ class HttpClient:
json=json_body,
data=data_body,
content=content,
files=(
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit)
else None
),
files=request_files,
timeout=timeout,
)
@@ -266,11 +277,17 @@ class HttpClient:
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 2,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> typing.Iterator[httpx.Response]:
base_url = self.get_base_url(base_url)
timeout = (
@@ -279,6 +296,15 @@ class HttpClient:
else self.base_timeout()
)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
with self.httpx_client.stream(
@@ -313,11 +339,7 @@ class HttpClient:
json=json_body,
data=data_body,
content=content,
files=(
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit)
else None
),
files=request_files,
timeout=timeout,
) as stream:
yield stream
@@ -356,11 +378,17 @@ class AsyncHttpClient:
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 2,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> httpx.Response:
base_url = self.get_base_url(base_url)
timeout = (
@@ -369,6 +397,15 @@ class AsyncHttpClient:
else self.base_timeout()
)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
# Add the input to each of these and do None-safety checks
@@ -404,11 +441,7 @@ class AsyncHttpClient:
json=json_body,
data=data_body,
content=content,
files=(
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if files is not None
else None
),
files=request_files,
timeout=timeout,
)
@@ -442,11 +475,17 @@ class AsyncHttpClient:
json: typing.Optional[typing.Any] = None,
data: typing.Optional[typing.Any] = None,
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
files: typing.Optional[
typing.Union[
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
typing.List[typing.Tuple[str, File]],
]
] = None,
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_options: typing.Optional[RequestOptions] = None,
retries: int = 2,
omit: typing.Optional[typing.Any] = None,
force_multipart: typing.Optional[bool] = None,
) -> typing.AsyncIterator[httpx.Response]:
base_url = self.get_base_url(base_url)
timeout = (
@@ -455,6 +494,15 @@ class AsyncHttpClient:
else self.base_timeout()
)
request_files: typing.Optional[RequestFiles] = (
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if (files is not None and files is not omit and isinstance(files, dict))
else None
)
if (request_files is None or len(request_files) == 0) and force_multipart:
request_files = FORCE_MULTIPART
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
async with self.httpx_client.stream(
@@ -489,11 +537,7 @@ class AsyncHttpClient:
json=json_body,
data=data_body,
content=content,
files=(
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
if files is not None
else None
),
files=request_files,
timeout=timeout,
) as stream:
yield stream

View File

@@ -0,0 +1,55 @@
# This file was auto-generated by Fern from our API Definition.
from typing import Dict, Generic, TypeVar
import httpx
# Generic to represent the underlying type of the data wrapped by the HTTP response.
T = TypeVar("T")
class BaseHttpResponse:
"""Minimalist HTTP response wrapper that exposes response headers."""
_response: httpx.Response
def __init__(self, response: httpx.Response):
self._response = response
@property
def headers(self) -> Dict[str, str]:
return dict(self._response.headers)
class HttpResponse(Generic[T], BaseHttpResponse):
"""HTTP response wrapper that exposes response headers and data."""
_data: T
def __init__(self, response: httpx.Response, data: T):
super().__init__(response)
self._data = data
@property
def data(self) -> T:
return self._data
def close(self) -> None:
self._response.close()
class AsyncHttpResponse(Generic[T], BaseHttpResponse):
"""HTTP response wrapper that exposes response headers and data."""
_data: T
def __init__(self, response: httpx.Response, data: T):
super().__init__(response)
self._data = data
@property
def data(self) -> T:
return self._data
async def close(self) -> None:
await self._response.aclose()

View File

@@ -0,0 +1,42 @@
# This file was auto-generated by Fern from our API Definition.
# isort: skip_file
import typing
from importlib import import_module
if typing.TYPE_CHECKING:
from ._api import EventSource, aconnect_sse, connect_sse
from ._exceptions import SSEError
from ._models import ServerSentEvent
_dynamic_imports: typing.Dict[str, str] = {
"EventSource": "._api",
"SSEError": "._exceptions",
"ServerSentEvent": "._models",
"aconnect_sse": "._api",
"connect_sse": "._api",
}
def __getattr__(attr_name: str) -> typing.Any:
module_name = _dynamic_imports.get(attr_name)
if module_name is None:
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
try:
module = import_module(module_name, __package__)
if module_name == f".{attr_name}":
return module
else:
return getattr(module, attr_name)
except ImportError as e:
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
except AttributeError as e:
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
def __dir__():
lazy_attrs = list(_dynamic_imports.keys())
return sorted(lazy_attrs)
__all__ = ["EventSource", "SSEError", "ServerSentEvent", "aconnect_sse", "connect_sse"]

View File

@@ -0,0 +1,112 @@
# This file was auto-generated by Fern from our API Definition.
import re
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Iterator, cast
import httpx
from ._decoders import SSEDecoder
from ._exceptions import SSEError
from ._models import ServerSentEvent
class EventSource:
def __init__(self, response: httpx.Response) -> None:
self._response = response
def _check_content_type(self) -> None:
content_type = self._response.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise SSEError(
f"Expected response header Content-Type to contain 'text/event-stream', got {content_type!r}"
)
def _get_charset(self) -> str:
"""Extract charset from Content-Type header, fallback to UTF-8."""
content_type = self._response.headers.get("content-type", "")
# Parse charset parameter using regex
charset_match = re.search(r"charset=([^;\s]+)", content_type, re.IGNORECASE)
if charset_match:
charset = charset_match.group(1).strip("\"'")
# Validate that it's a known encoding
try:
# Test if the charset is valid by trying to encode/decode
"test".encode(charset).decode(charset)
return charset
except (LookupError, UnicodeError):
# If charset is invalid, fall back to UTF-8
pass
# Default to UTF-8 if no charset specified or invalid charset
return "utf-8"
@property
def response(self) -> httpx.Response:
return self._response
def iter_sse(self) -> Iterator[ServerSentEvent]:
self._check_content_type()
decoder = SSEDecoder()
charset = self._get_charset()
buffer = ""
for chunk in self._response.iter_bytes():
# Decode chunk using detected charset
text_chunk = chunk.decode(charset, errors="replace")
buffer += text_chunk
# Process complete lines
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.rstrip("\r")
sse = decoder.decode(line)
# when we reach a "\n\n" => line = ''
# => decoder will attempt to return an SSE Event
if sse is not None:
yield sse
# Process any remaining data in buffer
if buffer.strip():
line = buffer.rstrip("\r")
sse = decoder.decode(line)
if sse is not None:
yield sse
async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]:
self._check_content_type()
decoder = SSEDecoder()
lines = cast(AsyncGenerator[str, None], self._response.aiter_lines())
try:
async for line in lines:
line = line.rstrip("\n")
sse = decoder.decode(line)
if sse is not None:
yield sse
finally:
await lines.aclose()
@contextmanager
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any) -> Iterator[EventSource]:
headers = kwargs.pop("headers", {})
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"
with client.stream(method, url, headers=headers, **kwargs) as response:
yield EventSource(response)
@asynccontextmanager
async def aconnect_sse(
client: httpx.AsyncClient,
method: str,
url: str,
**kwargs: Any,
) -> AsyncIterator[EventSource]:
headers = kwargs.pop("headers", {})
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"
async with client.stream(method, url, headers=headers, **kwargs) as response:
yield EventSource(response)

View File

@@ -0,0 +1,61 @@
# This file was auto-generated by Fern from our API Definition.
from typing import List, Optional
from ._models import ServerSentEvent
class SSEDecoder:
def __init__(self) -> None:
self._event = ""
self._data: List[str] = []
self._last_event_id = ""
self._retry: Optional[int] = None
def decode(self, line: str) -> Optional[ServerSentEvent]:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if not self._event and not self._data and not self._last_event_id and self._retry is None:
return None
sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = ""
self._data = []
self._retry = None
return sse
if line.startswith(":"):
return None
fieldname, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None

View File

@@ -0,0 +1,7 @@
# This file was auto-generated by Fern from our API Definition.
import httpx
class SSEError(httpx.TransportError):
pass

View File

@@ -0,0 +1,17 @@
# This file was auto-generated by Fern from our API Definition.
import json
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(frozen=True)
class ServerSentEvent:
event: str = "message"
data: str = ""
id: str = ""
retry: Optional[int] = None
def json(self) -> Any:
"""Parse the data field as JSON."""
return json.loads(self.data)

View File

@@ -17,7 +17,6 @@ from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Union
import pydantic
from .datetime_utils import serialize_datetime
from .pydantic_utilities import (
IS_PYDANTIC_V2,

View File

@@ -2,90 +2,66 @@
# nopycln: file
import datetime as dt
import typing
from collections import defaultdict
import typing_extensions
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast
import pydantic
from .datetime_utils import serialize_datetime
from .serialization import convert_and_respect_annotation_metadata
IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
if IS_PYDANTIC_V2:
# isort will try to reformat the comments on these imports, which breaks mypy
# isort: off
from pydantic.v1.datetime_parse import ( # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
parse_date as parse_date,
)
from pydantic.v1.datetime_parse import ( # pyright: ignore[reportMissingImports] # Pydantic v2
parse_datetime as parse_datetime,
)
from pydantic.v1.json import ( # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
ENCODERS_BY_TYPE as encoders_by_type,
)
from pydantic.v1.typing import ( # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
get_args as get_args,
)
from pydantic.v1.typing import ( # pyright: ignore[reportMissingImports] # Pydantic v2
get_origin as get_origin,
)
from pydantic.v1.typing import ( # pyright: ignore[reportMissingImports] # Pydantic v2
is_literal_type as is_literal_type,
)
from pydantic.v1.typing import ( # pyright: ignore[reportMissingImports] # Pydantic v2
is_union as is_union,
)
from pydantic.v1.fields import ModelField as ModelField # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
from pydantic.v1.datetime_parse import parse_date as parse_date
from pydantic.v1.datetime_parse import parse_datetime as parse_datetime
from pydantic.v1.fields import ModelField as ModelField
from pydantic.v1.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore[attr-defined]
from pydantic.v1.typing import get_args as get_args
from pydantic.v1.typing import get_origin as get_origin
from pydantic.v1.typing import is_literal_type as is_literal_type
from pydantic.v1.typing import is_union as is_union
else:
from pydantic.datetime_parse import parse_date as parse_date # type: ignore # Pydantic v1
from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore # Pydantic v1
from pydantic.fields import ModelField as ModelField # type: ignore # Pydantic v1
from pydantic.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore # Pydantic v1
from pydantic.typing import get_args as get_args # type: ignore # Pydantic v1
from pydantic.typing import get_origin as get_origin # type: ignore # Pydantic v1
from pydantic.typing import is_literal_type as is_literal_type # type: ignore # Pydantic v1
from pydantic.typing import is_union as is_union # type: ignore # Pydantic v1
from pydantic.datetime_parse import parse_date as parse_date # type: ignore[no-redef]
from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore[no-redef]
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined, no-redef]
from pydantic.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore[no-redef]
from pydantic.typing import get_args as get_args # type: ignore[no-redef]
from pydantic.typing import get_origin as get_origin # type: ignore[no-redef]
from pydantic.typing import is_literal_type as is_literal_type # type: ignore[no-redef]
from pydantic.typing import is_union as is_union # type: ignore[no-redef]
# isort: on
from .datetime_utils import serialize_datetime
from .serialization import convert_and_respect_annotation_metadata
from typing_extensions import TypeAlias
T = TypeVar("T")
Model = TypeVar("Model", bound=pydantic.BaseModel)
T = typing.TypeVar("T")
Model = typing.TypeVar("Model", bound=pydantic.BaseModel)
def parse_obj_as(type_: typing.Type[T], object_: typing.Any) -> T:
def parse_obj_as(type_: Type[T], object_: Any) -> T:
dealiased_object = convert_and_respect_annotation_metadata(object_=object_, annotation=type_, direction="read")
if IS_PYDANTIC_V2:
adapter = pydantic.TypeAdapter(type_) # type: ignore # Pydantic v2
adapter = pydantic.TypeAdapter(type_) # type: ignore[attr-defined]
return adapter.validate_python(dealiased_object)
else:
return pydantic.parse_obj_as(type_, dealiased_object)
return pydantic.parse_obj_as(type_, dealiased_object)
def to_jsonable_with_fallback(
obj: typing.Any, fallback_serializer: typing.Callable[[typing.Any], typing.Any]
) -> typing.Any:
def to_jsonable_with_fallback(obj: Any, fallback_serializer: Callable[[Any], Any]) -> Any:
if IS_PYDANTIC_V2:
from pydantic_core import to_jsonable_python
return to_jsonable_python(obj, fallback=fallback_serializer)
else:
return fallback_serializer(obj)
return fallback_serializer(obj)
class UniversalBaseModel(pydantic.BaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(
model_config: ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict( # type: ignore[typeddict-unknown-key]
# Allow fields beginning with `model_` to be used in the model
protected_namespaces=(),
) # type: ignore # Pydantic v2
)
@pydantic.model_serializer(mode="wrap", when_used="json") # type: ignore # Pydantic v2
def serialize_model(self, handler: pydantic.SerializerFunctionWrapHandler) -> typing.Any: # type: ignore # Pydantic v2
serialized = handler(self)
@pydantic.model_serializer(mode="plain", when_used="json") # type: ignore[attr-defined]
def serialize_model(self) -> Any: # type: ignore[name-defined]
serialized = self.dict() # type: ignore[attr-defined]
data = {k: serialize_datetime(v) if isinstance(v, dt.datetime) else v for k, v in serialized.items()}
return data
@@ -96,34 +72,28 @@ class UniversalBaseModel(pydantic.BaseModel):
json_encoders = {dt.datetime: serialize_datetime}
@classmethod
def model_construct(
cls: typing.Type["Model"], _fields_set: typing.Optional[typing.Set[str]] = None, **values: typing.Any
) -> "Model":
def model_construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
return cls.construct(_fields_set, **dealiased_object)
@classmethod
def construct(
cls: typing.Type["Model"], _fields_set: typing.Optional[typing.Set[str]] = None, **values: typing.Any
) -> "Model":
def construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
if IS_PYDANTIC_V2:
return super().model_construct(_fields_set, **dealiased_object) # type: ignore # Pydantic v2
else:
return super().construct(_fields_set, **dealiased_object)
return super().model_construct(_fields_set, **dealiased_object) # type: ignore[misc]
return super().construct(_fields_set, **dealiased_object)
def json(self, **kwargs: typing.Any) -> str:
kwargs_with_defaults: typing.Any = {
def json(self, **kwargs: Any) -> str:
kwargs_with_defaults = {
"by_alias": True,
"exclude_unset": True,
**kwargs,
}
if IS_PYDANTIC_V2:
return super().model_dump_json(**kwargs_with_defaults) # type: ignore # Pydantic v2
else:
return super().json(**kwargs_with_defaults)
return super().model_dump_json(**kwargs_with_defaults) # type: ignore[misc]
return super().json(**kwargs_with_defaults)
def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]:
def dict(self, **kwargs: Any) -> Dict[str, Any]:
"""
Override the default dict method to `exclude_unset` by default. This function patches
`exclude_unset` to work include fields within non-None default values.
@@ -134,21 +104,21 @@ class UniversalBaseModel(pydantic.BaseModel):
# We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
# that we have less control over, and this is less intrusive than custom serializers for now.
if IS_PYDANTIC_V2:
kwargs_with_defaults_exclude_unset: typing.Any = {
kwargs_with_defaults_exclude_unset = {
**kwargs,
"by_alias": True,
"exclude_unset": True,
"exclude_none": False,
}
kwargs_with_defaults_exclude_none: typing.Any = {
kwargs_with_defaults_exclude_none = {
**kwargs,
"by_alias": True,
"exclude_none": True,
"exclude_unset": False,
}
dict_dump = deep_union_pydantic_dicts(
super().model_dump(**kwargs_with_defaults_exclude_unset), # type: ignore # Pydantic v2
super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore # Pydantic v2
super().model_dump(**kwargs_with_defaults_exclude_unset), # type: ignore[misc]
super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore[misc]
)
else:
@@ -168,7 +138,7 @@ class UniversalBaseModel(pydantic.BaseModel):
if default is not None:
self.__fields_set__.add(name)
kwargs_with_defaults_exclude_unset_include_fields: typing.Any = {
kwargs_with_defaults_exclude_unset_include_fields = {
"by_alias": True,
"exclude_unset": True,
"include": _fields_set,
@@ -177,15 +147,16 @@ class UniversalBaseModel(pydantic.BaseModel):
dict_dump = super().dict(**kwargs_with_defaults_exclude_unset_include_fields)
return convert_and_respect_annotation_metadata(object_=dict_dump, annotation=self.__class__, direction="write")
return cast(
Dict[str, Any],
convert_and_respect_annotation_metadata(object_=dict_dump, annotation=self.__class__, direction="write"),
)
def _union_list_of_pydantic_dicts(
source: typing.List[typing.Any], destination: typing.List[typing.Any]
) -> typing.List[typing.Any]:
converted_list: typing.List[typing.Any] = []
def _union_list_of_pydantic_dicts(source: List[Any], destination: List[Any]) -> List[Any]:
converted_list: List[Any] = []
for i, item in enumerate(source):
destination_value = destination[i] # type: ignore
destination_value = destination[i]
if isinstance(item, dict):
converted_list.append(deep_union_pydantic_dicts(item, destination_value))
elif isinstance(item, list):
@@ -195,9 +166,7 @@ def _union_list_of_pydantic_dicts(
return converted_list
def deep_union_pydantic_dicts(
source: typing.Dict[str, typing.Any], destination: typing.Dict[str, typing.Any]
) -> typing.Dict[str, typing.Any]:
def deep_union_pydantic_dicts(source: Dict[str, Any], destination: Dict[str, Any]) -> Dict[str, Any]:
for key, value in source.items():
node = destination.setdefault(key, {})
if isinstance(value, dict):
@@ -215,18 +184,16 @@ def deep_union_pydantic_dicts(
if IS_PYDANTIC_V2:
class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore # Pydantic v2
class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore[misc, name-defined, type-arg]
pass
UniversalRootModel: typing_extensions.TypeAlias = V2RootModel # type: ignore
UniversalRootModel: TypeAlias = V2RootModel # type: ignore[misc]
else:
UniversalRootModel: typing_extensions.TypeAlias = UniversalBaseModel # type: ignore
UniversalRootModel: TypeAlias = UniversalBaseModel # type: ignore[misc, no-redef]
def encode_by_type(o: typing.Any) -> typing.Any:
encoders_by_class_tuples: typing.Dict[typing.Callable[[typing.Any], typing.Any], typing.Tuple[typing.Any, ...]] = (
defaultdict(tuple)
)
def encode_by_type(o: Any) -> Any:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple)
for type_, encoder in encoders_by_type.items():
encoders_by_class_tuples[encoder] += (type_,)
@@ -237,54 +204,49 @@ def encode_by_type(o: typing.Any) -> typing.Any:
return encoder(o)
def update_forward_refs(model: typing.Type["Model"], **localns: typing.Any) -> None:
def update_forward_refs(model: Type["Model"], **localns: Any) -> None:
if IS_PYDANTIC_V2:
model.model_rebuild(raise_errors=False) # type: ignore # Pydantic v2
model.model_rebuild(raise_errors=False) # type: ignore[attr-defined]
else:
model.update_forward_refs(**localns)
# Mirrors Pydantic's internal typing
AnyCallable = typing.Callable[..., typing.Any]
AnyCallable = Callable[..., Any]
def universal_root_validator(
pre: bool = False,
) -> typing.Callable[[AnyCallable], AnyCallable]:
) -> Callable[[AnyCallable], AnyCallable]:
def decorator(func: AnyCallable) -> AnyCallable:
if IS_PYDANTIC_V2:
return pydantic.model_validator(mode="before" if pre else "after")(func) # type: ignore # Pydantic v2
else:
return pydantic.root_validator(pre=pre)(func) # type: ignore # Pydantic v1
return cast(AnyCallable, pydantic.model_validator(mode="before" if pre else "after")(func)) # type: ignore[attr-defined]
return cast(AnyCallable, pydantic.root_validator(pre=pre)(func)) # type: ignore[call-overload]
return decorator
def universal_field_validator(field_name: str, pre: bool = False) -> typing.Callable[[AnyCallable], AnyCallable]:
def universal_field_validator(field_name: str, pre: bool = False) -> Callable[[AnyCallable], AnyCallable]:
def decorator(func: AnyCallable) -> AnyCallable:
if IS_PYDANTIC_V2:
return pydantic.field_validator(field_name, mode="before" if pre else "after")(func) # type: ignore # Pydantic v2
else:
return pydantic.validator(field_name, pre=pre)(func) # type: ignore # Pydantic v1
return cast(AnyCallable, pydantic.field_validator(field_name, mode="before" if pre else "after")(func)) # type: ignore[attr-defined]
return cast(AnyCallable, pydantic.validator(field_name, pre=pre)(func))
return decorator
PydanticField = typing.Union[ModelField, pydantic.fields.FieldInfo]
PydanticField = Union[ModelField, pydantic.fields.FieldInfo]
def _get_model_fields(
model: typing.Type["Model"],
) -> typing.Mapping[str, PydanticField]:
def _get_model_fields(model: Type["Model"]) -> Mapping[str, PydanticField]:
if IS_PYDANTIC_V2:
return model.model_fields # type: ignore # Pydantic v2
else:
return model.__fields__ # type: ignore # Pydantic v1
return cast(Mapping[str, PydanticField], model.model_fields) # type: ignore[attr-defined]
return cast(Mapping[str, PydanticField], model.__fields__)
def _get_field_default(field: PydanticField) -> typing.Any:
def _get_field_default(field: PydanticField) -> Any:
try:
value = field.get_default() # type: ignore # Pydantic < v1.10.15
value = field.get_default() # type: ignore[union-attr]
except:
value = field.default
if IS_PYDANTIC_V2:

View File

@@ -4,9 +4,8 @@ import collections
import inspect
import typing
import typing_extensions
import pydantic
import typing_extensions
class FieldMetadata:
@@ -161,7 +160,12 @@ def _convert_mapping(
direction: typing.Literal["read", "write"],
) -> typing.Mapping[str, object]:
converted_object: typing.Dict[str, object] = {}
annotations = typing_extensions.get_type_hints(expected_type, include_extras=True)
try:
annotations = typing_extensions.get_type_hints(expected_type, include_extras=True)
except NameError:
# The TypedDict contains a circular reference, so
# we use the __annotations__ attribute directly.
annotations = getattr(expected_type, "__annotations__", {})
aliases_to_field_names = _get_alias_to_field_name(annotations)
for key, value in object_.items():
if direction == "read" and key in aliases_to_field_names: