use task type instead of prompt template (#1261)

This commit is contained in:
LawyZheng
2024-11-26 11:29:33 +08:00
committed by GitHub
parent 74a9fc70d6
commit 6b417d0e83
11 changed files with 124 additions and 69 deletions

View File

@@ -0,0 +1,33 @@
"""use task type instead of prompt template
Revision ID: 56085e451bec
Revises: 2d79d5fc1baa
Create Date: 2024-11-26 03:22:11.224805+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "56085e451bec"
down_revision: Union[str, None] = "2d79d5fc1baa"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tasks", sa.Column("task_type", sa.String(), nullable=True))
op.drop_column("tasks", "prompt_template")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tasks", sa.Column("prompt_template", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_column("tasks", "task_type")
# ### end Alembic commands ###

View File

@@ -530,3 +530,15 @@ class InteractWithDisabledElement(SkyvernException):
super().__init__(
f"The element(id={element_id}) now is disabled, try to interact with it later when it's enabled."
)
class FailedToParseActionInstruction(SkyvernException):
def __init__(self, reason: str | None, error_type: str | None):
super().__init__(
f"Failed to parse the action instruction as '{reason}({error_type})'",
)
class UnsupportedTaskType(SkyvernException):
def __init__(self, task_type: str):
super().__init__(f"Not supported task type [{task_type}]")

View File

@@ -19,6 +19,7 @@ from skyvern.exceptions import (
BrowserStateMissingPage,
EmptyScrapePage,
FailedToNavigateToUrl,
FailedToParseActionInstruction,
FailedToSendWebhook,
FailedToTakeScreenshot,
InvalidTaskStatusTransition,
@@ -30,6 +31,8 @@ from skyvern.exceptions import (
StepUnableToExecuteError,
TaskAlreadyCanceled,
TaskNotFound,
UnsupportedActionType,
UnsupportedTaskType,
)
from skyvern.forge import app
from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool
@@ -39,7 +42,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.validators import validate_url
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager
@@ -133,7 +136,7 @@ class ForgeAgent:
task_url = validate_url(task_url)
task = await app.DATABASE.create_task(
url=task_url,
prompt_template=task_block.prompt_template,
task_type=task_block.task_type,
complete_criterion=task_block.complete_criterion,
terminate_criterion=task_block.terminate_criterion,
title=task_block.title or task_block.label,
@@ -524,6 +527,23 @@ class ForgeAgent:
need_call_webhook=False,
)
return step, detailed_output, None
except (UnsupportedActionType, UnsupportedTaskType, FailedToParseActionInstruction) as e:
LOG.warning(
"unsupported task type or action type, marking the task as failed",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
step_retry=step.retry_index,
)
await self.fail_task(task, step, e.message)
await self.clean_up_task(
task=task,
last_step=step,
api_key=api_key,
need_call_webhook=False,
)
return step, detailed_output, None
except Exception as e:
LOG.exception(
"Got an unexpected exception in step, marking task as failed",
@@ -930,6 +950,9 @@ class ForgeAgent:
output=detailed_agent_step_output.to_agent_step_output(),
)
return failed_step, detailed_agent_step_output.get_clean_detailed_output()
except (UnsupportedActionType, UnsupportedTaskType, FailedToParseActionInstruction):
raise
except Exception as e:
LOG.exception(
"Unexpected exception in agent_step, marking step as failed",
@@ -1179,6 +1202,7 @@ class ForgeAgent:
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
extract_action_prompt = await self._build_extract_action_prompt(
task,
step,
browser_state,
element_tree_in_prompt,
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
@@ -1216,6 +1240,7 @@ class ForgeAgent:
async def _build_extract_action_prompt(
self,
task: Task,
step: Step,
browser_state: BrowserState,
element_tree_in_prompt: str,
verification_code_check: bool = False,
@@ -1234,7 +1259,37 @@ class ForgeAgent:
task, expire_verification_code=expire_verification_code
)
template = task.prompt_template if task.prompt_template else TaskPromptTemplate.ExtractAction
task_type = task.task_type if task.task_type else TaskType.general
template = ""
if task_type == TaskType.general:
template = "extract-action"
elif task_type == TaskType.validation:
template = "decisive-criterion-validate"
elif task_type == TaskType.action:
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=navigation_goal)
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step)
if json_response.get("error"):
raise FailedToParseActionInstruction(
reason=json_response.get("thought"), error_type=json_response.get("error")
)
action_type: str = json_response.get("action_type") or ""
action_type = ActionType[action_type.upper()]
if action_type == ActionType.CLICK:
template = "single-click-action"
elif action_type == ActionType.INPUT_TEXT:
template = "single-input-action"
elif action_type == ActionType.UPLOAD_FILE:
template = "single-upload-action"
elif action_type == ActionType.SELECT_OPTION:
template = "single-select-action"
else:
raise UnsupportedActionType(action_type=action_type)
if not template:
raise UnsupportedTaskType(task_type=task_type)
return prompt_engine.load_prompt(
template=template,
navigation_goal=navigation_goal,
@@ -1916,6 +1971,7 @@ class ForgeAgent:
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
extract_action_prompt = await self._build_extract_action_prompt(
task,
step,
browser_state,
element_tree_in_prompt,
verification_code_check=False,

View File

@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from skyvern.config import settings
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ActionModel,
@@ -113,13 +113,13 @@ class AgentDB:
retry: int | None = None,
max_steps_per_run: int | None = None,
error_code_mapping: dict[str, str] | None = None,
prompt_template: str = TaskPromptTemplate.ExtractAction,
task_type: str = TaskType.general,
) -> Task:
try:
async with self.Session() as session:
new_task = TaskModel(
status="created",
prompt_template=prompt_template,
task_type=task_type,
url=url,
title=title,
webhook_callback_url=webhook_callback_url,

View File

@@ -5,10 +5,7 @@ class OrganizationAuthTokenType(StrEnum):
api = "api"
class TaskPromptTemplate(StrEnum):
ExtractAction = "extract-action"
DecisiveCriterionValidate = "decisive-criterion-validate"
SingleClickAction = "single-click-action"
SingleInputAction = "single-input-action"
SingleUploadAction = "single-upload-action"
SingleSelectAction = "single-select-action"
class TaskType(StrEnum):
general = "general"
validation = "validation"
action = "action"

View File

@@ -17,7 +17,7 @@ from sqlalchemy import (
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.db.id import (
generate_action_id,
generate_artifact_id,
@@ -54,7 +54,7 @@ class TaskModel(Base):
totp_verification_url = Column(String)
totp_identifier = Column(String)
title = Column(String)
prompt_template = Column(String, default=TaskPromptTemplate.ExtractAction)
task_type = Column(String, default=TaskType.general)
url = Column(String)
navigation_goal = Column(String)
data_extraction_goal = Column(String)

View File

@@ -60,7 +60,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
status=TaskStatus(task_obj.status),
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
prompt_template=task_obj.prompt_template,
task_type=task_obj.task_type,
title=task_obj.title,
url=task_obj.url,
complete_criterion=task_obj.complete_criterion,

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, HttpUrl, field_validator
from skyvern.exceptions import BlockedHost, InvalidTaskStatusTransition, TaskAlreadyCanceled
from skyvern.forge.sdk.core.validators import is_blocked_host
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
class ProxyLocation(StrEnum):
@@ -86,10 +86,10 @@ class TaskBase(BaseModel):
description="Criterion to terminate",
examples=["Terminate if 'existing account' shows up on the page"],
)
prompt_template: str | None = Field(
default=TaskPromptTemplate.ExtractAction,
description="The prompt template used for task",
examples=[TaskPromptTemplate.ExtractAction, TaskPromptTemplate.DecisiveCriterionValidate],
task_type: TaskType | None = Field(
default=TaskType.general,
description="The type of the task",
examples=[TaskType.general, TaskType.validation],
)

View File

@@ -109,13 +109,6 @@ class WorkflowParameterMissingRequiredValue(BaseWorkflowHTTPException):
)
class FailedToParseActionInstruction(SkyvernException):
def __init__(self, reason: str | None, error_type: str | None):
super().__init__(
f"Failed to parse the action instruction as '{reason}({error_type})'",
)
class InvalidWaitBlockTime(SkyvernException):
def __init__(self, max_sec: int):
super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.")

View File

@@ -41,12 +41,11 @@ from skyvern.forge.sdk.api.files import (
get_path_for_workflow_download_directory,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import (
FailedToParseActionInstruction,
InvalidEmailClientConfiguration,
InvalidFileType,
NoValidEmailRecipient,
@@ -58,7 +57,6 @@ from skyvern.forge.sdk.workflow.models.parameter import (
OutputParameter,
WorkflowParameter,
)
from skyvern.webeye.actions.actions import ActionType
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
@@ -185,7 +183,7 @@ class Block(BaseModel, abc.ABC):
class BaseTaskBlock(Block):
prompt_template: str = TaskPromptTemplate.ExtractAction
task_type: str = TaskType.general
url: str | None = None
title: str = ""
complete_criterion: str | None = None
@@ -1333,41 +1331,6 @@ class ValidationBlock(BaseTaskBlock):
class ActionBlock(BaseTaskBlock):
block_type: Literal[BlockType.ACTION] = BlockType.ACTION
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
try:
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=self.navigation_goal)
# TODO: no step here, so LLM call won't be saved as an artifact
json_response = await app.LLM_API_HANDLER(prompt=prompt)
if json_response.get("error"):
raise FailedToParseActionInstruction(
reason=json_response.get("thought"), error_type=json_response.get("error")
)
action_type: str = json_response.get("action_type") or ""
action_type = ActionType[action_type.upper()]
prompt_template = ""
if action_type == ActionType.CLICK:
prompt_template = TaskPromptTemplate.SingleClickAction
elif action_type == ActionType.INPUT_TEXT:
prompt_template = TaskPromptTemplate.SingleInputAction
elif action_type == ActionType.UPLOAD_FILE:
prompt_template = TaskPromptTemplate.SingleUploadAction
elif action_type == ActionType.SELECT_OPTION:
prompt_template = TaskPromptTemplate.SingleSelectAction
if not prompt_template:
raise Exception(
f"Not supported action for action block. Currently we only support [click, input_text, upload_file, select_option], but got [{action_type}]"
)
except Exception as e:
return self.build_block_result(
success=False, failure_reason=str(e), output_parameter_value=None, status=BlockStatus.failed
)
self.prompt_template = prompt_template
return await super().execute(workflow_run_id=workflow_run_id, kwargs=kwargs)
class NavigationBlock(BaseTaskBlock):
block_type: Literal[BlockType.NAVIGATION] = BlockType.NAVIGATION

View File

@@ -18,7 +18,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.settings_manager import SettingsManager
@@ -1354,7 +1354,7 @@ class WorkflowService:
return ValidationBlock(
label=block_yaml.label,
prompt_template=TaskPromptTemplate.DecisiveCriterionValidate,
task_type=TaskType.validation,
parameters=validation_block_parameters,
output_parameter=output_parameter,
complete_criterion=block_yaml.complete_criterion,
@@ -1379,6 +1379,7 @@ class WorkflowService:
label=block_yaml.label,
url=block_yaml.url,
title=block_yaml.title,
task_type=TaskType.action,
parameters=action_block_parameters,
output_parameter=output_parameter,
navigation_goal=block_yaml.navigation_goal,