workflow parameter validation (#1028)

This commit is contained in:
Shuchang Zheng
2024-10-22 17:36:25 -07:00
committed by GitHub
parent 7cba401e2e
commit 0e3da8d1d3
4 changed files with 46 additions and 23 deletions

View File

@@ -510,3 +510,14 @@ class BlockedHost(SkyvernHTTPException):
f"The host in your url is blocked: {host}", f"The host in your url is blocked: {host}",
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
) )
class InvalidWorkflowParameter(SkyvernHTTPException):
def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None:
message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}."
if workflow_permanent_id:
message += f" Workflow permanent id: {workflow_permanent_id}"
super().__init__(
message,
status_code=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -1421,8 +1421,9 @@ class AgentDB:
raise raise
async def create_workflow_run_parameter( async def create_workflow_run_parameter(
self, workflow_run_id: str, workflow_parameter_id: str, value: Any self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
) -> WorkflowRunParameter: ) -> WorkflowRunParameter:
workflow_parameter_id = workflow_parameter.workflow_parameter_id
try: try:
async with self.Session() as session: async with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel( workflow_run_parameter = WorkflowRunParameterModel(
@@ -1433,9 +1434,6 @@ class AgentDB:
session.add(workflow_run_parameter) session.add(workflow_run_parameter)
await session.commit() await session.commit()
await session.refresh(workflow_run_parameter) await session.refresh(workflow_run_parameter)
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled) return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
except SQLAlchemyError: except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True) LOG.error("SQLAlchemyError", exc_info=True)

View File

@@ -2,10 +2,12 @@ import abc
import json import json
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Annotated, Literal, Union from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from skyvern.exceptions import InvalidWorkflowParameter
class ParameterType(StrEnum): class ParameterType(StrEnum):
WORKFLOW = "workflow" WORKFLOW = "workflow"
@@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum):
JSON = "json" JSON = "json"
FILE_URL = "file_url" FILE_URL = "file_url"
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None: def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None:
if value is None: if value is None:
return None return None
if self == WorkflowParameterType.STRING: try:
return value if self == WorkflowParameterType.STRING:
elif self == WorkflowParameterType.INTEGER: return str(value)
return int(value) elif self == WorkflowParameterType.INTEGER:
elif self == WorkflowParameterType.FLOAT: return int(value)
return float(value) elif self == WorkflowParameterType.FLOAT:
elif self == WorkflowParameterType.BOOLEAN: return float(value)
return value.lower() in ["true", "1"] elif self == WorkflowParameterType.BOOLEAN:
elif self == WorkflowParameterType.JSON: if isinstance(value, bool):
return json.loads(value) return value
elif self == WorkflowParameterType.FILE_URL: lower_case = str(value).lower()
return value if lower_case in ["true", "false", "1", "0"]:
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
return lower_case in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
elif self == WorkflowParameterType.FILE_URL:
return value
except Exception:
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
class WorkflowParameter(Parameter): class WorkflowParameter(Parameter):

View File

@@ -1,5 +1,6 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Any
import requests import requests
import structlog import structlog
@@ -125,13 +126,13 @@ class WorkflowService:
request_body_value = workflow_request.data[workflow_parameter.key] request_body_value = workflow_request.data[workflow_parameter.key]
workflow_run_parameter = await self.create_workflow_run_parameter( workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id, workflow_parameter=workflow_parameter,
value=request_body_value, value=request_body_value,
) )
elif workflow_parameter.default_value is not None: elif workflow_parameter.default_value is not None:
workflow_run_parameter = await self.create_workflow_run_parameter( workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id, workflow_parameter=workflow_parameter,
value=workflow_parameter.default_value, value=workflow_parameter.default_value,
) )
else: else:
@@ -565,12 +566,15 @@ class WorkflowService:
async def create_workflow_run_parameter( async def create_workflow_run_parameter(
self, self,
workflow_run_id: str, workflow_run_id: str,
workflow_parameter_id: str, workflow_parameter: WorkflowParameter,
value: bool | int | float | str | dict | list, value: Any,
) -> WorkflowRunParameter: ) -> WorkflowRunParameter:
# InvalidWorkflowParameter will be raised if the validation fails
workflow_parameter.workflow_parameter_type.convert_value(value)
return await app.DATABASE.create_workflow_run_parameter( return await app.DATABASE.create_workflow_run_parameter(
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id, workflow_parameter=workflow_parameter,
value=json.dumps(value) if isinstance(value, (dict, list)) else value, value=json.dumps(value) if isinstance(value, (dict, list)) else value,
) )