script gen: extract action support (#3238)
This commit is contained in:
@@ -121,6 +121,7 @@ class Settings(BaseSettings):
|
||||
SINGLE_CLICK_AGENT_LLM_KEY: str | None = None
|
||||
SINGLE_INPUT_AGENT_LLM_KEY: str | None = None
|
||||
PROMPT_BLOCK_LLM_KEY: str | None = None
|
||||
EXTRACTION_LLM_KEY: str | None = None
|
||||
# COMMON
|
||||
LLM_CONFIG_TIMEOUT: int = 300
|
||||
LLM_CONFIG_MAX_TOKENS: int = 4096
|
||||
|
||||
@@ -151,11 +151,15 @@ def _make_decorator(block_label: str, block: dict[str, Any]) -> cst.Decorator:
|
||||
)
|
||||
|
||||
|
||||
def _action_to_stmt(act: dict[str, Any]) -> cst.BaseStatement:
|
||||
def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.BaseStatement:
|
||||
"""
|
||||
Turn one Action dict into:
|
||||
|
||||
await page.<method>(xpath=..., intention=..., data=context.parameters)
|
||||
|
||||
Or if assign_to_output is True for extract actions:
|
||||
|
||||
output = await page.extract(...)
|
||||
"""
|
||||
method = ACTION_MAP[act["action_type"]]
|
||||
|
||||
@@ -248,13 +252,23 @@ def _action_to_stmt(act: dict[str, Any]) -> cst.BaseStatement:
|
||||
# await page.method(...)
|
||||
await_expr = cst.Await(call)
|
||||
|
||||
# Wrap in a statement line: await ...
|
||||
return cst.SimpleStatementLine([cst.Expr(await_expr)])
|
||||
# If this is an extract action and we want to assign to output
|
||||
if assign_to_output and method == "extract":
|
||||
# output = await page.extract(...)
|
||||
assign = cst.Assign(
|
||||
targets=[cst.AssignTarget(cst.Name("output"))],
|
||||
value=await_expr,
|
||||
)
|
||||
return cst.SimpleStatementLine([assign])
|
||||
else:
|
||||
# Wrap in a statement line: await ...
|
||||
return cst.SimpleStatementLine([cst.Expr(await_expr)])
|
||||
|
||||
|
||||
def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> FunctionDef:
|
||||
name = block.get("label") or _safe_name(block.get("title") or f"block_{block.get('workflow_run_block_id')}")
|
||||
body_stmts: list[cst.BaseStatement] = []
|
||||
is_extraction_block = block.get("block_type") == "extraction"
|
||||
|
||||
if block.get("url"):
|
||||
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
|
||||
@@ -262,9 +276,19 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
for act in actions:
|
||||
if act["action_type"] in [ActionType.COMPLETE, ActionType.TERMINATE, ActionType.NULL_ACTION]:
|
||||
continue
|
||||
body_stmts.append(_action_to_stmt(act))
|
||||
|
||||
if not body_stmts:
|
||||
# For extraction blocks, assign extract action results to output variable
|
||||
assign_to_output = is_extraction_block and act["action_type"] == "extract"
|
||||
body_stmts.append(_action_to_stmt(act, assign_to_output=assign_to_output))
|
||||
|
||||
# For extraction blocks, add return output statement if we have actions
|
||||
if is_extraction_block and any(
|
||||
act["action_type"] == "extract"
|
||||
for act in actions
|
||||
if act["action_type"] not in [ActionType.COMPLETE, ActionType.TERMINATE, ActionType.NULL_ACTION]
|
||||
):
|
||||
body_stmts.append(cst.parse_statement("return output"))
|
||||
elif not body_stmts:
|
||||
body_stmts.append(cst.parse_statement("return None"))
|
||||
|
||||
return FunctionDef(
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
@@ -16,6 +16,7 @@ from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.files import download_file
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||
from skyvern.webeye.actions import handler_utils
|
||||
from skyvern.webeye.actions.action_types import ActionType
|
||||
from skyvern.webeye.actions.actions import Action, ActionStatus
|
||||
@@ -156,7 +157,7 @@ class SkyvernPage:
|
||||
raise
|
||||
finally:
|
||||
skyvern_page._record(call)
|
||||
# Auto-create action before execution
|
||||
# Auto-create action after execution
|
||||
await skyvern_page._create_action_before_execution(
|
||||
action_type=action,
|
||||
intention=intention,
|
||||
@@ -418,10 +419,43 @@ class SkyvernPage:
|
||||
|
||||
@action_wrap(ActionType.EXTRACT)
|
||||
async def extract(
|
||||
self, data_extraction_goal: str, intention: str | None = None, data: str | dict[str, Any] | None = None
|
||||
) -> None:
|
||||
# TODO: extract the data
|
||||
return
|
||||
self,
|
||||
data_extraction_goal: str,
|
||||
data_schema: dict[str, Any] | list | str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
intention: str | None = None,
|
||||
data: str | dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
scraped_page_refreshed = await self.scraped_page.refresh()
|
||||
context = skyvern_context.current()
|
||||
tz_info = datetime.now(tz=timezone.utc).tzinfo
|
||||
if context and context.tz_info:
|
||||
tz_info = context.tz_info
|
||||
extract_information_prompt = load_prompt_with_elements(
|
||||
element_tree_builder=scraped_page_refreshed,
|
||||
prompt_engine=prompt_engine,
|
||||
template_name="extract-information",
|
||||
html_need_skyvern_attrs=False,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
extracted_information_schema=data_schema,
|
||||
current_url=scraped_page_refreshed.url,
|
||||
extracted_text=scraped_page_refreshed.extracted_text,
|
||||
error_code_mapping_str=(json.dumps(error_code_mapping) if error_code_mapping else None),
|
||||
local_datetime=datetime.now(tz_info).isoformat(),
|
||||
)
|
||||
step = None
|
||||
if context and context.organization_id and context.task_id and context.step_id:
|
||||
step = await app.DATABASE.get_step(
|
||||
task_id=context.task_id, step_id=context.step_id, organization_id=context.organization_id
|
||||
)
|
||||
|
||||
result = await app.EXTRACTION_LLM_API_HANDLER(
|
||||
prompt=extract_information_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page_refreshed.screenshots,
|
||||
prompt_name="extract-information",
|
||||
)
|
||||
return result
|
||||
|
||||
@action_wrap(ActionType.VERIFICATION_CODE)
|
||||
async def verification_code(
|
||||
|
||||
@@ -74,6 +74,11 @@ SINGLE_INPUT_AGENT_LLM_API_HANDLER = (
|
||||
if SETTINGS_MANAGER.SINGLE_INPUT_AGENT_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
EXTRACTION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.EXTRACTION_LLM_KEY)
|
||||
if SETTINGS_MANAGER.EXTRACTION_LLM_KEY
|
||||
else LLM_API_HANDLER
|
||||
)
|
||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||
WORKFLOW_SERVICE = WorkflowService()
|
||||
AGENT_FUNCTION = AgentFunction()
|
||||
|
||||
@@ -4,17 +4,19 @@ import hashlib
|
||||
import importlib.util
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import structlog
|
||||
from fastapi import BackgroundTasks, HTTPException
|
||||
|
||||
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockStatus, BlockType
|
||||
from skyvern.schemas.scripts import CreateScriptResponse, FileNode, ScriptFileCreate
|
||||
|
||||
@@ -314,31 +316,75 @@ async def _create_workflow_block_run_and_task(
|
||||
return None, None
|
||||
|
||||
|
||||
async def _update_workflow_block_status(
|
||||
async def _record_output_parameter_value(
|
||||
workflow_run_id: str,
|
||||
output: dict[str, Any] | list | str | None,
|
||||
) -> None:
|
||||
# TODO support this in the future
|
||||
# workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
# await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
# parameter=self.output_parameter,
|
||||
# value=value,
|
||||
# )
|
||||
# await app.DATABASE.create_or_update_workflow_run_output_parameter(
|
||||
# workflow_run_id=workflow_run_id,
|
||||
# output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
# value=value,
|
||||
# )
|
||||
return
|
||||
|
||||
|
||||
async def _update_workflow_block(
|
||||
workflow_run_block_id: str,
|
||||
status: BlockStatus,
|
||||
task_id: str | None = None,
|
||||
task_status: TaskStatus = TaskStatus.completed,
|
||||
failure_reason: str | None = None,
|
||||
output: dict[str, Any] | list | str | None = None,
|
||||
) -> None:
|
||||
"""Update the status of a workflow run block."""
|
||||
try:
|
||||
context = skyvern_context.current()
|
||||
if not context or not context.organization_id:
|
||||
if not context or not context.organization_id or not context.workflow_run_id:
|
||||
return
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=context.organization_id if context else None,
|
||||
status=status,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
final_output = output
|
||||
if task_id:
|
||||
await app.DATABASE.update_task(
|
||||
updated_task = await app.DATABASE.update_task(
|
||||
task_id=task_id,
|
||||
organization_id=context.organization_id,
|
||||
status=task_status,
|
||||
failure_reason=failure_reason,
|
||||
extracted_information=output,
|
||||
)
|
||||
downloaded_files: list[FileInfo] = []
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=context.organization_id,
|
||||
run_id=context.workflow_run_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning("Timeout getting downloaded files", task_id=task_id)
|
||||
|
||||
task_output = TaskOutput.from_task(updated_task, downloaded_files)
|
||||
final_output = task_output.model_dump()
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=context.organization_id if context else None,
|
||||
status=status,
|
||||
failure_reason=failure_reason,
|
||||
output=final_output,
|
||||
)
|
||||
else:
|
||||
final_output = None
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=context.organization_id if context else None,
|
||||
status=status,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
await _record_output_parameter_value(context.workflow_run_id, final_output)
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning(
|
||||
"Failed to update workflow block status",
|
||||
@@ -349,12 +395,12 @@ async def _update_workflow_block_status(
|
||||
)
|
||||
|
||||
|
||||
async def _run_cached_function(cache_key: str) -> None:
|
||||
async def _run_cached_function(cache_key: str) -> Any:
|
||||
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
|
||||
if cached_fn:
|
||||
# TODO: handle exceptions here and fall back to AI run in case of error
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
await cached_fn(page=run_context.page, context=run_context)
|
||||
return await cached_fn(page=run_context.page, context=run_context)
|
||||
else:
|
||||
raise Exception(f"Cache key {cache_key} not found")
|
||||
|
||||
@@ -378,12 +424,12 @@ async def run_task(
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -393,7 +439,7 @@ async def run_task(
|
||||
raise
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -422,12 +468,12 @@ async def download(
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -437,7 +483,7 @@ async def download(
|
||||
raise
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -466,12 +512,12 @@ async def action(
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -481,7 +527,7 @@ async def action(
|
||||
raise
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -510,12 +556,12 @@ async def login(
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -525,7 +571,7 @@ async def login(
|
||||
raise
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -540,36 +586,44 @@ async def extract(
|
||||
url: str | None = None,
|
||||
max_steps: int | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> None:
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
# Auto-create workflow block run and task if workflow_run_id is available
|
||||
workflow_run_block_id, task_id = await _create_workflow_block_run_and_task(
|
||||
block_type=BlockType.EXTRACTION,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
)
|
||||
output: dict[str, Any] | list | str | None = None
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
await _run_cached_function(cache_key)
|
||||
output = cast(dict[str, Any] | list | str | None, await _run_cached_function(cache_key))
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.completed,
|
||||
task_id=task_id,
|
||||
output=output,
|
||||
)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
task_status=TaskStatus.failed,
|
||||
failure_reason=str(e),
|
||||
output=output,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(
|
||||
await _update_workflow_block(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.failed,
|
||||
task_id=task_id,
|
||||
@@ -588,12 +642,12 @@ async def wait(seconds: int) -> None:
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.completed)
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed)
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block_status(workflow_run_block_id, BlockStatus.failed, failure_reason=str(e))
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.failed, failure_reason=str(e))
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user