shu/add_workflow_runs_api (#2063)

This commit is contained in:
Shuchang Zheng
2025-04-01 15:52:35 -04:00
committed by GitHub
parent f774135049
commit e26b816f67
7 changed files with 166 additions and 51 deletions

View File

@@ -14,7 +14,7 @@ from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.files import create_folder_if_not_exist from skyvern.forge.sdk.api.files import create_folder_if_not_exist
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request
from skyvern.forge.sdk.schemas.tasks import TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponse, WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponseBase, WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.runs import ProxyLocation
@@ -71,7 +71,7 @@ class SkyvernClient:
assert response.status_code == 200, f"Expected to get task response status 200, but got {response.status_code}" assert response.status_code == 200, f"Expected to get task response status 200, but got {response.status_code}"
return TaskResponse(**response.json()) return TaskResponse(**response.json())
async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponse: async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponseBase:
url = f"{self.base_url}/workflows/{workflow_pid}/runs/{workflow_run_id}" url = f"{self.base_url}/workflows/{workflow_pid}/runs/{workflow_run_id}"
headers = {"x-api-key": self.credentials} headers = {"x-api-key": self.credentials}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@@ -79,7 +79,7 @@ class SkyvernClient:
assert response.status_code == 200, ( assert response.status_code == 200, (
f"Expected to get workflow run response status 200, but got {response.status_code}" f"Expected to get workflow run response status 200, but got {response.status_code}"
) )
return WorkflowRunResponse(**response.json()) return WorkflowRunResponseBase(**response.json())
class Evaluator: class Evaluator:

View File

@@ -15,6 +15,6 @@ setup_logger()
from skyvern.forge import app # noqa: E402, F401 from skyvern.forge import app # noqa: E402, F401
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402 from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402 from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponse"] __all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"]

View File

