validate taskV2 step (#1811)

This commit is contained in:
Shuchang Zheng
2025-02-22 00:44:12 -08:00
committed by GitHub
parent 8821b7e150
commit 9a07c0bc6f
3 changed files with 51 additions and 3 deletions

View File

@@ -355,11 +355,21 @@ class UnknownElementTreeFormat(SkyvernException):
super().__init__(f"Unknown element tree format {fmt}") super().__init__(f"Unknown element tree format {fmt}")
class StepTerminationError(SkyvernException): class TerminationError(SkyvernException):
def __init__(self, step_id: str, reason: str) -> None: def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
super().__init__(f"Termination error. Reason: {reason}")
class StepTerminationError(TerminationError):
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
super().__init__(f"Step {step_id} cannot be executed and task is failed. Reason: {reason}") super().__init__(f"Step {step_id} cannot be executed and task is failed. Reason: {reason}")
class TaskTerminationError(TerminationError):
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
super().__init__(f"Task {task_id} failed. Reason: {reason}")
class BlockTerminationError(SkyvernException): class BlockTerminationError(SkyvernException):
def __init__(self, workflow_run_block_id: str, workflow_run_id: str, reason: str) -> None: def __init__(self, workflow_run_block_id: str, workflow_run_id: str, reason: str) -> None:
super().__init__( super().__init__(

View File

@@ -425,6 +425,11 @@ class AgentFunction:
) -> None: ) -> None:
return return
async def validate_task_execution(
self, organization_id: str | None = None, task_id: str | None = None, task_version: str | None = None
) -> None:
return
async def prepare_step_execution( async def prepare_step_execution(
self, self,
organization: Organization | None, organization: Organization | None,

View File

@@ -8,7 +8,7 @@ import httpx
import structlog import structlog
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, UrlGenerationFailure from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, TaskTerminationError, UrlGenerationFailure
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
@@ -252,6 +252,15 @@ async def run_observer_task(
max_iterations_override=max_iterations_override, max_iterations_override=max_iterations_override,
browser_session_id=browser_session_id, browser_session_id=browser_session_id,
) )
except TaskTerminationError as e:
observer_task = await mark_observer_task_as_terminated(
observer_cruise_id=observer_cruise_id,
workflow_run_id=observer_task.workflow_run_id,
organization_id=organization_id,
failure_reason=e.message,
)
LOG.info("Task v2 is terminated", observer_cruise_id=observer_cruise_id, failure_reason=e.message)
return observer_task
except OperationalError: except OperationalError:
LOG.error("Database error when running observer cruise", exc_info=True) LOG.error("Database error when running observer cruise", exc_info=True)
observer_task = await mark_observer_task_as_failed( observer_task = await mark_observer_task_as_failed(
@@ -373,6 +382,13 @@ async def run_observer_task_helper(
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
for i in range(max_iterations): for i in range(max_iterations):
# validate the task execution
await app.AGENT_FUNCTION.validate_task_execution(
organization_id=organization_id,
task_id=observer_cruise_id,
task_version="v2",
)
# check the status of the workflow run # check the status of the workflow run
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id) workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
if not workflow_run: if not workflow_run:
@@ -1251,6 +1267,23 @@ async def mark_observer_task_as_canceled(
return observer_task return observer_task
async def mark_observer_task_as_terminated(
observer_cruise_id: str,
workflow_run_id: str | None = None,
organization_id: str | None = None,
failure_reason: str | None = None,
) -> ObserverTask:
observer_task = await app.DATABASE.update_observer_cruise(
observer_cruise_id,
organization_id=organization_id,
status=ObserverTaskStatus.terminated,
)
if workflow_run_id:
await app.WORKFLOW_SERVICE.mark_workflow_run_as_terminated(workflow_run_id, failure_reason)
await send_observer_task_webhook(observer_task)
return observer_task
def _get_extracted_data_from_block_result( def _get_extracted_data_from_block_result(
block_result: BlockResult, block_result: BlockResult,
task_type: str, task_type: str,