Workflow: Output Parameters & Code Blocks (#117)

This commit is contained in:
Kerem Yilmaz
2024-03-21 17:16:56 -07:00
committed by GitHub
parent d2ca6ca792
commit 066c2302b5
11 changed files with 556 additions and 44 deletions

View File

@@ -15,11 +15,13 @@ from skyvern.forge.sdk.db.models import (
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
OutputParameterModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunOutputParameterModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.db.utils import (
@@ -28,17 +30,30 @@ from skyvern.forge.sdk.db.utils import (
convert_to_aws_secret_parameter,
convert_to_organization,
convert_to_organization_auth_token,
convert_to_output_parameter,
convert_to_step,
convert_to_task,
convert_to_workflow,
convert_to_workflow_parameter,
convert_to_workflow_run,
convert_to_workflow_run_output_parameter,
convert_to_workflow_run_parameter,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
OutputParameter,
WorkflowParameter,
WorkflowParameterType,
)
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowRun,
WorkflowRunOutputParameter,
WorkflowRunParameter,
WorkflowRunStatus,
)
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
@@ -777,6 +792,68 @@ class AgentDB:
session.refresh(aws_secret_parameter)
return convert_to_aws_secret_parameter(aws_secret_parameter)
async def create_output_parameter(
self,
workflow_id: str,
key: str,
description: str | None = None,
) -> OutputParameter:
with self.Session() as session:
output_parameter = OutputParameterModel(
key=key,
description=description,
workflow_id=workflow_id,
)
session.add(output_parameter)
session.commit()
session.refresh(output_parameter)
return convert_to_output_parameter(output_parameter)
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
try:
with self.Session() as session:
output_parameters = session.query(OutputParameterModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
try:
with self.Session() as session:
workflow_run_output_parameters = (
session.query(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(WorkflowRunOutputParameterModel.created_at)
.all()
)
return [
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
for parameter in workflow_run_output_parameters
]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_run_output_parameter(
self, workflow_run_id: str, output_parameter_id: str, value: dict[str, Any] | list | str | None
) -> WorkflowRunOutputParameter:
try:
with self.Session() as session:
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
workflow_run_id=workflow_run_id,
output_parameter_id=output_parameter_id,
value=value,
)
session.add(workflow_run_output_parameter)
session.commit()
session.refresh(workflow_run_output_parameter)
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
try:
with self.Session() as session:

View File

@@ -37,6 +37,7 @@ WORKFLOW_PREFIX = "w"
WORKFLOW_RUN_PREFIX = "wr"
WORKFLOW_PARAMETER_PREFIX = "wp"
AWS_SECRET_PARAMETER_PREFIX = "asp"
OUTPUT_PARAMETER_PREFIX = "op"
def generate_workflow_id() -> str:
@@ -59,6 +60,11 @@ def generate_workflow_parameter_id() -> str:
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
def generate_output_parameter_id() -> str:
int_id = generate_id()
return f"{OUTPUT_PARAMETER_PREFIX}_{int_id}"
def generate_organization_auth_token_id() -> str:
int_id = generate_id()
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"

View File

@@ -9,6 +9,7 @@ from skyvern.forge.sdk.db.id import (
generate_aws_secret_parameter_id,
generate_org_id,
generate_organization_auth_token_id,
generate_output_parameter_id,
generate_step_id,
generate_task_id,
generate_workflow_id,
@@ -150,6 +151,18 @@ class WorkflowParameterModel(Base):
deleted_at = Column(DateTime, nullable=True)
class OutputParameterModel(Base):
__tablename__ = "output_parameters"
output_parameter_id = Column(String, primary_key=True, index=True, default=generate_output_parameter_id)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class AWSSecretParameterModel(Base):
__tablename__ = "aws_secret_parameters"
@@ -173,3 +186,14 @@ class WorkflowRunParameterModel(Base):
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
value = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
class WorkflowRunOutputParameterModel(Base):
__tablename__ = "workflow_run_output_parameters"
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
output_parameter_id = Column(
String, ForeignKey("output_parameters.output_parameter_id"), primary_key=True, index=True
)
value = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -11,20 +11,28 @@ from skyvern.forge.sdk.db.models import (
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
OutputParameterModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunOutputParameterModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
OutputParameter,
WorkflowParameter,
WorkflowParameterType,
)
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRun,
WorkflowRunOutputParameter,
WorkflowRunParameter,
WorkflowRunStatus,
)
@@ -188,7 +196,7 @@ def convert_to_aws_secret_parameter(
if debug_enabled:
LOG.debug(
"Converting AWSSecretParameterModel to AWSSecretParameter",
aws_secret_parameter_id=aws_secret_parameter_model.id,
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
)
return AWSSecretParameter(
@@ -203,6 +211,45 @@ def convert_to_aws_secret_parameter(
)
def convert_to_output_parameter(
output_parameter_model: OutputParameterModel, debug_enabled: bool = False
) -> OutputParameter:
if debug_enabled:
LOG.debug(
"Converting OutputParameterModel to OutputParameter",
output_parameter_id=output_parameter_model.output_parameter_id,
)
return OutputParameter(
output_parameter_id=output_parameter_model.output_parameter_id,
key=output_parameter_model.key,
description=output_parameter_model.description,
workflow_id=output_parameter_model.workflow_id,
created_at=output_parameter_model.created_at,
modified_at=output_parameter_model.modified_at,
deleted_at=output_parameter_model.deleted_at,
)
def convert_to_workflow_run_output_parameter(
workflow_run_output_parameter_model: WorkflowRunOutputParameterModel,
debug_enabled: bool = False,
) -> WorkflowRunOutputParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowRunOutputParameterModel to WorkflowRunOutputParameter",
workflow_run_id=workflow_run_output_parameter_model.workflow_run_id,
output_parameter_id=workflow_run_output_parameter_model.output_parameter_id,
)
return WorkflowRunOutputParameter(
workflow_run_id=workflow_run_output_parameter_model.workflow_run_id,
output_parameter_id=workflow_run_output_parameter_model.output_parameter_id,
value=workflow_run_output_parameter_model.value,
created_at=workflow_run_output_parameter_model.created_at,
)
def convert_to_workflow_run_parameter(
workflow_run_parameter_model: WorkflowRunParameterModel,
workflow_parameter: WorkflowParameter,

View File

@@ -5,12 +5,18 @@ import structlog
from skyvern.exceptions import WorkflowRunContextNotInitialized
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
OutputParameter,
Parameter,
ParameterType,
WorkflowParameter,
)
if TYPE_CHECKING:
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter
LOG = structlog.get_logger()
@@ -19,7 +25,11 @@ class WorkflowRunContext:
values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
def __init__(
self,
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
workflow_output_parameters: list[OutputParameter],
) -> None:
self.parameters = {}
self.values = {}
self.secrets = {}
@@ -34,6 +44,11 @@ class WorkflowRunContext:
self.parameters[parameter.key] = parameter
self.values[parameter.key] = run_parameter.value
for output_parameter in workflow_output_parameters:
if output_parameter.key in self.parameters:
raise OutputParameterKeyCollisionError(output_parameter.key)
self.parameters[output_parameter.key] = output_parameter
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
@@ -48,11 +63,23 @@ class WorkflowRunContext:
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
def get_original_secret_value_or_none(self, secret_id: str) -> Any:
def get_original_secret_value_or_none(self, secret_id_or_value: Any) -> Any:
"""
Get the original secret value from the secrets dict. If the secret id is not found, return None.
This function can be called with any possible parameter value, not just the random secret id.
All the obfuscated secret values are strings, so if the parameter value is a string, we'll assume it's a
parameter value and return it.
If the parameter value is a string, it could be a random secret id or an actual parameter value. We'll check if
the parameter value is a key in the secrets dict. If it is, we'll return the secret value. If it's not, we'll
assume it's an actual parameter value and return it.
"""
return self.secrets.get(secret_id)
if type(secret_id_or_value) is str:
return self.secrets.get(secret_id_or_value)
return None
@staticmethod
def generate_random_secret_id() -> str:
@@ -68,6 +95,9 @@ class WorkflowRunContext:
raise ValueError(
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
)
elif parameter.parameter_type == ParameterType.OUTPUT:
LOG.error(f"Output parameters are set after each block execution. Parameter key: {parameter.key}")
raise ValueError(f"Output parameters are set after each block execution. Parameter key: {parameter.key}")
elif parameter.parameter_type == ParameterType.AWS_SECRET:
# If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
@@ -77,9 +107,20 @@ class WorkflowRunContext:
random_secret_id = self.generate_random_secret_id()
self.secrets[random_secret_id] = secret_value
self.values[parameter.key] = random_secret_id
else:
elif parameter.parameter_type == ParameterType.CONTEXT:
# ContextParameter values will be set within the blocks
return None
return
else:
raise ValueError(f"Unknown parameter type: {parameter.parameter_type}")
async def register_output_parameter_value_post_execution(
self, parameter: OutputParameter, value: dict[str, Any] | list | str | None
) -> None:
if parameter.key in self.values:
LOG.error(f"Output parameter {parameter.output_parameter_id} already has a registered value")
return
self.values[parameter.key] = value
async def register_block_parameters(
self,
@@ -98,6 +139,13 @@ class WorkflowRunContext:
raise ValueError(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
elif parameter.parameter_type == ParameterType.OUTPUT:
LOG.error(
f"Output parameter {parameter.key} should have already been set through workflow run context init"
)
raise ValueError(
f"Output parameter {parameter.key} should have already been set through workflow run context init"
)
self.parameters[parameter.key] = parameter
await self.register_parameter_value(aws_client, parameter)
@@ -121,9 +169,12 @@ class WorkflowContextManager:
raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id)
def initialize_workflow_run_context(
self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]
self,
workflow_run_id: str,
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
workflow_output_parameters: list[OutputParameter],
) -> WorkflowRunContext:
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples)
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples, workflow_output_parameters)
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
return workflow_run_context

View File

@@ -0,0 +1,23 @@
from skyvern.exceptions import SkyvernException
class BaseWorkflowException(SkyvernException):
pass
class WorkflowDefinitionHasDuplicateBlockLabels(BaseWorkflowException):
def __init__(self, duplicate_labels: set[str]) -> None:
super().__init__(
f"WorkflowDefinition has blocks with duplicate labels. Each block needs to have a unique "
f"label. Duplicate label(s): {','.join(duplicate_labels)}"
)
class OutputParameterKeyCollisionError(BaseWorkflowException):
def __init__(self, key: str, retry_count: int | None = None) -> None:
message = f"Output parameter key {key} already exists in the context manager."
if retry_count is not None:
message += f" Retrying {retry_count} more times."
elif retry_count == 0:
message += " Max duplicate retries reached, aborting."
super().__init__(message)

View File

@@ -14,7 +14,12 @@ from skyvern.exceptions import (
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
ContextParameter,
OutputParameter,
WorkflowParameter,
)
LOG = structlog.get_logger()
@@ -22,12 +27,14 @@ LOG = structlog.get_logger()
class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
CODE = "code"
class Block(BaseModel, abc.ABC):
# Must be unique within workflow definition
label: str
block_type: BlockType
parent_block_id: str | None = None
next_block_id: str | None = None
output_parameter: OutputParameter | None = None
@classmethod
def get_subclasses(cls) -> tuple[type["Block"], ...]:
@@ -38,7 +45,7 @@ class Block(BaseModel, abc.ABC):
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
@abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
pass
@abc.abstractmethod
@@ -96,7 +103,7 @@ class TaskBlock(Block):
return order, retry + 1
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
@@ -158,6 +165,32 @@ class TaskBlock(Block):
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
if updated_task.status == TaskStatus.completed:
will_retry = False
LOG.info(
f"Task completed",
task_id=updated_task.task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=updated_task.extracted_information,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=updated_task.extracted_information,
)
LOG.info(
f"Registered output parameter value",
output_parameter_id=self.output_parameter.output_parameter_id,
value=updated_task.extracted_information,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
task_id=updated_task.task_id,
)
return self.output_parameter
else:
current_retry += 1
will_retry = current_retry <= self.max_retries
@@ -172,6 +205,7 @@ class TaskBlock(Block):
current_retry=current_retry,
max_retries=self.max_retries,
)
return None
class ForLoopBlock(Block):
@@ -216,9 +250,10 @@ class ForLoopBlock(Block):
return [parameter_value]
else:
# TODO (kerem): Implement this for context parameters
# TODO (kerem): Implement this for output parameters
raise NotImplementedError
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
LOG.info(
@@ -227,14 +262,77 @@ class ForLoopBlock(Block):
workflow_run_id=workflow_run_id,
num_loop_over_values=len(loop_over_values),
)
outputs_with_loop_values = []
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id)
if self.loop_block.output_parameter:
outputs_with_loop_values.append(
{
"loop_value": loop_over_value,
"output_parameter": self.loop_block.output_parameter,
"output_value": workflow_run_context.get_value(self.loop_block.output_parameter.key),
}
)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=outputs_with_loop_values,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=outputs_with_loop_values,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
class CodeBlock(Block):
block_type: Literal[BlockType.CODE] = BlockType.CODE
code: str
parameters: list[PARAMETER_TYPE] = []
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
value = workflow_run_context.get_value(parameter.key)
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
if secret_value is not None:
parameter_values[parameter.key] = secret_value
else:
parameter_values[parameter.key] = value
local_variables: dict[str, Any] = {}
exec(self.code, parameter_values, local_variables)
result = {"result": local_variables.get("result")}
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=result,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=result,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -11,6 +11,7 @@ class ParameterType(StrEnum):
WORKFLOW = "workflow"
CONTEXT = "context"
AWS_SECRET = "aws_secret"
OUTPUT = "output"
class Parameter(BaseModel, abc.ABC):
@@ -80,5 +81,16 @@ class ContextParameter(Parameter):
value: str | int | float | bool | dict | list | None = None
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
class OutputParameter(Parameter):
parameter_type: Literal[ParameterType.OUTPUT] = ParameterType.OUTPUT
output_parameter_id: str
workflow_id: str
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter, OutputParameter]
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]

