workflow parameter validation (#1028)
This commit is contained in:
@@ -510,3 +510,14 @@ class BlockedHost(SkyvernHTTPException):
|
|||||||
f"The host in your url is blocked: {host}",
|
f"The host in your url is blocked: {host}",
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidWorkflowParameter(SkyvernHTTPException):
|
||||||
|
def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None:
|
||||||
|
message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}."
|
||||||
|
if workflow_permanent_id:
|
||||||
|
message += f" Workflow permanent id: {workflow_permanent_id}"
|
||||||
|
super().__init__(
|
||||||
|
message,
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1421,8 +1421,9 @@ class AgentDB:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def create_workflow_run_parameter(
|
async def create_workflow_run_parameter(
|
||||||
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
|
self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
|
||||||
) -> WorkflowRunParameter:
|
) -> WorkflowRunParameter:
|
||||||
|
workflow_parameter_id = workflow_parameter.workflow_parameter_id
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
workflow_run_parameter = WorkflowRunParameterModel(
|
workflow_run_parameter = WorkflowRunParameterModel(
|
||||||
@@ -1433,9 +1434,6 @@ class AgentDB:
|
|||||||
session.add(workflow_run_parameter)
|
session.add(workflow_run_parameter)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(workflow_run_parameter)
|
await session.refresh(workflow_run_parameter)
|
||||||
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
|
|
||||||
if not workflow_parameter:
|
|
||||||
raise WorkflowParameterNotFound(workflow_parameter_id)
|
|
||||||
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
LOG.error("SQLAlchemyError", exc_info=True)
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ import abc
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Annotated, Literal, Union
|
from typing import Annotated, Any, Literal, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from skyvern.exceptions import InvalidWorkflowParameter
|
||||||
|
|
||||||
|
|
||||||
class ParameterType(StrEnum):
|
class ParameterType(StrEnum):
|
||||||
WORKFLOW = "workflow"
|
WORKFLOW = "workflow"
|
||||||
@@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum):
|
|||||||
JSON = "json"
|
JSON = "json"
|
||||||
FILE_URL = "file_url"
|
FILE_URL = "file_url"
|
||||||
|
|
||||||
def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
|
def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
if self == WorkflowParameterType.STRING:
|
try:
|
||||||
return value
|
if self == WorkflowParameterType.STRING:
|
||||||
elif self == WorkflowParameterType.INTEGER:
|
return str(value)
|
||||||
return int(value)
|
elif self == WorkflowParameterType.INTEGER:
|
||||||
elif self == WorkflowParameterType.FLOAT:
|
return int(value)
|
||||||
return float(value)
|
elif self == WorkflowParameterType.FLOAT:
|
||||||
elif self == WorkflowParameterType.BOOLEAN:
|
return float(value)
|
||||||
return value.lower() in ["true", "1"]
|
elif self == WorkflowParameterType.BOOLEAN:
|
||||||
elif self == WorkflowParameterType.JSON:
|
if isinstance(value, bool):
|
||||||
return json.loads(value)
|
return value
|
||||||
elif self == WorkflowParameterType.FILE_URL:
|
lower_case = str(value).lower()
|
||||||
return value
|
if lower_case in ["true", "false", "1", "0"]:
|
||||||
|
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
|
||||||
|
return lower_case in ["true", "1"]
|
||||||
|
elif self == WorkflowParameterType.JSON:
|
||||||
|
return json.loads(value)
|
||||||
|
elif self == WorkflowParameterType.FILE_URL:
|
||||||
|
return value
|
||||||
|
except Exception:
|
||||||
|
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
|
||||||
|
|
||||||
|
|
||||||
class WorkflowParameter(Parameter):
|
class WorkflowParameter(Parameter):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import structlog
|
import structlog
|
||||||
@@ -125,13 +126,13 @@ class WorkflowService:
|
|||||||
request_body_value = workflow_request.data[workflow_parameter.key]
|
request_body_value = workflow_request.data[workflow_parameter.key]
|
||||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||||
workflow_run_id=workflow_run.workflow_run_id,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
workflow_parameter=workflow_parameter,
|
||||||
value=request_body_value,
|
value=request_body_value,
|
||||||
)
|
)
|
||||||
elif workflow_parameter.default_value is not None:
|
elif workflow_parameter.default_value is not None:
|
||||||
workflow_run_parameter = await self.create_workflow_run_parameter(
|
workflow_run_parameter = await self.create_workflow_run_parameter(
|
||||||
workflow_run_id=workflow_run.workflow_run_id,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
workflow_parameter=workflow_parameter,
|
||||||
value=workflow_parameter.default_value,
|
value=workflow_parameter.default_value,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -565,12 +566,15 @@ class WorkflowService:
|
|||||||
async def create_workflow_run_parameter(
|
async def create_workflow_run_parameter(
|
||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
workflow_parameter_id: str,
|
workflow_parameter: WorkflowParameter,
|
||||||
value: bool | int | float | str | dict | list,
|
value: Any,
|
||||||
) -> WorkflowRunParameter:
|
) -> WorkflowRunParameter:
|
||||||
|
# InvalidWorkflowParameter will be raised if the validation fails
|
||||||
|
workflow_parameter.workflow_parameter_type.convert_value(value)
|
||||||
|
|
||||||
return await app.DATABASE.create_workflow_run_parameter(
|
return await app.DATABASE.create_workflow_run_parameter(
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
workflow_parameter_id=workflow_parameter_id,
|
workflow_parameter=workflow_parameter,
|
||||||
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
|
value=json.dumps(value) if isinstance(value, (dict, list)) else value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user