validate taskV2 step (#1811)
This commit is contained in:
@@ -355,11 +355,21 @@ class UnknownElementTreeFormat(SkyvernException):
|
|||||||
super().__init__(f"Unknown element tree format {fmt}")
|
super().__init__(f"Unknown element tree format {fmt}")
|
||||||
|
|
||||||
|
|
||||||
class StepTerminationError(SkyvernException):
|
class TerminationError(SkyvernException):
|
||||||
def __init__(self, step_id: str, reason: str) -> None:
|
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
|
||||||
|
super().__init__(f"Termination error. Reason: {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class StepTerminationError(TerminationError):
|
||||||
|
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
|
||||||
super().__init__(f"Step {step_id} cannot be executed and task is failed. Reason: {reason}")
|
super().__init__(f"Step {step_id} cannot be executed and task is failed. Reason: {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTerminationError(TerminationError):
|
||||||
|
def __init__(self, reason: str, step_id: str | None = None, task_id: str | None = None) -> None:
|
||||||
|
super().__init__(f"Task {task_id} failed. Reason: {reason}")
|
||||||
|
|
||||||
|
|
||||||
class BlockTerminationError(SkyvernException):
|
class BlockTerminationError(SkyvernException):
|
||||||
def __init__(self, workflow_run_block_id: str, workflow_run_id: str, reason: str) -> None:
|
def __init__(self, workflow_run_block_id: str, workflow_run_id: str, reason: str) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|||||||
@@ -425,6 +425,11 @@ class AgentFunction:
|
|||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def validate_task_execution(
|
||||||
|
self, organization_id: str | None = None, task_id: str | None = None, task_version: str | None = None
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
async def prepare_step_execution(
|
async def prepare_step_execution(
|
||||||
self,
|
self,
|
||||||
organization: Organization | None,
|
organization: Organization | None,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import httpx
|
|||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy.exc import OperationalError
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, UrlGenerationFailure
|
from skyvern.exceptions import FailedToSendWebhook, ObserverCruiseNotFound, TaskTerminationError, UrlGenerationFailure
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
@@ -252,6 +252,15 @@ async def run_observer_task(
|
|||||||
max_iterations_override=max_iterations_override,
|
max_iterations_override=max_iterations_override,
|
||||||
browser_session_id=browser_session_id,
|
browser_session_id=browser_session_id,
|
||||||
)
|
)
|
||||||
|
except TaskTerminationError as e:
|
||||||
|
observer_task = await mark_observer_task_as_terminated(
|
||||||
|
observer_cruise_id=observer_cruise_id,
|
||||||
|
workflow_run_id=observer_task.workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
failure_reason=e.message,
|
||||||
|
)
|
||||||
|
LOG.info("Task v2 is terminated", observer_cruise_id=observer_cruise_id, failure_reason=e.message)
|
||||||
|
return observer_task
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
LOG.error("Database error when running observer cruise", exc_info=True)
|
LOG.error("Database error when running observer cruise", exc_info=True)
|
||||||
observer_task = await mark_observer_task_as_failed(
|
observer_task = await mark_observer_task_as_failed(
|
||||||
@@ -373,6 +382,13 @@ async def run_observer_task_helper(
|
|||||||
|
|
||||||
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
|
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
|
||||||
for i in range(max_iterations):
|
for i in range(max_iterations):
|
||||||
|
# validate the task execution
|
||||||
|
await app.AGENT_FUNCTION.validate_task_execution(
|
||||||
|
organization_id=organization_id,
|
||||||
|
task_id=observer_cruise_id,
|
||||||
|
task_version="v2",
|
||||||
|
)
|
||||||
|
|
||||||
# check the status of the workflow run
|
# check the status of the workflow run
|
||||||
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
@@ -1251,6 +1267,23 @@ async def mark_observer_task_as_canceled(
|
|||||||
return observer_task
|
return observer_task
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_observer_task_as_terminated(
|
||||||
|
observer_cruise_id: str,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
organization_id: str | None = None,
|
||||||
|
failure_reason: str | None = None,
|
||||||
|
) -> ObserverTask:
|
||||||
|
observer_task = await app.DATABASE.update_observer_cruise(
|
||||||
|
observer_cruise_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
status=ObserverTaskStatus.terminated,
|
||||||
|
)
|
||||||
|
if workflow_run_id:
|
||||||
|
await app.WORKFLOW_SERVICE.mark_workflow_run_as_terminated(workflow_run_id, failure_reason)
|
||||||
|
await send_observer_task_webhook(observer_task)
|
||||||
|
return observer_task
|
||||||
|
|
||||||
|
|
||||||
def _get_extracted_data_from_block_result(
|
def _get_extracted_data_from_block_result(
|
||||||
block_result: BlockResult,
|
block_result: BlockResult,
|
||||||
task_type: str,
|
task_type: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user