workflow parameter validation (#1028)

This commit is contained in:
Shuchang Zheng
2024-10-22 17:36:25 -07:00
committed by GitHub
parent 7cba401e2e
commit 0e3da8d1d3
4 changed files with 46 additions and 23 deletions

View File

@@ -510,3 +510,14 @@ class BlockedHost(SkyvernHTTPException):
f"The host in your url is blocked: {host}",
status_code=status.HTTP_400_BAD_REQUEST,
)
class InvalidWorkflowParameter(SkyvernHTTPException):
def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None:
message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}."
if workflow_permanent_id:
message += f" Workflow permanent id: {workflow_permanent_id}"
super().__init__(
message,
status_code=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -1421,8 +1421,9 @@ class AgentDB:
raise
async def create_workflow_run_parameter(
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
) -> WorkflowRunParameter:
workflow_parameter_id = workflow_parameter.workflow_parameter_id
try:
async with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel(
@@ -1433,9 +1434,6 @@ class AgentDB:
session.add(workflow_run_parameter)
await session.commit()
await session.refresh(workflow_run_parameter)
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)

View File

@@ -2,10 +2,12 @@ import abc
import json
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Literal, Union
from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, ConfigDict, Field
from skyvern.exceptions import InvalidWorkflowParameter
class ParameterType(StrEnum):
WORKFLOW = "workflow"
@@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum):
JSON = "json"
FILE_URL = "file_url"
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None:
if value is None:
return None
if self == WorkflowParameterType.STRING:
return value
elif self == WorkflowParameterType.INTEGER:
return int(value)
elif self == WorkflowParameterType.FLOAT:
return float(value)
elif self == WorkflowParameterType.BOOLEAN:
return value.lower() in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
elif self == WorkflowParameterType.FILE_URL:
return value
try:
if self == WorkflowParameterType.STRING:
return str(value)
elif self == WorkflowParameterType.INTEGER:
return int(value)
elif self == WorkflowParameterType.FLOAT:
return float(value)
elif self == WorkflowParameterType.BOOLEAN:
if isinstance(value, bool):
return value
lower_case = str(value).lower()
if lower_case in ["true", "false", "1", "0"]:
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
return lower_case in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
elif self == WorkflowParameterType.FILE_URL:
return value
except Exception:
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
class WorkflowParameter(Parameter):

View File

@@ -1,5 +1,6 @@
import json
from datetime import datetime
from typing import Any
import requests
import structlog
@@ -125,13 +126,13 @@ class WorkflowService:
request_body_value = workflow_request.data[workflow_parameter.key]
workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
workflow_parameter=workflow_parameter,
value=request_body_value,
)
elif workflow_parameter.default_value is not None:
workflow_run_parameter = await self.create_workflow_run_parameter(
workflow_run_id=workflow_run.workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
workflow_parameter=workflow_parameter,
value=workflow_parameter.default_value,
)
else:
@@ -565,12 +566,15 @@ class WorkflowService:
async def create_workflow_run_parameter(
self,
workflow_run_id: str,
workflow_parameter_id: str,
value: bool | int | float | str | dict | list,
workflow_parameter: WorkflowParameter,
value: Any,
) -> WorkflowRunParameter:
# InvalidWorkflowParameter will be raised if the validation fails
workflow_parameter.workflow_parameter_type.convert_value(value)
return await app.DATABASE.create_workflow_run_parameter(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
workflow_parameter=workflow_parameter,
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
)