add observer task block (#1665)
This commit is contained in:
@@ -48,7 +48,8 @@ from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core.validators import prepend_scheme_and_validate_url
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTaskStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.exceptions import (
|
||||
FailedToFormatJinjaStyleParameter,
|
||||
@@ -71,6 +72,7 @@ LOG = structlog.get_logger()
|
||||
|
||||
class BlockType(StrEnum):
|
||||
TASK = "task"
|
||||
TaskV2 = "task_v2"
|
||||
FOR_LOOP = "for_loop"
|
||||
CODE = "code"
|
||||
TEXT_PROMPT = "text_prompt"
|
||||
@@ -2072,6 +2074,80 @@ class UrlBlock(BaseTaskBlock):
|
||||
url: str
|
||||
|
||||
|
||||
# observer block
|
||||
class TaskV2Block(Block):
|
||||
block_type: Literal[BlockType.TaskV2] = BlockType.TaskV2
|
||||
prompt: str
|
||||
url: str | None = None
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
max_iterations: int = 10
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return []
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
**kwargs: dict,
|
||||
) -> BlockResult:
|
||||
from skyvern.forge.sdk.services import observer_service
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
if not organization_id:
|
||||
raise ValueError("Running TaskV2Block requires organization_id")
|
||||
|
||||
organization = await app.DATABASE.get_organization(organization_id)
|
||||
if not organization:
|
||||
raise ValueError(f"Organization not found {organization_id}")
|
||||
observer_task = await observer_service.initialize_observer_task(
|
||||
organization,
|
||||
user_prompt=self.prompt,
|
||||
user_url=self.url,
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
proxy_location=ProxyLocation.NONE,
|
||||
)
|
||||
await app.DATABASE.update_observer_cruise(
|
||||
observer_task.observer_cruise_id, status=ObserverTaskStatus.queued, organization_id=organization_id
|
||||
)
|
||||
if observer_task.workflow_run_id:
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=observer_task.workflow_run_id,
|
||||
status=WorkflowRunStatus.queued,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
block_workflow_run_id=observer_task.workflow_run_id,
|
||||
)
|
||||
|
||||
observer_task = await observer_service.run_observer_task(
|
||||
organization=organization,
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
request_id=None,
|
||||
max_iterations_override=self.max_iterations,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
result_dict = None
|
||||
if observer_task:
|
||||
result_dict = observer_task.output
|
||||
|
||||
return await self.build_block_result(
|
||||
success=True,
|
||||
failure_reason=None,
|
||||
output_parameter_value=result_dict,
|
||||
status=BlockStatus.completed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
|
||||
BlockSubclasses = Union[
|
||||
ForLoopBlock,
|
||||
TaskBlock,
|
||||
@@ -2090,5 +2166,6 @@ BlockSubclasses = Union[
|
||||
WaitBlock,
|
||||
FileDownloadBlock,
|
||||
UrlBlock,
|
||||
TaskV2Block,
|
||||
]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
@@ -108,6 +108,7 @@ class WorkflowRun(BaseModel):
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
failure_reason: str | None = None
|
||||
parent_workflow_run_id: str | None = None
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
@@ -323,6 +323,15 @@ class UrlBlockYAML(BlockYAML):
|
||||
url: str
|
||||
|
||||
|
||||
class TaskV2BlockYAML(BlockYAML):
|
||||
block_type: Literal[BlockType.TaskV2] = BlockType.TaskV2 # type: ignore
|
||||
prompt: str
|
||||
url: str | None = None
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
max_iterations: int = 10
|
||||
|
||||
|
||||
PARAMETER_YAML_SUBCLASSES = (
|
||||
AWSSecretParameterYAML
|
||||
| BitwardenLoginCredentialParameterYAML
|
||||
@@ -352,6 +361,7 @@ BLOCK_YAML_SUBCLASSES = (
|
||||
| FileDownloadBlockYAML
|
||||
| UrlBlockYAML
|
||||
| PDFParserBlockYAML
|
||||
| TaskV2BlockYAML
|
||||
)
|
||||
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ from skyvern.forge.sdk.workflow.models.block import (
|
||||
PDFParserBlock,
|
||||
SendEmailBlock,
|
||||
TaskBlock,
|
||||
TaskV2Block,
|
||||
TextPromptBlock,
|
||||
UploadToS3Block,
|
||||
ValidationBlock,
|
||||
@@ -96,6 +97,7 @@ class WorkflowService:
|
||||
is_template_workflow: bool = False,
|
||||
version: int | None = None,
|
||||
max_steps_override: int | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing
|
||||
@@ -127,6 +129,7 @@ class WorkflowService:
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
)
|
||||
LOG.info(
|
||||
f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}",
|
||||
@@ -625,7 +628,12 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
async def create_workflow_run(
|
||||
self, workflow_request: WorkflowRequestBody, workflow_permanent_id: str, workflow_id: str, organization_id: str
|
||||
self,
|
||||
workflow_request: WorkflowRequestBody,
|
||||
workflow_permanent_id: str,
|
||||
workflow_id: str,
|
||||
organization_id: str,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
) -> WorkflowRun:
|
||||
return await app.DATABASE.create_workflow_run(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
@@ -635,6 +643,7 @@ class WorkflowService:
|
||||
webhook_callback_url=workflow_request.webhook_callback_url,
|
||||
totp_verification_url=workflow_request.totp_verification_url,
|
||||
totp_identifier=workflow_request.totp_identifier,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
)
|
||||
|
||||
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None:
|
||||
@@ -1731,6 +1740,16 @@ class WorkflowService:
|
||||
cache_actions=block_yaml.cache_actions,
|
||||
complete_on_download=True,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.TaskV2:
|
||||
return TaskV2Block(
|
||||
label=block_yaml.label,
|
||||
prompt=block_yaml.prompt,
|
||||
url=block_yaml.url,
|
||||
totp_verification_url=block_yaml.totp_verification_url,
|
||||
totp_identifier=block_yaml.totp_identifier,
|
||||
max_iterations=block_yaml.max_iterations,
|
||||
output_parameter=output_parameter,
|
||||
)
|
||||
|
||||
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user