From 066c2302b53e1c9c7c623612c5d213943cf109dc Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Thu, 21 Mar 2024 17:16:56 -0700 Subject: [PATCH] Workflow: Output Parameters & Code Blocks (#117) --- ...10-ffe2f57bd288_create_output_parameter.py | 85 +++++++++++++ skyvern/forge/sdk/db/client.py | 81 +++++++++++- skyvern/forge/sdk/db/id.py | 6 + skyvern/forge/sdk/db/models.py | 24 ++++ skyvern/forge/sdk/db/utils.py | 51 +++++++- skyvern/forge/sdk/workflow/context_manager.py | 69 +++++++++-- skyvern/forge/sdk/workflow/exceptions.py | 23 ++++ skyvern/forge/sdk/workflow/models/block.py | 112 +++++++++++++++-- .../forge/sdk/workflow/models/parameter.py | 14 ++- skyvern/forge/sdk/workflow/models/workflow.py | 20 +++ skyvern/forge/sdk/workflow/service.py | 115 ++++++++++++++---- 11 files changed, 556 insertions(+), 44 deletions(-) create mode 100644 alembic/versions/2024_03_22_0010-ffe2f57bd288_create_output_parameter.py create mode 100644 skyvern/forge/sdk/workflow/exceptions.py diff --git a/alembic/versions/2024_03_22_0010-ffe2f57bd288_create_output_parameter.py b/alembic/versions/2024_03_22_0010-ffe2f57bd288_create_output_parameter.py new file mode 100644 index 00000000..51b40501 --- /dev/null +++ b/alembic/versions/2024_03_22_0010-ffe2f57bd288_create_output_parameter.py @@ -0,0 +1,85 @@ +"""Create output parameter + +Revision ID: ffe2f57bd288 +Revises: 82a0c686152d +Create Date: 2024-03-22 00:10:16.225454+00:00 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "ffe2f57bd288" +down_revision: Union[str, None] = "82a0c686152d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "output_parameters", + sa.Column("output_parameter_id", sa.String(), nullable=False), + sa.Column("key", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("workflow_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("modified_at", sa.DateTime(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["workflow_id"], + ["workflows.workflow_id"], + ), + sa.PrimaryKeyConstraint("output_parameter_id"), + ) + op.create_index( + op.f("ix_output_parameters_output_parameter_id"), "output_parameters", ["output_parameter_id"], unique=False + ) + op.create_index(op.f("ix_output_parameters_workflow_id"), "output_parameters", ["workflow_id"], unique=False) + op.create_table( + "workflow_run_output_parameters", + sa.Column("workflow_run_id", sa.String(), nullable=False), + sa.Column("output_parameter_id", sa.String(), nullable=False), + sa.Column("value", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["output_parameter_id"], + ["output_parameters.output_parameter_id"], + ), + sa.ForeignKeyConstraint( + ["workflow_run_id"], + ["workflow_runs.workflow_run_id"], + ), + sa.PrimaryKeyConstraint("workflow_run_id", "output_parameter_id"), + ) + op.create_index( + op.f("ix_workflow_run_output_parameters_output_parameter_id"), + "workflow_run_output_parameters", + ["output_parameter_id"], + unique=False, + ) + op.create_index( + op.f("ix_workflow_run_output_parameters_workflow_run_id"), + "workflow_run_output_parameters", + ["workflow_run_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_workflow_run_output_parameters_workflow_run_id"), table_name="workflow_run_output_parameters" + ) + op.drop_index( + op.f("ix_workflow_run_output_parameters_output_parameter_id"), table_name="workflow_run_output_parameters" + ) + op.drop_table("workflow_run_output_parameters") + op.drop_index(op.f("ix_output_parameters_workflow_id"), table_name="output_parameters") + op.drop_index(op.f("ix_output_parameters_output_parameter_id"), table_name="output_parameters") + op.drop_table("output_parameters") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index b8291be1..9a8e2a18 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -15,11 +15,13 @@ from skyvern.forge.sdk.db.models import ( AWSSecretParameterModel, OrganizationAuthTokenModel, OrganizationModel, + OutputParameterModel, StepModel, TaskModel, WorkflowModel, WorkflowParameterModel, WorkflowRunModel, + WorkflowRunOutputParameterModel, WorkflowRunParameterModel, ) from skyvern.forge.sdk.db.utils import ( @@ -28,17 +30,30 @@ from skyvern.forge.sdk.db.utils import ( convert_to_aws_secret_parameter, convert_to_organization, convert_to_organization_auth_token, + convert_to_output_parameter, convert_to_step, convert_to_task, convert_to_workflow, convert_to_workflow_parameter, convert_to_workflow_run, + convert_to_workflow_run_output_parameter, convert_to_workflow_run_parameter, ) from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus -from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType -from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus +from skyvern.forge.sdk.workflow.models.parameter import ( + AWSSecretParameter, + OutputParameter, + WorkflowParameter, + WorkflowParameterType, +) +from skyvern.forge.sdk.workflow.models.workflow import ( + Workflow, + WorkflowRun, + WorkflowRunOutputParameter, + WorkflowRunParameter, + WorkflowRunStatus, +) from skyvern.webeye.actions.models import AgentStepOutput LOG = structlog.get_logger() @@ -777,6 +792,68 @@ class AgentDB: session.refresh(aws_secret_parameter) return convert_to_aws_secret_parameter(aws_secret_parameter) + async def create_output_parameter( + self, + workflow_id: str, + key: str, + description: str | None = None, + ) -> OutputParameter: + with self.Session() as session: + output_parameter = OutputParameterModel( + key=key, + description=description, + workflow_id=workflow_id, + ) + session.add(output_parameter) + session.commit() + session.refresh(output_parameter) + return convert_to_output_parameter(output_parameter) + + async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]: + try: + with self.Session() as session: + output_parameters = session.query(OutputParameterModel).filter_by(workflow_id=workflow_id).all() + return [convert_to_output_parameter(parameter) for parameter in output_parameters] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + + async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]: + try: + with self.Session() as session: + workflow_run_output_parameters = ( + session.query(WorkflowRunOutputParameterModel) + .filter_by(workflow_run_id=workflow_run_id) + .order_by(WorkflowRunOutputParameterModel.created_at) + .all() + ) + return [ + convert_to_workflow_run_output_parameter(parameter, self.debug_enabled) + for parameter in workflow_run_output_parameters + ] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + + async def create_workflow_run_output_parameter( + self, workflow_run_id: str, output_parameter_id: str, value: dict[str, Any] | list | str | None + ) -> WorkflowRunOutputParameter: + try: + with self.Session() as session: + workflow_run_output_parameter = WorkflowRunOutputParameterModel( + workflow_run_id=workflow_run_id, + output_parameter_id=output_parameter_id, + value=value, + ) + session.add(workflow_run_output_parameter) + session.commit() + session.refresh(workflow_run_output_parameter) + return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled) + + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]: try: with self.Session() as session: diff --git a/skyvern/forge/sdk/db/id.py b/skyvern/forge/sdk/db/id.py index 329cc057..03380dad 100644 --- a/skyvern/forge/sdk/db/id.py +++ b/skyvern/forge/sdk/db/id.py @@ -37,6 +37,7 @@ WORKFLOW_PREFIX = "w" WORKFLOW_RUN_PREFIX = "wr" WORKFLOW_PARAMETER_PREFIX = "wp" AWS_SECRET_PARAMETER_PREFIX = "asp" +OUTPUT_PARAMETER_PREFIX = "op" def generate_workflow_id() -> str: @@ -59,6 +60,11 @@ def generate_workflow_parameter_id() -> str: return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}" +def generate_output_parameter_id() -> str: + int_id = generate_id() + return f"{OUTPUT_PARAMETER_PREFIX}_{int_id}" + + def generate_organization_auth_token_id() -> str: int_id = generate_id() return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}" diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 6c759747..e4302f65 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -9,6 +9,7 @@ from skyvern.forge.sdk.db.id import ( generate_aws_secret_parameter_id, generate_org_id, generate_organization_auth_token_id, + generate_output_parameter_id, generate_step_id, generate_task_id, generate_workflow_id, @@ -150,6 +151,18 @@ class WorkflowParameterModel(Base): deleted_at = Column(DateTime, nullable=True) +class OutputParameterModel(Base): + __tablename__ = "output_parameters" + + output_parameter_id = Column(String, primary_key=True, index=True, default=generate_output_parameter_id) + key = Column(String, nullable=False) + description = Column(String, nullable=True) + workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) + modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + deleted_at = Column(DateTime, nullable=True) + + class AWSSecretParameterModel(Base): __tablename__ = "aws_secret_parameters" @@ -173,3 +186,14 @@ class WorkflowRunParameterModel(Base): # Can be bool | int | float | str | dict | list depending on the workflow parameter type value = Column(String, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) + + +class WorkflowRunOutputParameterModel(Base): + __tablename__ = "workflow_run_output_parameters" + + workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True) + output_parameter_id = Column( + String, ForeignKey("output_parameters.output_parameter_id"), primary_key=True, index=True + ) + value = Column(JSON, nullable=False) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index b18b4b7b..f711a334 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -11,20 +11,28 @@ from skyvern.forge.sdk.db.models import ( AWSSecretParameterModel, OrganizationAuthTokenModel, OrganizationModel, + OutputParameterModel, StepModel, TaskModel, WorkflowModel, WorkflowParameterModel, WorkflowRunModel, + WorkflowRunOutputParameterModel, WorkflowRunParameterModel, ) from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus -from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType +from skyvern.forge.sdk.workflow.models.parameter import ( + AWSSecretParameter, + OutputParameter, + WorkflowParameter, + WorkflowParameterType, +) from skyvern.forge.sdk.workflow.models.workflow import ( Workflow, WorkflowDefinition, WorkflowRun, + WorkflowRunOutputParameter, WorkflowRunParameter, WorkflowRunStatus, ) @@ -188,7 +196,7 @@ def convert_to_aws_secret_parameter( if debug_enabled: LOG.debug( "Converting AWSSecretParameterModel to AWSSecretParameter", - aws_secret_parameter_id=aws_secret_parameter_model.id, + aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id, ) return AWSSecretParameter( @@ -203,6 +211,45 @@ def convert_to_aws_secret_parameter( ) +def convert_to_output_parameter( + output_parameter_model: OutputParameterModel, debug_enabled: bool = False +) -> OutputParameter: + if debug_enabled: + LOG.debug( + "Converting OutputParameterModel to OutputParameter", + output_parameter_id=output_parameter_model.output_parameter_id, + ) + + return OutputParameter( + output_parameter_id=output_parameter_model.output_parameter_id, + key=output_parameter_model.key, + description=output_parameter_model.description, + workflow_id=output_parameter_model.workflow_id, + created_at=output_parameter_model.created_at, + modified_at=output_parameter_model.modified_at, + deleted_at=output_parameter_model.deleted_at, + ) + + +def convert_to_workflow_run_output_parameter( + workflow_run_output_parameter_model: WorkflowRunOutputParameterModel, + debug_enabled: bool = False, +) -> WorkflowRunOutputParameter: + if debug_enabled: + LOG.debug( + "Converting WorkflowRunOutputParameterModel to WorkflowRunOutputParameter", + workflow_run_id=workflow_run_output_parameter_model.workflow_run_id, + output_parameter_id=workflow_run_output_parameter_model.output_parameter_id, + ) + + return WorkflowRunOutputParameter( + workflow_run_id=workflow_run_output_parameter_model.workflow_run_id, + output_parameter_id=workflow_run_output_parameter_model.output_parameter_id, + value=workflow_run_output_parameter_model.value, + created_at=workflow_run_output_parameter_model.created_at, + ) + + def convert_to_workflow_run_parameter( workflow_run_parameter_model: WorkflowRunParameterModel, workflow_parameter: WorkflowParameter, diff --git a/skyvern/forge/sdk/workflow/context_manager.py b/skyvern/forge/sdk/workflow/context_manager.py index c3296645..c781772a 100644 --- a/skyvern/forge/sdk/workflow/context_manager.py +++ b/skyvern/forge/sdk/workflow/context_manager.py @@ -5,12 +5,18 @@ import structlog from skyvern.exceptions import WorkflowRunContextNotInitialized from skyvern.forge.sdk.api.aws import AsyncAWSClient -from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter +from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError +from skyvern.forge.sdk.workflow.models.parameter import ( + PARAMETER_TYPE, + OutputParameter, + Parameter, + ParameterType, + WorkflowParameter, +) if TYPE_CHECKING: from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter - LOG = structlog.get_logger() @@ -19,7 +25,11 @@ class WorkflowRunContext: values: dict[str, Any] secrets: dict[str, Any] - def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None: + def __init__( + self, + workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]], + workflow_output_parameters: list[OutputParameter], + ) -> None: self.parameters = {} self.values = {} self.secrets = {} @@ -34,6 +44,11 @@ class WorkflowRunContext: self.parameters[parameter.key] = parameter self.values[parameter.key] = run_parameter.value + for output_parameter in workflow_output_parameters: + if output_parameter.key in self.parameters: + raise OutputParameterKeyCollisionError(output_parameter.key) + self.parameters[output_parameter.key] = output_parameter + def get_parameter(self, key: str) -> Parameter: return self.parameters[key] @@ -48,11 +63,23 @@ class WorkflowRunContext: def set_value(self, key: str, value: Any) -> None: self.values[key] = value - def get_original_secret_value_or_none(self, secret_id: str) -> Any: + def get_original_secret_value_or_none(self, secret_id_or_value: Any) -> Any: """ Get the original secret value from the secrets dict. If the secret id is not found, return None. + + This function can be called with any possible parameter value, not just the random secret id. + + All the obfuscated secret values are strings, so if the parameter value is a string, we'll assume it's a + parameter value and return it. + + If the parameter value is a string, it could be a random secret id or an actual parameter value. We'll check if + the parameter value is a key in the secrets dict. If it is, we'll return the secret value. If it's not, we'll + assume it's an actual parameter value and return it. + """ - return self.secrets.get(secret_id) + if type(secret_id_or_value) is str: + return self.secrets.get(secret_id_or_value) + return None @staticmethod def generate_random_secret_id() -> str: @@ -68,6 +95,9 @@ class WorkflowRunContext: raise ValueError( f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}" ) + elif parameter.parameter_type == ParameterType.OUTPUT: + LOG.error(f"Output parameters are set after each block execution. Parameter key: {parameter.key}") + raise ValueError(f"Output parameters are set after each block execution. Parameter key: {parameter.key}") elif parameter.parameter_type == ParameterType.AWS_SECRET: # If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict # The value of the parameter will be the random secret id with format `secret_`. @@ -77,9 +107,20 @@ class WorkflowRunContext: random_secret_id = self.generate_random_secret_id() self.secrets[random_secret_id] = secret_value self.values[parameter.key] = random_secret_id - else: + elif parameter.parameter_type == ParameterType.CONTEXT: # ContextParameter values will be set within the blocks - return None + return + else: + raise ValueError(f"Unknown parameter type: {parameter.parameter_type}") + + async def register_output_parameter_value_post_execution( + self, parameter: OutputParameter, value: dict[str, Any] | list | str | None + ) -> None: + if parameter.key in self.values: + LOG.error(f"Output parameter {parameter.output_parameter_id} already has a registered value") + return + + self.values[parameter.key] = value async def register_block_parameters( self, @@ -98,6 +139,13 @@ class WorkflowRunContext: raise ValueError( f"Workflow parameter {parameter.key} should have already been set through workflow run parameters" ) + elif parameter.parameter_type == ParameterType.OUTPUT: + LOG.error( + f"Output parameter {parameter.key} should have already been set through workflow run context init" + ) + raise ValueError( + f"Output parameter {parameter.key} should have already been set through workflow run context init" + ) self.parameters[parameter.key] = parameter await self.register_parameter_value(aws_client, parameter) @@ -121,9 +169,12 @@ class WorkflowContextManager: raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id) def initialize_workflow_run_context( - self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]] + self, + workflow_run_id: str, + workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]], + workflow_output_parameters: list[OutputParameter], ) -> WorkflowRunContext: - workflow_run_context = WorkflowRunContext(workflow_parameter_tuples) + workflow_run_context = WorkflowRunContext(workflow_parameter_tuples, workflow_output_parameters) self.workflow_run_contexts[workflow_run_id] = workflow_run_context return workflow_run_context diff --git a/skyvern/forge/sdk/workflow/exceptions.py b/skyvern/forge/sdk/workflow/exceptions.py new file mode 100644 index 00000000..2f994365 --- /dev/null +++ b/skyvern/forge/sdk/workflow/exceptions.py @@ -0,0 +1,23 @@ +from skyvern.exceptions import SkyvernException + + +class BaseWorkflowException(SkyvernException): + pass + + +class WorkflowDefinitionHasDuplicateBlockLabels(BaseWorkflowException): + def __init__(self, duplicate_labels: set[str]) -> None: + super().__init__( + f"WorkflowDefinition has blocks with duplicate labels. Each block needs to have a unique " + f"label. Duplicate label(s): {','.join(duplicate_labels)}" + ) + + +class OutputParameterKeyCollisionError(BaseWorkflowException): + def __init__(self, key: str, retry_count: int | None = None) -> None: + message = f"Output parameter key {key} already exists in the context manager." + if retry_count is not None: + message += f" Retrying {retry_count} more times." + elif retry_count == 0: + message += " Max duplicate retries reached, aborting." + super().__init__(message) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 3336c3e5..d7653439 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -14,7 +14,12 @@ from skyvern.exceptions import ( from skyvern.forge import app from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext -from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter +from skyvern.forge.sdk.workflow.models.parameter import ( + PARAMETER_TYPE, + ContextParameter, + OutputParameter, + WorkflowParameter, +) LOG = structlog.get_logger() @@ -22,12 +27,14 @@ LOG = structlog.get_logger() class BlockType(StrEnum): TASK = "task" FOR_LOOP = "for_loop" + CODE = "code" class Block(BaseModel, abc.ABC): + # Must be unique within workflow definition + label: str block_type: BlockType - parent_block_id: str | None = None - next_block_id: str | None = None + output_parameter: OutputParameter | None = None @classmethod def get_subclasses(cls) -> tuple[type["Block"], ...]: @@ -38,7 +45,7 @@ class Block(BaseModel, abc.ABC): return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) @abc.abstractmethod - async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any: + async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: pass @abc.abstractmethod @@ -96,7 +103,7 @@ class TaskBlock(Block): return order, retry + 1 - async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any: + async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: workflow_run_context = self.get_workflow_run_context(workflow_run_id) current_retry = 0 # initial value for will_retry is True, so that the loop runs at least once @@ -158,6 +165,32 @@ class TaskBlock(Block): raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status) if updated_task.status == TaskStatus.completed: will_retry = False + LOG.info( + f"Task completed", + task_id=updated_task.task_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow.workflow_id, + organization_id=workflow.organization_id, + ) + if self.output_parameter: + await workflow_run_context.register_output_parameter_value_post_execution( + parameter=self.output_parameter, + value=updated_task.extracted_information, + ) + await app.DATABASE.create_workflow_run_output_parameter( + workflow_run_id=workflow_run_id, + output_parameter_id=self.output_parameter.output_parameter_id, + value=updated_task.extracted_information, + ) + LOG.info( + f"Registered output parameter value", + output_parameter_id=self.output_parameter.output_parameter_id, + value=updated_task.extracted_information, + workflow_run_id=workflow_run_id, + workflow_id=workflow.workflow_id, + task_id=updated_task.task_id, + ) + return self.output_parameter else: current_retry += 1 will_retry = current_retry <= self.max_retries @@ -172,6 +205,7 @@ class TaskBlock(Block): current_retry=current_retry, max_retries=self.max_retries, ) + return None class ForLoopBlock(Block): @@ -216,9 +250,10 @@ class ForLoopBlock(Block): return [parameter_value] else: # TODO (kerem): Implement this for context parameters + # TODO (kerem): Implement this for output parameters raise NotImplementedError - async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any: + async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: workflow_run_context = self.get_workflow_run_context(workflow_run_id) loop_over_values = self.get_loop_over_parameter_values(workflow_run_context) LOG.info( @@ -227,14 +262,77 @@ class ForLoopBlock(Block): workflow_run_id=workflow_run_id, num_loop_over_values=len(loop_over_values), ) + outputs_with_loop_values = [] for loop_over_value in loop_over_values: context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value) for context_parameter in context_parameters_with_value: workflow_run_context.set_value(context_parameter.key, context_parameter.value) await self.loop_block.execute(workflow_run_id=workflow_run_id) + if self.loop_block.output_parameter: + outputs_with_loop_values.append( + { + "loop_value": loop_over_value, + "output_parameter": self.loop_block.output_parameter, + "output_value": workflow_run_context.get_value(self.loop_block.output_parameter.key), + } + ) + + if self.output_parameter: + await workflow_run_context.register_output_parameter_value_post_execution( + parameter=self.output_parameter, + value=outputs_with_loop_values, + ) + await app.DATABASE.create_workflow_run_output_parameter( + workflow_run_id=workflow_run_id, + output_parameter_id=self.output_parameter.output_parameter_id, + value=outputs_with_loop_values, + ) + return self.output_parameter return None -BlockSubclasses = Union[ForLoopBlock, TaskBlock] +class CodeBlock(Block): + block_type: Literal[BlockType.CODE] = BlockType.CODE + + code: str + parameters: list[PARAMETER_TYPE] = [] + + def get_all_parameters( + self, + ) -> list[PARAMETER_TYPE]: + return self.parameters + + async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: + # get workflow run context + workflow_run_context = self.get_workflow_run_context(workflow_run_id) + # get all parameters into a dictionary + parameter_values = {} + for parameter in self.parameters: + value = workflow_run_context.get_value(parameter.key) + secret_value = workflow_run_context.get_original_secret_value_or_none(value) + if secret_value is not None: + parameter_values[parameter.key] = secret_value + else: + parameter_values[parameter.key] = value + + local_variables: dict[str, Any] = {} + exec(self.code, parameter_values, local_variables) + result = {"result": local_variables.get("result")} + if self.output_parameter: + await workflow_run_context.register_output_parameter_value_post_execution( + parameter=self.output_parameter, + value=result, + ) + await app.DATABASE.create_workflow_run_output_parameter( + workflow_run_id=workflow_run_id, + output_parameter_id=self.output_parameter.output_parameter_id, + value=result, + ) + return self.output_parameter + + return None + + +BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock] BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/models/parameter.py b/skyvern/forge/sdk/workflow/models/parameter.py index ec0364e6..c9692977 100644 --- a/skyvern/forge/sdk/workflow/models/parameter.py +++ b/skyvern/forge/sdk/workflow/models/parameter.py @@ -11,6 +11,7 @@ class ParameterType(StrEnum): WORKFLOW = "workflow" CONTEXT = "context" AWS_SECRET = "aws_secret" + OUTPUT = "output" class Parameter(BaseModel, abc.ABC): @@ -80,5 +81,16 @@ class ContextParameter(Parameter): value: str | int | float | bool | dict | list | None = None -ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter] +class OutputParameter(Parameter): + parameter_type: Literal[ParameterType.OUTPUT] = ParameterType.OUTPUT + + output_parameter_id: str + workflow_id: str + + created_at: datetime + modified_at: datetime + deleted_at: datetime | None = None + + +ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter, OutputParameter] PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")] diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 4bf2b902..5e7ea185 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -5,6 +5,7 @@ from typing import Any, List from pydantic import BaseModel from skyvern.forge.sdk.schemas.tasks import ProxyLocation +from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels from skyvern.forge.sdk.workflow.models.block import BlockTypeVar @@ -22,6 +23,18 @@ class RunWorkflowResponse(BaseModel): class WorkflowDefinition(BaseModel): blocks: List[BlockTypeVar] + def validate(self) -> None: + labels: set[str] = set() + duplicate_labels: set[str] = set() + for block in self.blocks: + if block.label in labels: + duplicate_labels.add(block.label) + else: + labels.add(block.label) + + if duplicate_labels: + raise WorkflowDefinitionHasDuplicateBlockLabels(duplicate_labels) + class Workflow(BaseModel): workflow_id: str @@ -61,6 +74,13 @@ class WorkflowRunParameter(BaseModel): created_at: datetime +class WorkflowRunOutputParameter(BaseModel): + workflow_run_id: str + output_parameter_id: str + value: dict[str, Any] | list | str | None + created_at: datetime + + class WorkflowRunStatusResponse(BaseModel): workflow_id: str workflow_run_id: str diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 72a3089b..d0e77078 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -20,12 +20,18 @@ from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus -from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType +from skyvern.forge.sdk.workflow.models.parameter import ( + AWSSecretParameter, + OutputParameter, + WorkflowParameter, + WorkflowParameterType, +) from skyvern.forge.sdk.workflow.models.workflow import ( Workflow, WorkflowDefinition, WorkflowRequestBody, WorkflowRun, + WorkflowRunOutputParameter, WorkflowRunParameter, WorkflowRunStatus, WorkflowRunStatusResponse, @@ -124,7 +130,10 @@ class WorkflowService: # Get all tuples wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id) - app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples) + workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id) + app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context( + workflow_run_id, wp_wps_tuples, workflow_output_parameters + ) # Execute workflow blocks blocks = workflow.workflow_definition.blocks try: @@ -148,15 +157,27 @@ class WorkflowService: ) tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id) - if not tasks: - LOG.warning( - f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed", - workflow_run_id=workflow_run.workflow_run_id, - ) - await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) - return workflow_run - workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks) + if tasks: + workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks) + else: + # Check if the workflow run has any workflow run output parameters + # if it does, mark the workflow run as completed, else mark it as failed + workflow_run_output_parameters = await self.get_workflow_run_output_parameters( + workflow_run_id=workflow_run.workflow_run_id + ) + if workflow_run_output_parameters: + LOG.info( + f"Workflow run {workflow_run.workflow_run_id} has output parameters, marking as completed", + workflow_run_id=workflow_run.workflow_run_id, + ) + await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id) + else: + LOG.error( + f"Workflow run {workflow_run.workflow_run_id} has no tasks or output parameters, marking as failed", + workflow_run_id=workflow_run.workflow_run_id, + ) + await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) await self.send_workflow_response( workflow=workflow, workflow_run=workflow_run, @@ -232,6 +253,8 @@ class WorkflowService: description: str | None = None, workflow_definition: WorkflowDefinition | None = None, ) -> Workflow | None: + if workflow_definition: + workflow_definition.validate() return await app.DATABASE.update_workflow( workflow_id=workflow_id, title=title, @@ -314,6 +337,11 @@ class WorkflowService: workflow_id=workflow_id, aws_key=aws_key, key=key, description=description ) + async def create_output_parameter( + self, workflow_id: str, key: str, description: str | None = None + ) -> OutputParameter: + return await app.DATABASE.create_output_parameter(workflow_id=workflow_id, key=key, description=description) + async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]: return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id) @@ -334,6 +362,33 @@ class WorkflowService: ) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]: return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id) + @staticmethod + async def get_workflow_output_parameters(workflow_id: str) -> list[OutputParameter]: + return await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id) + + @staticmethod + async def get_workflow_run_output_parameters( + workflow_run_id: str, + ) -> list[WorkflowRunOutputParameter]: + return await app.DATABASE.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id) + + @staticmethod + async def get_output_parameter_workflow_run_output_parameter_tuples( + workflow_id: str, + workflow_run_id: str, + ) -> list[tuple[OutputParameter, WorkflowRunOutputParameter]]: + workflow_run_output_parameters = await app.DATABASE.get_workflow_run_output_parameters( + workflow_run_id=workflow_run_id + ) + output_parameters = await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id) + + return [ + (output_parameter, workflow_run_output_parameter) + for output_parameter in output_parameters + for workflow_run_output_parameter in workflow_run_output_parameters + if output_parameter.output_parameter_id == workflow_run_output_parameter.output_parameter_id + ] + async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None: return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id) @@ -377,15 +432,27 @@ class WorkflowService: workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id) parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples} - payload = { - task.task_id: { - "title": task.title, - "extracted_information": task.extracted_information, - "navigation_payload": task.navigation_payload, - "errors": await app.agent.get_task_errors(task=task), + output_parameter_tuples: list[ + tuple[OutputParameter, WorkflowRunOutputParameter] + ] = await self.get_output_parameter_workflow_run_output_parameter_tuples( + workflow_id=workflow_id, workflow_run_id=workflow_run_id + ) + if output_parameter_tuples: + payload = { + output_parameter.key: wfrp.value + for output_parameter, wfrp in output_parameter_tuples + if wfrp.value is not None + } + else: + payload = { + task.task_id: { + "title": task.title, + "extracted_information": task.extracted_information, + "navigation_payload": task.navigation_payload, + "errors": await app.agent.get_task_errors(task=task), + } + for task in workflow_run_tasks } - for task in workflow_run_tasks - } return WorkflowRunStatusResponse( workflow_id=workflow_id, workflow_run_id=workflow_run_id, @@ -419,6 +486,13 @@ class WorkflowService: await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids) + workflow_run_status_response = await self.build_workflow_run_status_response( + workflow_id=workflow.workflow_id, + workflow_run_id=workflow_run.workflow_run_id, + organization_id=workflow.organization_id, + ) + LOG.info("Built workflow run status response", workflow_run_status_response=workflow_run_status_response) + if not workflow_run.webhook_callback_url: LOG.warning( "Workflow has no webhook callback url. Not sending workflow response", @@ -435,11 +509,6 @@ class WorkflowService: ) return - workflow_run_status_response = await self.build_workflow_run_status_response( - workflow_id=workflow.workflow_id, - workflow_run_id=workflow_run.workflow_run_id, - organization_id=workflow.organization_id, - ) # send task_response to the webhook callback url # TODO: use async requests (httpx) timestamp = str(int(datetime.utcnow().timestamp()))