Sync cloud skyvern to oss skyvern (#55)

This commit is contained in:
Kerem Yilmaz
2024-03-12 22:28:16 -07:00
committed by GitHub
parent 647ea2ac0f
commit 15d78d7b08
25 changed files with 554 additions and 163 deletions

View File

@@ -1,7 +1,9 @@
import uuid
from typing import TYPE_CHECKING, Any
import structlog
from skyvern.exceptions import WorkflowRunContextNotInitialized
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
@@ -12,15 +14,15 @@ if TYPE_CHECKING:
LOG = structlog.get_logger()
class ContextManager:
aws_client: AsyncAWSClient
class WorkflowRunContext:
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
self.aws_client = AsyncAWSClient()
self.parameters = {}
self.values = {}
self.secrets = {}
for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key]
@@ -32,8 +34,33 @@ class ContextManager:
self.parameters[parameter.key] = parameter
self.values[parameter.key] = run_parameter.value
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
"""
Get the value of a parameter. If the parameter is an AWS secret, the value will be the random secret id, not
the actual secret value. This will be used when building the navigation payload since we don't want to expose
the actual secret value in the payload.
"""
return self.values[key]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
def get_original_secret_value_or_none(self, secret_id: str) -> Any:
"""
Get the original secret value from the secrets dict. If the secret id is not found, return None.
"""
return self.secrets.get(secret_id)
@staticmethod
def generate_random_secret_id() -> str:
return f"secret_{uuid.uuid4()}"
async def register_parameter_value(
self,
aws_client: AsyncAWSClient,
parameter: PARAMETER_TYPE,
) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW:
@@ -42,15 +69,21 @@ class ContextManager:
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
)
elif parameter.parameter_type == ParameterType.AWS_SECRET:
secret_value = await self.aws_client.get_secret(parameter.aws_key)
# If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
# We'll replace the random secret id with the actual secret value when we need to use it.
secret_value = await aws_client.get_secret(parameter.aws_key)
if secret_value is not None:
self.values[parameter.key] = secret_value
random_secret_id = self.generate_random_secret_id()
self.secrets[random_secret_id] = secret_value
self.values[parameter.key] = random_secret_id
else:
# ContextParameter values will be set within the blocks
return None
async def register_block_parameters(
self,
aws_client: AsyncAWSClient,
parameters: list[PARAMETER_TYPE],
) -> None:
for parameter in parameters:
@@ -67,13 +100,41 @@ class ContextManager:
)
self.parameters[parameter.key] = parameter
await self.register_parameter_value(parameter)
await self.register_parameter_value(aws_client, parameter)
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
return self.values[key]
class WorkflowContextManager:
aws_client: AsyncAWSClient
workflow_run_contexts: dict[str, WorkflowRunContext]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self) -> None:
self.aws_client = AsyncAWSClient()
self.workflow_run_contexts = {}
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
if workflow_run_id not in self.workflow_run_contexts:
LOG.error(f"WorkflowRunContext not initialized for workflow run {workflow_run_id}")
raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id)
def initialize_workflow_run_context(
self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]
) -> WorkflowRunContext:
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples)
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
return workflow_run_context
def get_workflow_run_context(self, workflow_run_id: str) -> WorkflowRunContext:
self._validate_workflow_run_context(workflow_run_id)
return self.workflow_run_contexts[workflow_run_id]
async def register_block_parameters_for_workflow_run(
self,
workflow_run_id: str,
parameters: list[PARAMETER_TYPE],
) -> None:
self._validate_workflow_run_context(workflow_run_id)
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(self.aws_client, parameters)

View File

@@ -13,7 +13,7 @@ from skyvern.exceptions import (
)
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
LOG = structlog.get_logger()
@@ -33,8 +33,12 @@ class Block(BaseModel, abc.ABC):
def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__())
@staticmethod
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
@abc.abstractmethod
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
pass
@abc.abstractmethod
@@ -48,9 +52,12 @@ class TaskBlock(Block):
block_type: Literal[BlockType.TASK] = BlockType.TASK
url: str | None = None
title: str = "Untitled Task"
navigation_goal: str | None = None
data_extraction_goal: str | None = None
data_schema: dict[str, Any] | None = None
# error code to error description for the LLM
error_code_mapping: dict[str, str] | None = None
max_retries: int = 0
parameters: list[PARAMETER_TYPE] = []
@@ -89,8 +96,8 @@ class TaskBlock(Block):
return order, retry + 1
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
task = None
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
will_retry = True
@@ -104,7 +111,7 @@ class TaskBlock(Block):
task_block=self,
workflow=workflow,
workflow_run=workflow_run,
context_manager=context_manager,
workflow_run_context=workflow_run_context,
task_order=task_order,
task_retry=task_retry,
)
@@ -131,7 +138,18 @@ class TaskBlock(Block):
if self.url:
await browser_state.page.goto(self.url)
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
try:
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
failure_reason=str(e),
)
raise e
# Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
if not updated_task:
@@ -188,9 +206,9 @@ class ForLoopBlock(Block):
return context_parameters
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = context_manager.get_value(self.loop_over.key)
parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(parameter_value, list):
return parameter_value
else:
@@ -200,8 +218,9 @@ class ForLoopBlock(Block):
# TODO (kerem): Implement this for context parameters
raise NotImplementedError
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
loop_over_values = self.get_loop_over_parameter_values(context_manager)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type,
@@ -211,8 +230,8 @@ class ForLoopBlock(Block):
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
context_manager.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id)
return None

View File