View File

@@ -5,6 +5,7 @@ from typing import Any, List
from pydantic import BaseModel
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
@@ -22,6 +23,18 @@ class RunWorkflowResponse(BaseModel):
class WorkflowDefinition(BaseModel):
blocks: List[BlockTypeVar]
def validate(self) -> None:
labels: set[str] = set()
duplicate_labels: set[str] = set()
for block in self.blocks:
if block.label in labels:
duplicate_labels.add(block.label)
else:
labels.add(block.label)
if duplicate_labels:
raise WorkflowDefinitionHasDuplicateBlockLabels(duplicate_labels)
class Workflow(BaseModel):
workflow_id: str
@@ -61,6 +74,13 @@ class WorkflowRunParameter(BaseModel):
created_at: datetime
class WorkflowRunOutputParameter(BaseModel):
workflow_run_id: str
output_parameter_id: str
value: dict[str, Any] | list | str | None
created_at: datetime
class WorkflowRunStatusResponse(BaseModel):
workflow_id: str
workflow_run_id: str

View File

@@ -20,12 +20,18 @@ from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
OutputParameter,
WorkflowParameter,
WorkflowParameterType,
)
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRequestBody,
WorkflowRun,
WorkflowRunOutputParameter,
WorkflowRunParameter,
WorkflowRunStatus,
WorkflowRunStatusResponse,
@@ -124,7 +130,10 @@ class WorkflowService:
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples)
workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(
workflow_run_id, wp_wps_tuples, workflow_output_parameters
)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
try:
@@ -148,15 +157,27 @@ class WorkflowService:
)
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
if not tasks:
LOG.warning(
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
if tasks:
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
else:
# Check if the workflow run has any workflow run output parameters
# if it does, mark the workflow run as completed, else mark it as failed
workflow_run_output_parameters = await self.get_workflow_run_output_parameters(
workflow_run_id=workflow_run.workflow_run_id
)
if workflow_run_output_parameters:
LOG.info(
f"Workflow run {workflow_run.workflow_run_id} has output parameters, marking as completed",
workflow_run_id=workflow_run.workflow_run_id,
)
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.error(
f"Workflow run {workflow_run.workflow_run_id} has no tasks or output parameters, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
@@ -232,6 +253,8 @@ class WorkflowService:
description: str | None = None,
workflow_definition: WorkflowDefinition | None = None,
) -> Workflow | None:
if workflow_definition:
workflow_definition.validate()
return await app.DATABASE.update_workflow(
workflow_id=workflow_id,
title=title,
@@ -314,6 +337,11 @@ class WorkflowService:
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
)
async def create_output_parameter(
self, workflow_id: str, key: str, description: str | None = None
) -> OutputParameter:
return await app.DATABASE.create_output_parameter(workflow_id=workflow_id, key=key, description=description)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
@@ -334,6 +362,33 @@ class WorkflowService:
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
@staticmethod
async def get_workflow_output_parameters(workflow_id: str) -> list[OutputParameter]:
return await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id)
@staticmethod
async def get_workflow_run_output_parameters(
workflow_run_id: str,
) -> list[WorkflowRunOutputParameter]:
return await app.DATABASE.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id)
@staticmethod
async def get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id: str,
workflow_run_id: str,
) -> list[tuple[OutputParameter, WorkflowRunOutputParameter]]:
workflow_run_output_parameters = await app.DATABASE.get_workflow_run_output_parameters(
workflow_run_id=workflow_run_id
)
output_parameters = await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id)
return [
(output_parameter, workflow_run_output_parameter)
for output_parameter in output_parameters
for workflow_run_output_parameter in workflow_run_output_parameters
if output_parameter.output_parameter_id == workflow_run_output_parameter.output_parameter_id
]
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
@@ -377,15 +432,27 @@ class WorkflowService:
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
payload = {
task.task_id: {
"title": task.title,
"extracted_information": task.extracted_information,
"navigation_payload": task.navigation_payload,
"errors": await app.agent.get_task_errors(task=task),
output_parameter_tuples: list[
tuple[OutputParameter, WorkflowRunOutputParameter]
] = await self.get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id=workflow_id, workflow_run_id=workflow_run_id
)
if output_parameter_tuples:
payload = {
output_parameter.key: wfrp.value
for output_parameter, wfrp in output_parameter_tuples
if wfrp.value is not None
}
else:
payload = {
task.task_id: {
"title": task.title,
"extracted_information": task.extracted_information,
"navigation_payload": task.navigation_payload,
"errors": await app.agent.get_task_errors(task=task),
}
for task in workflow_run_tasks
}
for task in workflow_run_tasks
}
return WorkflowRunStatusResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
@@ -419,6 +486,13 @@ class WorkflowService:
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
workflow_run_status_response = await self.build_workflow_run_status_response(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
LOG.info("Built workflow run status response", workflow_run_status_response=workflow_run_status_response)
if not workflow_run.webhook_callback_url:
LOG.warning(
"Workflow has no webhook callback url. Not sending workflow response",
@@ -435,11 +509,6 @@ class WorkflowService:
)
return
workflow_run_status_response = await self.build_workflow_run_status_response(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
# send task_response to the webhook callback url
# TODO: use async requests (httpx)
timestamp = str(int(datetime.utcnow().timestamp()))