use task type instead of prompt template (#1261)
This commit is contained in:
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ActionModel,
|
||||
@@ -113,13 +113,13 @@ class AgentDB:
|
||||
retry: int | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
prompt_template: str = TaskPromptTemplate.ExtractAction,
|
||||
task_type: str = TaskType.general,
|
||||
) -> Task:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
prompt_template=prompt_template,
|
||||
task_type=task_type,
|
||||
url=url,
|
||||
title=title,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
|
||||
@@ -5,10 +5,7 @@ class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
|
||||
|
||||
class TaskPromptTemplate(StrEnum):
|
||||
ExtractAction = "extract-action"
|
||||
DecisiveCriterionValidate = "decisive-criterion-validate"
|
||||
SingleClickAction = "single-click-action"
|
||||
SingleInputAction = "single-input-action"
|
||||
SingleUploadAction = "single-upload-action"
|
||||
SingleSelectAction = "single-select-action"
|
||||
class TaskType(StrEnum):
|
||||
general = "general"
|
||||
validation = "validation"
|
||||
action = "action"
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlalchemy import (
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
|
||||
from skyvern.forge.sdk.db.id import (
|
||||
generate_action_id,
|
||||
generate_artifact_id,
|
||||
@@ -54,7 +54,7 @@ class TaskModel(Base):
|
||||
totp_verification_url = Column(String)
|
||||
totp_identifier = Column(String)
|
||||
title = Column(String)
|
||||
prompt_template = Column(String, default=TaskPromptTemplate.ExtractAction)
|
||||
task_type = Column(String, default=TaskType.general)
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
data_extraction_goal = Column(String)
|
||||
|
||||
@@ -60,7 +60,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
status=TaskStatus(task_obj.status),
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
prompt_template=task_obj.prompt_template,
|
||||
task_type=task_obj.task_type,
|
||||
title=task_obj.title,
|
||||
url=task_obj.url,
|
||||
complete_criterion=task_obj.complete_criterion,
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, HttpUrl, field_validator
|
||||
|
||||
from skyvern.exceptions import BlockedHost, InvalidTaskStatusTransition, TaskAlreadyCanceled
|
||||
from skyvern.forge.sdk.core.validators import is_blocked_host
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
@@ -86,10 +86,10 @@ class TaskBase(BaseModel):
|
||||
description="Criterion to terminate",
|
||||
examples=["Terminate if 'existing account' shows up on the page"],
|
||||
)
|
||||
prompt_template: str | None = Field(
|
||||
default=TaskPromptTemplate.ExtractAction,
|
||||
description="The prompt template used for task",
|
||||
examples=[TaskPromptTemplate.ExtractAction, TaskPromptTemplate.DecisiveCriterionValidate],
|
||||
task_type: TaskType | None = Field(
|
||||
default=TaskType.general,
|
||||
description="The type of the task",
|
||||
examples=[TaskType.general, TaskType.validation],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -109,13 +109,6 @@ class WorkflowParameterMissingRequiredValue(BaseWorkflowHTTPException):
|
||||
)
|
||||
|
||||
|
||||
class FailedToParseActionInstruction(SkyvernException):
|
||||
def __init__(self, reason: str | None, error_type: str | None):
|
||||
super().__init__(
|
||||
f"Failed to parse the action instruction as '{reason}({error_type})'",
|
||||
)
|
||||
|
||||
|
||||
class InvalidWaitBlockTime(SkyvernException):
|
||||
def __init__(self, max_sec: int):
|
||||
super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.")
|
||||
|
||||
@@ -41,12 +41,11 @@ from skyvern.forge.sdk.api.files import (
|
||||
get_path_for_workflow_download_directory,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.exceptions import (
|
||||
FailedToParseActionInstruction,
|
||||
InvalidEmailClientConfiguration,
|
||||
InvalidFileType,
|
||||
NoValidEmailRecipient,
|
||||
@@ -58,7 +57,6 @@ from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
)
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -185,7 +183,7 @@ class Block(BaseModel, abc.ABC):
|
||||
|
||||
|
||||
class BaseTaskBlock(Block):
|
||||
prompt_template: str = TaskPromptTemplate.ExtractAction
|
||||
task_type: str = TaskType.general
|
||||
url: str | None = None
|
||||
title: str = ""
|
||||
complete_criterion: str | None = None
|
||||
@@ -1333,41 +1331,6 @@ class ValidationBlock(BaseTaskBlock):
|
||||
class ActionBlock(BaseTaskBlock):
|
||||
block_type: Literal[BlockType.ACTION] = BlockType.ACTION
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
|
||||
try:
|
||||
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=self.navigation_goal)
|
||||
# TODO: no step here, so LLM call won't be saved as an artifact
|
||||
json_response = await app.LLM_API_HANDLER(prompt=prompt)
|
||||
if json_response.get("error"):
|
||||
raise FailedToParseActionInstruction(
|
||||
reason=json_response.get("thought"), error_type=json_response.get("error")
|
||||
)
|
||||
|
||||
action_type: str = json_response.get("action_type") or ""
|
||||
action_type = ActionType[action_type.upper()]
|
||||
|
||||
prompt_template = ""
|
||||
if action_type == ActionType.CLICK:
|
||||
prompt_template = TaskPromptTemplate.SingleClickAction
|
||||
elif action_type == ActionType.INPUT_TEXT:
|
||||
prompt_template = TaskPromptTemplate.SingleInputAction
|
||||
elif action_type == ActionType.UPLOAD_FILE:
|
||||
prompt_template = TaskPromptTemplate.SingleUploadAction
|
||||
elif action_type == ActionType.SELECT_OPTION:
|
||||
prompt_template = TaskPromptTemplate.SingleSelectAction
|
||||
|
||||
if not prompt_template:
|
||||
raise Exception(
|
||||
f"Not supported action for action block. Currently we only support [click, input_text, upload_file, select_option], but got [{action_type}]"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.build_block_result(
|
||||
success=False, failure_reason=str(e), output_parameter_value=None, status=BlockStatus.failed
|
||||
)
|
||||
|
||||
self.prompt_template = prompt_template
|
||||
return await super().execute(workflow_run_id=workflow_run_id, kwargs=kwargs)
|
||||
|
||||
|
||||
class NavigationBlock(BaseTaskBlock):
|
||||
block_type: Literal[BlockType.NAVIGATION] = BlockType.NAVIGATION
|
||||
|
||||
@@ -18,7 +18,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
@@ -1354,7 +1354,7 @@ class WorkflowService:
|
||||
|
||||
return ValidationBlock(
|
||||
label=block_yaml.label,
|
||||
prompt_template=TaskPromptTemplate.DecisiveCriterionValidate,
|
||||
task_type=TaskType.validation,
|
||||
parameters=validation_block_parameters,
|
||||
output_parameter=output_parameter,
|
||||
complete_criterion=block_yaml.complete_criterion,
|
||||
@@ -1379,6 +1379,7 @@ class WorkflowService:
|
||||
label=block_yaml.label,
|
||||
url=block_yaml.url,
|
||||
title=block_yaml.title,
|
||||
task_type=TaskType.action,
|
||||
parameters=action_block_parameters,
|
||||
output_parameter=output_parameter,
|
||||
navigation_goal=block_yaml.navigation_goal,
|
||||
|
||||
Reference in New Issue
Block a user