@@ -72,3 +72,4 @@ class WorkflowRunStatusResponse(BaseModel):
parameters: dict[str, Any]
screenshot_urls: list[str] | None = None
recording_url: str | None = None
payload: dict[str, Any] | None = None

View File

@@ -1,6 +1,5 @@
import asyncio
import json
import time
from collections import Counter
from datetime import datetime
import requests
@@ -19,8 +18,8 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
@@ -55,7 +54,6 @@ class WorkflowService:
:param max_steps_override: The max steps override for the workflow run, if any.
:return: The created workflow run.
"""
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
# Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None:
@@ -83,9 +81,6 @@ class WorkflowService:
)
)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
workflow_run_parameters = []
@@ -113,11 +108,6 @@ class WorkflowService:
workflow_run_parameters.append(workflow_run_parameter)
LOG.info(
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run
async def execute_workflow(
@@ -129,59 +119,92 @@ class WorkflowService:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
context_manager = ContextManager(wp_wps_tuples)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
for block_idx, block in enumerate(blocks):
parameters = block.get_all_parameters()
await context_manager.register_block_parameters(parameters)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
block_type=block.block_type,
try:
for block_idx, block in enumerate(blocks):
parameters = block.get_all_parameters()
await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
workflow_run_id, parameters
)
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}",
block_type=block.block_type,
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
)
await block.execute(workflow_run_id=workflow_run_id)
except Exception:
LOG.exception(
f"Error while executing workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
exc_info=True,
)
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
# Get last task for workflow run
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
if not task:
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
if not tasks:
LOG.warning(
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run
# Update workflow status
if task.status == "completed":
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "failed":
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "terminated":
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.warning(
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id,
status=task.status,
workflow_run_status=workflow_run.status,
)
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
await self.send_workflow_response(
workflow=workflow,
workflow_run=workflow_run,
tasks=tasks,
api_key=api_key,
last_task=task,
)
return workflow_run
async def handle_workflow_status(self, workflow_run: WorkflowRun, tasks: list[Task]) -> WorkflowRun:
task_counts_by_status = Counter(task.status for task in tasks)
# Create a mapping of status to (action, log_func, log_message)
status_action_mapping = {
TaskStatus.running: (None, LOG.error, "has running tasks, this should not happen"),
TaskStatus.terminated: (
self.mark_workflow_run_as_terminated,
LOG.warning,
"has terminated tasks, marking as terminated",
),
TaskStatus.failed: (self.mark_workflow_run_as_failed, LOG.warning, "has failed tasks, marking as failed"),
TaskStatus.completed: (
self.mark_workflow_run_as_completed,
LOG.info,
"tasks are completed, marking as completed",
),
}
for status, (action, log_func, log_message) in status_action_mapping.items():
if task_counts_by_status.get(status, 0) > 0:
if action is not None:
await action(workflow_run_id=workflow_run.workflow_run_id)
if log_func and log_message:
log_func(
f"Workflow run {workflow_run.workflow_run_id} {log_message}",
workflow_run_id=workflow_run.workflow_run_id,
task_counts_by_status=task_counts_by_status,
)
return workflow_run
# Handle unexpected state
LOG.error(
f"Workflow run {workflow_run.workflow_run_id} has tasks in an unexpected state, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
task_counts_by_status=task_counts_by_status,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run
async def create_workflow(
self,
organization_id: str,
@@ -354,6 +377,15 @@ class WorkflowService:
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
payload = {
task.task_id: {
"title": task.title,
"extracted_information": task.extracted_information,
"navigation_payload": task.navigation_payload,
"errors": await app.agent.get_task_errors(task=task),
}
for task in workflow_run_tasks
}
return WorkflowRunStatusResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
@@ -365,50 +397,28 @@ class WorkflowService:
parameters=parameters_with_value,
screenshot_urls=screenshot_urls,
recording_url=recording_url,
payload=payload,
)
async def send_workflow_response(
self,
workflow: Workflow,
workflow_run: WorkflowRun,
last_task: Task,
tasks: list[Task],
api_key: str | None = None,
close_browser_on_completion: bool = True,
) -> None:
analytics.capture("skyvern-oss-agent-workflow-status", {"status": workflow_run.status})
all_workflow_task_ids = [task.task_id for task in tasks]
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, close_browser_on_completion
workflow_run.workflow_run_id, all_workflow_task_ids, close_browser_on_completion
)
if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run)
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
await self.persist_debug_artifacts(browser_state, tasks[-1], workflow, workflow_run)
# Wait for all tasks to complete before generating the links for the artifacts
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
workflow_run_id=workflow_run.workflow_run_id
)
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
try:
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
)
LOG.info(
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.warning(
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
if not workflow_run.webhook_callback_url:
LOG.warning(
"Workflow has no webhook callback url. Not sending workflow response",
@@ -493,19 +503,35 @@ class WorkflowService:
)
async def persist_har_data(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
self, browser_state: BrowserState, last_step: Step, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
)
if har_data:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)
if last_step:
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
)
async def persist_tracing_data(
self, browser_state: BrowserState, last_step: Step, workflow_run: WorkflowRun
) -> None:
if browser_state.browser_context is None or browser_state.browser_artifacts.traces_dir is None:
return
trace_path = f"{browser_state.browser_artifacts.traces_dir}/{workflow_run.workflow_run_id}.zip"
await app.ARTIFACT_MANAGER.create_artifact(step=last_step, artifact_type=ArtifactType.TRACE, path=trace_path)
async def persist_debug_artifacts(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
)
if not last_step:
return
await self.persist_har_data(browser_state, last_step, workflow, workflow_run)
await self.persist_tracing_data(browser_state, last_step, workflow_run)