Workflow: Output Parameters & Code Blocks (#117)
This commit is contained in:
@@ -15,11 +15,13 @@ from skyvern.forge.sdk.db.models import (
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
OutputParameterModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunOutputParameterModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
@@ -28,17 +30,30 @@ from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_aws_secret_parameter,
|
||||
convert_to_organization,
|
||||
convert_to_organization_auth_token,
|
||||
convert_to_output_parameter,
|
||||
convert_to_step,
|
||||
convert_to_task,
|
||||
convert_to_workflow,
|
||||
convert_to_workflow_parameter,
|
||||
convert_to_workflow_run,
|
||||
convert_to_workflow_run_output_parameter,
|
||||
convert_to_workflow_run_parameter,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
AWSSecretParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRun,
|
||||
WorkflowRunOutputParameter,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -777,6 +792,68 @@ class AgentDB:
|
||||
session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
async def create_output_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
description: str | None = None,
|
||||
) -> OutputParameter:
|
||||
with self.Session() as session:
|
||||
output_parameter = OutputParameterModel(
|
||||
key=key,
|
||||
description=description,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(output_parameter)
|
||||
session.commit()
|
||||
session.refresh(output_parameter)
|
||||
return convert_to_output_parameter(output_parameter)
|
||||
|
||||
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
output_parameters = session.query(OutputParameterModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_output_parameters = (
|
||||
session.query(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
|
||||
for parameter in workflow_run_output_parameters
|
||||
]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_run_output_parameter(
|
||||
self, workflow_run_id: str, output_parameter_id: str, value: dict[str, Any] | list | str | None
|
||||
) -> WorkflowRunOutputParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=output_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_output_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
|
||||
@@ -37,6 +37,7 @@ WORKFLOW_PREFIX = "w"
|
||||
WORKFLOW_RUN_PREFIX = "wr"
|
||||
WORKFLOW_PARAMETER_PREFIX = "wp"
|
||||
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
||||
OUTPUT_PARAMETER_PREFIX = "op"
|
||||
|
||||
|
||||
def generate_workflow_id() -> str:
|
||||
@@ -59,6 +60,11 @@ def generate_workflow_parameter_id() -> str:
|
||||
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_output_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{OUTPUT_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_organization_auth_token_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
|
||||
|
||||
@@ -9,6 +9,7 @@ from skyvern.forge.sdk.db.id import (
|
||||
generate_aws_secret_parameter_id,
|
||||
generate_org_id,
|
||||
generate_organization_auth_token_id,
|
||||
generate_output_parameter_id,
|
||||
generate_step_id,
|
||||
generate_task_id,
|
||||
generate_workflow_id,
|
||||
@@ -150,6 +151,18 @@ class WorkflowParameterModel(Base):
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class OutputParameterModel(Base):
|
||||
__tablename__ = "output_parameters"
|
||||
|
||||
output_parameter_id = Column(String, primary_key=True, index=True, default=generate_output_parameter_id)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class AWSSecretParameterModel(Base):
|
||||
__tablename__ = "aws_secret_parameters"
|
||||
|
||||
@@ -173,3 +186,14 @@ class WorkflowRunParameterModel(Base):
|
||||
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
|
||||
value = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowRunOutputParameterModel(Base):
|
||||
__tablename__ = "workflow_run_output_parameters"
|
||||
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
|
||||
output_parameter_id = Column(
|
||||
String, ForeignKey("output_parameters.output_parameter_id"), primary_key=True, index=True
|
||||
)
|
||||
value = Column(JSON, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
@@ -11,20 +11,28 @@ from skyvern.forge.sdk.db.models import (
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
OutputParameterModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunOutputParameterModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
AWSSecretParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRun,
|
||||
WorkflowRunOutputParameter,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@@ -188,7 +196,7 @@ def convert_to_aws_secret_parameter(
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting AWSSecretParameterModel to AWSSecretParameter",
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.id,
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
|
||||
)
|
||||
|
||||
return AWSSecretParameter(
|
||||
@@ -203,6 +211,45 @@ def convert_to_aws_secret_parameter(
|
||||
)
|
||||
|
||||
|
||||
def convert_to_output_parameter(
|
||||
output_parameter_model: OutputParameterModel, debug_enabled: bool = False
|
||||
) -> OutputParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting OutputParameterModel to OutputParameter",
|
||||
output_parameter_id=output_parameter_model.output_parameter_id,
|
||||
)
|
||||
|
||||
return OutputParameter(
|
||||
output_parameter_id=output_parameter_model.output_parameter_id,
|
||||
key=output_parameter_model.key,
|
||||
description=output_parameter_model.description,
|
||||
workflow_id=output_parameter_model.workflow_id,
|
||||
created_at=output_parameter_model.created_at,
|
||||
modified_at=output_parameter_model.modified_at,
|
||||
deleted_at=output_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run_output_parameter(
|
||||
workflow_run_output_parameter_model: WorkflowRunOutputParameterModel,
|
||||
debug_enabled: bool = False,
|
||||
) -> WorkflowRunOutputParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowRunOutputParameterModel to WorkflowRunOutputParameter",
|
||||
workflow_run_id=workflow_run_output_parameter_model.workflow_run_id,
|
||||
output_parameter_id=workflow_run_output_parameter_model.output_parameter_id,
|
||||
)
|
||||
|
||||
return WorkflowRunOutputParameter(
|
||||
workflow_run_id=workflow_run_output_parameter_model.workflow_run_id,
|
||||
output_parameter_id=workflow_run_output_parameter_model.output_parameter_id,
|
||||
value=workflow_run_output_parameter_model.value,
|
||||
created_at=workflow_run_output_parameter_model.created_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter_model: WorkflowRunParameterModel,
|
||||
workflow_parameter: WorkflowParameter,
|
||||
|
||||
@@ -5,12 +5,18 @@ import structlog
|
||||
|
||||
from skyvern.exceptions import WorkflowRunContextNotInitialized
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
|
||||
from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
OutputParameter,
|
||||
Parameter,
|
||||
ParameterType,
|
||||
WorkflowParameter,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter
|
||||
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -19,7 +25,11 @@ class WorkflowRunContext:
|
||||
values: dict[str, Any]
|
||||
secrets: dict[str, Any]
|
||||
|
||||
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
|
||||
workflow_output_parameters: list[OutputParameter],
|
||||
) -> None:
|
||||
self.parameters = {}
|
||||
self.values = {}
|
||||
self.secrets = {}
|
||||
@@ -34,6 +44,11 @@ class WorkflowRunContext:
|
||||
self.parameters[parameter.key] = parameter
|
||||
self.values[parameter.key] = run_parameter.value
|
||||
|
||||
for output_parameter in workflow_output_parameters:
|
||||
if output_parameter.key in self.parameters:
|
||||
raise OutputParameterKeyCollisionError(output_parameter.key)
|
||||
self.parameters[output_parameter.key] = output_parameter
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
|
||||
@@ -48,11 +63,23 @@ class WorkflowRunContext:
|
||||
def set_value(self, key: str, value: Any) -> None:
|
||||
self.values[key] = value
|
||||
|
||||
def get_original_secret_value_or_none(self, secret_id: str) -> Any:
|
||||
def get_original_secret_value_or_none(self, secret_id_or_value: Any) -> Any:
|
||||
"""
|
||||
Get the original secret value from the secrets dict. If the secret id is not found, return None.
|
||||
|
||||
This function can be called with any possible parameter value, not just the random secret id.
|
||||
|
||||
All the obfuscated secret values are strings, so if the parameter value is a string, we'll assume it's a
|
||||
parameter value and return it.
|
||||
|
||||
If the parameter value is a string, it could be a random secret id or an actual parameter value. We'll check if
|
||||
the parameter value is a key in the secrets dict. If it is, we'll return the secret value. If it's not, we'll
|
||||
assume it's an actual parameter value and return it.
|
||||
|
||||
"""
|
||||
return self.secrets.get(secret_id)
|
||||
if type(secret_id_or_value) is str:
|
||||
return self.secrets.get(secret_id_or_value)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def generate_random_secret_id() -> str:
|
||||
@@ -68,6 +95,9 @@ class WorkflowRunContext:
|
||||
raise ValueError(
|
||||
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
|
||||
)
|
||||
elif parameter.parameter_type == ParameterType.OUTPUT:
|
||||
LOG.error(f"Output parameters are set after each block execution. Parameter key: {parameter.key}")
|
||||
raise ValueError(f"Output parameters are set after each block execution. Parameter key: {parameter.key}")
|
||||
elif parameter.parameter_type == ParameterType.AWS_SECRET:
|
||||
# If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict
|
||||
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
|
||||
@@ -77,9 +107,20 @@ class WorkflowRunContext:
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
self.secrets[random_secret_id] = secret_value
|
||||
self.values[parameter.key] = random_secret_id
|
||||
else:
|
||||
elif parameter.parameter_type == ParameterType.CONTEXT:
|
||||
# ContextParameter values will be set within the blocks
|
||||
return None
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Unknown parameter type: {parameter.parameter_type}")
|
||||
|
||||
async def register_output_parameter_value_post_execution(
|
||||
self, parameter: OutputParameter, value: dict[str, Any] | list | str | None
|
||||
) -> None:
|
||||
if parameter.key in self.values:
|
||||
LOG.error(f"Output parameter {parameter.output_parameter_id} already has a registered value")
|
||||
return
|
||||
|
||||
self.values[parameter.key] = value
|
||||
|
||||
async def register_block_parameters(
|
||||
self,
|
||||
@@ -98,6 +139,13 @@ class WorkflowRunContext:
|
||||
raise ValueError(
|
||||
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
|
||||
)
|
||||
elif parameter.parameter_type == ParameterType.OUTPUT:
|
||||
LOG.error(
|
||||
f"Output parameter {parameter.key} should have already been set through workflow run context init"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Output parameter {parameter.key} should have already been set through workflow run context init"
|
||||
)
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
await self.register_parameter_value(aws_client, parameter)
|
||||
@@ -121,9 +169,12 @@ class WorkflowContextManager:
|
||||
raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id)
|
||||
|
||||
def initialize_workflow_run_context(
|
||||
self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
|
||||
workflow_output_parameters: list[OutputParameter],
|
||||
) -> WorkflowRunContext:
|
||||
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples)
|
||||
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples, workflow_output_parameters)
|
||||
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
|
||||
return workflow_run_context
|
||||
|
||||
|
||||
23
skyvern/forge/sdk/workflow/exceptions.py
Normal file
23
skyvern/forge/sdk/workflow/exceptions.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from skyvern.exceptions import SkyvernException
|
||||
|
||||
|
||||
class BaseWorkflowException(SkyvernException):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowDefinitionHasDuplicateBlockLabels(BaseWorkflowException):
|
||||
def __init__(self, duplicate_labels: set[str]) -> None:
|
||||
super().__init__(
|
||||
f"WorkflowDefinition has blocks with duplicate labels. Each block needs to have a unique "
|
||||
f"label. Duplicate label(s): {','.join(duplicate_labels)}"
|
||||
)
|
||||
|
||||
|
||||
class OutputParameterKeyCollisionError(BaseWorkflowException):
|
||||
def __init__(self, key: str, retry_count: int | None = None) -> None:
|
||||
message = f"Output parameter key {key} already exists in the context manager."
|
||||
if retry_count is not None:
|
||||
message += f" Retrying {retry_count} more times."
|
||||
elif retry_count == 0:
|
||||
message += " Max duplicate retries reached, aborting."
|
||||
super().__init__(message)
|
||||
@@ -14,7 +14,12 @@ from skyvern.exceptions import (
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
ContextParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
)
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -22,12 +27,14 @@ LOG = structlog.get_logger()
|
||||
class BlockType(StrEnum):
|
||||
TASK = "task"
|
||||
FOR_LOOP = "for_loop"
|
||||
CODE = "code"
|
||||
|
||||
|
||||
class Block(BaseModel, abc.ABC):
|
||||
# Must be unique within workflow definition
|
||||
label: str
|
||||
block_type: BlockType
|
||||
parent_block_id: str | None = None
|
||||
next_block_id: str | None = None
|
||||
output_parameter: OutputParameter | None = None
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
||||
@@ -38,7 +45,7 @@ class Block(BaseModel, abc.ABC):
|
||||
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -96,7 +103,7 @@ class TaskBlock(Block):
|
||||
|
||||
return order, retry + 1
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
current_retry = 0
|
||||
# initial value for will_retry is True, so that the loop runs at least once
|
||||
@@ -158,6 +165,32 @@ class TaskBlock(Block):
|
||||
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
|
||||
if updated_task.status == TaskStatus.completed:
|
||||
will_retry = False
|
||||
LOG.info(
|
||||
f"Task completed",
|
||||
task_id=updated_task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=updated_task.extracted_information,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=updated_task.extracted_information,
|
||||
)
|
||||
LOG.info(
|
||||
f"Registered output parameter value",
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=updated_task.extracted_information,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
task_id=updated_task.task_id,
|
||||
)
|
||||
return self.output_parameter
|
||||
else:
|
||||
current_retry += 1
|
||||
will_retry = current_retry <= self.max_retries
|
||||
@@ -172,6 +205,7 @@ class TaskBlock(Block):
|
||||
current_retry=current_retry,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ForLoopBlock(Block):
|
||||
@@ -216,9 +250,10 @@ class ForLoopBlock(Block):
|
||||
return [parameter_value]
|
||||
else:
|
||||
# TODO (kerem): Implement this for context parameters
|
||||
# TODO (kerem): Implement this for output parameters
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
|
||||
LOG.info(
|
||||
@@ -227,14 +262,77 @@ class ForLoopBlock(Block):
|
||||
workflow_run_id=workflow_run_id,
|
||||
num_loop_over_values=len(loop_over_values),
|
||||
)
|
||||
outputs_with_loop_values = []
|
||||
for loop_over_value in loop_over_values:
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
||||
await self.loop_block.execute(workflow_run_id=workflow_run_id)
|
||||
if self.loop_block.output_parameter:
|
||||
outputs_with_loop_values.append(
|
||||
{
|
||||
"loop_value": loop_over_value,
|
||||
"output_parameter": self.loop_block.output_parameter,
|
||||
"output_value": workflow_run_context.get_value(self.loop_block.output_parameter.key),
|
||||
}
|
||||
)
|
||||
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=outputs_with_loop_values,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=outputs_with_loop_values,
|
||||
)
|
||||
return self.output_parameter
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
|
||||
class CodeBlock(Block):
|
||||
block_type: Literal[BlockType.CODE] = BlockType.CODE
|
||||
|
||||
code: str
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.parameters
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
# get workflow run context
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
# get all parameters into a dictionary
|
||||
parameter_values = {}
|
||||
for parameter in self.parameters:
|
||||
value = workflow_run_context.get_value(parameter.key)
|
||||
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
|
||||
if secret_value is not None:
|
||||
parameter_values[parameter.key] = secret_value
|
||||
else:
|
||||
parameter_values[parameter.key] = value
|
||||
|
||||
local_variables: dict[str, Any] = {}
|
||||
exec(self.code, parameter_values, local_variables)
|
||||
result = {"result": local_variables.get("result")}
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=result,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=result,
|
||||
)
|
||||
return self.output_parameter
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
@@ -11,6 +11,7 @@ class ParameterType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
CONTEXT = "context"
|
||||
AWS_SECRET = "aws_secret"
|
||||
OUTPUT = "output"
|
||||
|
||||
|
||||
class Parameter(BaseModel, abc.ABC):
|
||||
@@ -80,5 +81,16 @@ class ContextParameter(Parameter):
|
||||
value: str | int | float | bool | dict | list | None = None
|
||||
|
||||
|
||||
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
|
||||
class OutputParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.OUTPUT] = ParameterType.OUTPUT
|
||||
|
||||
output_parameter_id: str
|
||||
workflow_id: str
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter, OutputParameter]
|
||||
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||
|
||||
|
||||
@@ -22,6 +23,18 @@ class RunWorkflowResponse(BaseModel):
|
||||
class WorkflowDefinition(BaseModel):
|
||||
blocks: List[BlockTypeVar]
|
||||
|
||||
def validate(self) -> None:
|
||||
labels: set[str] = set()
|
||||
duplicate_labels: set[str] = set()
|
||||
for block in self.blocks:
|
||||
if block.label in labels:
|
||||
duplicate_labels.add(block.label)
|
||||
else:
|
||||
labels.add(block.label)
|
||||
|
||||
if duplicate_labels:
|
||||
raise WorkflowDefinitionHasDuplicateBlockLabels(duplicate_labels)
|
||||
|
||||
|
||||
class Workflow(BaseModel):
|
||||
workflow_id: str
|
||||
@@ -61,6 +74,13 @@ class WorkflowRunParameter(BaseModel):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunOutputParameter(BaseModel):
|
||||
workflow_run_id: str
|
||||
output_parameter_id: str
|
||||
value: dict[str, Any] | list | str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunStatusResponse(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
@@ -20,12 +20,18 @@ from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
AWSSecretParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRequestBody,
|
||||
WorkflowRun,
|
||||
WorkflowRunOutputParameter,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunStatusResponse,
|
||||
@@ -124,7 +130,10 @@ class WorkflowService:
|
||||
|
||||
# Get all <workflow parameter, workflow run parameter> tuples
|
||||
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
|
||||
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples)
|
||||
workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id)
|
||||
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(
|
||||
workflow_run_id, wp_wps_tuples, workflow_output_parameters
|
||||
)
|
||||
# Execute workflow blocks
|
||||
blocks = workflow.workflow_definition.blocks
|
||||
try:
|
||||
@@ -148,15 +157,27 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
|
||||
if not tasks:
|
||||
LOG.warning(
|
||||
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
return workflow_run
|
||||
|
||||
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
|
||||
if tasks:
|
||||
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
|
||||
else:
|
||||
# Check if the workflow run has any workflow run output parameters
|
||||
# if it does, mark the workflow run as completed, else mark it as failed
|
||||
workflow_run_output_parameters = await self.get_workflow_run_output_parameters(
|
||||
workflow_run_id=workflow_run.workflow_run_id
|
||||
)
|
||||
if workflow_run_output_parameters:
|
||||
LOG.info(
|
||||
f"Workflow run {workflow_run.workflow_run_id} has output parameters, marking as completed",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
else:
|
||||
LOG.error(
|
||||
f"Workflow run {workflow_run.workflow_run_id} has no tasks or output parameters, marking as failed",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
await self.send_workflow_response(
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
@@ -232,6 +253,8 @@ class WorkflowService:
|
||||
description: str | None = None,
|
||||
workflow_definition: WorkflowDefinition | None = None,
|
||||
) -> Workflow | None:
|
||||
if workflow_definition:
|
||||
workflow_definition.validate()
|
||||
return await app.DATABASE.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
title=title,
|
||||
@@ -314,6 +337,11 @@ class WorkflowService:
|
||||
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
|
||||
)
|
||||
|
||||
async def create_output_parameter(
|
||||
self, workflow_id: str, key: str, description: str | None = None
|
||||
) -> OutputParameter:
|
||||
return await app.DATABASE.create_output_parameter(workflow_id=workflow_id, key=key, description=description)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
|
||||
|
||||
@@ -334,6 +362,33 @@ class WorkflowService:
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_workflow_output_parameters(workflow_id: str) -> list[OutputParameter]:
|
||||
return await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_workflow_run_output_parameters(
|
||||
workflow_run_id: str,
|
||||
) -> list[WorkflowRunOutputParameter]:
|
||||
return await app.DATABASE.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_output_parameter_workflow_run_output_parameter_tuples(
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> list[tuple[OutputParameter, WorkflowRunOutputParameter]]:
|
||||
workflow_run_output_parameters = await app.DATABASE.get_workflow_run_output_parameters(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
output_parameters = await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id)
|
||||
|
||||
return [
|
||||
(output_parameter, workflow_run_output_parameter)
|
||||
for output_parameter in output_parameters
|
||||
for workflow_run_output_parameter in workflow_run_output_parameters
|
||||
if output_parameter.output_parameter_id == workflow_run_output_parameter.output_parameter_id
|
||||
]
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
|
||||
@@ -377,15 +432,27 @@ class WorkflowService:
|
||||
|
||||
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
|
||||
payload = {
|
||||
task.task_id: {
|
||||
"title": task.title,
|
||||
"extracted_information": task.extracted_information,
|
||||
"navigation_payload": task.navigation_payload,
|
||||
"errors": await app.agent.get_task_errors(task=task),
|
||||
output_parameter_tuples: list[
|
||||
tuple[OutputParameter, WorkflowRunOutputParameter]
|
||||
] = await self.get_output_parameter_workflow_run_output_parameter_tuples(
|
||||
workflow_id=workflow_id, workflow_run_id=workflow_run_id
|
||||
)
|
||||
if output_parameter_tuples:
|
||||
payload = {
|
||||
output_parameter.key: wfrp.value
|
||||
for output_parameter, wfrp in output_parameter_tuples
|
||||
if wfrp.value is not None
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
task.task_id: {
|
||||
"title": task.title,
|
||||
"extracted_information": task.extracted_information,
|
||||
"navigation_payload": task.navigation_payload,
|
||||
"errors": await app.agent.get_task_errors(task=task),
|
||||
}
|
||||
for task in workflow_run_tasks
|
||||
}
|
||||
for task in workflow_run_tasks
|
||||
}
|
||||
return WorkflowRunStatusResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
@@ -419,6 +486,13 @@ class WorkflowService:
|
||||
|
||||
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
|
||||
|
||||
workflow_run_status_response = await self.build_workflow_run_status_response(
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
LOG.info("Built workflow run status response", workflow_run_status_response=workflow_run_status_response)
|
||||
|
||||
if not workflow_run.webhook_callback_url:
|
||||
LOG.warning(
|
||||
"Workflow has no webhook callback url. Not sending workflow response",
|
||||
@@ -435,11 +509,6 @@ class WorkflowService:
|
||||
)
|
||||
return
|
||||
|
||||
workflow_run_status_response = await self.build_workflow_run_status_response(
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
# send task_response to the webhook callback url
|
||||
# TODO: use async requests (httpx)
|
||||
timestamp = str(int(datetime.utcnow().timestamp()))
|
||||
|
||||
Reference in New Issue
Block a user