use task type instead of prompt template (#1261)
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""use task type instead of prompt template
|
||||
|
||||
Revision ID: 56085e451bec
|
||||
Revises: 2d79d5fc1baa
|
||||
Create Date: 2024-11-26 03:22:11.224805+00:00
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "56085e451bec"
|
||||
down_revision: Union[str, None] = "2d79d5fc1baa"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("tasks", sa.Column("task_type", sa.String(), nullable=True))
|
||||
op.drop_column("tasks", "prompt_template")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("tasks", sa.Column("prompt_template", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_column("tasks", "task_type")
|
||||
# ### end Alembic commands ###
|
||||
@@ -530,3 +530,15 @@ class InteractWithDisabledElement(SkyvernException):
|
||||
super().__init__(
|
||||
f"The element(id={element_id}) now is disabled, try to interact with it later when it's enabled."
|
||||
)
|
||||
|
||||
|
||||
class FailedToParseActionInstruction(SkyvernException):
|
||||
def __init__(self, reason: str | None, error_type: str | None):
|
||||
super().__init__(
|
||||
f"Failed to parse the action instruction as '{reason}({error_type})'",
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedTaskType(SkyvernException):
|
||||
def __init__(self, task_type: str):
|
||||
super().__init__(f"Not supported task type [{task_type}]")
|
||||
|
||||
@@ -19,6 +19,7 @@ from skyvern.exceptions import (
|
||||
BrowserStateMissingPage,
|
||||
EmptyScrapePage,
|
||||
FailedToNavigateToUrl,
|
||||
FailedToParseActionInstruction,
|
||||
FailedToSendWebhook,
|
||||
FailedToTakeScreenshot,
|
||||
InvalidTaskStatusTransition,
|
||||
@@ -30,6 +31,8 @@ from skyvern.exceptions import (
|
||||
StepUnableToExecuteError,
|
||||
TaskAlreadyCanceled,
|
||||
TaskNotFound,
|
||||
UnsupportedActionType,
|
||||
UnsupportedTaskType,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool
|
||||
@@ -39,7 +42,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.validators import validate_url
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.models import Organization, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
@@ -133,7 +136,7 @@ class ForgeAgent:
|
||||
task_url = validate_url(task_url)
|
||||
task = await app.DATABASE.create_task(
|
||||
url=task_url,
|
||||
prompt_template=task_block.prompt_template,
|
||||
task_type=task_block.task_type,
|
||||
complete_criterion=task_block.complete_criterion,
|
||||
terminate_criterion=task_block.terminate_criterion,
|
||||
title=task_block.title or task_block.label,
|
||||
@@ -524,6 +527,23 @@ class ForgeAgent:
|
||||
need_call_webhook=False,
|
||||
)
|
||||
return step, detailed_output, None
|
||||
except (UnsupportedActionType, UnsupportedTaskType, FailedToParseActionInstruction) as e:
|
||||
LOG.warning(
|
||||
"unsupported task type or action type, marking the task as failed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
await self.fail_task(task, step, e.message)
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
last_step=step,
|
||||
api_key=api_key,
|
||||
need_call_webhook=False,
|
||||
)
|
||||
return step, detailed_output, None
|
||||
|
||||
except Exception as e:
|
||||
LOG.exception(
|
||||
"Got an unexpected exception in step, marking task as failed",
|
||||
@@ -930,6 +950,9 @@ class ForgeAgent:
|
||||
output=detailed_agent_step_output.to_agent_step_output(),
|
||||
)
|
||||
return failed_step, detailed_agent_step_output.get_clean_detailed_output()
|
||||
except (UnsupportedActionType, UnsupportedTaskType, FailedToParseActionInstruction):
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
LOG.exception(
|
||||
"Unexpected exception in agent_step, marking step as failed",
|
||||
@@ -1179,6 +1202,7 @@ class ForgeAgent:
|
||||
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
|
||||
extract_action_prompt = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
element_tree_in_prompt,
|
||||
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
|
||||
@@ -1216,6 +1240,7 @@ class ForgeAgent:
|
||||
async def _build_extract_action_prompt(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
browser_state: BrowserState,
|
||||
element_tree_in_prompt: str,
|
||||
verification_code_check: bool = False,
|
||||
@@ -1234,7 +1259,37 @@ class ForgeAgent:
|
||||
task, expire_verification_code=expire_verification_code
|
||||
)
|
||||
|
||||
template = task.prompt_template if task.prompt_template else TaskPromptTemplate.ExtractAction
|
||||
task_type = task.task_type if task.task_type else TaskType.general
|
||||
template = ""
|
||||
if task_type == TaskType.general:
|
||||
template = "extract-action"
|
||||
elif task_type == TaskType.validation:
|
||||
template = "decisive-criterion-validate"
|
||||
elif task_type == TaskType.action:
|
||||
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=navigation_goal)
|
||||
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step)
|
||||
if json_response.get("error"):
|
||||
raise FailedToParseActionInstruction(
|
||||
reason=json_response.get("thought"), error_type=json_response.get("error")
|
||||
)
|
||||
|
||||
action_type: str = json_response.get("action_type") or ""
|
||||
action_type = ActionType[action_type.upper()]
|
||||
|
||||
if action_type == ActionType.CLICK:
|
||||
template = "single-click-action"
|
||||
elif action_type == ActionType.INPUT_TEXT:
|
||||
template = "single-input-action"
|
||||
elif action_type == ActionType.UPLOAD_FILE:
|
||||
template = "single-upload-action"
|
||||
elif action_type == ActionType.SELECT_OPTION:
|
||||
template = "single-select-action"
|
||||
else:
|
||||
raise UnsupportedActionType(action_type=action_type)
|
||||
|
||||
if not template:
|
||||
raise UnsupportedTaskType(task_type=task_type)
|
||||
|
||||
return prompt_engine.load_prompt(
|
||||
template=template,
|
||||
navigation_goal=navigation_goal,
|
||||
@@ -1916,6 +1971,7 @@ class ForgeAgent:
|
||||
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
|
||||
extract_action_prompt = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
element_tree_in_prompt,
|
||||
verification_code_check=False,
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ActionModel,
|
||||
@@ -113,13 +113,13 @@ class AgentDB:
|
||||
retry: int | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
prompt_template: str = TaskPromptTemplate.ExtractAction,
|
||||
task_type: str = TaskType.general,
|
||||
) -> Task:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
prompt_template=prompt_template,
|
||||
task_type=task_type,
|
||||
url=url,
|
||||
title=title,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
|
||||
@@ -5,10 +5,7 @@ class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
|
||||
|
||||
class TaskPromptTemplate(StrEnum):
|
||||
ExtractAction = "extract-action"
|
||||
DecisiveCriterionValidate = "decisive-criterion-validate"
|
||||
SingleClickAction = "single-click-action"
|
||||
SingleInputAction = "single-input-action"
|
||||
SingleUploadAction = "single-upload-action"
|
||||
SingleSelectAction = "single-select-action"
|
||||
class TaskType(StrEnum):
|
||||
general = "general"
|
||||
validation = "validation"
|
||||
action = "action"
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlalchemy import (
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
|
||||
from skyvern.forge.sdk.db.id import (
|
||||
generate_action_id,
|
||||
generate_artifact_id,
|
||||
@@ -54,7 +54,7 @@ class TaskModel(Base):
|
||||
totp_verification_url = Column(String)
|
||||
totp_identifier = Column(String)
|
||||
title = Column(String)
|
||||
prompt_template = Column(String, default=TaskPromptTemplate.ExtractAction)
|
||||
task_type = Column(String, default=TaskType.general)
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
data_extraction_goal = Column(String)
|
||||
|
||||
@@ -60,7 +60,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
status=TaskStatus(task_obj.status),
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
prompt_template=task_obj.prompt_template,
|
||||
task_type=task_obj.task_type,
|
||||
title=task_obj.title,
|
||||
url=task_obj.url,
|
||||
complete_criterion=task_obj.complete_criterion,
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, HttpUrl, field_validator
|
||||
|
||||
from skyvern.exceptions import BlockedHost, InvalidTaskStatusTransition, TaskAlreadyCanceled
|
||||
from skyvern.forge.sdk.core.validators import is_blocked_host
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
@@ -86,10 +86,10 @@ class TaskBase(BaseModel):
|
||||
description="Criterion to terminate",
|
||||
examples=["Terminate if 'existing account' shows up on the page"],
|
||||
)
|
||||
prompt_template: str | None = Field(
|
||||
default=TaskPromptTemplate.ExtractAction,
|
||||
description="The prompt template used for task",
|
||||
examples=[TaskPromptTemplate.ExtractAction, TaskPromptTemplate.DecisiveCriterionValidate],
|
||||
task_type: TaskType | None = Field(
|
||||
default=TaskType.general,
|
||||
description="The type of the task",
|
||||
examples=[TaskType.general, TaskType.validation],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -109,13 +109,6 @@ class WorkflowParameterMissingRequiredValue(BaseWorkflowHTTPException):
|
||||
)
|
||||
|
||||
|
||||
class FailedToParseActionInstruction(SkyvernException):
|
||||
def __init__(self, reason: str | None, error_type: str | None):
|
||||
super().__init__(
|
||||
f"Failed to parse the action instruction as '{reason}({error_type})'",
|
||||
)
|
||||
|
||||
|
||||
class InvalidWaitBlockTime(SkyvernException):
|
||||
def __init__(self, max_sec: int):
|
||||
super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.")
|
||||
|
||||
@@ -41,12 +41,11 @@ from skyvern.forge.sdk.api.files import (
|
||||
get_path_for_workflow_download_directory,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.exceptions import (
|
||||
FailedToParseActionInstruction,
|
||||
InvalidEmailClientConfiguration,
|
||||
InvalidFileType,
|
||||
NoValidEmailRecipient,
|
||||
@@ -58,7 +57,6 @@ from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
)
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -185,7 +183,7 @@ class Block(BaseModel, abc.ABC):
|
||||
|
||||
|
||||
class BaseTaskBlock(Block):
|
||||
prompt_template: str = TaskPromptTemplate.ExtractAction
|
||||
task_type: str = TaskType.general
|
||||
url: str | None = None
|
||||
title: str = ""
|
||||
complete_criterion: str | None = None
|
||||
@@ -1333,41 +1331,6 @@ class ValidationBlock(BaseTaskBlock):
|
||||
class ActionBlock(BaseTaskBlock):
|
||||
block_type: Literal[BlockType.ACTION] = BlockType.ACTION
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
|
||||
try:
|
||||
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=self.navigation_goal)
|
||||
# TODO: no step here, so LLM call won't be saved as an artifact
|
||||
json_response = await app.LLM_API_HANDLER(prompt=prompt)
|
||||
if json_response.get("error"):
|
||||
raise FailedToParseActionInstruction(
|
||||
reason=json_response.get("thought"), error_type=json_response.get("error")
|
||||
)
|
||||
|
||||
action_type: str = json_response.get("action_type") or ""
|
||||
action_type = ActionType[action_type.upper()]
|
||||
|
||||
prompt_template = ""
|
||||
if action_type == ActionType.CLICK:
|
||||
prompt_template = TaskPromptTemplate.SingleClickAction
|
||||
elif action_type == ActionType.INPUT_TEXT:
|
||||
prompt_template = TaskPromptTemplate.SingleInputAction
|
||||
elif action_type == ActionType.UPLOAD_FILE:
|
||||
prompt_template = TaskPromptTemplate.SingleUploadAction
|
||||
elif action_type == ActionType.SELECT_OPTION:
|
||||
prompt_template = TaskPromptTemplate.SingleSelectAction
|
||||
|
||||
if not prompt_template:
|
||||
raise Exception(
|
||||
f"Not supported action for action block. Currently we only support [click, input_text, upload_file, select_option], but got [{action_type}]"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.build_block_result(
|
||||
success=False, failure_reason=str(e), output_parameter_value=None, status=BlockStatus.failed
|
||||
)
|
||||
|
||||
self.prompt_template = prompt_template
|
||||
return await super().execute(workflow_run_id=workflow_run_id, kwargs=kwargs)
|
||||
|
||||
|
||||
class NavigationBlock(BaseTaskBlock):
|
||||
block_type: Literal[BlockType.NAVIGATION] = BlockType.NAVIGATION
|
||||
|
||||
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
@@ -1354,7 +1354,7 @@ class WorkflowService:
|
||||
|
||||
return ValidationBlock(
|
||||
label=block_yaml.label,
|
||||
prompt_template=TaskPromptTemplate.DecisiveCriterionValidate,
|
||||
task_type=TaskType.validation,
|
||||
parameters=validation_block_parameters,
|
||||
output_parameter=output_parameter,
|
||||
complete_criterion=block_yaml.complete_criterion,
|
||||
@@ -1379,6 +1379,7 @@ class WorkflowService:
|
||||
label=block_yaml.label,
|
||||
url=block_yaml.url,
|
||||
title=block_yaml.title,
|
||||
task_type=TaskType.action,
|
||||
parameters=action_block_parameters,
|
||||
output_parameter=output_parameter,
|
||||
navigation_goal=block_yaml.navigation_goal,
|
||||
|
||||
Reference in New Issue
Block a user