From 6b417d0e835b2d1de8a9ba8c01dc8cf5a107d961 Mon Sep 17 00:00:00 2001 From: LawyZheng Date: Tue, 26 Nov 2024 11:29:33 +0800 Subject: [PATCH] use task type instead of prompt template (#1261) --- ...se_task_type_instead_of_prompt_template.py | 33 ++++++++++ skyvern/exceptions.py | 12 ++++ skyvern/forge/agent.py | 62 ++++++++++++++++++- skyvern/forge/sdk/db/client.py | 6 +- skyvern/forge/sdk/db/enums.py | 11 ++-- skyvern/forge/sdk/db/models.py | 4 +- skyvern/forge/sdk/db/utils.py | 2 +- skyvern/forge/sdk/schemas/tasks.py | 10 +-- skyvern/forge/sdk/workflow/exceptions.py | 7 --- skyvern/forge/sdk/workflow/models/block.py | 41 +----------- skyvern/forge/sdk/workflow/service.py | 5 +- 11 files changed, 124 insertions(+), 69 deletions(-) create mode 100644 alembic/versions/2024_11_26_0322-56085e451bec_use_task_type_instead_of_prompt_template.py diff --git a/alembic/versions/2024_11_26_0322-56085e451bec_use_task_type_instead_of_prompt_template.py b/alembic/versions/2024_11_26_0322-56085e451bec_use_task_type_instead_of_prompt_template.py new file mode 100644 index 00000000..38c8f388 --- /dev/null +++ b/alembic/versions/2024_11_26_0322-56085e451bec_use_task_type_instead_of_prompt_template.py @@ -0,0 +1,33 @@ +"""use task type instead of prompt template + +Revision ID: 56085e451bec +Revises: 2d79d5fc1baa +Create Date: 2024-11-26 03:22:11.224805+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "56085e451bec" +down_revision: Union[str, None] = "2d79d5fc1baa" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tasks", sa.Column("task_type", sa.String(), nullable=True)) + op.drop_column("tasks", "prompt_template") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tasks", sa.Column("prompt_template", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_column("tasks", "task_type") + # ### end Alembic commands ### diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index a599559f..622bc098 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -530,3 +530,15 @@ class InteractWithDisabledElement(SkyvernException): super().__init__( f"The element(id={element_id}) now is disabled, try to interact with it later when it's enabled." ) + + +class FailedToParseActionInstruction(SkyvernException): + def __init__(self, reason: str | None, error_type: str | None): + super().__init__( + f"Failed to parse the action instruction as '{reason}({error_type})'", + ) + + +class UnsupportedTaskType(SkyvernException): + def __init__(self, task_type: str): + super().__init__(f"Not supported task type [{task_type}]") diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 893c7760..8e7eed1b 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -19,6 +19,7 @@ from skyvern.exceptions import ( BrowserStateMissingPage, EmptyScrapePage, FailedToNavigateToUrl, + FailedToParseActionInstruction, FailedToSendWebhook, FailedToTakeScreenshot, InvalidTaskStatusTransition, @@ -30,6 +31,8 @@ from skyvern.exceptions import ( StepUnableToExecuteError, TaskAlreadyCanceled, TaskNotFound, + UnsupportedActionType, + UnsupportedTaskType, ) from skyvern.forge import app from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool @@ -39,7 +42,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.validators import validate_url -from skyvern.forge.sdk.db.enums import TaskPromptTemplate +from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.models import Organization, Step, StepStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus from skyvern.forge.sdk.settings_manager import SettingsManager @@ -133,7 +136,7 @@ class ForgeAgent: task_url = validate_url(task_url) task = await app.DATABASE.create_task( url=task_url, - prompt_template=task_block.prompt_template, + task_type=task_block.task_type, complete_criterion=task_block.complete_criterion, terminate_criterion=task_block.terminate_criterion, title=task_block.title or task_block.label, @@ -524,6 +527,23 @@ class ForgeAgent: need_call_webhook=False, ) return step, detailed_output, None + except (UnsupportedActionType, UnsupportedTaskType, FailedToParseActionInstruction) as e: + LOG.warning( + "unsupported task type or action type, marking the task as failed", + task_id=task.task_id, + step_id=step.step_id, + step_order=step.order, + step_retry=step.retry_index, + ) + await self.fail_task(task, step, e.message) + await self.clean_up_task( + task=task, + last_step=step, + api_key=api_key, + need_call_webhook=False, + ) + return step, detailed_output, None + except Exception as e: LOG.exception( "Got an unexpected exception in step, marking task as failed", @@ -930,6 +950,9 @@ class ForgeAgent: output=detailed_agent_step_output.to_agent_step_output(), ) return failed_step, detailed_agent_step_output.get_clean_detailed_output() + except (UnsupportedActionType, UnsupportedTaskType, FailedToParseActionInstruction): + raise + except Exception as e: LOG.exception( "Unexpected exception in agent_step, marking step as failed", @@ -1179,6 +1202,7 @@ class ForgeAgent: element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format) extract_action_prompt = await self._build_extract_action_prompt( task, + step, browser_state, element_tree_in_prompt, verification_code_check=bool(task.totp_verification_url or task.totp_identifier), @@ -1216,6 +1240,7 @@ class ForgeAgent: async def _build_extract_action_prompt( self, task: Task, + step: Step, browser_state: BrowserState, element_tree_in_prompt: str, verification_code_check: bool = False, @@ -1234,7 +1259,37 @@ class ForgeAgent: task, expire_verification_code=expire_verification_code ) - template = task.prompt_template if task.prompt_template else TaskPromptTemplate.ExtractAction + task_type = task.task_type if task.task_type else TaskType.general + template = "" + if task_type == TaskType.general: + template = "extract-action" + elif task_type == TaskType.validation: + template = "decisive-criterion-validate" + elif task_type == TaskType.action: + prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=navigation_goal) + json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step) + if json_response.get("error"): + raise FailedToParseActionInstruction( + reason=json_response.get("thought"), error_type=json_response.get("error") + ) + + action_type: str = json_response.get("action_type") or "" + action_type = ActionType[action_type.upper()] + + if action_type == ActionType.CLICK: + template = "single-click-action" + elif action_type == ActionType.INPUT_TEXT: + template = "single-input-action" + elif action_type == ActionType.UPLOAD_FILE: + template = "single-upload-action" + elif action_type == ActionType.SELECT_OPTION: + template = "single-select-action" + else: + raise UnsupportedActionType(action_type=action_type) + + if not template: + raise UnsupportedTaskType(task_type=task_type) + return prompt_engine.load_prompt( template=template, navigation_goal=navigation_goal, @@ -1916,6 +1971,7 @@ class ForgeAgent: element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML) extract_action_prompt = await self._build_extract_action_prompt( task, + step, browser_state, element_tree_in_prompt, verification_code_check=False, diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 31dcc957..52589758 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from skyvern.config import settings from skyvern.exceptions import WorkflowParameterNotFound from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType -from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.exceptions import NotFoundError from skyvern.forge.sdk.db.models import ( ActionModel, @@ -113,13 +113,13 @@ class AgentDB: retry: int | None = None, max_steps_per_run: int | None = None, error_code_mapping: dict[str, str] | None = None, - prompt_template: str = TaskPromptTemplate.ExtractAction, + task_type: str = TaskType.general, ) -> Task: try: async with self.Session() as session: new_task = TaskModel( status="created", - prompt_template=prompt_template, + task_type=task_type, url=url, title=title, webhook_callback_url=webhook_callback_url, diff --git a/skyvern/forge/sdk/db/enums.py b/skyvern/forge/sdk/db/enums.py index 79327f0d..99ecc30b 100644 --- a/skyvern/forge/sdk/db/enums.py +++ b/skyvern/forge/sdk/db/enums.py @@ -5,10 +5,7 @@ class OrganizationAuthTokenType(StrEnum): api = "api" -class TaskPromptTemplate(StrEnum): - ExtractAction = "extract-action" - DecisiveCriterionValidate = "decisive-criterion-validate" - SingleClickAction = "single-click-action" - SingleInputAction = "single-input-action" - SingleUploadAction = "single-upload-action" - SingleSelectAction = "single-select-action" +class TaskType(StrEnum): + general = "general" + validation = "validation" + action = "action" diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 6dbb6c42..05ad24ab 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -17,7 +17,7 @@ from sqlalchemy import ( from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import DeclarativeBase -from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.id import ( generate_action_id, generate_artifact_id, @@ -54,7 +54,7 @@ class TaskModel(Base): totp_verification_url = Column(String) totp_identifier = Column(String) title = Column(String) - prompt_template = Column(String, default=TaskPromptTemplate.ExtractAction) + task_type = Column(String, default=TaskType.general) url = Column(String) navigation_goal = Column(String) data_extraction_goal = Column(String) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index d8c7a1af..19ea6cf8 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -60,7 +60,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: status=TaskStatus(task_obj.status), created_at=task_obj.created_at, modified_at=task_obj.modified_at, - prompt_template=task_obj.prompt_template, + task_type=task_obj.task_type, title=task_obj.title, url=task_obj.url, complete_criterion=task_obj.complete_criterion, diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index a38aaed4..1505b7a4 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, HttpUrl, field_validator from skyvern.exceptions import BlockedHost, InvalidTaskStatusTransition, TaskAlreadyCanceled from skyvern.forge.sdk.core.validators import is_blocked_host -from skyvern.forge.sdk.db.enums import TaskPromptTemplate +from skyvern.forge.sdk.db.enums import TaskType class ProxyLocation(StrEnum): @@ -86,10 +86,10 @@ class TaskBase(BaseModel): description="Criterion to terminate", examples=["Terminate if 'existing account' shows up on the page"], ) - prompt_template: str | None = Field( - default=TaskPromptTemplate.ExtractAction, - description="The prompt template used for task", - examples=[TaskPromptTemplate.ExtractAction, TaskPromptTemplate.DecisiveCriterionValidate], + task_type: TaskType | None = Field( + default=TaskType.general, + description="The type of the task", + examples=[TaskType.general, TaskType.validation], ) diff --git a/skyvern/forge/sdk/workflow/exceptions.py b/skyvern/forge/sdk/workflow/exceptions.py index 99eafd44..7d8b38d5 100644 --- a/skyvern/forge/sdk/workflow/exceptions.py +++ b/skyvern/forge/sdk/workflow/exceptions.py @@ -109,13 +109,6 @@ class WorkflowParameterMissingRequiredValue(BaseWorkflowHTTPException): ) -class FailedToParseActionInstruction(SkyvernException): - def __init__(self, reason: str | None, error_type: str | None): - super().__init__( - f"Failed to parse the action instruction as '{reason}({error_type})'", - ) - - class InvalidWaitBlockTime(SkyvernException): def __init__(self, max_sec: int): super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.") diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index a4a368b9..07dab217 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -41,12 +41,11 @@ from skyvern.forge.sdk.api.files import ( get_path_for_workflow_download_directory, ) from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory -from skyvern.forge.sdk.db.enums import TaskPromptTemplate +from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.exceptions import ( - FailedToParseActionInstruction, InvalidEmailClientConfiguration, InvalidFileType, NoValidEmailRecipient, @@ -58,7 +57,6 @@ from skyvern.forge.sdk.workflow.models.parameter import ( OutputParameter, WorkflowParameter, ) -from skyvern.webeye.actions.actions import ActionType from skyvern.webeye.browser_factory import BrowserState LOG = structlog.get_logger() @@ -185,7 +183,7 @@ class Block(BaseModel, abc.ABC): class BaseTaskBlock(Block): - prompt_template: str = TaskPromptTemplate.ExtractAction + task_type: str = TaskType.general url: str | None = None title: str = "" complete_criterion: str | None = None @@ -1333,41 +1331,6 @@ class ValidationBlock(BaseTaskBlock): class ActionBlock(BaseTaskBlock): block_type: Literal[BlockType.ACTION] = BlockType.ACTION - async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: - try: - prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=self.navigation_goal) - # TODO: no step here, so LLM call won't be saved as an artifact - json_response = await app.LLM_API_HANDLER(prompt=prompt) - if json_response.get("error"): - raise FailedToParseActionInstruction( - reason=json_response.get("thought"), error_type=json_response.get("error") - ) - - action_type: str = json_response.get("action_type") or "" - action_type = ActionType[action_type.upper()] - - prompt_template = "" - if action_type == ActionType.CLICK: - prompt_template = TaskPromptTemplate.SingleClickAction - elif action_type == ActionType.INPUT_TEXT: - prompt_template = TaskPromptTemplate.SingleInputAction - elif action_type == ActionType.UPLOAD_FILE: - prompt_template = TaskPromptTemplate.SingleUploadAction - elif action_type == ActionType.SELECT_OPTION: - prompt_template = TaskPromptTemplate.SingleSelectAction - - if not prompt_template: - raise Exception( - f"Not supported action for action block. Currently we only support [click, input_text, upload_file, select_option], but got [{action_type}]" - ) - except Exception as e: - return self.build_block_result( - success=False, failure_reason=str(e), output_parameter_value=None, status=BlockStatus.failed - ) - - self.prompt_template = prompt_template - return await super().execute(workflow_run_id=workflow_run_id, kwargs=kwargs) - class NavigationBlock(BaseTaskBlock): block_type: Literal[BlockType.NAVIGATION] = BlockType.NAVIGATION diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 9052bcfa..2282341c 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -18,7 +18,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.skyvern_context import SkyvernContext -from skyvern.forge.sdk.db.enums import TaskPromptTemplate +from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.models import Organization, Step from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task from skyvern.forge.sdk.settings_manager import SettingsManager @@ -1354,7 +1354,7 @@ class WorkflowService: return ValidationBlock( label=block_yaml.label, - prompt_template=TaskPromptTemplate.DecisiveCriterionValidate, + task_type=TaskType.validation, parameters=validation_block_parameters, output_parameter=output_parameter, complete_criterion=block_yaml.complete_criterion, @@ -1379,6 +1379,7 @@ class WorkflowService: label=block_yaml.label, url=block_yaml.url, title=block_yaml.title, + task_type=TaskType.action, parameters=action_block_parameters, output_parameter=output_parameter, navigation_goal=block_yaml.navigation_goal,