validate taskV2 step (#1811)

This commit is contained in:
Shuchang Zheng
2025-02-22 00:44:12 -08:00
committed by GitHub
parent 8821b7e150
commit 9a07c0bc6f
3 changed files with 51 additions and 3 deletions

View File

@@ -355,11 +355,21 @@ class UnknownElementTreeFormat(SkyvernException):
super().__init__(f"Unknown element tree format {fmt}")
class StepTerminationError(SkyvernException):
def __init__(self, step_id: str, reason: str) -> None:
class TerminationError(SkyvernException):
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
super().__init__(f"Termination error. Reason: {reason}")
class StepTerminationError(TerminationError):
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
super().__init__(f"Step {step_id} cannot be executed and task is failed. Reason: {reason}")
class TaskTerminationError(TerminationError):
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
super().__init__(f"Task {task_id} failed. Reason: {reason}")
class BlockTerminationError(SkyvernException):
def __init__(self, workflow_run_block_id: str, workflow_run_id: str, reason: str) -> None:
super().__init__(

View File

@@ -425,6 +425,11 @@ class AgentFunction:
) -> None:
return
async def validate_task_execution(
self, organization_id: str | None = None, task_id: str | None = None, task_version: str | None = None
) -> None:
return
async def prepare_step_execution(
self,
organization: Organization | None,

View File

@@ -8,7 +8,7 @@ import httpx
import structlog
from sqlalchemy.exc import OperationalError
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, UrlGenerationFailure
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, TaskTerminationError, UrlGenerationFailure
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import ArtifactType
@@ -252,6 +252,15 @@ async def run_observer_task(
max_iterations_override=max_iterations_override,
browser_session_id=browser_session_id,
)
except TaskTerminationError as e:
observer_task = await mark_observer_task_as_terminated(
observer_cruise_id=observer_cruise_id,
workflow_run_id=observer_task.workflow_run_id,
organization_id=organization_id,
failure_reason=e.message,
)
LOG.info("Task v2 is terminated", observer_cruise_id=observer_cruise_id, failure_reason=e.message)
return observer_task
except OperationalError:
LOG.error("Database error when running observer cruise", exc_info=True)
observer_task = await mark_observer_task_as_failed(
@@ -373,6 +382,13 @@ async def run_observer_task_helper(
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
for i in range(max_iterations):
# validate the task execution
await app.AGENT_FUNCTION.validate_task_execution(
organization_id=organization_id,
task_id=observer_cruise_id,
task_version="v2",
)
# check the status of the workflow run
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
if not workflow_run:
@@ -1251,6 +1267,23 @@ async def mark_observer_task_as_canceled(
return observer_task
async def mark_observer_task_as_terminated(
observer_cruise_id: str,
workflow_run_id: str | None = None,
organization_id: str | None = None,
failure_reason: str | None = None,
) -> ObserverTask:
observer_task = await app.DATABASE.update_observer_cruise(
observer_cruise_id,
organization_id=organization_id,
status=ObserverTaskStatus.terminated,
)
if workflow_run_id:
await app.WORKFLOW_SERVICE.mark_workflow_run_as_terminated(workflow_run_id, failure_reason)
await send_observer_task_webhook(observer_task)
return observer_task
def _get_extracted_data_from_block_result(
block_result: BlockResult,
task_type: str,