diff --git a/evaluation/core/__init__.py b/evaluation/core/__init__.py index ac72c71d..9a815e4f 100644 --- a/evaluation/core/__init__.py +++ b/evaluation/core/__init__.py @@ -14,7 +14,7 @@ from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.api.files import create_folder_if_not_exist from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request from skyvern.forge.sdk.schemas.tasks import TaskRequest, TaskResponse, TaskStatus -from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponse, WorkflowRunStatus +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponseBase, WorkflowRunStatus from skyvern.schemas.runs import ProxyLocation @@ -71,7 +71,7 @@ class SkyvernClient: assert response.status_code == 200, f"Expected to get task response status 200, but got {response.status_code}" return TaskResponse(**response.json()) - async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponse: + async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponseBase: url = f"{self.base_url}/workflows/{workflow_pid}/runs/{workflow_run_id}" headers = {"x-api-key": self.credentials} async with httpx.AsyncClient() as client: @@ -79,7 +79,7 @@ class SkyvernClient: assert response.status_code == 200, ( f"Expected to get workflow run response status 200, but got {response.status_code}" ) - return WorkflowRunResponse(**response.json()) + return WorkflowRunResponseBase(**response.json()) class Evaluator: diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 43eb4242..84c82dbf 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -15,6 +15,6 @@ setup_logger() from skyvern.forge import app # noqa: E402, F401 from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402 -from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402 +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402 -__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponse"] +__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"] diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index eb5e9427..a2ba2006 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -55,13 +55,21 @@ from skyvern.forge.sdk.workflow.models.workflow import ( Workflow, WorkflowRequestBody, WorkflowRun, - WorkflowRunResponse, + WorkflowRunResponseBase, WorkflowRunStatus, WorkflowStatus, ) from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest -from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse -from skyvern.services import run_service, task_v1_service, task_v2_service +from skyvern.schemas.runs import ( + RunEngine, + RunResponse, + RunType, + TaskRunRequest, + TaskRunResponse, + WorkflowRunRequest, + WorkflowRunResponse, +) +from skyvern.services import run_service, task_v1_service, task_v2_service, workflow_service from skyvern.webeye.actions.actions import Action LOG = structlog.get_logger() @@ -620,7 +628,7 @@ async def get_actions( tags=["agent"], openapi_extra={ "x-fern-sdk-group-name": "agent", - "x-fern-sdk-method-name": "run_workflow", + "x-fern-sdk-method-name": "run_workflow_legacy", }, ) @legacy_base_router.post( @@ -628,7 +636,7 @@ async def get_actions( response_model=RunWorkflowResponse, include_in_schema=False, ) -async def run_workflow( +async def run_workflow_legacy( request: Request, background_tasks: BackgroundTasks, workflow_id: str, # this is the workflow_permanent_id internally @@ -647,42 +655,19 @@ async def run_workflow( browser_session_id=workflow_request.browser_session_id, ) - if template: - if workflow_id not in await app.STORAGE.retrieve_global_workflows(): - raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id) - - workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run( - request_id=request_id, + workflow_run = await workflow_service.run_workflow( + workflow_id=workflow_id, + organization_id=current_org.organization_id, workflow_request=workflow_request, - workflow_permanent_id=workflow_id, - organization_id=current_org.organization_id, + template=template, version=version, - max_steps_override=x_max_steps_override, - is_template_workflow=template, - ) - workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( - workflow_permanent_id=workflow_id, - organization_id=None if template else current_org.organization_id, - version=version, - ) - await app.DATABASE.create_task_run( - task_run_type=RunType.workflow_run, - organization_id=current_org.organization_id, - run_id=workflow_run.workflow_run_id, - title=workflow.title, - ) - if x_max_steps_override: - LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) - await AsyncExecutorFactory.get_executor().execute_workflow( + max_steps=x_max_steps_override, + api_key=x_api_key, + request_id=request_id, request=request, background_tasks=background_tasks, - organization_id=current_org.organization_id, - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.workflow_run_id, - max_steps_override=x_max_steps_override, - browser_session_id=workflow_request.browser_session_id, - api_key=x_api_key, ) + return RunWorkflowResponse( workflow_id=workflow_id, workflow_run_id=workflow_run.workflow_run_id, @@ -806,7 +791,7 @@ async def get_workflow_run_timeline( @legacy_base_router.get( "/workflows/runs/{workflow_run_id}", - response_model=WorkflowRunResponse, + response_model=WorkflowRunResponseBase, tags=["agent"], openapi_extra={ "x-fern-sdk-group-name": "agent", @@ -815,13 +800,13 @@ async def get_workflow_run_timeline( ) @legacy_base_router.get( "/workflows/runs/{workflow_run_id}/", - response_model=WorkflowRunResponse, + response_model=WorkflowRunResponseBase, include_in_schema=False, ) async def get_workflow_run( workflow_run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), -) -> WorkflowRunResponse: +) -> WorkflowRunResponseBase: analytics.capture("skyvern-oss-agent-workflow-run-get") return await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id( workflow_run_id=workflow_run_id, @@ -1385,7 +1370,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id: @base_router.post( - "/tasks", + "/tasks/run", tags=["Agent"], openapi_extra={ "x-fern-sdk-group-name": "agent", @@ -1398,7 +1383,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id: 400: {"description": "Invalid agent engine"}, }, ) -@base_router.post("/tasks/", include_in_schema=False) +@base_router.post("/tasks/run/", include_in_schema=False) async def run_task( request: Request, background_tasks: BackgroundTasks, @@ -1523,3 +1508,69 @@ async def run_task( ), ) raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}") + + +@base_router.post( + "/workflows/run", + tags=["Agent"], + openapi_extra={ + "x-fern-sdk-group-name": "agent", + "x-fern-sdk-method-name": "run_workflow", + }, + description="Run a workflow", + summary="Run a workflow", + responses={ + 200: {"description": "Successfully run workflow"}, + 400: {"description": "Invalid workflow run request"}, + }, +) +@base_router.post("/workflows/run/", include_in_schema=False) +async def run_workflow( + request: Request, + background_tasks: BackgroundTasks, + workflow_run_request: WorkflowRunRequest, + current_org: Organization = Depends(org_auth_service.get_current_org), + template: bool = Query(False), + x_api_key: Annotated[str | None, Header()] = None, + x_max_steps_override: Annotated[int | None, Header()] = None, +) -> WorkflowRunResponse: + analytics.capture("skyvern-oss-run-workflow") + await PermissionCheckerFactory.get_instance().check( + current_org, browser_session_id=workflow_run_request.browser_session_id + ) + workflow_id = workflow_run_request.workflow_id + context = skyvern_context.ensure_context() + request_id = context.request_id + legacy_workflow_request = WorkflowRequestBody( + data=workflow_run_request.parameters, + proxy_location=workflow_run_request.proxy_location, + webhook_callback_url=workflow_run_request.webhook_url, + totp_identifier=workflow_run_request.totp_identifier, + totp_url=workflow_run_request.totp_url, + browser_session_id=workflow_run_request.browser_session_id, + ) + workflow_run = await workflow_service.run_workflow( + workflow_id=workflow_id, + organization_id=current_org.organization_id, + workflow_request=legacy_workflow_request, + template=template, + version=None, + max_steps=x_max_steps_override, + api_key=x_api_key, + request_id=request_id, + request=request, + background_tasks=background_tasks, + ) + + return WorkflowRunResponse( + run_id=workflow_run.workflow_run_id, + run_type=RunType.workflow_run, + status=str(workflow_run.status), + output=workflow_run.output, + failure_reason=workflow_run.failure_reason, + created_at=workflow_run.created_at, + modified_at=workflow_run.modified_at, + run_request=workflow_run_request, + downloaded_files=workflow_run.downloaded_files, + recording_url=workflow_run.recording_url, + ) diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 97e2cfa4..f8d4a466 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -133,7 +133,7 @@ class WorkflowRunOutputParameter(BaseModel): created_at: datetime -class WorkflowRunResponse(BaseModel): +class WorkflowRunResponseBase(BaseModel): workflow_id: str workflow_run_id: str status: WorkflowRunStatus diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 856faa21..3e6ac2cb 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -80,7 +80,7 @@ from skyvern.forge.sdk.workflow.models.workflow import ( WorkflowRun, WorkflowRunOutputParameter, WorkflowRunParameter, - WorkflowRunResponse, + WorkflowRunResponseBase, WorkflowRunStatus, WorkflowStatus, ) @@ -958,7 +958,7 @@ class WorkflowService: workflow_run_id: str, organization_id: str, include_cost: bool = False, - ) -> WorkflowRunResponse: + ) -> WorkflowRunResponseBase: workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) if workflow_run is None: LOG.error(f"Workflow run {workflow_run_id} not found") @@ -977,7 +977,7 @@ class WorkflowService: workflow_run_id: str, organization_id: str, include_cost: bool = False, - ) -> WorkflowRunResponse: + ) -> WorkflowRunResponseBase: workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id) if workflow is None: LOG.error(f"Workflow {workflow_permanent_id} not found") @@ -1073,7 +1073,7 @@ class WorkflowService: # successful steps are the ones that have a status of completed and the total count of unique step.order successful_steps = [step for step in workflow_run_steps if step.status == StepStatus.completed] total_cost = 0.1 * (len(successful_steps) + len(text_prompt_blocks)) - return WorkflowRunResponse( + return WorkflowRunResponseBase( workflow_id=workflow.workflow_permanent_id, workflow_run_id=workflow_run_id, status=workflow_run.status, diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index 54e5de0f..9151a676 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo from pydantic import BaseModel, Field, field_validator +from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.utils.url_validators import validate_url @@ -206,6 +207,8 @@ class BaseRunResponse(BaseModel): output: dict | list | str | None = Field( default=None, description="Output data from the run, if any. Format depends on the schema in the input" ) + downloaded_files: list[FileInfo] | None = Field(default=None, description="List of files downloaded during the run") + recording_url: str | None = Field(default=None, description="URL to the recording of the run") failure_reason: str | None = Field(default=None, description="Reason for failure if the run failed") created_at: datetime = Field(description="Timestamp when this run was created") modified_at: datetime = Field(description="Timestamp when this run was last modified") diff --git a/skyvern/services/workflow_service.py b/skyvern/services/workflow_service.py new file mode 100644 index 00000000..f12d7aeb --- /dev/null +++ b/skyvern/services/workflow_service.py @@ -0,0 +1,61 @@ +import structlog +from fastapi import BackgroundTasks, Request + +from skyvern.forge import app +from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory +from skyvern.forge.sdk.workflow.exceptions import InvalidTemplateWorkflowPermanentId +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRun +from skyvern.schemas.runs import RunType + +LOG = structlog.get_logger(__name__) + + +async def run_workflow( + workflow_id: str, + organization_id: str, + workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body + template: bool = False, + version: int | None = None, + max_steps: int | None = None, + api_key: str | None = None, + request_id: str | None = None, + request: Request | None = None, + background_tasks: BackgroundTasks | None = None, +) -> WorkflowRun: + if template: + if workflow_id not in await app.STORAGE.retrieve_global_workflows(): + raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id) + + workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run( + request_id=request_id, + workflow_request=workflow_request, + workflow_permanent_id=workflow_id, + organization_id=organization_id, + version=version, + max_steps_override=max_steps, + is_template_workflow=template, + ) + workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_id, + organization_id=None if template else organization_id, + version=version, + ) + await app.DATABASE.create_task_run( + task_run_type=RunType.workflow_run, + organization_id=organization_id, + run_id=workflow_run.workflow_run_id, + title=workflow.title, + ) + if max_steps: + LOG.info("Overriding max steps per run", max_steps_override=max_steps) + await AsyncExecutorFactory.get_executor().execute_workflow( + request=request, + background_tasks=background_tasks, + organization_id=organization_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.workflow_run_id, + max_steps_override=max_steps, + browser_session_id=workflow_request.browser_session_id, + api_key=api_key, + ) + return workflow_run