use task type instead of prompt template (#1261)

This commit is contained in:
LawyZheng
2024-11-26 11:29:33 +08:00
committed by GitHub
parent 74a9fc70d6
commit 6b417d0e83
11 changed files with 124 additions and 69 deletions

View File

@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from skyvern.config import settings
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ActionModel,
@@ -113,13 +113,13 @@ class AgentDB:
retry: int | None = None,
max_steps_per_run: int | None = None,
error_code_mapping: dict[str, str] | None = None,
prompt_template: str = TaskPromptTemplate.ExtractAction,
task_type: str = TaskType.general,
) -> Task:
try:
async with self.Session() as session:
new_task = TaskModel(
status="created",
prompt_template=prompt_template,
task_type=task_type,
url=url,
title=title,
webhook_callback_url=webhook_callback_url,

View File

@@ -5,10 +5,7 @@ class OrganizationAuthTokenType(StrEnum):
api = "api"
class TaskPromptTemplate(StrEnum):
ExtractAction = "extract-action"
DecisiveCriterionValidate = "decisive-criterion-validate"
SingleClickAction = "single-click-action"
SingleInputAction = "single-input-action"
SingleUploadAction = "single-upload-action"
SingleSelectAction = "single-select-action"
class TaskType(StrEnum):
general = "general"
validation = "validation"
action = "action"

View File

@@ -17,7 +17,7 @@ from sqlalchemy import (
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskPromptTemplate
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.db.id import (
generate_action_id,
generate_artifact_id,
@@ -54,7 +54,7 @@ class TaskModel(Base):
totp_verification_url = Column(String)
totp_identifier = Column(String)
title = Column(String)
prompt_template = Column(String, default=TaskPromptTemplate.ExtractAction)
task_type = Column(String, default=TaskType.general)
url = Column(String)
navigation_goal = Column(String)
data_extraction_goal = Column(String)

View File

@@ -60,7 +60,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
status=TaskStatus(task_obj.status),
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
prompt_template=task_obj.prompt_template,
task_type=task_obj.task_type,
title=task_obj.title,
url=task_obj.url,
complete_criterion=task_obj.complete_criterion,

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, HttpUrl, field_validator
from skyvern.exceptions import BlockedHost, InvalidTaskStatusTransition, TaskAlreadyCanceled
from skyvern.forge.sdk.core.validators import is_blocked_host
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
class ProxyLocation(StrEnum):
@@ -86,10 +86,10 @@ class TaskBase(BaseModel):
description="Criterion to terminate",
examples=["Terminate if 'existing account' shows up on the page"],
)
prompt_template: str | None = Field(
default=TaskPromptTemplate.ExtractAction,
description="The prompt template used for task",
examples=[TaskPromptTemplate.ExtractAction, TaskPromptTemplate.DecisiveCriterionValidate],
task_type: TaskType | None = Field(
default=TaskType.general,
description="The type of the task",
examples=[TaskType.general, TaskType.validation],
)

View File

@@ -109,13 +109,6 @@ class WorkflowParameterMissingRequiredValue(BaseWorkflowHTTPException):
)
class FailedToParseActionInstruction(SkyvernException):
def __init__(self, reason: str | None, error_type: str | None):
super().__init__(
f"Failed to parse the action instruction as '{reason}({error_type})'",
)
class InvalidWaitBlockTime(SkyvernException):
def __init__(self, max_sec: int):
super().__init__(f"Invalid wait time for wait block, it should be a number between 0 and {max_sec}.")

View File

@@ -41,12 +41,11 @@ from skyvern.forge.sdk.api.files import (
get_path_for_workflow_download_directory,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import (
FailedToParseActionInstruction,
InvalidEmailClientConfiguration,
InvalidFileType,
NoValidEmailRecipient,
@@ -58,7 +57,6 @@ from skyvern.forge.sdk.workflow.models.parameter import (
OutputParameter,
WorkflowParameter,
)
from skyvern.webeye.actions.actions import ActionType
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
@@ -185,7 +183,7 @@ class Block(BaseModel, abc.ABC):
class BaseTaskBlock(Block):
prompt_template: str = TaskPromptTemplate.ExtractAction
task_type: str = TaskType.general
url: str | None = None
title: str = ""
complete_criterion: str | None = None
@@ -1333,41 +1331,6 @@ class ValidationBlock(BaseTaskBlock):
class ActionBlock(BaseTaskBlock):
block_type: Literal[BlockType.ACTION] = BlockType.ACTION
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
try:
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=self.navigation_goal)
# TODO: no step here, so LLM call won't be saved as an artifact
json_response = await app.LLM_API_HANDLER(prompt=prompt)
if json_response.get("error"):
raise FailedToParseActionInstruction(
reason=json_response.get("thought"), error_type=json_response.get("error")
)
action_type: str = json_response.get("action_type") or ""
action_type = ActionType[action_type.upper()]
prompt_template = ""
if action_type == ActionType.CLICK:
prompt_template = TaskPromptTemplate.SingleClickAction
elif action_type == ActionType.INPUT_TEXT:
prompt_template = TaskPromptTemplate.SingleInputAction
elif action_type == ActionType.UPLOAD_FILE:
prompt_template = TaskPromptTemplate.SingleUploadAction
elif action_type == ActionType.SELECT_OPTION:
prompt_template = TaskPromptTemplate.SingleSelectAction
if not prompt_template:
raise Exception(
f"Not supported action for action block. Currently we only support [click, input_text, upload_file, select_option], but got [{action_type}]"
)
except Exception as e:
return self.build_block_result(
success=False, failure_reason=str(e), output_parameter_value=None, status=BlockStatus.failed
)
self.prompt_template = prompt_template
return await super().execute(workflow_run_id=workflow_run_id, kwargs=kwargs)
class NavigationBlock(BaseTaskBlock):
block_type: Literal[BlockType.NAVIGATION] = BlockType.NAVIGATION

View File

@@ -18,7 +18,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskPromptTemplate
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.settings_manager import SettingsManager
@@ -1354,7 +1354,7 @@ class WorkflowService:
return ValidationBlock(
label=block_yaml.label,
prompt_template=TaskPromptTemplate.DecisiveCriterionValidate,
task_type=TaskType.validation,
parameters=validation_block_parameters,
output_parameter=output_parameter,
complete_criterion=block_yaml.complete_criterion,
@@ -1379,6 +1379,7 @@ class WorkflowService:
label=block_yaml.label,
url=block_yaml.url,
title=block_yaml.title,
task_type=TaskType.action,
parameters=action_block_parameters,
output_parameter=output_parameter,
navigation_goal=block_yaml.navigation_goal,