From 185fc330a4c0c956474bdbceccf52dec4d3f61cd Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 28 Jan 2025 16:59:54 +0800 Subject: [PATCH] add observer task block (#1665) --- ..._add_block_workflow_run_id_to_workflow_.py | 41 ++++++++++ skyvern/forge/sdk/db/client.py | 11 ++- skyvern/forge/sdk/db/models.py | 4 + skyvern/forge/sdk/db/utils.py | 2 + skyvern/forge/sdk/routes/agent_protocol.py | 69 +++++++++++----- skyvern/forge/sdk/schemas/workflow_runs.py | 1 + .../forge/sdk/services/observer_service.py | 7 +- skyvern/forge/sdk/workflow/models/block.py | 79 ++++++++++++++++++- skyvern/forge/sdk/workflow/models/workflow.py | 1 + skyvern/forge/sdk/workflow/models/yaml.py | 10 +++ skyvern/forge/sdk/workflow/service.py | 21 ++++- 11 files changed, 224 insertions(+), 22 deletions(-) create mode 100644 alembic/versions/2025_01_28_0853-3aa0ef96942d_add_block_workflow_run_id_to_workflow_.py diff --git a/alembic/versions/2025_01_28_0853-3aa0ef96942d_add_block_workflow_run_id_to_workflow_.py b/alembic/versions/2025_01_28_0853-3aa0ef96942d_add_block_workflow_run_id_to_workflow_.py new file mode 100644 index 00000000..59bd2ff4 --- /dev/null +++ b/alembic/versions/2025_01_28_0853-3aa0ef96942d_add_block_workflow_run_id_to_workflow_.py @@ -0,0 +1,41 @@ +"""add block_workflow_run_id to workflow_run_blocks table; add parent_workflow_run_id to workflow_runs table + +Revision ID: 3aa0ef96942d +Revises: 957ad2d1d3f7 +Create Date: 2025-01-28 08:53:06.357361+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "3aa0ef96942d" +down_revision: Union[str, None] = "957ad2d1d3f7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("workflow_run_blocks", sa.Column("block_workflow_run_id", sa.String(), nullable=True)) + op.create_foreign_key(None, "workflow_run_blocks", "workflow_runs", ["block_workflow_run_id"], ["workflow_run_id"]) + op.add_column("workflow_runs", sa.Column("parent_workflow_run_id", sa.String(), nullable=True)) + op.create_index( + op.f("ix_workflow_runs_parent_workflow_run_id"), "workflow_runs", ["parent_workflow_run_id"], unique=False + ) + op.create_foreign_key(None, "workflow_runs", "workflow_runs", ["parent_workflow_run_id"], ["workflow_run_id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "workflow_runs", type_="foreignkey") + op.drop_index(op.f("ix_workflow_runs_parent_workflow_run_id"), table_name="workflow_runs") + op.drop_column("workflow_runs", "parent_workflow_run_id") + op.drop_constraint(None, "workflow_run_blocks", type_="foreignkey") + op.drop_column("workflow_run_blocks", "block_workflow_run_id") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 0efd2ac7..e8fff53a 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1344,6 +1344,7 @@ class AgentDB: webhook_callback_url: str | None = None, totp_verification_url: str | None = None, totp_identifier: str | None = None, + parent_workflow_run_id: str | None = None, ) -> WorkflowRun: try: async with self.Session() as session: @@ -1356,6 +1357,7 @@ class AgentDB: webhook_callback_url=webhook_callback_url, totp_verification_url=totp_verification_url, totp_identifier=totp_identifier, + parent_workflow_run_id=parent_workflow_run_id, ) session.add(workflow_run) await session.commit() @@ -1404,7 +1406,11 @@ class AgentDB: try: async with self.Session() as session: db_page = page - 1 # offset logic is 0 based - query = select(WorkflowRunModel).filter(WorkflowRunModel.organization_id == organization_id) + query = ( + select(WorkflowRunModel) + .filter(WorkflowRunModel.organization_id == organization_id) + .filter(WorkflowRunModel.parent_workflow_run_id.is_(None)) + ) if status: query = query.filter(WorkflowRunModel.status.in_(status)) query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size) @@ -2293,6 +2299,7 @@ class AgentDB: prompt: str | None = None, wait_sec: int | None = None, description: str | None = None, + block_workflow_run_id: str | None = None, ) -> WorkflowRunBlock: async with self.Session() as session: workflow_run_block = ( @@ -2331,6 +2338,8 @@ class AgentDB: workflow_run_block.wait_sec = wait_sec if description: workflow_run_block.description = description + if block_workflow_run_id: + workflow_run_block.block_workflow_run_id = block_workflow_run_id await session.commit() await session.refresh(workflow_run_block) else: diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index af10b018..3864d544 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -230,6 +230,8 @@ class WorkflowRunModel(Base): workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id) workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False) workflow_permanent_id = Column(String, nullable=False, index=True) + # workfow runs with parent_workflow_run_id are nested workflow runs which won't show up in the workflow run history + parent_workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True, index=True) organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False, index=True) status = Column(String, nullable=False) failure_reason = Column(String) @@ -505,6 +507,8 @@ class WorkflowRunBlockModel(Base): workflow_run_block_id = Column(String, primary_key=True, default=generate_workflow_run_block_id) workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=False) + # this is the inner workflow run id of the taskv2 block + block_workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True) parent_workflow_run_block_id = Column( String, ForeignKey("workflow_run_blocks.workflow_run_block_id"), nullable=True ) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index cd9f4b01..a3e4361a 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -202,6 +202,7 @@ def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: return WorkflowRun( workflow_run_id=workflow_run_model.workflow_run_id, workflow_permanent_id=workflow_run_model.workflow_permanent_id, + parent_workflow_run_id=workflow_run_model.parent_workflow_run_id, workflow_id=workflow_run_model.workflow_id, organization_id=workflow_run_model.organization_id, status=WorkflowRunStatus[workflow_run_model.status], @@ -382,6 +383,7 @@ def convert_to_workflow_run_block( block = WorkflowRunBlock( workflow_run_block_id=workflow_run_block_model.workflow_run_block_id, workflow_run_id=workflow_run_block_model.workflow_run_id, + block_workflow_run_id=workflow_run_block_model.block_workflow_run_id, organization_id=workflow_run_block_model.organization_id, parent_workflow_run_block_id=workflow_run_block_model.parent_workflow_run_block_id, description=workflow_run_block_model.description, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index fea8646a..78ea5485 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -63,6 +63,7 @@ from skyvern.forge.sdk.workflow.exceptions import ( InvalidTemplateWorkflowPermanentId, WorkflowParameterMissingRequiredValue, ) +from skyvern.forge.sdk.workflow.models.block import BlockType from skyvern.forge.sdk.workflow.models.workflow import ( RunWorkflowResponse, Workflow, @@ -767,24 +768,7 @@ async def get_workflow_run_timeline( page_size: int = Query(20, ge=1), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> list[WorkflowRunTimeline]: - # get observer cruise by workflow run id - observer_cruise_obj = await app.DATABASE.get_observer_cruise_by_workflow_run_id( - workflow_run_id=workflow_run_id, - organization_id=current_org.organization_id, - ) - # get all the workflow run blocks - workflow_run_block_timeline = await app.WORKFLOW_SERVICE.get_workflow_run_timeline( - workflow_run_id=workflow_run_id, - organization_id=current_org.organization_id, - ) - if observer_cruise_obj and observer_cruise_obj.observer_cruise_id: - observer_thought_timeline = await observer_service.get_observer_thought_timelines( - observer_cruise_id=observer_cruise_obj.observer_cruise_id, - organization_id=current_org.organization_id, - ) - workflow_run_block_timeline.extend(observer_thought_timeline) - workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True) - return workflow_run_block_timeline + return await _flatten_workflow_run_timeline(current_org.organization_id, workflow_run_id) @base_router.get( @@ -1310,3 +1294,52 @@ async def close_browser_session( status_code=200, media_type="application/json", ) + + +async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id: str) -> list[WorkflowRunTimeline]: + """ + Get the timeline workflow runs including the nested workflow runs in a flattened list + """ + + # get observer task by workflow run id + observer_task_obj = await app.DATABASE.get_observer_cruise_by_workflow_run_id( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + # get all the workflow run blocks + workflow_run_block_timeline = await app.WORKFLOW_SERVICE.get_workflow_run_timeline( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + # loop through the run block timeline, find the task_v2 blocks, flatten the timeline for task_v2 + final_workflow_run_block_timeline = [] + for timeline in workflow_run_block_timeline: + if not timeline.block: + continue + if timeline.block.block_type != BlockType.TaskV2: + # flatten the timeline for task_v2 + final_workflow_run_block_timeline.append(timeline) + continue + if not timeline.block.block_workflow_run_id: + LOG.error( + "Block workflow run id is not set for task_v2 block", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + observer_cruise_id=observer_task_obj.observer_cruise_id if observer_task_obj else None, + ) + continue + # in the future if we want to nested taskv2 shows up as a nested block, we should not flatten the timeline + workflow_blocks = await _flatten_workflow_run_timeline( + organization_id=organization_id, + workflow_run_id=timeline.block.block_workflow_run_id, + ) + final_workflow_run_block_timeline.extend(workflow_blocks) + + if observer_task_obj and observer_task_obj.observer_cruise_id: + observer_thought_timeline = await observer_service.get_observer_thought_timelines( + observer_cruise_id=observer_task_obj.observer_cruise_id, + organization_id=organization_id, + ) + final_workflow_run_block_timeline.extend(observer_thought_timeline) + final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True) + return final_workflow_run_block_timeline diff --git a/skyvern/forge/sdk/schemas/workflow_runs.py b/skyvern/forge/sdk/schemas/workflow_runs.py index bc4c8540..4bbd5a66 100644 --- a/skyvern/forge/sdk/schemas/workflow_runs.py +++ b/skyvern/forge/sdk/schemas/workflow_runs.py @@ -13,6 +13,7 @@ from skyvern.webeye.actions.actions import Action class WorkflowRunBlock(BaseModel): workflow_run_block_id: str + block_workflow_run_id: str | None = None workflow_run_id: str organization_id: str | None = None description: str | None = None diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index 196ba4c9..06a7e2fa 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -94,6 +94,7 @@ async def initialize_observer_task( totp_verification_url: str | None = None, webhook_callback_url: str | None = None, publish_workflow: bool = False, + parent_workflow_run_id: str | None = None, ) -> ObserverTask: observer_task = await app.DATABASE.create_observer_cruise( prompt=user_prompt, @@ -148,6 +149,7 @@ async def initialize_observer_task( organization_id=organization.organization_id, version=None, max_steps_override=max_steps_override, + parent_workflow_run_id=parent_workflow_run_id, ) except Exception: LOG.error("Failed to setup cruise workflow run", exc_info=True) @@ -501,7 +503,10 @@ async def run_observer_task_helper( break # generate the extraction task - block_result = await block.execute_safe(workflow_run_id=workflow_run_id, organization_id=organization_id) + block_result = await block.execute_safe( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) task_history_record["status"] = str(block_result.status) if block_result.failure_reason: task_history_record["reason"] = block_result.failure_reason diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index c2862ed3..83b64db6 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -48,7 +48,8 @@ from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core.validators import prepend_scheme_and_validate_url from skyvern.forge.sdk.db.enums import TaskType -from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus +from skyvern.forge.sdk.schemas.observers import ObserverTaskStatus +from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskOutput, TaskStatus from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext from skyvern.forge.sdk.workflow.exceptions import ( FailedToFormatJinjaStyleParameter, @@ -71,6 +72,7 @@ LOG = structlog.get_logger() class BlockType(StrEnum): TASK = "task" + TaskV2 = "task_v2" FOR_LOOP = "for_loop" CODE = "code" TEXT_PROMPT = "text_prompt" @@ -2072,6 +2074,80 @@ class UrlBlock(BaseTaskBlock): url: str +# observer block +class TaskV2Block(Block): + block_type: Literal[BlockType.TaskV2] = BlockType.TaskV2 + prompt: str + url: str | None = None + totp_verification_url: str | None = None + totp_identifier: str | None = None + max_iterations: int = 10 + + def get_all_parameters( + self, + workflow_run_id: str, + ) -> list[PARAMETER_TYPE]: + return [] + + async def execute( + self, + workflow_run_id: str, + workflow_run_block_id: str, + organization_id: str | None = None, + browser_session_id: str | None = None, + **kwargs: dict, + ) -> BlockResult: + from skyvern.forge.sdk.services import observer_service + from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus + + if not organization_id: + raise ValueError("Running TaskV2Block requires organization_id") + + organization = await app.DATABASE.get_organization(organization_id) + if not organization: + raise ValueError(f"Organization not found {organization_id}") + observer_task = await observer_service.initialize_observer_task( + organization, + user_prompt=self.prompt, + user_url=self.url, + parent_workflow_run_id=workflow_run_id, + proxy_location=ProxyLocation.NONE, + ) + await app.DATABASE.update_observer_cruise( + observer_task.observer_cruise_id, status=ObserverTaskStatus.queued, organization_id=organization_id + ) + if observer_task.workflow_run_id: + await app.DATABASE.update_workflow_run( + workflow_run_id=observer_task.workflow_run_id, + status=WorkflowRunStatus.queued, + ) + await app.DATABASE.update_workflow_run_block( + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + block_workflow_run_id=observer_task.workflow_run_id, + ) + + observer_task = await observer_service.run_observer_task( + organization=organization, + observer_cruise_id=observer_task.observer_cruise_id, + request_id=None, + max_iterations_override=self.max_iterations, + browser_session_id=browser_session_id, + ) + result_dict = None + if observer_task: + result_dict = observer_task.output + + return await self.build_block_result( + success=True, + failure_reason=None, + output_parameter_value=result_dict, + status=BlockStatus.completed, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + ) + + BlockSubclasses = Union[ ForLoopBlock, TaskBlock, @@ -2090,5 +2166,6 @@ BlockSubclasses = Union[ WaitBlock, FileDownloadBlock, UrlBlock, + TaskV2Block, ] BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index d3f02e27..5fcf8e68 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -108,6 +108,7 @@ class WorkflowRun(BaseModel): totp_verification_url: str | None = None totp_identifier: str | None = None failure_reason: str | None = None + parent_workflow_run_id: str | None = None created_at: datetime modified_at: datetime diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index 00fb87ba..50e4e669 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -323,6 +323,15 @@ class UrlBlockYAML(BlockYAML): url: str +class TaskV2BlockYAML(BlockYAML): + block_type: Literal[BlockType.TaskV2] = BlockType.TaskV2 # type: ignore + prompt: str + url: str | None = None + totp_verification_url: str | None = None + totp_identifier: str | None = None + max_iterations: int = 10 + + PARAMETER_YAML_SUBCLASSES = ( AWSSecretParameterYAML | BitwardenLoginCredentialParameterYAML @@ -352,6 +361,7 @@ BLOCK_YAML_SUBCLASSES = ( | FileDownloadBlockYAML | UrlBlockYAML | PDFParserBlockYAML + | TaskV2BlockYAML ) BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 07b0027a..894e16c4 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -49,6 +49,7 @@ from skyvern.forge.sdk.workflow.models.block import ( PDFParserBlock, SendEmailBlock, TaskBlock, + TaskV2Block, TextPromptBlock, UploadToS3Block, ValidationBlock, @@ -96,6 +97,7 @@ class WorkflowService: is_template_workflow: bool = False, version: int | None = None, max_steps_override: int | None = None, + parent_workflow_run_id: str | None = None, ) -> WorkflowRun: """ Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing @@ -127,6 +129,7 @@ class WorkflowService: workflow_permanent_id=workflow_permanent_id, workflow_id=workflow_id, organization_id=organization_id, + parent_workflow_run_id=parent_workflow_run_id, ) LOG.info( f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}", @@ -625,7 +628,12 @@ class WorkflowService: ) async def create_workflow_run( - self, workflow_request: WorkflowRequestBody, workflow_permanent_id: str, workflow_id: str, organization_id: str + self, + workflow_request: WorkflowRequestBody, + workflow_permanent_id: str, + workflow_id: str, + organization_id: str, + parent_workflow_run_id: str | None = None, ) -> WorkflowRun: return await app.DATABASE.create_workflow_run( workflow_permanent_id=workflow_permanent_id, @@ -635,6 +643,7 @@ class WorkflowService: webhook_callback_url=workflow_request.webhook_callback_url, totp_verification_url=workflow_request.totp_verification_url, totp_identifier=workflow_request.totp_identifier, + parent_workflow_run_id=parent_workflow_run_id, ) async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None: @@ -1731,6 +1740,16 @@ class WorkflowService: cache_actions=block_yaml.cache_actions, complete_on_download=True, ) + elif block_yaml.block_type == BlockType.TaskV2: + return TaskV2Block( + label=block_yaml.label, + prompt=block_yaml.prompt, + url=block_yaml.url, + totp_verification_url=block_yaml.totp_verification_url, + totp_identifier=block_yaml.totp_identifier, + max_iterations=block_yaml.max_iterations, + output_parameter=output_parameter, + ) raise ValueError(f"Invalid block type {block_yaml.block_type}")