diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 01275de6..a53681c8 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -128,6 +128,11 @@ class UnknownBlockType(SkyvernException): super().__init__(f"Unknown block type {block_type}") +class BlockNotFound(SkyvernException): + def __init__(self, block_label: str) -> None: + super().__init__(f"Block {block_label} not found") + + class WorkflowNotFound(SkyvernHTTPException): def __init__( self, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 594d02b8..1f826d0d 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1,3 +1,4 @@ +import asyncio from enum import Enum from typing import Annotated, Any @@ -76,6 +77,8 @@ from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.schemas.artifacts import EntityType, entity_type_to_param from skyvern.schemas.runs import ( CUA_ENGINES, + BlockRunRequest, + BlockRunResponse, RunEngine, RunResponse, RunType, @@ -85,7 +88,7 @@ from skyvern.schemas.runs import ( WorkflowRunResponse, ) from skyvern.schemas.workflows import WorkflowRequest -from skyvern.services import run_service, task_v1_service, task_v2_service, workflow_service +from skyvern.services import block_service, run_service, task_v1_service, task_v2_service, workflow_service from skyvern.webeye.actions.actions import Action LOG = structlog.get_logger() @@ -850,6 +853,57 @@ async def retry_run_webhook( await run_service.retry_run_webhook(run_id, organization_id=current_org.organization_id, api_key=x_api_key) +@base_router.post( + "/run/workflows/blocks", + include_in_schema=False, + response_model=BlockRunResponse, +) +async def run_block( + block_run_request: BlockRunRequest, + organization: Organization = Depends(org_auth_service.get_current_org), + template: bool = Query(False), + x_api_key: Annotated[str | None, Header()] = None, +) -> BlockRunResponse: + """ + Kick off the execution of one or more blocks in a workflow. Returns the + workflow_run_id. + """ + + workflow_run = await block_service.ensure_workflow_run( + organization=organization, + template=template, + workflow_permanent_id=block_run_request.workflow_id, + workflow_run_request=block_run_request, + ) + + browser_session_id = block_run_request.browser_session_id + + asyncio.create_task( + block_service.execute_blocks( + api_key=x_api_key or "", + block_labels=block_run_request.block_labels, + workflow_run_id=workflow_run.workflow_run_id, + organization=organization, + browser_session_id=browser_session_id, + ) + ) + + return BlockRunResponse( + block_labels=block_run_request.block_labels, + run_id=workflow_run.workflow_run_id, + run_type=RunType.workflow_run, + status=str(workflow_run.status), + output=None, + failure_reason=workflow_run.failure_reason, + created_at=workflow_run.created_at, + modified_at=workflow_run.modified_at, + run_request=block_run_request, + downloaded_files=None, + recording_url=None, + app_url=f"{settings.SKYVERN_APP_URL.rstrip('/')}/workflows/{workflow_run.workflow_permanent_id}/{workflow_run.workflow_run_id}", + ) + + ################# Legacy Endpoints ################# @legacy_base_router.post( "/webhook", diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 508c84b2..1ef5b50b 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -10,6 +10,7 @@ from skyvern import analytics from skyvern.config import settings from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT from skyvern.exceptions import ( + BlockNotFound, BrowserSessionNotFound, FailedToSendWebhook, InvalidCredentialId, @@ -252,6 +253,7 @@ class WorkflowService: workflow_run_id: str, api_key: str, organization: Organization, + block_labels: list[str] | None = None, browser_session_id: str | None = None, ) -> WorkflowRun: """Execute a workflow.""" @@ -326,8 +328,32 @@ class WorkflowService: ) return workflow_run + all_blocks = workflow.workflow_definition.blocks + + if block_labels and len(block_labels): + blocks: list[BlockTypeVar] = [] + all_labels = {block.label: block for block in all_blocks} + + for label in block_labels: + if label not in all_labels: + raise BlockNotFound(block_label=label) + + blocks.append(all_labels[label]) + + LOG.info( + "Executing workflow blocks via whitelist", + workflow_run_id=workflow_run.workflow_run_id, + block_cnt=len(blocks), + block_labels=block_labels, + ) + + else: + blocks = all_blocks + + if not blocks: + raise SkyvernException(f"No blocks found for the given block labels: {block_labels}") + # Execute workflow blocks - blocks = workflow.workflow_definition.blocks blocks_cnt = len(blocks) block_result = None for block_idx, block in enumerate(blocks): diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index 667d2bbe..bf9ada7f 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -357,6 +357,13 @@ class WorkflowRunRequest(BaseModel): return validate_url(url) +class BlockRunRequest(WorkflowRunRequest): + block_labels: list[str] = Field( + description="Labels of the blocks to execute", + examples=["block_1", "block_2"], + ) + + class BaseRunResponse(BaseModel): run_id: str = Field( description="Unique identifier for this run. Run ID starts with `tsk_` for task runs and `wr_` for workflow runs.", @@ -415,3 +422,7 @@ class WorkflowRunResponse(BaseRunResponse): RunResponse = Annotated[Union[TaskRunResponse, WorkflowRunResponse], Field(discriminator="run_type")] + + +class BlockRunResponse(WorkflowRunResponse): + block_labels: list[str] = Field(description="A whitelist of block labels; only these blocks will execute") diff --git a/skyvern/services/block_service.py b/skyvern/services/block_service.py new file mode 100644 index 00000000..5f24361a --- /dev/null +++ b/skyvern/services/block_service.py @@ -0,0 +1,72 @@ +import structlog + +from skyvern.forge import app +from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.schemas.organizations import Organization +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRun +from skyvern.schemas.runs import WorkflowRunRequest +from skyvern.services import workflow_service + +LOG = structlog.get_logger() + + +async def ensure_workflow_run( + organization: Organization, + template: bool, + workflow_permanent_id: str, + workflow_run_request: WorkflowRunRequest, + x_max_steps_override: int | None = None, +) -> WorkflowRun: + context = skyvern_context.ensure_context() + + legacy_workflow_request = WorkflowRequestBody( + data=workflow_run_request.parameters, + proxy_location=workflow_run_request.proxy_location, + webhook_callback_url=workflow_run_request.webhook_url, + totp_identifier=workflow_run_request.totp_identifier, + totp_verification_url=workflow_run_request.totp_url, + browser_session_id=workflow_run_request.browser_session_id, + max_screenshot_scrolls=workflow_run_request.max_screenshot_scrolls, + extra_http_headers=workflow_run_request.extra_http_headers, + ) + + workflow_run = await workflow_service.prepare_workflow( + workflow_id=workflow_permanent_id, + organization=organization, + workflow_request=legacy_workflow_request, + template=template, + version=None, + max_steps=x_max_steps_override, + request_id=context.request_id, + ) + + return workflow_run + + +async def execute_blocks( + api_key: str, + block_labels: list[str], + workflow_run_id: str, + organization: Organization, + browser_session_id: str | None = None, +) -> WorkflowRun: + """ + Runs one or more blocks of a workflow. + """ + + LOG.info( + "Executing block(s)", + organization_id=organization.organization_id, + workflow_run_id=workflow_run_id, + block_labels=block_labels, + ) + + workflow_run = await app.WORKFLOW_SERVICE.execute_workflow( + workflow_run_id=workflow_run_id, + api_key=api_key, + organization=organization, + block_labels=block_labels, + browser_session_id=browser_session_id, + ) + + return workflow_run diff --git a/skyvern/services/workflow_service.py b/skyvern/services/workflow_service.py index e7200e68..e403672d 100644 --- a/skyvern/services/workflow_service.py +++ b/skyvern/services/workflow_service.py @@ -12,18 +12,18 @@ from skyvern.schemas.runs import RunStatus, RunType, WorkflowRunRequest, Workflo LOG = structlog.get_logger(__name__) -async def run_workflow( +async def prepare_workflow( workflow_id: str, organization: Organization, workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body template: bool = False, version: int | None = None, max_steps: int | None = None, - api_key: str | None = None, request_id: str | None = None, - request: Request | None = None, - background_tasks: BackgroundTasks | None = None, ) -> WorkflowRun: + """ + Prepare a workflow to be run. + """ if template: if workflow_id not in await app.STORAGE.retrieve_global_workflows(): raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id) @@ -37,19 +37,48 @@ async def run_workflow( max_steps_override=max_steps, is_template_workflow=template, ) + workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( workflow_permanent_id=workflow_id, organization_id=None if template else organization.organization_id, version=version, ) + await app.DATABASE.create_task_run( task_run_type=RunType.workflow_run, organization_id=organization.organization_id, run_id=workflow_run.workflow_run_id, title=workflow.title, ) + if max_steps: LOG.info("Overriding max steps per run", max_steps_override=max_steps) + + return workflow_run + + +async def run_workflow( + workflow_id: str, + organization: Organization, + workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body + template: bool = False, + version: int | None = None, + max_steps: int | None = None, + api_key: str | None = None, + request_id: str | None = None, + request: Request | None = None, + background_tasks: BackgroundTasks | None = None, +) -> WorkflowRun: + workflow_run = await prepare_workflow( + workflow_id=workflow_id, + organization=organization, + workflow_request=workflow_request, + template=template, + version=version, + max_steps=max_steps, + request_id=request_id, + ) + await AsyncExecutorFactory.get_executor().execute_workflow( request=request, background_tasks=background_tasks, @@ -60,6 +89,7 @@ async def run_workflow( browser_session_id=workflow_request.browser_session_id, api_key=api_key, ) + return workflow_run