prompted workflows: use (nav block, [extract block],) for v1 prompts (#3658)
This commit is contained in:
@@ -53,6 +53,7 @@ from skyvern.forge.sdk.schemas.organizations import (
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.prompts import CreateFromPromptRequest
|
||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import (
|
||||
@@ -535,18 +536,21 @@ async def create_workflow(
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def create_workflow_from_prompt(
|
||||
data: TaskV2Request,
|
||||
data: CreateFromPromptRequest,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_max_iterations_override: Annotated[int | str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | str | None, Header()] = None,
|
||||
) -> dict[str, Any]:
|
||||
task_version = data.task_version or "v2"
|
||||
request = data.request
|
||||
|
||||
if x_max_iterations_override or x_max_steps_override:
|
||||
LOG.info(
|
||||
"Overriding max steps for workflow-from-prompt",
|
||||
max_iterations_override=x_max_iterations_override,
|
||||
max_steps_override=x_max_steps_override,
|
||||
)
|
||||
await PermissionCheckerFactory.get_instance().check(organization, browser_session_id=data.browser_session_id)
|
||||
await PermissionCheckerFactory.get_instance().check(organization, browser_session_id=request.browser_session_id)
|
||||
|
||||
if isinstance(x_max_iterations_override, str):
|
||||
try:
|
||||
@@ -559,21 +563,23 @@ async def create_workflow_from_prompt(
|
||||
x_max_steps_override = int(x_max_steps_override)
|
||||
except ValueError:
|
||||
x_max_steps_override = None
|
||||
|
||||
try:
|
||||
workflow = await app.WORKFLOW_SERVICE.create_workflow_from_prompt(
|
||||
organization=organization,
|
||||
user_prompt=data.user_prompt,
|
||||
totp_identifier=data.totp_identifier,
|
||||
totp_verification_url=data.totp_verification_url,
|
||||
webhook_callback_url=data.webhook_callback_url,
|
||||
proxy_location=data.proxy_location,
|
||||
max_screenshot_scrolling_times=data.max_screenshot_scrolls,
|
||||
extra_http_headers=data.extra_http_headers,
|
||||
user_prompt=request.user_prompt,
|
||||
totp_identifier=request.totp_identifier,
|
||||
totp_verification_url=request.totp_verification_url,
|
||||
webhook_callback_url=request.webhook_callback_url,
|
||||
proxy_location=request.proxy_location,
|
||||
max_screenshot_scrolling_times=request.max_screenshot_scrolls,
|
||||
extra_http_headers=request.extra_http_headers,
|
||||
max_iterations=x_max_iterations_override,
|
||||
max_steps=x_max_steps_override,
|
||||
status=WorkflowStatus.published if data.publish_workflow else WorkflowStatus.auto_generated,
|
||||
run_with=data.run_with,
|
||||
ai_fallback=data.ai_fallback,
|
||||
status=WorkflowStatus.published if request.publish_workflow else WorkflowStatus.auto_generated,
|
||||
run_with=request.run_with,
|
||||
ai_fallback=request.ai_fallback if request.ai_fallback is not None else True,
|
||||
task_version=task_version,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error("Failed to create workflow from prompt", exc_info=True, organization_id=organization.organization_id)
|
||||
|
||||
21
skyvern/forge/sdk/schemas/prompts.py
Normal file
21
skyvern/forge/sdk/schemas/prompts.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import typing as t
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import PromptedTaskRequest
|
||||
|
||||
|
||||
class CreateWorkflowFromPromptRequestV1(BaseModel):
|
||||
task_version: t.Literal["v1"]
|
||||
request: PromptedTaskRequest
|
||||
|
||||
|
||||
class CreateWorkflowFromPromptRequestV2(BaseModel):
|
||||
task_version: t.Literal["v2"]
|
||||
request: TaskV2Request
|
||||
|
||||
|
||||
CreateFromPromptRequest = t.Annotated[
|
||||
t.Union[CreateWorkflowFromPromptRequestV1, CreateWorkflowFromPromptRequestV2], Field(discriminator="task_version")
|
||||
]
|
||||
@@ -160,6 +160,29 @@ class TaskRequest(TaskBase):
|
||||
return validate_url(url)
|
||||
|
||||
|
||||
class PromptedTaskRequest(TaskRequest):
|
||||
ai_fallback: bool | None = Field(
|
||||
default=False,
|
||||
description="Whether to use AI fallback when the task fails.",
|
||||
examples=[True, False],
|
||||
)
|
||||
publish_workflow: bool | None = Field(
|
||||
default=False,
|
||||
description="Whether to publish the workflow created from the prompt.",
|
||||
examples=[True, False],
|
||||
)
|
||||
run_with: str | None = Field(
|
||||
default=None,
|
||||
description="The executor to run the task with.",
|
||||
examples=["code", "agent"],
|
||||
)
|
||||
user_prompt: str = Field(
|
||||
...,
|
||||
description="The user's prompt for the task.",
|
||||
examples=["Get a quote for car insurance"],
|
||||
)
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
created = "created"
|
||||
queued = "queued"
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import uuid
|
||||
from collections import deque
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
@@ -761,6 +761,7 @@ class WorkflowService:
|
||||
status: WorkflowStatus = WorkflowStatus.auto_generated,
|
||||
run_with: str | None = None,
|
||||
ai_fallback: bool = True,
|
||||
task_version: Literal["v1", "v2"] = "v2",
|
||||
) -> Workflow:
|
||||
metadata_prompt = prompt_engine.load_prompt(
|
||||
"conversational_ui_goal",
|
||||
@@ -776,25 +777,82 @@ class WorkflowService:
|
||||
block_label: str = metadata_response.get("block_label", DEFAULT_FIRST_BLOCK_LABEL)
|
||||
title: str = metadata_response.get("title", DEFAULT_WORKFLOW_TITLE)
|
||||
|
||||
task_v2_block = TaskV2Block(
|
||||
prompt=user_prompt,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_verification_url,
|
||||
label=block_label,
|
||||
max_iterations=max_iterations or settings.MAX_ITERATIONS_PER_TASK_V2,
|
||||
max_steps=max_steps or settings.MAX_STEPS_PER_TASK_V2,
|
||||
output_parameter=OutputParameter(
|
||||
output_parameter_id=str(uuid.uuid4()),
|
||||
key=f"{block_label}_output",
|
||||
workflow_id="",
|
||||
created_at=datetime.now(UTC),
|
||||
modified_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
if task_version == "v1":
|
||||
task_prompt = prompt_engine.load_prompt(
|
||||
"generate-task",
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
task_response = await app.LLM_API_HANDLER(
|
||||
prompt=task_prompt,
|
||||
prompt_name="generate-task",
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
data_extraction_goal: str | None = task_response.get("data_extraction_goal")
|
||||
navigation_goal: str = task_response.get("navigation_goal", user_prompt)
|
||||
url: str = task_response.get("url", "")
|
||||
|
||||
blocks = [
|
||||
NavigationBlock(
|
||||
url=url,
|
||||
label=block_label,
|
||||
title=title,
|
||||
navigation_goal=navigation_goal,
|
||||
max_steps_per_run=max_steps or settings.MAX_STEPS_PER_RUN,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
output_parameter=OutputParameter(
|
||||
output_parameter_id=str(uuid.uuid4()),
|
||||
key=f"{block_label}_output",
|
||||
workflow_id="",
|
||||
created_at=datetime.now(UTC),
|
||||
modified_at=datetime.now(UTC),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
if data_extraction_goal:
|
||||
blocks.append(
|
||||
ExtractionBlock(
|
||||
label="extract_data",
|
||||
title="Extract Data",
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
output_parameter=OutputParameter(
|
||||
output_parameter_id=str(uuid.uuid4()),
|
||||
key="extract_data_output",
|
||||
workflow_id="",
|
||||
created_at=datetime.now(UTC),
|
||||
modified_at=datetime.now(UTC),
|
||||
),
|
||||
max_steps_per_run=max_steps or settings.MAX_STEPS_PER_RUN,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
elif task_version == "v2":
|
||||
blocks = [
|
||||
TaskV2Block(
|
||||
prompt=user_prompt,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_verification_url,
|
||||
label=block_label,
|
||||
max_iterations=max_iterations or settings.MAX_ITERATIONS_PER_TASK_V2,
|
||||
max_steps=max_steps or settings.MAX_STEPS_PER_TASK_V2,
|
||||
output_parameter=OutputParameter(
|
||||
output_parameter_id=str(uuid.uuid4()),
|
||||
key=f"{block_label}_output",
|
||||
workflow_id="",
|
||||
created_at=datetime.now(UTC),
|
||||
modified_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
new_workflow = await self.create_workflow(
|
||||
title=title,
|
||||
workflow_definition=WorkflowDefinition(parameters=[], blocks=[task_v2_block]),
|
||||
workflow_definition=WorkflowDefinition(parameters=[], blocks=blocks),
|
||||
organization_id=organization.organization_id,
|
||||
proxy_location=proxy_location,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
|
||||
Reference in New Issue
Block a user