diff --git a/skyvern/forge/sdk/routes/sdk.py b/skyvern/forge/sdk/routes/sdk.py index a01924be..631e9196 100644 --- a/skyvern/forge/sdk/routes/sdk.py +++ b/skyvern/forge/sdk/routes/sdk.py @@ -130,6 +130,8 @@ async def run_sdk_action( [], [], [], + None, + workflow, ) context = skyvern_context.ensure_context() diff --git a/skyvern/forge/sdk/workflow/context_manager.py b/skyvern/forge/sdk/workflow/context_manager.py index 9e9a7c75..b1b263a5 100644 --- a/skyvern/forge/sdk/workflow/context_manager.py +++ b/skyvern/forge/sdk/workflow/context_manager.py @@ -44,7 +44,7 @@ from skyvern.forge.sdk.workflow.models.parameter import ( from skyvern.utils.strings import generate_random_string if TYPE_CHECKING: - from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunParameter + from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRunParameter LOG = structlog.get_logger() @@ -76,6 +76,7 @@ class WorkflowRunContext: | CredentialParameter ], block_outputs: dict[str, Any] | None = None, + workflow: "Workflow | None" = None, ) -> Self: # key is label name workflow_run_context = cls( @@ -84,6 +85,7 @@ class WorkflowRunContext: workflow_permanent_id=workflow_permanent_id, workflow_run_id=workflow_run_id, aws_client=aws_client, + workflow=workflow, ) workflow_run_context.organization_id = organization.organization_id @@ -161,11 +163,13 @@ class WorkflowRunContext: workflow_permanent_id: str, workflow_run_id: str, aws_client: AsyncAWSClient, + workflow: "Workflow | None" = None, ) -> None: self.workflow_title = workflow_title self.workflow_id = workflow_id self.workflow_permanent_id = workflow_permanent_id self.workflow_run_id = workflow_run_id + self.workflow = workflow self.blocks_metadata: dict[str, BlockMetadata] = {} self.parameters: dict[str, PARAMETER_TYPE] = {} self.values: dict[str, Any] = {} @@ -175,6 +179,13 @@ class WorkflowRunContext: self.include_secrets_in_templates: bool = False self.credential_totp_identifiers: dict[str, str] = {} + def set_workflow(self, workflow: "Workflow") -> None: + """ + Update the cached workflow object in the context. + This is used when the workflow is fetched from the database as a fallback. + """ + self.workflow = workflow + def get_parameter(self, key: str) -> Parameter: return self.parameters[key] @@ -1078,6 +1089,7 @@ class WorkflowContextManager: | CredentialParameter ], block_outputs: dict[str, Any] | None = None, + workflow: "Workflow | None" = None, ) -> WorkflowRunContext: workflow_run_context = await WorkflowRunContext.init( self.aws_client, @@ -1091,6 +1103,7 @@ class WorkflowContextManager: context_parameters, secret_parameters, block_outputs, + workflow, ) self.workflow_run_contexts[workflow_run_id] = workflow_run_context return workflow_run_context diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index a78a4157..40bd739c 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -626,9 +626,14 @@ class BaseTaskBlock(Block): workflow_run_id=workflow_run_id, organization_id=organization_id, ) - workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( - workflow_permanent_id=workflow_run.workflow_permanent_id, - ) + # Get workflow from context if available, otherwise query database + workflow = workflow_run_context.workflow + if workflow is None: + workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_run.workflow_permanent_id, + ) + # Cache the workflow back to context for future block executions + workflow_run_context.set_workflow(workflow) # if the task url is parameterized, we need to get the value from the workflow run context if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url): task_url_parameter_value = workflow_run_context.get_value(self.url) diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index f90b7a50..2d60ec68 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -624,6 +624,7 @@ class WorkflowService: context_parameters, secret_parameters, block_outputs, + workflow, ) except Exception as e: LOG.exception( diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index a5ae965e..0a83f315 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -1096,6 +1096,8 @@ async def _set_up_workflow_context(workflow: Workflow, workflow_run_id: str, org workflow_output_parameters, [], [], + None, + workflow, )