workflow parameter validation (#1028)
This commit is contained in:
@@ -510,3 +510,14 @@ class BlockedHost(SkyvernHTTPException):
|
||||
f"The host in your url is blocked: {host}",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
class InvalidWorkflowParameter(SkyvernHTTPException):
|
||||
def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None:
|
||||
message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}."
|
||||
if workflow_permanent_id:
|
||||
message += f" Workflow permanent id: {workflow_permanent_id}"
|
||||
super().__init__(
|
||||
message,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
@@ -1421,8 +1421,9 @@ class AgentDB:
|
||||
raise
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
|
||||
self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
workflow_parameter_id = workflow_parameter.workflow_parameter_id
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
@@ -1433,9 +1434,6 @@ class AgentDB:
|
||||
session.add(workflow_run_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_parameter)
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id)
|
||||
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
|
||||
@@ -2,10 +2,12 @@ import abc
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Union
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from skyvern.exceptions import InvalidWorkflowParameter
|
||||
|
||||
|
||||
class ParameterType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
@@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum):
|
||||
JSON = "json"
|
||||
FILE_URL = "file_url"
|
||||
|
||||
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
|
||||
def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self == WorkflowParameterType.STRING:
|
||||
return value
|
||||
elif self == WorkflowParameterType.INTEGER:
|
||||
return int(value)
|
||||
elif self == WorkflowParameterType.FLOAT:
|
||||
return float(value)
|
||||
elif self == WorkflowParameterType.BOOLEAN:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif self == WorkflowParameterType.JSON:
|
||||
return json.loads(value)
|
||||
elif self == WorkflowParameterType.FILE_URL:
|
||||
return value
|
||||
try:
|
||||
if self == WorkflowParameterType.STRING:
|
||||
return str(value)
|
||||
elif self == WorkflowParameterType.INTEGER:
|
||||
return int(value)
|
||||
elif self == WorkflowParameterType.FLOAT:
|
||||
return float(value)
|
||||
elif self == WorkflowParameterType.BOOLEAN:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
lower_case = str(value).lower()
|
||||
if lower_case in ["true", "false", "1", "0"]:
|
||||
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
|
||||
return lower_case in ["true", "1"]
|
||||
elif self == WorkflowParameterType.JSON:
|
||||
return json.loads(value)
|
||||
elif self == WorkflowParameterType.FILE_URL:
|
||||
return value
|
||||
except Exception:
|
||||
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
|
||||
|
||||
|
||||
class WorkflowParameter(Parameter):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
@@ -125,13 +126,13 @@ class WorkflowService:
|
||||
request_body_value = workflow_request.data[workflow_parameter.key]
|
||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
workflow_parameter=workflow_parameter,
|
||||
value=request_body_value,
|
||||
)
|
||||
elif workflow_parameter.default_value is not None:
|
||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
workflow_parameter=workflow_parameter,
|
||||
value=workflow_parameter.default_value,
|
||||
)
|
||||
else:
|
||||
@@ -565,12 +566,15 @@ class WorkflowService:
|
||||
async def create_workflow_run_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_parameter_id: str,
|
||||
value: bool | int | float | str | dict | list,
|
||||
workflow_parameter: WorkflowParameter,
|
||||
value: Any,
|
||||
) -> WorkflowRunParameter:
|
||||
# InvalidWorkflowParameter will be raised if the validation fails
|
||||
workflow_parameter.workflow_parameter_type.convert_value(value)
|
||||
|
||||
return await app.DATABASE.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
workflow_parameter=workflow_parameter,
|
||||
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user