diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index d0c580ef..49bace7a 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -510,3 +510,14 @@ class BlockedHost(SkyvernHTTPException): f"The host in your url is blocked: {host}", status_code=status.HTTP_400_BAD_REQUEST, ) + + +class InvalidWorkflowParameter(SkyvernHTTPException): + def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None: + message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}." + if workflow_permanent_id: + message += f" Workflow permanent id: {workflow_permanent_id}" + super().__init__( + message, + status_code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 4e605b29..51010f8e 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1421,8 +1421,9 @@ class AgentDB: raise async def create_workflow_run_parameter( - self, workflow_run_id: str, workflow_parameter_id: str, value: Any + self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any ) -> WorkflowRunParameter: + workflow_parameter_id = workflow_parameter.workflow_parameter_id try: async with self.Session() as session: workflow_run_parameter = WorkflowRunParameterModel( @@ -1433,9 +1434,6 @@ class AgentDB: session.add(workflow_run_parameter) await session.commit() await session.refresh(workflow_run_parameter) - workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id) - if not workflow_parameter: - raise WorkflowParameterNotFound(workflow_parameter_id) return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled) except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) diff --git a/skyvern/forge/sdk/workflow/models/parameter.py b/skyvern/forge/sdk/workflow/models/parameter.py index 1ebd943c..ae611db9 100644 --- a/skyvern/forge/sdk/workflow/models/parameter.py +++ b/skyvern/forge/sdk/workflow/models/parameter.py @@ -2,10 +2,12 @@ import abc import json from datetime import datetime from enum import StrEnum -from typing import Annotated, Literal, Union +from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, ConfigDict, Field +from skyvern.exceptions import InvalidWorkflowParameter + class ParameterType(StrEnum): WORKFLOW = "workflow" @@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum): JSON = "json" FILE_URL = "file_url" - def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None: + def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None: if value is None: return None - if self == WorkflowParameterType.STRING: - return value - elif self == WorkflowParameterType.INTEGER: - return int(value) - elif self == WorkflowParameterType.FLOAT: - return float(value) - elif self == WorkflowParameterType.BOOLEAN: - return value.lower() in ["true", "1"] - elif self == WorkflowParameterType.JSON: - return json.loads(value) - elif self == WorkflowParameterType.FILE_URL: - return value + try: + if self == WorkflowParameterType.STRING: + return str(value) + elif self == WorkflowParameterType.INTEGER: + return int(value) + elif self == WorkflowParameterType.FLOAT: + return float(value) + elif self == WorkflowParameterType.BOOLEAN: + if isinstance(value, bool): + return value + lower_case = str(value).lower() + if lower_case in ["true", "false", "1", "0"]: + raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value)) + return lower_case in ["true", "1"] + elif self == WorkflowParameterType.JSON: + return json.loads(value) + elif self == WorkflowParameterType.FILE_URL: + return value + except Exception: + raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value)) class WorkflowParameter(Parameter): diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index f008af3c..24dceae0 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from typing import Any import requests import structlog @@ -125,13 +126,13 @@ class WorkflowService: request_body_value = workflow_request.data[workflow_parameter.key] workflow_run_parameter = await self.create_workflow_run_parameter( workflow_run_id=workflow_run.workflow_run_id, - workflow_parameter_id=workflow_parameter.workflow_parameter_id, + workflow_parameter=workflow_parameter, value=request_body_value, ) elif workflow_parameter.default_value is not None: workflow_run_parameter = await self.create_workflow_run_parameter( workflow_run_id=workflow_run.workflow_run_id, - workflow_parameter_id=workflow_parameter.workflow_parameter_id, + workflow_parameter=workflow_parameter, value=workflow_parameter.default_value, ) else: @@ -565,12 +566,15 @@ class WorkflowService: async def create_workflow_run_parameter( self, workflow_run_id: str, - workflow_parameter_id: str, - value: bool | int | float | str | dict | list, + workflow_parameter: WorkflowParameter, + value: Any, ) -> WorkflowRunParameter: + # InvalidWorkflowParameter will be raised if the validation fails + workflow_parameter.workflow_parameter_type.convert_value(value) + return await app.DATABASE.create_workflow_run_parameter( workflow_run_id=workflow_run_id, - workflow_parameter_id=workflow_parameter_id, + workflow_parameter=workflow_parameter, value=json.dumps(value) if isinstance(value, (dict, list)) else value, )