diff --git a/pyproject.toml b/pyproject.toml index 9a5e254a..935d1902 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ cloud = [ "temporalio[opentelemetry]>=1.6.0,<2", "temporalio>=1.6.0,<2", "redis>=5.0.3,<6", + "pyrate-limiter>=3.7.0,<4", "opentelemetry-exporter-otlp-proto-grpc>=1.38.0,<2", "kr8s>=0.20.14,<1", ] diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 58e5bede..58b9220f 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -24,6 +24,15 @@ class DisabledBlockExecutionError(SkyvernHTTPException): super().__init__(message, status_code=status.HTTP_400_BAD_REQUEST) +class RateLimitExceeded(SkyvernHTTPException): + def __init__(self, organization_id: str, max_requests: int, window_seconds: int): + message = ( + f"Rate limit exceeded for organization {organization_id}. " + f"Maximum {max_requests} requests per {window_seconds} seconds allowed." + ) + super().__init__(message, status_code=status.HTTP_429_TOO_MANY_REQUESTS) + + class InvalidOpenAIResponseFormat(SkyvernException): def __init__(self, message: str | None = None): super().__init__(f"Invalid response format: {message}") diff --git a/skyvern/forge/forge_app.py b/skyvern/forge/forge_app.py index 3eda358e..71463266 100644 --- a/skyvern/forge/forge_app.py +++ b/skyvern/forge/forge_app.py @@ -21,6 +21,7 @@ from skyvern.forge.sdk.artifact.storage.factory import StorageFactory from skyvern.forge.sdk.artifact.storage.s3 import S3Storage from skyvern.forge.sdk.cache.base import BaseCache from skyvern.forge.sdk.cache.factory import CacheFactory +from skyvern.forge.sdk.core.rate_limiter import NoopRateLimiter, RateLimiter from skyvern.forge.sdk.db.agent_db import AgentDB from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider from skyvern.forge.sdk.schemas.credentials import CredentialVaultType @@ -50,6 +51,7 @@ class ForgeApp: ARTIFACT_MANAGER: ArtifactManager BROWSER_MANAGER: BrowserManager EXPERIMENTATION_PROVIDER: BaseExperimentationProvider + RATE_LIMITER: RateLimiter LLM_API_HANDLER: LLMAPIHandler OPENAI_CLIENT: AsyncOpenAI | AsyncAzureOpenAI ANTHROPIC_CLIENT: AsyncAnthropic | AsyncAnthropicBedrock @@ -107,6 +109,7 @@ def create_forge_app() -> ForgeApp: app.ARTIFACT_MANAGER = ArtifactManager() app.BROWSER_MANAGER = RealBrowserManager() app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider() + app.RATE_LIMITER = NoopRateLimiter() app.LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(settings.LLM_KEY) app.OPENAI_CLIENT = AsyncOpenAI( diff --git a/skyvern/forge/sdk/core/rate_limiter.py b/skyvern/forge/sdk/core/rate_limiter.py new file mode 100644 index 00000000..95501d5e --- /dev/null +++ b/skyvern/forge/sdk/core/rate_limiter.py @@ -0,0 +1,33 @@ +from typing import Protocol + + +class RateLimiter(Protocol): + """ + Protocol for rate limiting submit run requests per organization. + + Implementations should be thread-safe and work correctly in distributed environments. + """ + + async def rate_limit_submit_run(self, organization_id: str) -> None: + """ + Check and enforce rate limit for submitting a new run (task/workflow) + raises RateLimitExceeded exception if rate limit is exceeded. + + Args: + organization_id: The organization ID to rate limit + + Raises: + Exception: If rate limit is exceeded (implementation-specific exception) + """ + ... + + +class NoopRateLimiter(RateLimiter): + """ + No-op rate limiter. + + This implementation does not enforce any rate limits. + """ + + async def rate_limit_submit_run(self, organization_id: str) -> None: + """No-op implementation that never rate limits.""" diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 67136354..d6184d3a 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -162,6 +162,7 @@ async def run_task( ) -> TaskRunResponse: analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) + await app.RATE_LIMITER.rate_limit_submit_run(current_org.organization_id) if run_request.engine in CUA_ENGINES or run_request.engine == RunEngine.skyvern_v1: # create task v1 @@ -347,6 +348,7 @@ async def run_workflow( await PermissionCheckerFactory.get_instance().check( current_org, browser_session_id=workflow_run_request.browser_session_id ) + await app.RATE_LIMITER.rate_limit_submit_run(current_org.organization_id) workflow_id = workflow_run_request.workflow_id context = skyvern_context.ensure_context() request_id = context.request_id @@ -1623,6 +1625,7 @@ async def run_task_v1( ) -> CreateTaskResponse: analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url}) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id) + await app.RATE_LIMITER.rate_limit_submit_run(current_org.organization_id) created_task = await task_v1_service.run_task( task=task, @@ -2086,6 +2089,7 @@ async def run_workflow_legacy( current_org, browser_session_id=workflow_request.browser_session_id, ) + await app.RATE_LIMITER.rate_limit_submit_run(current_org.organization_id) try: workflow_run = await workflow_service.run_workflow( @@ -2667,6 +2671,7 @@ async def run_task_v2( max_steps_override=x_max_steps_override, ) await PermissionCheckerFactory.get_instance().check(organization, browser_session_id=data.browser_session_id) + await app.RATE_LIMITER.rate_limit_submit_run(organization.organization_id) try: task_v2 = await task_v2_service.initialize_task_v2( diff --git a/uv.lock b/uv.lock index 9314c658..b5478b2b 100644 --- a/uv.lock +++ b/uv.lock @@ -3435,11 +3435,11 @@ wheels = [ [[package]] name = "packaging" -version = "25.0" +version = "24.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" }, ] [[package]] @@ -4281,6 +4281,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216, upload-time = "2024-09-29T09:24:11.978Z" }, ] +[[package]] +name = "pyrate-limiter" +version = "3.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/da/f682c5c5f9f0a5414363eb4397e6b07d84a02cde69c4ceadcbf32c85537c/pyrate_limiter-3.9.0.tar.gz", hash = "sha256:6b882e2c77cda07a241d3730975daea4258344b39c878f1dd8849df73f70b0ce", size = 289308, upload-time = "2025-07-30T14:36:58.659Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/af/d8bf0959ece9bc4679bd203908c31019556a421d76d8143b0c6871c7f614/pyrate_limiter-3.9.0-py3-none-any.whl", hash = "sha256:77357840c8cf97a36d67005d4e090787043f54000c12c2b414ff65657653e378", size = 33628, upload-time = "2025-07-30T14:36:57.71Z" }, +] + [[package]] name = "pytest" version = "7.4.4" @@ -5143,6 +5152,7 @@ cloud = [ { name = "ddtrace" }, { name = "kr8s" }, { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "pyrate-limiter" }, { name = "redis" }, { name = "stripe" }, { name = "temporalio", extra = ["opentelemetry"] }, @@ -5239,6 +5249,7 @@ cloud = [ { name = "ddtrace", specifier = ">=2.3.2,<3" }, { name = "kr8s", specifier = ">=0.20.14,<1" }, { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = ">=1.38.0,<2" }, + { name = "pyrate-limiter", specifier = ">=3.7.0,<4" }, { name = "redis", specifier = ">=5.0.3,<6" }, { name = "stripe", specifier = ">=9.7.0,<10" }, { name = "temporalio", specifier = ">=1.6.0,<2" },