task v2 termination (#4589)
This commit is contained in:
@@ -146,6 +146,80 @@ async def _summarize_max_steps_failure_reason(
|
||||
return ""
|
||||
|
||||
|
||||
async def _handle_task_v2_termination(
|
||||
task_v2_id: str,
|
||||
organization_id: str,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
workflow_permanent_id: str,
|
||||
termination_reason: str | None,
|
||||
iteration: int,
|
||||
source: str | None = None,
|
||||
) -> TaskV2:
|
||||
"""
|
||||
Handle task v2 termination by creating a termination thought and marking the task as terminated.
|
||||
|
||||
Args:
|
||||
task_v2_id: The task v2 ID
|
||||
organization_id: The organization ID
|
||||
workflow_run_id: The workflow run ID
|
||||
workflow_id: The workflow ID
|
||||
workflow_permanent_id: The workflow permanent ID
|
||||
termination_reason: The reason for termination (from LLM response)
|
||||
iteration: The current iteration number
|
||||
source: Optional source identifier (e.g., "completion_check")
|
||||
|
||||
Returns:
|
||||
The updated TaskV2 object with terminated status
|
||||
"""
|
||||
log_message = "Task v2 should terminate"
|
||||
if source:
|
||||
log_message = f"Task v2 should terminate according to {source}"
|
||||
log_message += " - goal is impossible to achieve"
|
||||
|
||||
LOG.info(
|
||||
log_message,
|
||||
iteration=iteration,
|
||||
workflow_run_id=workflow_run_id,
|
||||
termination_reason=termination_reason,
|
||||
)
|
||||
|
||||
# Create a dedicated termination thought for UI visibility
|
||||
termination_thought = await app.DATABASE.create_thought(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
thought_type=ThoughtType.termination,
|
||||
thought_scenario=ThoughtScenario.termination,
|
||||
thought=termination_reason or "Task goal is impossible to achieve",
|
||||
)
|
||||
|
||||
output: dict[str, Any] = {
|
||||
"should_terminate": True,
|
||||
"termination_reason": termination_reason,
|
||||
"iteration": iteration,
|
||||
}
|
||||
if source:
|
||||
output["source"] = source
|
||||
|
||||
await app.DATABASE.update_thought(
|
||||
thought_id=termination_thought.observer_thought_id,
|
||||
organization_id=organization_id,
|
||||
output=output,
|
||||
)
|
||||
|
||||
task_v2 = await mark_task_v2_as_terminated(
|
||||
task_v2_id=task_v2_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
failure_reason=termination_reason or "Task goal is impossible to achieve",
|
||||
)
|
||||
|
||||
return task_v2
|
||||
|
||||
|
||||
async def initialize_task_v2(
|
||||
organization: Organization,
|
||||
user_prompt: str,
|
||||
@@ -526,6 +600,16 @@ async def run_task_v2_helper(
|
||||
current_run_id,
|
||||
properties={"organization_id": organization_id, "task_url": task_v2.url},
|
||||
)
|
||||
enable_task_v2_termination = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
|
||||
"ENABLE_TASK_V2_TERMINATION",
|
||||
current_run_id,
|
||||
properties={"organization_id": organization_id, "task_url": task_v2.url},
|
||||
)
|
||||
LOG.info(
|
||||
"Task v2 termination feature flag",
|
||||
enable_task_v2_termination=enable_task_v2_termination,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
skyvern_context.set(
|
||||
SkyvernContext(
|
||||
organization_id=organization_id,
|
||||
@@ -702,6 +786,7 @@ async def run_task_v2_helper(
|
||||
user_goal=user_prompt,
|
||||
task_history=task_history,
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
enable_termination=bool(enable_task_v2_termination),
|
||||
)
|
||||
thought = await app.DATABASE.create_thought(
|
||||
task_v2_id=task_v2_id,
|
||||
@@ -730,6 +815,8 @@ async def run_task_v2_helper(
|
||||
)
|
||||
# see if the user goal has achieved or not
|
||||
user_goal_achieved = task_v2_response.get("user_goal_achieved", False)
|
||||
should_terminate = task_v2_response.get("should_terminate", False)
|
||||
termination_reason = task_v2_response.get("termination_reason")
|
||||
observation = task_v2_response.get("page_info", "")
|
||||
thoughts: str = task_v2_response.get("thoughts", "")
|
||||
plan = task_v2_response.get("plan", "")
|
||||
@@ -741,7 +828,12 @@ async def run_task_v2_helper(
|
||||
thought=thoughts,
|
||||
observation=observation,
|
||||
answer=plan,
|
||||
output={"task_type": task_type, "user_goal_achieved": user_goal_achieved},
|
||||
output={
|
||||
"task_type": task_type,
|
||||
"user_goal_achieved": user_goal_achieved,
|
||||
"should_terminate": should_terminate,
|
||||
"termination_reason": termination_reason,
|
||||
},
|
||||
)
|
||||
|
||||
if user_goal_achieved is True:
|
||||
@@ -763,6 +855,19 @@ async def run_task_v2_helper(
|
||||
)
|
||||
break
|
||||
|
||||
# Only handle termination if the feature flag is enabled
|
||||
if enable_task_v2_termination and should_terminate is True:
|
||||
task_v2 = await _handle_task_v2_termination(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
termination_reason=termination_reason,
|
||||
iteration=i,
|
||||
)
|
||||
return workflow, workflow_run, task_v2
|
||||
|
||||
if not plan:
|
||||
LOG.warning("No plan found in task v2 response", task_v2_response=task_v2_response)
|
||||
continue
|
||||
@@ -925,6 +1030,7 @@ async def run_task_v2_helper(
|
||||
user_goal=user_prompt,
|
||||
task_history=task_history,
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
enable_termination=bool(enable_task_v2_termination),
|
||||
)
|
||||
thought = await app.DATABASE.create_thought(
|
||||
task_v2_id=task_v2_id,
|
||||
@@ -949,12 +1055,18 @@ async def run_task_v2_helper(
|
||||
task_history=task_history,
|
||||
)
|
||||
user_goal_achieved = completion_resp.get("user_goal_achieved", False)
|
||||
should_terminate = completion_resp.get("should_terminate", False)
|
||||
termination_reason = completion_resp.get("termination_reason")
|
||||
thought_content = completion_resp.get("thoughts", "")
|
||||
await app.DATABASE.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=organization_id,
|
||||
thought=thought_content,
|
||||
output={"user_goal_achieved": user_goal_achieved},
|
||||
output={
|
||||
"user_goal_achieved": user_goal_achieved,
|
||||
"should_terminate": should_terminate,
|
||||
"termination_reason": termination_reason,
|
||||
},
|
||||
)
|
||||
if user_goal_achieved:
|
||||
LOG.info(
|
||||
@@ -977,6 +1089,20 @@ async def run_task_v2_helper(
|
||||
)
|
||||
break
|
||||
|
||||
# Only handle termination if the feature flag is enabled
|
||||
if enable_task_v2_termination and should_terminate:
|
||||
task_v2 = await _handle_task_v2_termination(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
termination_reason=termination_reason,
|
||||
iteration=i,
|
||||
source="completion_check",
|
||||
)
|
||||
return workflow, workflow_run, task_v2
|
||||
|
||||
# total step number validation
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
total_step_count = await app.DATABASE.get_total_unique_step_order_count_by_task_ids(
|
||||
|
||||
Reference in New Issue
Block a user