validate taskV2 step (#1811)
This commit is contained in:
@@ -355,11 +355,21 @@ class UnknownElementTreeFormat(SkyvernException):
|
||||
super().__init__(f"Unknown element tree format {fmt}")
|
||||
|
||||
|
||||
class StepTerminationError(SkyvernException):
|
||||
def __init__(self, step_id: str, reason: str) -> None:
|
||||
class TerminationError(SkyvernException):
|
||||
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
|
||||
super().__init__(f"Termination error. Reason: {reason}")
|
||||
|
||||
|
||||
class StepTerminationError(TerminationError):
|
||||
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
|
||||
super().__init__(f"Step {step_id} cannot be executed and task is failed. Reason: {reason}")
|
||||
|
||||
|
||||
class TaskTerminationError(TerminationError):
|
||||
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
|
||||
super().__init__(f"Task {task_id} failed. Reason: {reason}")
|
||||
|
||||
|
||||
class BlockTerminationError(SkyvernException):
|
||||
def __init__(self, workflow_run_block_id: str, workflow_run_id: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
|
||||
@@ -425,6 +425,11 @@ class AgentFunction:
|
||||
) -> None:
|
||||
return
|
||||
|
||||
async def validate_task_execution(
|
||||
self, organization_id: str | None = None, task_id: str | None = None, task_version: str | None = None
|
||||
) -> None:
|
||||
return
|
||||
|
||||
async def prepare_step_execution(
|
||||
self,
|
||||
organization: Organization | None,
|
||||
|
||||
@@ -8,7 +8,7 @@ import httpx
|
||||
import structlog
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, UrlGenerationFailure
|
||||
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, TaskTerminationError, UrlGenerationFailure
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
@@ -252,6 +252,15 @@ async def run_observer_task(
|
||||
max_iterations_override=max_iterations_override,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
except TaskTerminationError as e:
|
||||
observer_task = await mark_observer_task_as_terminated(
|
||||
observer_cruise_id=observer_cruise_id,
|
||||
workflow_run_id=observer_task.workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
failure_reason=e.message,
|
||||
)
|
||||
LOG.info("Task v2 is terminated", observer_cruise_id=observer_cruise_id, failure_reason=e.message)
|
||||
return observer_task
|
||||
except OperationalError:
|
||||
LOG.error("Database error when running observer cruise", exc_info=True)
|
||||
observer_task = await mark_observer_task_as_failed(
|
||||
@@ -373,6 +382,13 @@ async def run_observer_task_helper(
|
||||
|
||||
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
|
||||
for i in range(max_iterations):
|
||||
# validate the task execution
|
||||
await app.AGENT_FUNCTION.validate_task_execution(
|
||||
organization_id=organization_id,
|
||||
task_id=observer_cruise_id,
|
||||
task_version="v2",
|
||||
)
|
||||
|
||||
# check the status of the workflow run
|
||||
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||
if not workflow_run:
|
||||
@@ -1251,6 +1267,23 @@ async def mark_observer_task_as_canceled(
|
||||
return observer_task
|
||||
|
||||
|
||||
async def mark_observer_task_as_terminated(
|
||||
observer_cruise_id: str,
|
||||
workflow_run_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> ObserverTask:
|
||||
observer_task = await app.DATABASE.update_observer_cruise(
|
||||
observer_cruise_id,
|
||||
organization_id=organization_id,
|
||||
status=ObserverTaskStatus.terminated,
|
||||
)
|
||||
if workflow_run_id:
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_terminated(workflow_run_id, failure_reason)
|
||||
await send_observer_task_webhook(observer_task)
|
||||
return observer_task
|
||||
|
||||
|
||||
def _get_extracted_data_from_block_result(
|
||||
block_result: BlockResult,
|
||||
task_type: str,
|
||||
|
||||
Reference in New Issue
Block a user