Sync cloud skyvern to oss skyvern (#55)

This commit is contained in:
Kerem Yilmaz
2024-03-12 22:28:16 -07:00
committed by GitHub
parent 647ea2ac0f
commit 15d78d7b08
25 changed files with 554 additions and 163 deletions

View File

@@ -1,6 +1,5 @@
import asyncio
import json
import time
from collections import Counter
from datetime import datetime
import requests
@@ -19,8 +18,8 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
@@ -55,7 +54,6 @@ class WorkflowService:
:param max_steps_override: The max steps override for the workflow run, if any.
:return: The created workflow run.
"""
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
# Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None:
@@ -83,9 +81,6 @@ class WorkflowService:
)
)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
workflow_run_parameters = []
@@ -113,11 +108,6 @@ class WorkflowService:
workflow_run_parameters.append(workflow_run_parameter)
LOG.info(
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run
async def execute_workflow(
@@ -129,59 +119,92 @@ class WorkflowService:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
context_manager = ContextManager(wp_wps_tuples)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
for block_idx, block in enumerate(blocks):
parameters = block.get_all_parameters()
await context_manager.register_block_parameters(parameters)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
block_type=block.block_type,
try:
for block_idx, block in enumerate(blocks):
parameters = block.get_all_parameters()
await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
workflow_run_id, parameters
)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}",
block_type=block.block_type,
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
)
await block.execute(workflow_run_id=workflow_run_id)
except Exception:
LOG.exception(
f"Error while executing workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
exc_info=True,
)
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
# Get last task for workflow run
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
if not task:
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
if not tasks:
LOG.warning(
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run
# Update workflow status
if task.status == "completed":
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "failed":
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "terminated":
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.warning(
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id,
status=task.status,
workflow_run_status=workflow_run.status,
)
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
tasks=tasks,
api_key=api_key,
last_task=task,
)
return workflow_run
async def handle_workflow_status(self, workflow_run: WorkflowRun, tasks: list[Task]) -> WorkflowRun:
task_counts_by_status = Counter(task.status for task in tasks)
# Create a mapping of status to (action, log_func, log_message)
status_action_mapping = {
TaskStatus.running: (None, LOG.error, "has running tasks, this should not happen"),
TaskStatus.terminated: (
self.mark_workflow_run_as_terminated,
LOG.warning,
"has terminated tasks, marking as terminated",
),
TaskStatus.failed: (self.mark_workflow_run_as_failed, LOG.warning, "has failed tasks, marking as failed"),
TaskStatus.completed: (
self.mark_workflow_run_as_completed,
LOG.info,
"tasks are completed, marking as completed",
),
}
for status, (action, log_func, log_message) in status_action_mapping.items():
if task_counts_by_status.get(status, 0) > 0:
if action is not None:
await action(workflow_run_id=workflow_run.workflow_run_id)
if log_func and log_message:
log_func(
f"Workflow run {workflow_run.workflow_run_id} {log_message}",
workflow_run_id=workflow_run.workflow_run_id,
task_counts_by_status=task_counts_by_status,
)
return workflow_run
# Handle unexpected state
LOG.error(
f"Workflow run {workflow_run.workflow_run_id} has tasks in an unexpected state, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
task_counts_by_status=task_counts_by_status,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run
async def create_workflow(
self,
organization_id: str,
@@ -354,6 +377,15 @@ class WorkflowService:
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
payload = {
task.task_id: {
"title": task.title,
"extracted_information": task.extracted_information,
"navigation_payload": task.navigation_payload,
"errors": await app.agent.get_task_errors(task=task),
}
for task in workflow_run_tasks
}
return WorkflowRunStatusResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
@@ -365,50 +397,28 @@ class WorkflowService:
parameters=parameters_with_value,
screenshot_urls=screenshot_urls,
recording_url=recording_url,
payload=payload,
)
async def send_workflow_response(
self,
workflow: Workflow,
workflow_run: WorkflowRun,
last_task: Task,
tasks: list[Task],
api_key: str | None = None,
close_browser_on_completion: bool = True,
) -> None:
analytics.capture("skyvern-oss-agent-workflow-status", {"status": workflow_run.status})
all_workflow_task_ids = [task.task_id for task in tasks]
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, close_browser_on_completion
workflow_run.workflow_run_id, all_workflow_task_ids, close_browser_on_completion
)
if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run)
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
await self.persist_debug_artifacts(browser_state, tasks[-1], workflow, workflow_run)
# Wait for all tasks to complete before generating the links for the artifacts
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
workflow_run_id=workflow_run.workflow_run_id
)
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
try:
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
)
LOG.info(
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.warning(
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
if not workflow_run.webhook_callback_url:
LOG.warning(
"Workflow has no webhook callback url. Not sending workflow response",
@@ -493,19 +503,35 @@ class WorkflowService:
)
async def persist_har_data(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
self, browser_state: BrowserState, last_step: Step, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
)
if har_data:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)
if last_step:
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)
async def persist_tracing_data(
self, browser_state: BrowserState, last_step: Step, workflow_run: WorkflowRun
) -> None:
if browser_state.browser_context is None or browser_state.browser_artifacts.traces_dir is None:
return
trace_path = f"{browser_state.browser_artifacts.traces_dir}/{workflow_run.workflow_run_id}.zip"
await app.ARTIFACT_MANAGER.create_artifact(step=last_step, artifact_type=ArtifactType.TRACE, path=trace_path)
async def persist_debug_artifacts(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
)
if not last_step:
return
await self.persist_har_data(browser_state, last_step, workflow, workflow_run)
await self.persist_tracing_data(browser_state, last_step, workflow_run)