Sync cloud skyvern to oss skyvern (#55)
This commit is contained in:
@@ -13,6 +13,7 @@ from starlette_context.plugins.base import Plugin
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
from skyvern.scheduler import SCHEDULER
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -58,6 +59,12 @@ class Agent:
|
||||
),
|
||||
)
|
||||
|
||||
# Register the scheduler on startup so that we can schedule jobs dynamically
|
||||
@app.on_event("startup")
|
||||
def start_scheduler() -> None:
|
||||
LOG.info("Starting the skyvern scheduler.")
|
||||
SCHEDULER.start()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
LOG.exception("Unexpected error in agent server.", exc_info=exc)
|
||||
|
||||
@@ -55,6 +55,7 @@ class AgentDB:
|
||||
async def create_task(
|
||||
self,
|
||||
url: str,
|
||||
title: str | None,
|
||||
navigation_goal: str | None,
|
||||
data_extraction_goal: str | None,
|
||||
navigation_payload: dict[str, Any] | list | str | None,
|
||||
@@ -65,12 +66,14 @@ class AgentDB:
|
||||
workflow_run_id: str | None = None,
|
||||
order: int | None = None,
|
||||
retry: int | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
url=url,
|
||||
title=title,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
navigation_goal=navigation_goal,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
@@ -81,6 +84,7 @@ class AgentDB:
|
||||
workflow_run_id=workflow_run_id,
|
||||
order=order,
|
||||
retry=retry,
|
||||
error_code_mapping=error_code_mapping,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
@@ -312,11 +316,16 @@ class AgentDB:
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
status: TaskStatus | None = None,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
errors: list[dict[str, Any]] | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Task:
|
||||
if status is None and extracted_information is None and failure_reason is None and errors is None:
|
||||
raise ValueError(
|
||||
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
|
||||
)
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
@@ -325,11 +334,14 @@ class AgentDB:
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
task.status = status
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if extracted_information is not None:
|
||||
task.extracted_information = extracted_information
|
||||
if failure_reason is not None:
|
||||
task.failure_reason = failure_reason
|
||||
if errors is not None:
|
||||
task.errors = errors
|
||||
session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
|
||||
@@ -29,6 +29,7 @@ class TaskModel(Base):
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
status = Column(String)
|
||||
webhook_callback_url = Column(String)
|
||||
title = Column(String)
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
data_extraction_goal = Column(String)
|
||||
@@ -40,6 +41,8 @@ class TaskModel(Base):
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
|
||||
order = Column(Integer, nullable=True)
|
||||
retry = Column(Integer, nullable=True)
|
||||
error_code_mapping = Column(JSON, nullable=True)
|
||||
errors = Column(JSON, default=[], nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
status=TaskStatus(task_obj.status),
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
title=task_obj.title,
|
||||
url=task_obj.url,
|
||||
webhook_callback_url=task_obj.webhook_callback_url,
|
||||
navigation_goal=task_obj.navigation_goal,
|
||||
@@ -61,6 +62,8 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
workflow_run_id=task_obj.workflow_run_id,
|
||||
order=task_obj.order,
|
||||
retry=task_obj.retry,
|
||||
error_code_mapping=task_obj.error_code_mapping,
|
||||
errors=task_obj.errors,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import abc
|
||||
|
||||
import structlog
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from skyvern.forge import app
|
||||
@@ -8,6 +9,8 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AsyncExecutor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
@@ -43,6 +46,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
LOG.info("Executing task using background task executor", task_id=task.task_id)
|
||||
step = await app.DATABASE.create_step(
|
||||
task.task_id,
|
||||
order=0,
|
||||
@@ -52,7 +56,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
|
||||
task = await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
TaskStatus.running,
|
||||
status=TaskStatus.running,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
@@ -78,6 +82,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
max_steps_override: int | None,
|
||||
api_key: str | None,
|
||||
) -> None:
|
||||
LOG.info("Executing workflow using background task executor", workflow_run_id=workflow_run_id)
|
||||
background_tasks.add_task(
|
||||
app.WORKFLOW_SERVICE.execute_workflow,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
||||
@@ -11,8 +11,16 @@ from skyvern.forge import app
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import (
|
||||
CreateTaskResponse,
|
||||
ProxyLocation,
|
||||
Task,
|
||||
TaskRequest,
|
||||
TaskResponse,
|
||||
TaskStatus,
|
||||
)
|
||||
from skyvern.forge.sdk.services import org_auth_service
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
@@ -80,10 +88,13 @@ async def create_agent_task(
|
||||
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
||||
agent = request["agent"]
|
||||
|
||||
if current_org and current_org.organization_name == "CoverageCat":
|
||||
task.proxy_location = ProxyLocation.RESIDENTIAL
|
||||
|
||||
created_task = await agent.create_task(task, current_org.organization_id)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await app.ASYNC_EXECUTOR.execute_task(
|
||||
await AsyncExecutorFactory.get_executor().execute_task(
|
||||
background_tasks=background_tasks,
|
||||
task=created_task,
|
||||
organization=current_org,
|
||||
@@ -398,10 +409,6 @@ async def execute_workflow(
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> RunWorkflowResponse:
|
||||
analytics.capture("skyvern-oss-agent-workflow-execute")
|
||||
LOG.info(
|
||||
f"Running workflow {workflow_id}",
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
context = skyvern_context.ensure_context()
|
||||
request_id = context.request_id
|
||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||
@@ -413,7 +420,7 @@ async def execute_workflow(
|
||||
)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await app.ASYNC_EXECUTOR.execute_workflow(
|
||||
await AsyncExecutorFactory.get_executor().execute_workflow(
|
||||
background_tasks=background_tasks,
|
||||
organization=current_org,
|
||||
workflow_id=workflow_id,
|
||||
|
||||
@@ -18,6 +18,11 @@ class ProxyLocation(StrEnum):
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
title: str | None = Field(
|
||||
default=None,
|
||||
description="The title of the task.",
|
||||
examples=["Get a quote for car insurance"],
|
||||
)
|
||||
url: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
@@ -41,17 +46,27 @@ class TaskRequest(BaseModel):
|
||||
examples=["Extract the quote price"],
|
||||
)
|
||||
navigation_payload: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description="The user's details needed to achieve the task.",
|
||||
examples=[{"name": "John Doe", "email": "john@doe.com"}],
|
||||
)
|
||||
error_code_mapping: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="The mapping of error codes and their descriptions.",
|
||||
examples=[
|
||||
{
|
||||
"out_of_stock": "Return this error when the product is out of stock",
|
||||
"not_found": "Return this error when the product is not found",
|
||||
}
|
||||
],
|
||||
)
|
||||
proxy_location: ProxyLocation | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description="The location of the proxy to use for the task.",
|
||||
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
|
||||
)
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description="The requested schema of the extracted information.",
|
||||
)
|
||||
|
||||
@@ -122,6 +137,7 @@ class Task(TaskRequest):
|
||||
workflow_run_id: str | None = None
|
||||
order: int | None = None
|
||||
retry: int | None = None
|
||||
errors: list[dict[str, Any]] = []
|
||||
|
||||
def validate_update(
|
||||
self,
|
||||
@@ -162,6 +178,7 @@ class Task(TaskRequest):
|
||||
failure_reason=failure_reason or self.failure_reason,
|
||||
screenshot_url=screenshot_url,
|
||||
recording_url=recording_url,
|
||||
errors=self.errors,
|
||||
)
|
||||
|
||||
|
||||
@@ -175,6 +192,7 @@ class TaskResponse(BaseModel):
|
||||
screenshot_url: str | None = None
|
||||
recording_url: str | None = None
|
||||
failure_reason: str | None = None
|
||||
errors: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class CreateTaskResponse(BaseModel):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.exceptions import WorkflowRunContextNotInitialized
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
|
||||
|
||||
@@ -12,15 +14,15 @@ if TYPE_CHECKING:
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ContextManager:
|
||||
aws_client: AsyncAWSClient
|
||||
class WorkflowRunContext:
|
||||
parameters: dict[str, PARAMETER_TYPE]
|
||||
values: dict[str, Any]
|
||||
secrets: dict[str, Any]
|
||||
|
||||
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
|
||||
self.aws_client = AsyncAWSClient()
|
||||
self.parameters = {}
|
||||
self.values = {}
|
||||
self.secrets = {}
|
||||
for parameter, run_parameter in workflow_parameter_tuples:
|
||||
if parameter.key in self.parameters:
|
||||
prev_value = self.parameters[parameter.key]
|
||||
@@ -32,8 +34,33 @@ class ContextManager:
|
||||
self.parameters[parameter.key] = parameter
|
||||
self.values[parameter.key] = run_parameter.value
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
|
||||
def get_value(self, key: str) -> Any:
|
||||
"""
|
||||
Get the value of a parameter. If the parameter is an AWS secret, the value will be the random secret id, not
|
||||
the actual secret value. This will be used when building the navigation payload since we don't want to expose
|
||||
the actual secret value in the payload.
|
||||
"""
|
||||
return self.values[key]
|
||||
|
||||
def set_value(self, key: str, value: Any) -> None:
|
||||
self.values[key] = value
|
||||
|
||||
def get_original_secret_value_or_none(self, secret_id: str) -> Any:
|
||||
"""
|
||||
Get the original secret value from the secrets dict. If the secret id is not found, return None.
|
||||
"""
|
||||
return self.secrets.get(secret_id)
|
||||
|
||||
@staticmethod
|
||||
def generate_random_secret_id() -> str:
|
||||
return f"secret_{uuid.uuid4()}"
|
||||
|
||||
async def register_parameter_value(
|
||||
self,
|
||||
aws_client: AsyncAWSClient,
|
||||
parameter: PARAMETER_TYPE,
|
||||
) -> None:
|
||||
if parameter.parameter_type == ParameterType.WORKFLOW:
|
||||
@@ -42,15 +69,21 @@ class ContextManager:
|
||||
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
|
||||
)
|
||||
elif parameter.parameter_type == ParameterType.AWS_SECRET:
|
||||
secret_value = await self.aws_client.get_secret(parameter.aws_key)
|
||||
# If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict
|
||||
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
|
||||
# We'll replace the random secret id with the actual secret value when we need to use it.
|
||||
secret_value = await aws_client.get_secret(parameter.aws_key)
|
||||
if secret_value is not None:
|
||||
self.values[parameter.key] = secret_value
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
self.secrets[random_secret_id] = secret_value
|
||||
self.values[parameter.key] = random_secret_id
|
||||
else:
|
||||
# ContextParameter values will be set within the blocks
|
||||
return None
|
||||
|
||||
async def register_block_parameters(
|
||||
self,
|
||||
aws_client: AsyncAWSClient,
|
||||
parameters: list[PARAMETER_TYPE],
|
||||
) -> None:
|
||||
for parameter in parameters:
|
||||
@@ -67,13 +100,41 @@ class ContextManager:
|
||||
)
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
await self.register_parameter_value(parameter)
|
||||
await self.register_parameter_value(aws_client, parameter)
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
|
||||
def get_value(self, key: str) -> Any:
|
||||
return self.values[key]
|
||||
class WorkflowContextManager:
|
||||
aws_client: AsyncAWSClient
|
||||
workflow_run_contexts: dict[str, WorkflowRunContext]
|
||||
|
||||
def set_value(self, key: str, value: Any) -> None:
|
||||
self.values[key] = value
|
||||
parameters: dict[str, PARAMETER_TYPE]
|
||||
values: dict[str, Any]
|
||||
secrets: dict[str, Any]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.aws_client = AsyncAWSClient()
|
||||
self.workflow_run_contexts = {}
|
||||
|
||||
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
|
||||
if workflow_run_id not in self.workflow_run_contexts:
|
||||
LOG.error(f"WorkflowRunContext not initialized for workflow run {workflow_run_id}")
|
||||
raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id)
|
||||
|
||||
def initialize_workflow_run_context(
|
||||
self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]
|
||||
) -> WorkflowRunContext:
|
||||
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples)
|
||||
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
|
||||
return workflow_run_context
|
||||
|
||||
def get_workflow_run_context(self, workflow_run_id: str) -> WorkflowRunContext:
|
||||
self._validate_workflow_run_context(workflow_run_id)
|
||||
return self.workflow_run_contexts[workflow_run_id]
|
||||
|
||||
async def register_block_parameters_for_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
parameters: list[PARAMETER_TYPE],
|
||||
) -> None:
|
||||
self._validate_workflow_run_context(workflow_run_id)
|
||||
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(self.aws_client, parameters)
|
||||
|
||||
@@ -13,7 +13,7 @@ from skyvern.exceptions import (
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -33,8 +33,12 @@ class Block(BaseModel, abc.ABC):
|
||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
||||
return tuple(cls.__subclasses__())
|
||||
|
||||
@staticmethod
|
||||
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
|
||||
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -48,9 +52,12 @@ class TaskBlock(Block):
|
||||
block_type: Literal[BlockType.TASK] = BlockType.TASK
|
||||
|
||||
url: str | None = None
|
||||
title: str = "Untitled Task"
|
||||
navigation_goal: str | None = None
|
||||
data_extraction_goal: str | None = None
|
||||
data_schema: dict[str, Any] | None = None
|
||||
# error code to error description for the LLM
|
||||
error_code_mapping: dict[str, str] | None = None
|
||||
max_retries: int = 0
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
|
||||
@@ -89,8 +96,8 @@ class TaskBlock(Block):
|
||||
|
||||
return order, retry + 1
|
||||
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
task = None
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
current_retry = 0
|
||||
# initial value for will_retry is True, so that the loop runs at least once
|
||||
will_retry = True
|
||||
@@ -104,7 +111,7 @@ class TaskBlock(Block):
|
||||
task_block=self,
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
context_manager=context_manager,
|
||||
workflow_run_context=workflow_run_context,
|
||||
task_order=task_order,
|
||||
task_retry=task_retry,
|
||||
)
|
||||
@@ -131,7 +138,18 @@ class TaskBlock(Block):
|
||||
if self.url:
|
||||
await browser_state.page.goto(self.url)
|
||||
|
||||
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
|
||||
try:
|
||||
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
|
||||
except Exception as e:
|
||||
# Make sure the task is marked as failed in the database before raising the exception
|
||||
await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
status=TaskStatus.failed,
|
||||
organization_id=workflow.organization_id,
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
# Check task status
|
||||
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
|
||||
if not updated_task:
|
||||
@@ -188,9 +206,9 @@ class ForLoopBlock(Block):
|
||||
|
||||
return context_parameters
|
||||
|
||||
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
|
||||
def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
|
||||
if isinstance(self.loop_over, WorkflowParameter):
|
||||
parameter_value = context_manager.get_value(self.loop_over.key)
|
||||
parameter_value = workflow_run_context.get_value(self.loop_over.key)
|
||||
if isinstance(parameter_value, list):
|
||||
return parameter_value
|
||||
else:
|
||||
@@ -200,8 +218,9 @@ class ForLoopBlock(Block):
|
||||
# TODO (kerem): Implement this for context parameters
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
|
||||
loop_over_values = self.get_loop_over_parameter_values(context_manager)
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
|
||||
LOG.info(
|
||||
f"Number of loop_over values: {len(loop_over_values)}",
|
||||
block_type=self.block_type,
|
||||
@@ -211,8 +230,8 @@ class ForLoopBlock(Block):
|
||||
for loop_over_value in loop_over_values:
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
context_manager.set_value(context_parameter.key, context_parameter.value)
|
||||
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
|
||||
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
||||
await self.loop_block.execute(workflow_run_id=workflow_run_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -72,3 +72,4 @@ class WorkflowRunStatusResponse(BaseModel):
|
||||
parameters: dict[str, Any]
|
||||
screenshot_urls: list[str] | None = None
|
||||
recording_url: str | None = None
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
@@ -19,8 +18,8 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.context_manager import ContextManager
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
@@ -55,7 +54,6 @@ class WorkflowService:
|
||||
:param max_steps_override: The max steps override for the workflow run, if any.
|
||||
:return: The created workflow run.
|
||||
"""
|
||||
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
|
||||
# Validate the workflow and the organization
|
||||
workflow = await self.get_workflow(workflow_id=workflow_id)
|
||||
if workflow is None:
|
||||
@@ -83,9 +81,6 @@ class WorkflowService:
|
||||
)
|
||||
)
|
||||
|
||||
# Set workflow run status to running, create workflow run parameters
|
||||
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
|
||||
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
|
||||
workflow_run_parameters = []
|
||||
@@ -113,11 +108,6 @@ class WorkflowService:
|
||||
|
||||
workflow_run_parameters.append(workflow_run_parameter)
|
||||
|
||||
LOG.info(
|
||||
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
async def execute_workflow(
|
||||
@@ -129,59 +119,92 @@ class WorkflowService:
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
|
||||
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
|
||||
|
||||
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run)
|
||||
# Set workflow run status to running, create workflow run parameters
|
||||
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
||||
# Get all <workflow parameter, workflow run parameter> tuples
|
||||
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
|
||||
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job))
|
||||
context_manager = ContextManager(wp_wps_tuples)
|
||||
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples)
|
||||
# Execute workflow blocks
|
||||
blocks = workflow.workflow_definition.blocks
|
||||
for block_idx, block in enumerate(blocks):
|
||||
parameters = block.get_all_parameters()
|
||||
await context_manager.register_block_parameters(parameters)
|
||||
LOG.info(
|
||||
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}",
|
||||
block_type=block.block_type,
|
||||
try:
|
||||
for block_idx, block in enumerate(blocks):
|
||||
parameters = block.get_all_parameters()
|
||||
await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
|
||||
workflow_run_id, parameters
|
||||
)
|
||||
LOG.info(
|
||||
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}",
|
||||
block_type=block.block_type,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
await block.execute(workflow_run_id=workflow_run_id)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
f"Error while executing workflow run {workflow_run.workflow_run_id}",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
block_idx=block_idx,
|
||||
exc_info=True,
|
||||
)
|
||||
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
|
||||
|
||||
# Get last task for workflow run
|
||||
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
|
||||
if not task:
|
||||
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
|
||||
if not tasks:
|
||||
LOG.warning(
|
||||
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook",
|
||||
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
return workflow_run
|
||||
|
||||
# Update workflow status
|
||||
if task.status == "completed":
|
||||
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
elif task.status == "failed":
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
elif task.status == "terminated":
|
||||
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
workflow_run_status=workflow_run.status,
|
||||
)
|
||||
|
||||
workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
|
||||
await self.send_workflow_response(
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
tasks=tasks,
|
||||
api_key=api_key,
|
||||
last_task=task,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
async def handle_workflow_status(self, workflow_run: WorkflowRun, tasks: list[Task]) -> WorkflowRun:
|
||||
task_counts_by_status = Counter(task.status for task in tasks)
|
||||
|
||||
# Create a mapping of status to (action, log_func, log_message)
|
||||
status_action_mapping = {
|
||||
TaskStatus.running: (None, LOG.error, "has running tasks, this should not happen"),
|
||||
TaskStatus.terminated: (
|
||||
self.mark_workflow_run_as_terminated,
|
||||
LOG.warning,
|
||||
"has terminated tasks, marking as terminated",
|
||||
),
|
||||
TaskStatus.failed: (self.mark_workflow_run_as_failed, LOG.warning, "has failed tasks, marking as failed"),
|
||||
TaskStatus.completed: (
|
||||
self.mark_workflow_run_as_completed,
|
||||
LOG.info,
|
||||
"tasks are completed, marking as completed",
|
||||
),
|
||||
}
|
||||
|
||||
for status, (action, log_func, log_message) in status_action_mapping.items():
|
||||
if task_counts_by_status.get(status, 0) > 0:
|
||||
if action is not None:
|
||||
await action(workflow_run_id=workflow_run.workflow_run_id)
|
||||
if log_func and log_message:
|
||||
log_func(
|
||||
f"Workflow run {workflow_run.workflow_run_id} {log_message}",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
task_counts_by_status=task_counts_by_status,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
# Handle unexpected state
|
||||
LOG.error(
|
||||
f"Workflow run {workflow_run.workflow_run_id} has tasks in an unexpected state, marking as failed",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
task_counts_by_status=task_counts_by_status,
|
||||
)
|
||||
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
|
||||
return workflow_run
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
organization_id: str,
|
||||
@@ -354,6 +377,15 @@ class WorkflowService:
|
||||
|
||||
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
|
||||
payload = {
|
||||
task.task_id: {
|
||||
"title": task.title,
|
||||
"extracted_information": task.extracted_information,
|
||||
"navigation_payload": task.navigation_payload,
|
||||
"errors": await app.agent.get_task_errors(task=task),
|
||||
}
|
||||
for task in workflow_run_tasks
|
||||
}
|
||||
return WorkflowRunStatusResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
@@ -365,50 +397,28 @@ class WorkflowService:
|
||||
parameters=parameters_with_value,
|
||||
screenshot_urls=screenshot_urls,
|
||||
recording_url=recording_url,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
async def send_workflow_response(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
last_task: Task,
|
||||
tasks: list[Task],
|
||||
api_key: str | None = None,
|
||||
close_browser_on_completion: bool = True,
|
||||
) -> None:
|
||||
analytics.capture("skyvern-oss-agent-workflow-status", {"status": workflow_run.status})
|
||||
all_workflow_task_ids = [task.task_id for task in tasks]
|
||||
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
|
||||
workflow_run.workflow_run_id, close_browser_on_completion
|
||||
workflow_run.workflow_run_id, all_workflow_task_ids, close_browser_on_completion
|
||||
)
|
||||
if browser_state:
|
||||
await self.persist_video_data(browser_state, workflow, workflow_run)
|
||||
await self.persist_har_data(browser_state, last_task, workflow, workflow_run)
|
||||
await self.persist_debug_artifacts(browser_state, tasks[-1], workflow, workflow_run)
|
||||
|
||||
# Wait for all tasks to complete before generating the links for the artifacts
|
||||
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run.workflow_run_id
|
||||
)
|
||||
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
|
||||
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
|
||||
|
||||
try:
|
||||
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
|
||||
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
|
||||
st = time.time()
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.gather(
|
||||
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
|
||||
)
|
||||
LOG.info(
|
||||
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
|
||||
duration=time.time() - st,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
if not workflow_run.webhook_callback_url:
|
||||
LOG.warning(
|
||||
"Workflow has no webhook callback url. Not sending workflow response",
|
||||
@@ -493,19 +503,35 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
async def persist_har_data(
|
||||
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
|
||||
self, browser_state: BrowserState, last_step: Step, workflow: Workflow, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
har_data = await app.BROWSER_MANAGER.get_har_data(
|
||||
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
|
||||
)
|
||||
if har_data:
|
||||
last_step = await app.DATABASE.get_latest_step(
|
||||
task_id=last_task.task_id, organization_id=last_task.organization_id
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=last_step,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
data=har_data,
|
||||
)
|
||||
|
||||
if last_step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=last_step,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
data=har_data,
|
||||
)
|
||||
async def persist_tracing_data(
|
||||
self, browser_state: BrowserState, last_step: Step, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
if browser_state.browser_context is None or browser_state.browser_artifacts.traces_dir is None:
|
||||
return
|
||||
|
||||
trace_path = f"{browser_state.browser_artifacts.traces_dir}/{workflow_run.workflow_run_id}.zip"
|
||||
await app.ARTIFACT_MANAGER.create_artifact(step=last_step, artifact_type=ArtifactType.TRACE, path=trace_path)
|
||||
|
||||
async def persist_debug_artifacts(
|
||||
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
|
||||
) -> None:
|
||||
last_step = await app.DATABASE.get_latest_step(
|
||||
task_id=last_task.task_id, organization_id=last_task.organization_id
|
||||
)
|
||||
if not last_step:
|
||||
return
|
||||
|
||||
await self.persist_har_data(browser_state, last_step, workflow, workflow_run)
|
||||
await self.persist_tracing_data(browser_state, last_step, workflow_run)
|
||||
|
||||
Reference in New Issue
Block a user