@@ -55,13 +55,21 @@ from skyvern.forge.sdk.workflow.models.workflow import (
Workflow, Workflow,
WorkflowRequestBody, WorkflowRequestBody,
WorkflowRun, WorkflowRun,
WorkflowRunResponse, WorkflowRunResponseBase,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowStatus, WorkflowStatus,
) )
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse from skyvern.schemas.runs import (
from skyvern.services import run_service, task_v1_service, task_v2_service RunEngine,
RunResponse,
RunType,
TaskRunRequest,
TaskRunResponse,
WorkflowRunRequest,
WorkflowRunResponse,
)
from skyvern.services import run_service, task_v1_service, task_v2_service, workflow_service
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -620,7 +628,7 @@ async def get_actions(
tags=["agent"], tags=["agent"],
openapi_extra={ openapi_extra={
"x-fern-sdk-group-name": "agent", "x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "run_workflow", "x-fern-sdk-method-name": "run_workflow_legacy",
}, },
) )
@legacy_base_router.post( @legacy_base_router.post(
@@ -628,7 +636,7 @@ async def get_actions(
response_model=RunWorkflowResponse, response_model=RunWorkflowResponse,
include_in_schema=False, include_in_schema=False,
) )
async def run_workflow( async def run_workflow_legacy(
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
workflow_id: str, # this is the workflow_permanent_id internally workflow_id: str, # this is the workflow_permanent_id internally
@@ -647,42 +655,19 @@ async def run_workflow(
browser_session_id=workflow_request.browser_session_id, browser_session_id=workflow_request.browser_session_id,
) )
if template: workflow_run = await workflow_service.run_workflow(
if workflow_id not in await app.STORAGE.retrieve_global_workflows(): workflow_id=workflow_id,
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id) organization_id=current_org.organization_id,
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_request=workflow_request, workflow_request=workflow_request,
workflow_permanent_id=workflow_id, template=template,
organization_id=current_org.organization_id,
version=version, version=version,
max_steps_override=x_max_steps_override, max_steps=x_max_steps_override,
is_template_workflow=template, api_key=x_api_key,
) request_id=request_id,
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_id,
organization_id=None if template else current_org.organization_id,
version=version,
)
await app.DATABASE.create_task_run(
task_run_type=RunType.workflow_run,
organization_id=current_org.organization_id,
run_id=workflow_run.workflow_run_id,
title=workflow.title,
)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_workflow(
request=request, request=request,
background_tasks=background_tasks, background_tasks=background_tasks,
organization_id=current_org.organization_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=x_max_steps_override,
browser_session_id=workflow_request.browser_session_id,
api_key=x_api_key,
) )
return RunWorkflowResponse( return RunWorkflowResponse(
workflow_id=workflow_id, workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
@@ -806,7 +791,7 @@ async def get_workflow_run_timeline(
@legacy_base_router.get( @legacy_base_router.get(
"/workflows/runs/{workflow_run_id}", "/workflows/runs/{workflow_run_id}",
response_model=WorkflowRunResponse, response_model=WorkflowRunResponseBase,
tags=["agent"], tags=["agent"],
openapi_extra={ openapi_extra={
"x-fern-sdk-group-name": "agent", "x-fern-sdk-group-name": "agent",
@@ -815,13 +800,13 @@ async def get_workflow_run_timeline(
) )
@legacy_base_router.get( @legacy_base_router.get(
"/workflows/runs/{workflow_run_id}/", "/workflows/runs/{workflow_run_id}/",
response_model=WorkflowRunResponse, response_model=WorkflowRunResponseBase,
include_in_schema=False, include_in_schema=False,
) )
async def get_workflow_run( async def get_workflow_run(
workflow_run_id: str, workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowRunResponse: ) -> WorkflowRunResponseBase:
analytics.capture("skyvern-oss-agent-workflow-run-get") analytics.capture("skyvern-oss-agent-workflow-run-get")
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id( return await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id(
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
@@ -1385,7 +1370,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
@base_router.post( @base_router.post(
"/tasks", "/tasks/run",
tags=["Agent"], tags=["Agent"],
openapi_extra={ openapi_extra={
"x-fern-sdk-group-name": "agent", "x-fern-sdk-group-name": "agent",
@@ -1398,7 +1383,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
400: {"description": "Invalid agent engine"}, 400: {"description": "Invalid agent engine"},
}, },
) )
@base_router.post("/tasks/", include_in_schema=False) @base_router.post("/tasks/run/", include_in_schema=False)
async def run_task( async def run_task(
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@@ -1523,3 +1508,69 @@ async def run_task(
), ),
) )
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}") raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
@base_router.post(
"/workflows/run",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "run_workflow",
},
description="Run a workflow",
summary="Run a workflow",
responses={
200: {"description": "Successfully run workflow"},
400: {"description": "Invalid workflow run request"},
},
)
@base_router.post("/workflows/run/", include_in_schema=False)
async def run_workflow(
request: Request,
background_tasks: BackgroundTasks,
workflow_run_request: WorkflowRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
template: bool = Query(False),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> WorkflowRunResponse:
analytics.capture("skyvern-oss-run-workflow")
await PermissionCheckerFactory.get_instance().check(
current_org, browser_session_id=workflow_run_request.browser_session_id
)
workflow_id = workflow_run_request.workflow_id
context = skyvern_context.ensure_context()
request_id = context.request_id
legacy_workflow_request = WorkflowRequestBody(
data=workflow_run_request.parameters,
proxy_location=workflow_run_request.proxy_location,
webhook_callback_url=workflow_run_request.webhook_url,
totp_identifier=workflow_run_request.totp_identifier,
totp_url=workflow_run_request.totp_url,
browser_session_id=workflow_run_request.browser_session_id,
)
workflow_run = await workflow_service.run_workflow(
workflow_id=workflow_id,
organization_id=current_org.organization_id,
workflow_request=legacy_workflow_request,
template=template,
version=None,
max_steps=x_max_steps_override,
api_key=x_api_key,
request_id=request_id,
request=request,
background_tasks=background_tasks,
)
return WorkflowRunResponse(
run_id=workflow_run.workflow_run_id,
run_type=RunType.workflow_run,
status=str(workflow_run.status),
output=workflow_run.output,
failure_reason=workflow_run.failure_reason,
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
run_request=workflow_run_request,
downloaded_files=workflow_run.downloaded_files,
recording_url=workflow_run.recording_url,
)

View File

@@ -133,7 +133,7 @@ class WorkflowRunOutputParameter(BaseModel):
created_at: datetime created_at: datetime
class WorkflowRunResponse(BaseModel): class WorkflowRunResponseBase(BaseModel):
workflow_id: str workflow_id: str
workflow_run_id: str workflow_run_id: str
status: WorkflowRunStatus status: WorkflowRunStatus

View File

@@ -80,7 +80,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRun, WorkflowRun,
WorkflowRunOutputParameter, WorkflowRunOutputParameter,
WorkflowRunParameter, WorkflowRunParameter,
WorkflowRunResponse, WorkflowRunResponseBase,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowStatus, WorkflowStatus,
) )
@@ -958,7 +958,7 @@ class WorkflowService:
workflow_run_id: str, workflow_run_id: str,
organization_id: str, organization_id: str,
include_cost: bool = False, include_cost: bool = False,
) -> WorkflowRunResponse: ) -> WorkflowRunResponseBase:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
if workflow_run is None: if workflow_run is None:
LOG.error(f"Workflow run {workflow_run_id} not found") LOG.error(f"Workflow run {workflow_run_id} not found")
@@ -977,7 +977,7 @@ class WorkflowService:
workflow_run_id: str, workflow_run_id: str,
organization_id: str, organization_id: str,
include_cost: bool = False, include_cost: bool = False,
) -> WorkflowRunResponse: ) -> WorkflowRunResponseBase:
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id) workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id)
if workflow is None: if workflow is None:
LOG.error(f"Workflow {workflow_permanent_id} not found") LOG.error(f"Workflow {workflow_permanent_id} not found")
@@ -1073,7 +1073,7 @@ class WorkflowService:
# successful steps are the ones that have a status of completed and the total count of unique step.order # successful steps are the ones that have a status of completed and the total count of unique step.order
successful_steps = [step for step in workflow_run_steps if step.status == StepStatus.completed] successful_steps = [step for step in workflow_run_steps if step.status == StepStatus.completed]
total_cost = 0.1 * (len(successful_steps) + len(text_prompt_blocks)) total_cost = 0.1 * (len(successful_steps) + len(text_prompt_blocks))
return WorkflowRunResponse( return WorkflowRunResponseBase(
workflow_id=workflow.workflow_permanent_id, workflow_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
status=workflow_run.status, status=workflow_run.status,

View File

@@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.utils.url_validators import validate_url from skyvern.utils.url_validators import validate_url
@@ -206,6 +207,8 @@ class BaseRunResponse(BaseModel):
output: dict | list | str | None = Field( output: dict | list | str | None = Field(
default=None, description="Output data from the run, if any. Format depends on the schema in the input" default=None, description="Output data from the run, if any. Format depends on the schema in the input"
) )
downloaded_files: list[FileInfo] | None = Field(default=None, description="List of files downloaded during the run")
recording_url: str | None = Field(default=None, description="URL to the recording of the run")
failure_reason: str | None = Field(default=None, description="Reason for failure if the run failed") failure_reason: str | None = Field(default=None, description="Reason for failure if the run failed")
created_at: datetime = Field(description="Timestamp when this run was created") created_at: datetime = Field(description="Timestamp when this run was created")
modified_at: datetime = Field(description="Timestamp when this run was last modified") modified_at: datetime = Field(description="Timestamp when this run was last modified")

View File

@@ -0,0 +1,61 @@
import structlog
from fastapi import BackgroundTasks, Request
from skyvern.forge import app
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.workflow.exceptions import InvalidTemplateWorkflowPermanentId
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRun
from skyvern.schemas.runs import RunType
LOG = structlog.get_logger(__name__)
async def run_workflow(
workflow_id: str,
organization_id: str,
workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body
template: bool = False,
version: int | None = None,
max_steps: int | None = None,
api_key: str | None = None,
request_id: str | None = None,
request: Request | None = None,
background_tasks: BackgroundTasks | None = None,
) -> WorkflowRun:
if template:
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_request=workflow_request,
workflow_permanent_id=workflow_id,
organization_id=organization_id,
version=version,
max_steps_override=max_steps,
is_template_workflow=template,
)
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_id,
organization_id=None if template else organization_id,
version=version,
)
await app.DATABASE.create_task_run(
task_run_type=RunType.workflow_run,
organization_id=organization_id,
run_id=workflow_run.workflow_run_id,
title=workflow.title,
)
if max_steps:
LOG.info("Overriding max steps per run", max_steps_override=max_steps)
await AsyncExecutorFactory.get_executor().execute_workflow(
request=request,
background_tasks=background_tasks,
organization_id=organization_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=max_steps,
browser_session_id=workflow_request.browser_session_id,
api_key=api_key,
)
return workflow_run