shu/add_workflow_runs_api (#2063)

This commit is contained in:
Shuchang Zheng
2025-04-01 15:52:35 -04:00
committed by GitHub
parent f774135049
commit e26b816f67
7 changed files with 166 additions and 51 deletions

View File

@@ -55,13 +55,21 @@ from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowRequestBody,
WorkflowRun,
WorkflowRunResponse,
WorkflowRunResponseBase,
WorkflowRunStatus,
WorkflowStatus,
)
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.schemas.runs import (
RunEngine,
RunResponse,
RunType,
TaskRunRequest,
TaskRunResponse,
WorkflowRunRequest,
WorkflowRunResponse,
)
from skyvern.services import run_service, task_v1_service, task_v2_service, workflow_service
from skyvern.webeye.actions.actions import Action
LOG = structlog.get_logger()
@@ -620,7 +628,7 @@ async def get_actions(
tags=["agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "run_workflow",
"x-fern-sdk-method-name": "run_workflow_legacy",
},
)
@legacy_base_router.post(
@@ -628,7 +636,7 @@ async def get_actions(
response_model=RunWorkflowResponse,
include_in_schema=False,
)
async def run_workflow(
async def run_workflow_legacy(
request: Request,
background_tasks: BackgroundTasks,
workflow_id: str, # this is the workflow_permanent_id internally
@@ -647,42 +655,19 @@ async def run_workflow(
browser_session_id=workflow_request.browser_session_id,
)
if template:
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_run = await workflow_service.run_workflow(
workflow_id=workflow_id,
organization_id=current_org.organization_id,
workflow_request=workflow_request,
workflow_permanent_id=workflow_id,
organization_id=current_org.organization_id,
template=template,
version=version,
max_steps_override=x_max_steps_override,
is_template_workflow=template,
)
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_id,
organization_id=None if template else current_org.organization_id,
version=version,
)
await app.DATABASE.create_task_run(
task_run_type=RunType.workflow_run,
organization_id=current_org.organization_id,
run_id=workflow_run.workflow_run_id,
title=workflow.title,
)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_workflow(
max_steps=x_max_steps_override,
api_key=x_api_key,
request_id=request_id,
request=request,
background_tasks=background_tasks,
organization_id=current_org.organization_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=x_max_steps_override,
browser_session_id=workflow_request.browser_session_id,
api_key=x_api_key,
)
return RunWorkflowResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
@@ -806,7 +791,7 @@ async def get_workflow_run_timeline(
@legacy_base_router.get(
"/workflows/runs/{workflow_run_id}",
response_model=WorkflowRunResponse,
response_model=WorkflowRunResponseBase,
tags=["agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
@@ -815,13 +800,13 @@ async def get_workflow_run_timeline(
)
@legacy_base_router.get(
"/workflows/runs/{workflow_run_id}/",
response_model=WorkflowRunResponse,
response_model=WorkflowRunResponseBase,
include_in_schema=False,
)
async def get_workflow_run(
workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowRunResponse:
) -> WorkflowRunResponseBase:
analytics.capture("skyvern-oss-agent-workflow-run-get")
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id(
workflow_run_id=workflow_run_id,
@@ -1385,7 +1370,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
@base_router.post(
"/tasks",
"/tasks/run",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
@@ -1398,7 +1383,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
400: {"description": "Invalid agent engine"},
},
)
@base_router.post("/tasks/", include_in_schema=False)
@base_router.post("/tasks/run/", include_in_schema=False)
async def run_task(
request: Request,
background_tasks: BackgroundTasks,
@@ -1523,3 +1508,69 @@ async def run_task(
),
)
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
@base_router.post(
"/workflows/run",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "run_workflow",
},
description="Run a workflow",
summary="Run a workflow",
responses={
200: {"description": "Successfully run workflow"},
400: {"description": "Invalid workflow run request"},
},
)
@base_router.post("/workflows/run/", include_in_schema=False)
async def run_workflow(
request: Request,
background_tasks: BackgroundTasks,
workflow_run_request: WorkflowRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
template: bool = Query(False),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> WorkflowRunResponse:
analytics.capture("skyvern-oss-run-workflow")
await PermissionCheckerFactory.get_instance().check(
current_org, browser_session_id=workflow_run_request.browser_session_id
)
workflow_id = workflow_run_request.workflow_id
context = skyvern_context.ensure_context()
request_id = context.request_id
legacy_workflow_request = WorkflowRequestBody(
data=workflow_run_request.parameters,
proxy_location=workflow_run_request.proxy_location,
webhook_callback_url=workflow_run_request.webhook_url,
totp_identifier=workflow_run_request.totp_identifier,
totp_url=workflow_run_request.totp_url,
browser_session_id=workflow_run_request.browser_session_id,
)
workflow_run = await workflow_service.run_workflow(
workflow_id=workflow_id,
organization_id=current_org.organization_id,
workflow_request=legacy_workflow_request,
template=template,
version=None,
max_steps=x_max_steps_override,
api_key=x_api_key,
request_id=request_id,
request=request,
background_tasks=background_tasks,
)
return WorkflowRunResponse(
run_id=workflow_run.workflow_run_id,
run_type=RunType.workflow_run,
status=str(workflow_run.status),
output=workflow_run.output,
failure_reason=workflow_run.failure_reason,
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
run_request=workflow_run_request,
downloaded_files=workflow_run.downloaded_files,
recording_url=workflow_run.recording_url,
)

View File

@@ -133,7 +133,7 @@ class WorkflowRunOutputParameter(BaseModel):
created_at: datetime
class WorkflowRunResponse(BaseModel):
class WorkflowRunResponseBase(BaseModel):
workflow_id: str
workflow_run_id: str
status: WorkflowRunStatus

View File

@@ -80,7 +80,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRun,
WorkflowRunOutputParameter,
WorkflowRunParameter,
WorkflowRunResponse,
WorkflowRunResponseBase,
WorkflowRunStatus,
WorkflowStatus,
)
@@ -958,7 +958,7 @@ class WorkflowService:
workflow_run_id: str,
organization_id: str,
include_cost: bool = False,
) -> WorkflowRunResponse:
) -> WorkflowRunResponseBase:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
if workflow_run is None:
LOG.error(f"Workflow run {workflow_run_id} not found")
@@ -977,7 +977,7 @@ class WorkflowService:
workflow_run_id: str,
organization_id: str,
include_cost: bool = False,
) -> WorkflowRunResponse:
) -> WorkflowRunResponseBase:
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id)
if workflow is None:
LOG.error(f"Workflow {workflow_permanent_id} not found")
@@ -1073,7 +1073,7 @@ class WorkflowService:
# successful steps are the ones that have a status of completed and the total count of unique step.order
successful_steps = [step for step in workflow_run_steps if step.status == StepStatus.completed]
total_cost = 0.1 * (len(successful_steps) + len(text_prompt_blocks))
return WorkflowRunResponse(
return WorkflowRunResponseBase(
workflow_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id,
status=workflow_run.status,