Fern: regenerated Python SDK (#3829)
This commit is contained in:
committed by
GitHub
parent
c12c047768
commit
ba0b25cb4b
@@ -1,33 +1,91 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
from .api_error import ApiError
|
||||
from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
|
||||
from .datetime_utils import serialize_datetime
|
||||
from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
|
||||
from .http_client import AsyncHttpClient, HttpClient
|
||||
from .jsonable_encoder import jsonable_encoder
|
||||
from .pydantic_utilities import (
|
||||
IS_PYDANTIC_V2,
|
||||
UniversalBaseModel,
|
||||
UniversalRootModel,
|
||||
parse_obj_as,
|
||||
universal_field_validator,
|
||||
universal_root_validator,
|
||||
update_forward_refs,
|
||||
)
|
||||
from .query_encoder import encode_query
|
||||
from .remove_none_from_dict import remove_none_from_dict
|
||||
from .request_options import RequestOptions
|
||||
from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
|
||||
# isort: skip_file
|
||||
|
||||
import typing
|
||||
from importlib import import_module
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .api_error import ApiError
|
||||
from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
|
||||
from .datetime_utils import serialize_datetime
|
||||
from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
|
||||
from .http_client import AsyncHttpClient, HttpClient
|
||||
from .http_response import AsyncHttpResponse, HttpResponse
|
||||
from .jsonable_encoder import jsonable_encoder
|
||||
from .pydantic_utilities import (
|
||||
IS_PYDANTIC_V2,
|
||||
UniversalBaseModel,
|
||||
UniversalRootModel,
|
||||
parse_obj_as,
|
||||
universal_field_validator,
|
||||
universal_root_validator,
|
||||
update_forward_refs,
|
||||
)
|
||||
from .query_encoder import encode_query
|
||||
from .remove_none_from_dict import remove_none_from_dict
|
||||
from .request_options import RequestOptions
|
||||
from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
|
||||
_dynamic_imports: typing.Dict[str, str] = {
|
||||
"ApiError": ".api_error",
|
||||
"AsyncClientWrapper": ".client_wrapper",
|
||||
"AsyncHttpClient": ".http_client",
|
||||
"AsyncHttpResponse": ".http_response",
|
||||
"BaseClientWrapper": ".client_wrapper",
|
||||
"FieldMetadata": ".serialization",
|
||||
"File": ".file",
|
||||
"HttpClient": ".http_client",
|
||||
"HttpResponse": ".http_response",
|
||||
"IS_PYDANTIC_V2": ".pydantic_utilities",
|
||||
"RequestOptions": ".request_options",
|
||||
"SyncClientWrapper": ".client_wrapper",
|
||||
"UniversalBaseModel": ".pydantic_utilities",
|
||||
"UniversalRootModel": ".pydantic_utilities",
|
||||
"convert_and_respect_annotation_metadata": ".serialization",
|
||||
"convert_file_dict_to_httpx_tuples": ".file",
|
||||
"encode_query": ".query_encoder",
|
||||
"jsonable_encoder": ".jsonable_encoder",
|
||||
"parse_obj_as": ".pydantic_utilities",
|
||||
"remove_none_from_dict": ".remove_none_from_dict",
|
||||
"serialize_datetime": ".datetime_utils",
|
||||
"universal_field_validator": ".pydantic_utilities",
|
||||
"universal_root_validator": ".pydantic_utilities",
|
||||
"update_forward_refs": ".pydantic_utilities",
|
||||
"with_content_type": ".file",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(attr_name: str) -> typing.Any:
|
||||
module_name = _dynamic_imports.get(attr_name)
|
||||
if module_name is None:
|
||||
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
|
||||
try:
|
||||
module = import_module(module_name, __package__)
|
||||
if module_name == f".{attr_name}":
|
||||
return module
|
||||
else:
|
||||
return getattr(module, attr_name)
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
|
||||
|
||||
|
||||
def __dir__():
|
||||
lazy_attrs = list(_dynamic_imports.keys())
|
||||
return sorted(lazy_attrs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ApiError",
|
||||
"AsyncClientWrapper",
|
||||
"AsyncHttpClient",
|
||||
"AsyncHttpResponse",
|
||||
"BaseClientWrapper",
|
||||
"FieldMetadata",
|
||||
"File",
|
||||
"HttpClient",
|
||||
"HttpResponse",
|
||||
"IS_PYDANTIC_V2",
|
||||
"RequestOptions",
|
||||
"SyncClientWrapper",
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
import typing
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ApiError(Exception):
|
||||
status_code: typing.Optional[int]
|
||||
body: typing.Any
|
||||
headers: Optional[Dict[str, str]]
|
||||
status_code: Optional[int]
|
||||
body: Any
|
||||
|
||||
def __init__(self, *, status_code: typing.Optional[int] = None, body: typing.Any = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
status_code: Optional[int] = None,
|
||||
body: Any = None,
|
||||
) -> None:
|
||||
self.headers = headers
|
||||
self.status_code = status_code
|
||||
self.body = body
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"status_code: {self.status_code}, body: {self.body}"
|
||||
return f"headers: {self.headers}, status_code: {self.status_code}, body: {self.body}"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
import typing
|
||||
|
||||
import httpx
|
||||
from .http_client import HttpClient
|
||||
from .http_client import AsyncHttpClient
|
||||
from .http_client import AsyncHttpClient, HttpClient
|
||||
|
||||
|
||||
class BaseClientWrapper:
|
||||
@@ -11,26 +11,30 @@ class BaseClientWrapper:
|
||||
self,
|
||||
*,
|
||||
api_key: typing.Optional[str] = None,
|
||||
x_api_key: str,
|
||||
headers: typing.Optional[typing.Dict[str, str]] = None,
|
||||
base_url: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self.x_api_key = x_api_key
|
||||
self._headers = headers
|
||||
self._base_url = base_url
|
||||
self._timeout = timeout
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
headers: typing.Dict[str, str] = {
|
||||
"User-Agent": "skyvern/0.2.20",
|
||||
"X-Fern-Language": "Python",
|
||||
"X-Fern-SDK-Name": "skyvern",
|
||||
"X-Fern-SDK-Version": "0.2.18",
|
||||
"X-Fern-SDK-Version": "0.2.20",
|
||||
**(self.get_custom_headers() or {}),
|
||||
}
|
||||
if self._api_key is not None:
|
||||
headers["x-api-key"] = self._api_key
|
||||
headers["x-api-key"] = self.x_api_key
|
||||
return headers
|
||||
|
||||
def get_custom_headers(self) -> typing.Optional[typing.Dict[str, str]]:
|
||||
return self._headers
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
||||
@@ -43,12 +47,12 @@ class SyncClientWrapper(BaseClientWrapper):
|
||||
self,
|
||||
*,
|
||||
api_key: typing.Optional[str] = None,
|
||||
x_api_key: str,
|
||||
headers: typing.Optional[typing.Dict[str, str]] = None,
|
||||
base_url: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
httpx_client: httpx.Client,
|
||||
):
|
||||
super().__init__(api_key=api_key, x_api_key=x_api_key, base_url=base_url, timeout=timeout)
|
||||
super().__init__(api_key=api_key, headers=headers, base_url=base_url, timeout=timeout)
|
||||
self.httpx_client = HttpClient(
|
||||
httpx_client=httpx_client,
|
||||
base_headers=self.get_headers,
|
||||
@@ -62,12 +66,12 @@ class AsyncClientWrapper(BaseClientWrapper):
|
||||
self,
|
||||
*,
|
||||
api_key: typing.Optional[str] = None,
|
||||
x_api_key: str,
|
||||
headers: typing.Optional[typing.Dict[str, str]] = None,
|
||||
base_url: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
httpx_client: httpx.AsyncClient,
|
||||
):
|
||||
super().__init__(api_key=api_key, x_api_key=x_api_key, base_url=base_url, timeout=timeout)
|
||||
super().__init__(api_key=api_key, headers=headers, base_url=base_url, timeout=timeout)
|
||||
self.httpx_client = AsyncHttpClient(
|
||||
httpx_client=httpx_client,
|
||||
base_headers=self.get_headers,
|
||||
|
||||
18
skyvern/client/core/force_multipart.py
Normal file
18
skyvern/client/core/force_multipart.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class ForceMultipartDict(Dict[str, Any]):
|
||||
"""
|
||||
A dictionary subclass that always evaluates to True in boolean contexts.
|
||||
|
||||
This is used to force multipart/form-data encoding in HTTP requests even when
|
||||
the dictionary is empty, which would normally evaluate to False.
|
||||
"""
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
FORCE_MULTIPART = ForceMultipartDict()
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
@@ -11,12 +10,13 @@ from contextlib import asynccontextmanager, contextmanager
|
||||
from random import random
|
||||
|
||||
import httpx
|
||||
|
||||
from .file import File, convert_file_dict_to_httpx_tuples
|
||||
from .force_multipart import FORCE_MULTIPART
|
||||
from .jsonable_encoder import jsonable_encoder
|
||||
from .query_encoder import encode_query
|
||||
from .remove_none_from_dict import remove_none_from_dict
|
||||
from .request_options import RequestOptions
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
INITIAL_RETRY_DELAY_SECONDS = 0.5
|
||||
MAX_RETRY_DELAY_SECONDS = 10
|
||||
@@ -180,11 +180,17 @@ class HttpClient:
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
data: typing.Optional[typing.Any] = None,
|
||||
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
|
||||
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
|
||||
files: typing.Optional[
|
||||
typing.Union[
|
||||
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
|
||||
typing.List[typing.Tuple[str, File]],
|
||||
]
|
||||
] = None,
|
||||
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
request_options: typing.Optional[RequestOptions] = None,
|
||||
retries: int = 2,
|
||||
omit: typing.Optional[typing.Any] = None,
|
||||
force_multipart: typing.Optional[bool] = None,
|
||||
) -> httpx.Response:
|
||||
base_url = self.get_base_url(base_url)
|
||||
timeout = (
|
||||
@@ -195,6 +201,15 @@ class HttpClient:
|
||||
|
||||
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
|
||||
|
||||
request_files: typing.Optional[RequestFiles] = (
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if (files is not None and files is not omit and isinstance(files, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if (request_files is None or len(request_files) == 0) and force_multipart:
|
||||
request_files = FORCE_MULTIPART
|
||||
|
||||
response = self.httpx_client.request(
|
||||
method=method,
|
||||
url=urllib.parse.urljoin(f"{base_url}/", path),
|
||||
@@ -227,11 +242,7 @@ class HttpClient:
|
||||
json=json_body,
|
||||
data=data_body,
|
||||
content=content,
|
||||
files=(
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if (files is not None and files is not omit)
|
||||
else None
|
||||
),
|
||||
files=request_files,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@@ -266,11 +277,17 @@ class HttpClient:
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
data: typing.Optional[typing.Any] = None,
|
||||
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
|
||||
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
|
||||
files: typing.Optional[
|
||||
typing.Union[
|
||||
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
|
||||
typing.List[typing.Tuple[str, File]],
|
||||
]
|
||||
] = None,
|
||||
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
request_options: typing.Optional[RequestOptions] = None,
|
||||
retries: int = 2,
|
||||
omit: typing.Optional[typing.Any] = None,
|
||||
force_multipart: typing.Optional[bool] = None,
|
||||
) -> typing.Iterator[httpx.Response]:
|
||||
base_url = self.get_base_url(base_url)
|
||||
timeout = (
|
||||
@@ -279,6 +296,15 @@ class HttpClient:
|
||||
else self.base_timeout()
|
||||
)
|
||||
|
||||
request_files: typing.Optional[RequestFiles] = (
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if (files is not None and files is not omit and isinstance(files, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if (request_files is None or len(request_files) == 0) and force_multipart:
|
||||
request_files = FORCE_MULTIPART
|
||||
|
||||
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
|
||||
|
||||
with self.httpx_client.stream(
|
||||
@@ -313,11 +339,7 @@ class HttpClient:
|
||||
json=json_body,
|
||||
data=data_body,
|
||||
content=content,
|
||||
files=(
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if (files is not None and files is not omit)
|
||||
else None
|
||||
),
|
||||
files=request_files,
|
||||
timeout=timeout,
|
||||
) as stream:
|
||||
yield stream
|
||||
@@ -356,11 +378,17 @@ class AsyncHttpClient:
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
data: typing.Optional[typing.Any] = None,
|
||||
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
|
||||
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
|
||||
files: typing.Optional[
|
||||
typing.Union[
|
||||
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
|
||||
typing.List[typing.Tuple[str, File]],
|
||||
]
|
||||
] = None,
|
||||
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
request_options: typing.Optional[RequestOptions] = None,
|
||||
retries: int = 2,
|
||||
omit: typing.Optional[typing.Any] = None,
|
||||
force_multipart: typing.Optional[bool] = None,
|
||||
) -> httpx.Response:
|
||||
base_url = self.get_base_url(base_url)
|
||||
timeout = (
|
||||
@@ -369,6 +397,15 @@ class AsyncHttpClient:
|
||||
else self.base_timeout()
|
||||
)
|
||||
|
||||
request_files: typing.Optional[RequestFiles] = (
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if (files is not None and files is not omit and isinstance(files, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if (request_files is None or len(request_files) == 0) and force_multipart:
|
||||
request_files = FORCE_MULTIPART
|
||||
|
||||
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
|
||||
|
||||
# Add the input to each of these and do None-safety checks
|
||||
@@ -404,11 +441,7 @@ class AsyncHttpClient:
|
||||
json=json_body,
|
||||
data=data_body,
|
||||
content=content,
|
||||
files=(
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if files is not None
|
||||
else None
|
||||
),
|
||||
files=request_files,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@@ -442,11 +475,17 @@ class AsyncHttpClient:
|
||||
json: typing.Optional[typing.Any] = None,
|
||||
data: typing.Optional[typing.Any] = None,
|
||||
content: typing.Optional[typing.Union[bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]]] = None,
|
||||
files: typing.Optional[typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]]] = None,
|
||||
files: typing.Optional[
|
||||
typing.Union[
|
||||
typing.Dict[str, typing.Optional[typing.Union[File, typing.List[File]]]],
|
||||
typing.List[typing.Tuple[str, File]],
|
||||
]
|
||||
] = None,
|
||||
headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
||||
request_options: typing.Optional[RequestOptions] = None,
|
||||
retries: int = 2,
|
||||
omit: typing.Optional[typing.Any] = None,
|
||||
force_multipart: typing.Optional[bool] = None,
|
||||
) -> typing.AsyncIterator[httpx.Response]:
|
||||
base_url = self.get_base_url(base_url)
|
||||
timeout = (
|
||||
@@ -455,6 +494,15 @@ class AsyncHttpClient:
|
||||
else self.base_timeout()
|
||||
)
|
||||
|
||||
request_files: typing.Optional[RequestFiles] = (
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if (files is not None and files is not omit and isinstance(files, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if (request_files is None or len(request_files) == 0) and force_multipart:
|
||||
request_files = FORCE_MULTIPART
|
||||
|
||||
json_body, data_body = get_request_body(json=json, data=data, request_options=request_options, omit=omit)
|
||||
|
||||
async with self.httpx_client.stream(
|
||||
@@ -489,11 +537,7 @@ class AsyncHttpClient:
|
||||
json=json_body,
|
||||
data=data_body,
|
||||
content=content,
|
||||
files=(
|
||||
convert_file_dict_to_httpx_tuples(remove_omit_from_dict(remove_none_from_dict(files), omit))
|
||||
if files is not None
|
||||
else None
|
||||
),
|
||||
files=request_files,
|
||||
timeout=timeout,
|
||||
) as stream:
|
||||
yield stream
|
||||
|
||||
55
skyvern/client/core/http_response.py
Normal file
55
skyvern/client/core/http_response.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
from typing import Dict, Generic, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
# Generic to represent the underlying type of the data wrapped by the HTTP response.
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseHttpResponse:
|
||||
"""Minimalist HTTP response wrapper that exposes response headers."""
|
||||
|
||||
_response: httpx.Response
|
||||
|
||||
def __init__(self, response: httpx.Response):
|
||||
self._response = response
|
||||
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
return dict(self._response.headers)
|
||||
|
||||
|
||||
class HttpResponse(Generic[T], BaseHttpResponse):
|
||||
"""HTTP response wrapper that exposes response headers and data."""
|
||||
|
||||
_data: T
|
||||
|
||||
def __init__(self, response: httpx.Response, data: T):
|
||||
super().__init__(response)
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> T:
|
||||
return self._data
|
||||
|
||||
def close(self) -> None:
|
||||
self._response.close()
|
||||
|
||||
|
||||
class AsyncHttpResponse(Generic[T], BaseHttpResponse):
|
||||
"""HTTP response wrapper that exposes response headers and data."""
|
||||
|
||||
_data: T
|
||||
|
||||
def __init__(self, response: httpx.Response, data: T):
|
||||
super().__init__(response)
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> T:
|
||||
return self._data
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._response.aclose()
|
||||
42
skyvern/client/core/http_sse/__init__.py
Normal file
42
skyvern/client/core/http_sse/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
# isort: skip_file
|
||||
|
||||
import typing
|
||||
from importlib import import_module
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._api import EventSource, aconnect_sse, connect_sse
|
||||
from ._exceptions import SSEError
|
||||
from ._models import ServerSentEvent
|
||||
_dynamic_imports: typing.Dict[str, str] = {
|
||||
"EventSource": "._api",
|
||||
"SSEError": "._exceptions",
|
||||
"ServerSentEvent": "._models",
|
||||
"aconnect_sse": "._api",
|
||||
"connect_sse": "._api",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(attr_name: str) -> typing.Any:
|
||||
module_name = _dynamic_imports.get(attr_name)
|
||||
if module_name is None:
|
||||
raise AttributeError(f"No {attr_name} found in _dynamic_imports for module name -> {__name__}")
|
||||
try:
|
||||
module = import_module(module_name, __package__)
|
||||
if module_name == f".{attr_name}":
|
||||
return module
|
||||
else:
|
||||
return getattr(module, attr_name)
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import {attr_name} from {module_name}: {e}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"Failed to get {attr_name} from {module_name}: {e}") from e
|
||||
|
||||
|
||||
def __dir__():
|
||||
lazy_attrs = list(_dynamic_imports.keys())
|
||||
return sorted(lazy_attrs)
|
||||
|
||||
|
||||
__all__ = ["EventSource", "SSEError", "ServerSentEvent", "aconnect_sse", "connect_sse"]
|
||||
112
skyvern/client/core/http_sse/_api.py
Normal file
112
skyvern/client/core/http_sse/_api.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
import re
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Iterator, cast
|
||||
|
||||
import httpx
|
||||
from ._decoders import SSEDecoder
|
||||
from ._exceptions import SSEError
|
||||
from ._models import ServerSentEvent
|
||||
|
||||
|
||||
class EventSource:
|
||||
def __init__(self, response: httpx.Response) -> None:
|
||||
self._response = response
|
||||
|
||||
def _check_content_type(self) -> None:
|
||||
content_type = self._response.headers.get("content-type", "").partition(";")[0]
|
||||
if "text/event-stream" not in content_type:
|
||||
raise SSEError(
|
||||
f"Expected response header Content-Type to contain 'text/event-stream', got {content_type!r}"
|
||||
)
|
||||
|
||||
def _get_charset(self) -> str:
|
||||
"""Extract charset from Content-Type header, fallback to UTF-8."""
|
||||
content_type = self._response.headers.get("content-type", "")
|
||||
|
||||
# Parse charset parameter using regex
|
||||
charset_match = re.search(r"charset=([^;\s]+)", content_type, re.IGNORECASE)
|
||||
if charset_match:
|
||||
charset = charset_match.group(1).strip("\"'")
|
||||
# Validate that it's a known encoding
|
||||
try:
|
||||
# Test if the charset is valid by trying to encode/decode
|
||||
"test".encode(charset).decode(charset)
|
||||
return charset
|
||||
except (LookupError, UnicodeError):
|
||||
# If charset is invalid, fall back to UTF-8
|
||||
pass
|
||||
|
||||
# Default to UTF-8 if no charset specified or invalid charset
|
||||
return "utf-8"
|
||||
|
||||
@property
|
||||
def response(self) -> httpx.Response:
|
||||
return self._response
|
||||
|
||||
def iter_sse(self) -> Iterator[ServerSentEvent]:
|
||||
self._check_content_type()
|
||||
decoder = SSEDecoder()
|
||||
charset = self._get_charset()
|
||||
|
||||
buffer = ""
|
||||
for chunk in self._response.iter_bytes():
|
||||
# Decode chunk using detected charset
|
||||
text_chunk = chunk.decode(charset, errors="replace")
|
||||
buffer += text_chunk
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.rstrip("\r")
|
||||
sse = decoder.decode(line)
|
||||
# when we reach a "\n\n" => line = ''
|
||||
# => decoder will attempt to return an SSE Event
|
||||
if sse is not None:
|
||||
yield sse
|
||||
|
||||
# Process any remaining data in buffer
|
||||
if buffer.strip():
|
||||
line = buffer.rstrip("\r")
|
||||
sse = decoder.decode(line)
|
||||
if sse is not None:
|
||||
yield sse
|
||||
|
||||
async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
self._check_content_type()
|
||||
decoder = SSEDecoder()
|
||||
lines = cast(AsyncGenerator[str, None], self._response.aiter_lines())
|
||||
try:
|
||||
async for line in lines:
|
||||
line = line.rstrip("\n")
|
||||
sse = decoder.decode(line)
|
||||
if sse is not None:
|
||||
yield sse
|
||||
finally:
|
||||
await lines.aclose()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any) -> Iterator[EventSource]:
|
||||
headers = kwargs.pop("headers", {})
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["Cache-Control"] = "no-store"
|
||||
|
||||
with client.stream(method, url, headers=headers, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aconnect_sse(
|
||||
client: httpx.AsyncClient,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[EventSource]:
|
||||
headers = kwargs.pop("headers", {})
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["Cache-Control"] = "no-store"
|
||||
|
||||
async with client.stream(method, url, headers=headers, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
61
skyvern/client/core/http_sse/_decoders.py
Normal file
61
skyvern/client/core/http_sse/_decoders.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ._models import ServerSentEvent
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
def __init__(self) -> None:
|
||||
self._event = ""
|
||||
self._data: List[str] = []
|
||||
self._last_event_id = ""
|
||||
self._retry: Optional[int] = None
|
||||
|
||||
def decode(self, line: str) -> Optional[ServerSentEvent]:
|
||||
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
|
||||
|
||||
if not line:
|
||||
if not self._event and not self._data and not self._last_event_id and self._retry is None:
|
||||
return None
|
||||
|
||||
sse = ServerSentEvent(
|
||||
event=self._event,
|
||||
data="\n".join(self._data),
|
||||
id=self._last_event_id,
|
||||
retry=self._retry,
|
||||
)
|
||||
|
||||
# NOTE: as per the SSE spec, do not reset last_event_id.
|
||||
self._event = ""
|
||||
self._data = []
|
||||
self._retry = None
|
||||
|
||||
return sse
|
||||
|
||||
if line.startswith(":"):
|
||||
return None
|
||||
|
||||
fieldname, _, value = line.partition(":")
|
||||
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
|
||||
if fieldname == "event":
|
||||
self._event = value
|
||||
elif fieldname == "data":
|
||||
self._data.append(value)
|
||||
elif fieldname == "id":
|
||||
if "\0" in value:
|
||||
pass
|
||||
else:
|
||||
self._last_event_id = value
|
||||
elif fieldname == "retry":
|
||||
try:
|
||||
self._retry = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
pass # Field is ignored.
|
||||
|
||||
return None
|
||||
7
skyvern/client/core/http_sse/_exceptions.py
Normal file
7
skyvern/client/core/http_sse/_exceptions.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class SSEError(httpx.TransportError):
|
||||
pass
|
||||
17
skyvern/client/core/http_sse/_models.py
Normal file
17
skyvern/client/core/http_sse/_models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file was auto-generated by Fern from our API Definition.
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServerSentEvent:
|
||||
event: str = "message"
|
||||
data: str = ""
|
||||
id: str = ""
|
||||
retry: Optional[int] = None
|
||||
|
||||
def json(self) -> Any:
|
||||
"""Parse the data field as JSON."""
|
||||
return json.loads(self.data)
|
||||
@@ -17,7 +17,6 @@ from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import pydantic
|
||||
|
||||
from .datetime_utils import serialize_datetime
|
||||
from .pydantic_utilities import (
|
||||
IS_PYDANTIC_V2,
|
||||
|
||||
@@ -2,90 +2,66 @@
|
||||
|
||||
# nopycln: file
|
||||
import datetime as dt
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
|
||||
import typing_extensions
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast
|
||||
|
||||
import pydantic
|
||||
|
||||
from .datetime_utils import serialize_datetime
|
||||
from .serialization import convert_and_respect_annotation_metadata
|
||||
|
||||
IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
||||
|
||||
if IS_PYDANTIC_V2:
|
||||
# isort will try to reformat the comments on these imports, which breaks mypy
|
||||
# isort: off
|
||||
from pydantic.v1.datetime_parse import ( # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
parse_date as parse_date,
|
||||
)
|
||||
from pydantic.v1.datetime_parse import ( # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
parse_datetime as parse_datetime,
|
||||
)
|
||||
from pydantic.v1.json import ( # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
ENCODERS_BY_TYPE as encoders_by_type,
|
||||
)
|
||||
from pydantic.v1.typing import ( # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
get_args as get_args,
|
||||
)
|
||||
from pydantic.v1.typing import ( # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
get_origin as get_origin,
|
||||
)
|
||||
from pydantic.v1.typing import ( # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
is_literal_type as is_literal_type,
|
||||
)
|
||||
from pydantic.v1.typing import ( # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
is_union as is_union,
|
||||
)
|
||||
from pydantic.v1.fields import ModelField as ModelField # type: ignore # pyright: ignore[reportMissingImports] # Pydantic v2
|
||||
from pydantic.v1.datetime_parse import parse_date as parse_date
|
||||
from pydantic.v1.datetime_parse import parse_datetime as parse_datetime
|
||||
from pydantic.v1.fields import ModelField as ModelField
|
||||
from pydantic.v1.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore[attr-defined]
|
||||
from pydantic.v1.typing import get_args as get_args
|
||||
from pydantic.v1.typing import get_origin as get_origin
|
||||
from pydantic.v1.typing import is_literal_type as is_literal_type
|
||||
from pydantic.v1.typing import is_union as is_union
|
||||
else:
|
||||
from pydantic.datetime_parse import parse_date as parse_date # type: ignore # Pydantic v1
|
||||
from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore # Pydantic v1
|
||||
from pydantic.fields import ModelField as ModelField # type: ignore # Pydantic v1
|
||||
from pydantic.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore # Pydantic v1
|
||||
from pydantic.typing import get_args as get_args # type: ignore # Pydantic v1
|
||||
from pydantic.typing import get_origin as get_origin # type: ignore # Pydantic v1
|
||||
from pydantic.typing import is_literal_type as is_literal_type # type: ignore # Pydantic v1
|
||||
from pydantic.typing import is_union as is_union # type: ignore # Pydantic v1
|
||||
from pydantic.datetime_parse import parse_date as parse_date # type: ignore[no-redef]
|
||||
from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore[no-redef]
|
||||
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined, no-redef]
|
||||
from pydantic.json import ENCODERS_BY_TYPE as encoders_by_type # type: ignore[no-redef]
|
||||
from pydantic.typing import get_args as get_args # type: ignore[no-redef]
|
||||
from pydantic.typing import get_origin as get_origin # type: ignore[no-redef]
|
||||
from pydantic.typing import is_literal_type as is_literal_type # type: ignore[no-redef]
|
||||
from pydantic.typing import is_union as is_union # type: ignore[no-redef]
|
||||
|
||||
# isort: on
|
||||
from .datetime_utils import serialize_datetime
|
||||
from .serialization import convert_and_respect_annotation_metadata
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
T = TypeVar("T")
|
||||
Model = TypeVar("Model", bound=pydantic.BaseModel)
|
||||
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
Model = typing.TypeVar("Model", bound=pydantic.BaseModel)
|
||||
|
||||
|
||||
def parse_obj_as(type_: typing.Type[T], object_: typing.Any) -> T:
|
||||
def parse_obj_as(type_: Type[T], object_: Any) -> T:
|
||||
dealiased_object = convert_and_respect_annotation_metadata(object_=object_, annotation=type_, direction="read")
|
||||
if IS_PYDANTIC_V2:
|
||||
adapter = pydantic.TypeAdapter(type_) # type: ignore # Pydantic v2
|
||||
adapter = pydantic.TypeAdapter(type_) # type: ignore[attr-defined]
|
||||
return adapter.validate_python(dealiased_object)
|
||||
else:
|
||||
return pydantic.parse_obj_as(type_, dealiased_object)
|
||||
return pydantic.parse_obj_as(type_, dealiased_object)
|
||||
|
||||
|
||||
def to_jsonable_with_fallback(
|
||||
obj: typing.Any, fallback_serializer: typing.Callable[[typing.Any], typing.Any]
|
||||
) -> typing.Any:
|
||||
def to_jsonable_with_fallback(obj: Any, fallback_serializer: Callable[[Any], Any]) -> Any:
|
||||
if IS_PYDANTIC_V2:
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
return to_jsonable_python(obj, fallback=fallback_serializer)
|
||||
else:
|
||||
return fallback_serializer(obj)
|
||||
return fallback_serializer(obj)
|
||||
|
||||
|
||||
class UniversalBaseModel(pydantic.BaseModel):
|
||||
if IS_PYDANTIC_V2:
|
||||
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(
|
||||
model_config: ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict( # type: ignore[typeddict-unknown-key]
|
||||
# Allow fields beginning with `model_` to be used in the model
|
||||
protected_namespaces=(),
|
||||
) # type: ignore # Pydantic v2
|
||||
)
|
||||
|
||||
@pydantic.model_serializer(mode="wrap", when_used="json") # type: ignore # Pydantic v2
|
||||
def serialize_model(self, handler: pydantic.SerializerFunctionWrapHandler) -> typing.Any: # type: ignore # Pydantic v2
|
||||
serialized = handler(self)
|
||||
@pydantic.model_serializer(mode="plain", when_used="json") # type: ignore[attr-defined]
|
||||
def serialize_model(self) -> Any: # type: ignore[name-defined]
|
||||
serialized = self.dict() # type: ignore[attr-defined]
|
||||
data = {k: serialize_datetime(v) if isinstance(v, dt.datetime) else v for k, v in serialized.items()}
|
||||
return data
|
||||
|
||||
@@ -96,34 +72,28 @@ class UniversalBaseModel(pydantic.BaseModel):
|
||||
json_encoders = {dt.datetime: serialize_datetime}
|
||||
|
||||
@classmethod
|
||||
def model_construct(
|
||||
cls: typing.Type["Model"], _fields_set: typing.Optional[typing.Set[str]] = None, **values: typing.Any
|
||||
) -> "Model":
|
||||
def model_construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
|
||||
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
|
||||
return cls.construct(_fields_set, **dealiased_object)
|
||||
|
||||
@classmethod
|
||||
def construct(
|
||||
cls: typing.Type["Model"], _fields_set: typing.Optional[typing.Set[str]] = None, **values: typing.Any
|
||||
) -> "Model":
|
||||
def construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
|
||||
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
|
||||
if IS_PYDANTIC_V2:
|
||||
return super().model_construct(_fields_set, **dealiased_object) # type: ignore # Pydantic v2
|
||||
else:
|
||||
return super().construct(_fields_set, **dealiased_object)
|
||||
return super().model_construct(_fields_set, **dealiased_object) # type: ignore[misc]
|
||||
return super().construct(_fields_set, **dealiased_object)
|
||||
|
||||
def json(self, **kwargs: typing.Any) -> str:
|
||||
kwargs_with_defaults: typing.Any = {
|
||||
def json(self, **kwargs: Any) -> str:
|
||||
kwargs_with_defaults = {
|
||||
"by_alias": True,
|
||||
"exclude_unset": True,
|
||||
**kwargs,
|
||||
}
|
||||
if IS_PYDANTIC_V2:
|
||||
return super().model_dump_json(**kwargs_with_defaults) # type: ignore # Pydantic v2
|
||||
else:
|
||||
return super().json(**kwargs_with_defaults)
|
||||
return super().model_dump_json(**kwargs_with_defaults) # type: ignore[misc]
|
||||
return super().json(**kwargs_with_defaults)
|
||||
|
||||
def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]:
|
||||
def dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Override the default dict method to `exclude_unset` by default. This function patches
|
||||
`exclude_unset` to work include fields within non-None default values.
|
||||
@@ -134,21 +104,21 @@ class UniversalBaseModel(pydantic.BaseModel):
|
||||
# We'd ideally do the same for Pydantic V2, but it shells out to a library to serialize models
|
||||
# that we have less control over, and this is less intrusive than custom serializers for now.
|
||||
if IS_PYDANTIC_V2:
|
||||
kwargs_with_defaults_exclude_unset: typing.Any = {
|
||||
kwargs_with_defaults_exclude_unset = {
|
||||
**kwargs,
|
||||
"by_alias": True,
|
||||
"exclude_unset": True,
|
||||
"exclude_none": False,
|
||||
}
|
||||
kwargs_with_defaults_exclude_none: typing.Any = {
|
||||
kwargs_with_defaults_exclude_none = {
|
||||
**kwargs,
|
||||
"by_alias": True,
|
||||
"exclude_none": True,
|
||||
"exclude_unset": False,
|
||||
}
|
||||
dict_dump = deep_union_pydantic_dicts(
|
||||
super().model_dump(**kwargs_with_defaults_exclude_unset), # type: ignore # Pydantic v2
|
||||
super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore # Pydantic v2
|
||||
super().model_dump(**kwargs_with_defaults_exclude_unset), # type: ignore[misc]
|
||||
super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore[misc]
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -168,7 +138,7 @@ class UniversalBaseModel(pydantic.BaseModel):
|
||||
if default is not None:
|
||||
self.__fields_set__.add(name)
|
||||
|
||||
kwargs_with_defaults_exclude_unset_include_fields: typing.Any = {
|
||||
kwargs_with_defaults_exclude_unset_include_fields = {
|
||||
"by_alias": True,
|
||||
"exclude_unset": True,
|
||||
"include": _fields_set,
|
||||
@@ -177,15 +147,16 @@ class UniversalBaseModel(pydantic.BaseModel):
|
||||
|
||||
dict_dump = super().dict(**kwargs_with_defaults_exclude_unset_include_fields)
|
||||
|
||||
return convert_and_respect_annotation_metadata(object_=dict_dump, annotation=self.__class__, direction="write")
|
||||
return cast(
|
||||
Dict[str, Any],
|
||||
convert_and_respect_annotation_metadata(object_=dict_dump, annotation=self.__class__, direction="write"),
|
||||
)
|
||||
|
||||
|
||||
def _union_list_of_pydantic_dicts(
|
||||
source: typing.List[typing.Any], destination: typing.List[typing.Any]
|
||||
) -> typing.List[typing.Any]:
|
||||
converted_list: typing.List[typing.Any] = []
|
||||
def _union_list_of_pydantic_dicts(source: List[Any], destination: List[Any]) -> List[Any]:
|
||||
converted_list: List[Any] = []
|
||||
for i, item in enumerate(source):
|
||||
destination_value = destination[i] # type: ignore
|
||||
destination_value = destination[i]
|
||||
if isinstance(item, dict):
|
||||
converted_list.append(deep_union_pydantic_dicts(item, destination_value))
|
||||
elif isinstance(item, list):
|
||||
@@ -195,9 +166,7 @@ def _union_list_of_pydantic_dicts(
|
||||
return converted_list
|
||||
|
||||
|
||||
def deep_union_pydantic_dicts(
|
||||
source: typing.Dict[str, typing.Any], destination: typing.Dict[str, typing.Any]
|
||||
) -> typing.Dict[str, typing.Any]:
|
||||
def deep_union_pydantic_dicts(source: Dict[str, Any], destination: Dict[str, Any]) -> Dict[str, Any]:
|
||||
for key, value in source.items():
|
||||
node = destination.setdefault(key, {})
|
||||
if isinstance(value, dict):
|
||||
@@ -215,18 +184,16 @@ def deep_union_pydantic_dicts(
|
||||
|
||||
if IS_PYDANTIC_V2:
|
||||
|
||||
class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore # Pydantic v2
|
||||
class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore[misc, name-defined, type-arg]
|
||||
pass
|
||||
|
||||
UniversalRootModel: typing_extensions.TypeAlias = V2RootModel # type: ignore
|
||||
UniversalRootModel: TypeAlias = V2RootModel # type: ignore[misc]
|
||||
else:
|
||||
UniversalRootModel: typing_extensions.TypeAlias = UniversalBaseModel # type: ignore
|
||||
UniversalRootModel: TypeAlias = UniversalBaseModel # type: ignore[misc, no-redef]
|
||||
|
||||
|
||||
def encode_by_type(o: typing.Any) -> typing.Any:
|
||||
encoders_by_class_tuples: typing.Dict[typing.Callable[[typing.Any], typing.Any], typing.Tuple[typing.Any, ...]] = (
|
||||
defaultdict(tuple)
|
||||
)
|
||||
def encode_by_type(o: Any) -> Any:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple)
|
||||
for type_, encoder in encoders_by_type.items():
|
||||
encoders_by_class_tuples[encoder] += (type_,)
|
||||
|
||||
@@ -237,54 +204,49 @@ def encode_by_type(o: typing.Any) -> typing.Any:
|
||||
return encoder(o)
|
||||
|
||||
|
||||
def update_forward_refs(model: typing.Type["Model"], **localns: typing.Any) -> None:
|
||||
def update_forward_refs(model: Type["Model"], **localns: Any) -> None:
|
||||
if IS_PYDANTIC_V2:
|
||||
model.model_rebuild(raise_errors=False) # type: ignore # Pydantic v2
|
||||
model.model_rebuild(raise_errors=False) # type: ignore[attr-defined]
|
||||
else:
|
||||
model.update_forward_refs(**localns)
|
||||
|
||||
|
||||
# Mirrors Pydantic's internal typing
|
||||
AnyCallable = typing.Callable[..., typing.Any]
|
||||
AnyCallable = Callable[..., Any]
|
||||
|
||||
|
||||
def universal_root_validator(
|
||||
pre: bool = False,
|
||||
) -> typing.Callable[[AnyCallable], AnyCallable]:
|
||||
) -> Callable[[AnyCallable], AnyCallable]:
|
||||
def decorator(func: AnyCallable) -> AnyCallable:
|
||||
if IS_PYDANTIC_V2:
|
||||
return pydantic.model_validator(mode="before" if pre else "after")(func) # type: ignore # Pydantic v2
|
||||
else:
|
||||
return pydantic.root_validator(pre=pre)(func) # type: ignore # Pydantic v1
|
||||
return cast(AnyCallable, pydantic.model_validator(mode="before" if pre else "after")(func)) # type: ignore[attr-defined]
|
||||
return cast(AnyCallable, pydantic.root_validator(pre=pre)(func)) # type: ignore[call-overload]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def universal_field_validator(field_name: str, pre: bool = False) -> typing.Callable[[AnyCallable], AnyCallable]:
|
||||
def universal_field_validator(field_name: str, pre: bool = False) -> Callable[[AnyCallable], AnyCallable]:
|
||||
def decorator(func: AnyCallable) -> AnyCallable:
|
||||
if IS_PYDANTIC_V2:
|
||||
return pydantic.field_validator(field_name, mode="before" if pre else "after")(func) # type: ignore # Pydantic v2
|
||||
else:
|
||||
return pydantic.validator(field_name, pre=pre)(func) # type: ignore # Pydantic v1
|
||||
return cast(AnyCallable, pydantic.field_validator(field_name, mode="before" if pre else "after")(func)) # type: ignore[attr-defined]
|
||||
return cast(AnyCallable, pydantic.validator(field_name, pre=pre)(func))
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
PydanticField = typing.Union[ModelField, pydantic.fields.FieldInfo]
|
||||
PydanticField = Union[ModelField, pydantic.fields.FieldInfo]
|
||||
|
||||
|
||||
def _get_model_fields(
|
||||
model: typing.Type["Model"],
|
||||
) -> typing.Mapping[str, PydanticField]:
|
||||
def _get_model_fields(model: Type["Model"]) -> Mapping[str, PydanticField]:
|
||||
if IS_PYDANTIC_V2:
|
||||
return model.model_fields # type: ignore # Pydantic v2
|
||||
else:
|
||||
return model.__fields__ # type: ignore # Pydantic v1
|
||||
return cast(Mapping[str, PydanticField], model.model_fields) # type: ignore[attr-defined]
|
||||
return cast(Mapping[str, PydanticField], model.__fields__)
|
||||
|
||||
|
||||
def _get_field_default(field: PydanticField) -> typing.Any:
|
||||
def _get_field_default(field: PydanticField) -> Any:
|
||||
try:
|
||||
value = field.get_default() # type: ignore # Pydantic < v1.10.15
|
||||
value = field.get_default() # type: ignore[union-attr]
|
||||
except:
|
||||
value = field.default
|
||||
if IS_PYDANTIC_V2:
|
||||
|
||||
@@ -4,9 +4,8 @@ import collections
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
import typing_extensions
|
||||
|
||||
import pydantic
|
||||
import typing_extensions
|
||||
|
||||
|
||||
class FieldMetadata:
|
||||
@@ -161,7 +160,12 @@ def _convert_mapping(
|
||||
direction: typing.Literal["read", "write"],
|
||||
) -> typing.Mapping[str, object]:
|
||||
converted_object: typing.Dict[str, object] = {}
|
||||
annotations = typing_extensions.get_type_hints(expected_type, include_extras=True)
|
||||
try:
|
||||
annotations = typing_extensions.get_type_hints(expected_type, include_extras=True)
|
||||
except NameError:
|
||||
# The TypedDict contains a circular reference, so
|
||||
# we use the __annotations__ attribute directly.
|
||||
annotations = getattr(expected_type, "__annotations__", {})
|
||||
aliases_to_field_names = _get_alias_to_field_name(annotations)
|
||||
for key, value in object_.items():
|
||||
if direction == "read" and key in aliases_to_field_names:
|
||||
|
||||
Reference in New Issue
Block a user