ExtractAction (#1632)
This commit is contained in:
@@ -64,6 +64,7 @@ from skyvern.webeye.actions.actions import (
|
|||||||
CompleteAction,
|
CompleteAction,
|
||||||
CompleteVerifyResult,
|
CompleteVerifyResult,
|
||||||
DecisiveAction,
|
DecisiveAction,
|
||||||
|
ExtractAction,
|
||||||
ReloadPageAction,
|
ReloadPageAction,
|
||||||
UserDefinedError,
|
UserDefinedError,
|
||||||
WebAction,
|
WebAction,
|
||||||
@@ -721,19 +722,7 @@ class ForgeAgent:
|
|||||||
|
|
||||||
using_cached_action_plan = False
|
using_cached_action_plan = False
|
||||||
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
|
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
|
||||||
actions = [
|
actions = [await self.create_extract_action(task, step, scraped_page)]
|
||||||
CompleteAction(
|
|
||||||
reasoning="Task has no navigation goal.",
|
|
||||||
data_extraction_goal=task.data_extraction_goal,
|
|
||||||
organization_id=task.organization_id,
|
|
||||||
task_id=task.task_id,
|
|
||||||
workflow_run_id=task.workflow_run_id,
|
|
||||||
step_id=step.step_id,
|
|
||||||
step_order=step.order,
|
|
||||||
action_order=0,
|
|
||||||
confidence_float=1.0,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
elif (
|
elif (
|
||||||
task_block
|
task_block
|
||||||
and task_block.cache_actions
|
and task_block.cache_actions
|
||||||
@@ -1039,6 +1028,21 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
detailed_agent_step_output.actions_and_results.append((complete_action, complete_results))
|
detailed_agent_step_output.actions_and_results.append((complete_action, complete_results))
|
||||||
await self.record_artifacts_after_action(task, step, browser_state)
|
await self.record_artifacts_after_action(task, step, browser_state)
|
||||||
|
|
||||||
|
# if the last action is complete and is successful, check if there's a data extraction goal
|
||||||
|
# if task has navigation goal and extraction goal at the same time, handle ExtractAction before marking step as completed
|
||||||
|
if (
|
||||||
|
task.navigation_goal
|
||||||
|
and task.data_extraction_goal
|
||||||
|
and self.step_has_completed_goal(detailed_agent_step_output)
|
||||||
|
):
|
||||||
|
working_page = await browser_state.must_get_working_page()
|
||||||
|
extract_action = await self.create_extract_action(task, step, scraped_page)
|
||||||
|
extract_results = await ActionHandler.handle_action(
|
||||||
|
scraped_page, task, step, working_page, extract_action
|
||||||
|
)
|
||||||
|
detailed_agent_step_output.actions_and_results.append((extract_action, extract_results))
|
||||||
|
|
||||||
# If no action errors return the agent state and output
|
# If no action errors return the agent state and output
|
||||||
completed_step = await self.update_step(
|
completed_step = await self.update_step(
|
||||||
step=step,
|
step=step,
|
||||||
@@ -1490,6 +1494,7 @@ class ForgeAgent:
|
|||||||
"""
|
"""
|
||||||
Find the last successful ScrapeAction for the task and return the extracted information.
|
Find the last successful ScrapeAction for the task and return the extracted information.
|
||||||
"""
|
"""
|
||||||
|
# TODO: make sure we can get extracted information with the ExtractAction change
|
||||||
steps = await app.DATABASE.get_task_steps(
|
steps = await app.DATABASE.get_task_steps(
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
organization_id=task.organization_id,
|
organization_id=task.organization_id,
|
||||||
@@ -1500,7 +1505,7 @@ class ForgeAgent:
|
|||||||
if not step.output or not step.output.actions_and_results:
|
if not step.output or not step.output.actions_and_results:
|
||||||
continue
|
continue
|
||||||
for action, action_results in step.output.actions_and_results:
|
for action, action_results in step.output.actions_and_results:
|
||||||
if action.action_type != ActionType.COMPLETE:
|
if action.action_type != ActionType.EXTRACT:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for action_result in action_results:
|
for action_result in action_results:
|
||||||
@@ -2197,3 +2202,43 @@ class ForgeAgent:
|
|||||||
organization_id=task.organization_id,
|
organization_id=task.organization_id,
|
||||||
errors=task_errors,
|
errors=task_errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_extract_action(task: Task, step: Step, scraped_page: ScrapedPage) -> ExtractAction:
|
||||||
|
context = skyvern_context.ensure_context()
|
||||||
|
# generate reasoning by prompt llm to think briefly what data to extract
|
||||||
|
prompt = prompt_engine.load_prompt(
|
||||||
|
"data-extraction-summary",
|
||||||
|
data_extraction_goal=task.data_extraction_goal,
|
||||||
|
data_extraction_schema=task.extracted_information_schema,
|
||||||
|
current_url=scraped_page.url,
|
||||||
|
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
data_extraction_summary_resp = await app.SECONDARY_LLM_API_HANDLER(
|
||||||
|
prompt=prompt,
|
||||||
|
step=step,
|
||||||
|
screenshots=scraped_page.screenshots,
|
||||||
|
)
|
||||||
|
return ExtractAction(
|
||||||
|
reasoning=data_extraction_summary_resp.get("summary", "Extracting information from the page"),
|
||||||
|
data_extraction_goal=task.data_extraction_goal,
|
||||||
|
organization_id=task.organization_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
step_order=step.order,
|
||||||
|
action_order=0,
|
||||||
|
confidence_float=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def step_has_completed_goal(detailed_agent_step_output: DetailedAgentStepOutput) -> bool:
|
||||||
|
if not detailed_agent_step_output.actions_and_results:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_action, last_action_results = detailed_agent_step_output.actions_and_results[-1]
|
||||||
|
if last_action.action_type not in [ActionType.COMPLETE, ActionType.EXTRACT]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return any(action_result.success for action_result in last_action_results)
|
||||||
|
|||||||
23
skyvern/forge/prompts/skyvern/data-extraction-summary.j2
Normal file
23
skyvern/forge/prompts/skyvern/data-extraction-summary.j2
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
Your are an AI assistant to help the user extract data from websites. Given a goal to extract information from a web page{% if data_extraction_schema%} and the output schema of the data you're going to extract{% endif %}, summarize what data you're going to extract from the page so that the user has a clear overview of your plan.
|
||||||
|
|
||||||
|
Reply in JSON format with the following keys:
|
||||||
|
{
|
||||||
|
"summary": str, // Summary of the data you will extract within one sentence. Be precise and concise.
|
||||||
|
}
|
||||||
|
|
||||||
|
The URL of the page you're on right now is `{{ current_url }}`.
|
||||||
|
|
||||||
|
Data extraction goal:
|
||||||
|
```
|
||||||
|
{{ data_extraction_goal }}
|
||||||
|
```{% if data_extraction_schema %}
|
||||||
|
|
||||||
|
Data extraction schema:
|
||||||
|
```
|
||||||
|
{{ data_extraction_schema }}
|
||||||
|
```{% endif %}
|
||||||
|
|
||||||
|
Current datetime, ISO format:
|
||||||
|
```
|
||||||
|
{{ local_datetime }}
|
||||||
|
```
|
||||||
@@ -2010,6 +2010,10 @@ class AgentDB:
|
|||||||
prompt: str | None = None,
|
prompt: str | None = None,
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
|
proxy_location: ProxyLocation | None = None,
|
||||||
|
totp_identifier: str | None = None,
|
||||||
|
totp_verification_url: str | None = None,
|
||||||
|
webhook_callback_url: str | None = None,
|
||||||
) -> ObserverTask:
|
) -> ObserverTask:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
new_observer_cruise = ObserverCruiseModel(
|
new_observer_cruise = ObserverCruiseModel(
|
||||||
@@ -2018,6 +2022,10 @@ class AgentDB:
|
|||||||
workflow_permanent_id=workflow_permanent_id,
|
workflow_permanent_id=workflow_permanent_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=url,
|
url=url,
|
||||||
|
proxy_location=proxy_location,
|
||||||
|
totp_identifier=totp_identifier,
|
||||||
|
totp_verification_url=totp_verification_url,
|
||||||
|
webhook_callback_url=webhook_callback_url,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
session.add(new_observer_cruise)
|
session.add(new_observer_cruise)
|
||||||
|
|||||||
@@ -87,6 +87,9 @@ class Step(BaseModel):
|
|||||||
raise ValueError(f"cant_set_is_last_to_false({self.step_id})")
|
raise ValueError(f"cant_set_is_last_to_false({self.step_id})")
|
||||||
|
|
||||||
def is_goal_achieved(self) -> bool:
|
def is_goal_achieved(self) -> bool:
|
||||||
|
# TODO: now we also consider a step has achieved the goal if the task doesn't have a navigation goal
|
||||||
|
# and the data extraction is successful
|
||||||
|
|
||||||
if self.status != StepStatus.completed:
|
if self.status != StepStatus.completed:
|
||||||
return False
|
return False
|
||||||
# TODO (kerem): Remove this check once we have backfilled all the steps
|
# TODO (kerem): Remove this check once we have backfilled all the steps
|
||||||
@@ -94,14 +97,14 @@ class Step(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if there is a successful complete action
|
# Check if there is a successful complete action
|
||||||
for action, action_results in self.output.actions_and_results:
|
if not self.output.actions_and_results:
|
||||||
if action.action_type != ActionType.COMPLETE:
|
return False
|
||||||
continue
|
|
||||||
|
|
||||||
if any(action_result.success for action_result in action_results):
|
last_action, last_action_results = self.output.actions_and_results[-1]
|
||||||
return True
|
if last_action.action_type not in [ActionType.COMPLETE, ActionType.EXTRACT]:
|
||||||
|
return False
|
||||||
|
|
||||||
return False
|
return any(action_result.success for action_result in last_action_results)
|
||||||
|
|
||||||
def is_success(self) -> bool:
|
def is_success(self) -> bool:
|
||||||
if self.status != StepStatus.completed:
|
if self.status != StepStatus.completed:
|
||||||
|
|||||||
@@ -1147,6 +1147,10 @@ async def observer_task(
|
|||||||
organization=organization,
|
organization=organization,
|
||||||
user_prompt=data.user_prompt,
|
user_prompt=data.user_prompt,
|
||||||
user_url=str(data.url) if data.url else None,
|
user_url=str(data.url) if data.url else None,
|
||||||
|
totp_identifier=data.totp_identifier,
|
||||||
|
totp_verification_url=data.totp_verification_url,
|
||||||
|
webhook_callback_url=data.webhook_callback_url,
|
||||||
|
proxy_location=data.proxy_location,
|
||||||
)
|
)
|
||||||
except LLMProviderError:
|
except LLMProviderError:
|
||||||
LOG.error("LLM failure to initialize observer cruise", exc_info=True)
|
LOG.error("LLM failure to initialize observer cruise", exc_info=True)
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class ObserverTaskRequest(BaseModel):
|
|||||||
webhook_callback_url: str | None = None
|
webhook_callback_url: str | None = None
|
||||||
totp_verification_url: str | None = None
|
totp_verification_url: str | None = None
|
||||||
totp_identifier: str | None = None
|
totp_identifier: str | None = None
|
||||||
|
proxy_location: ProxyLocation | None = None
|
||||||
|
|
||||||
@field_validator("url", "webhook_callback_url", "totp_verification_url")
|
@field_validator("url", "webhook_callback_url", "totp_verification_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -80,11 +80,21 @@ def _generate_data_extraction_schema_for_loop(loop_values_key: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_observer_cruise(
|
async def initialize_observer_cruise(
|
||||||
organization: Organization, user_prompt: str, user_url: str | None = None
|
organization: Organization,
|
||||||
|
user_prompt: str,
|
||||||
|
user_url: str | None = None,
|
||||||
|
proxy_location: ProxyLocation | None = None,
|
||||||
|
totp_identifier: str | None = None,
|
||||||
|
totp_verification_url: str | None = None,
|
||||||
|
webhook_callback_url: str | None = None,
|
||||||
) -> ObserverTask:
|
) -> ObserverTask:
|
||||||
observer_cruise = await app.DATABASE.create_observer_cruise(
|
observer_cruise = await app.DATABASE.create_observer_cruise(
|
||||||
prompt=user_prompt,
|
prompt=user_prompt,
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
|
totp_verification_url=totp_verification_url,
|
||||||
|
totp_identifier=totp_identifier,
|
||||||
|
webhook_callback_url=webhook_callback_url,
|
||||||
|
proxy_location=proxy_location,
|
||||||
)
|
)
|
||||||
# set observer cruise id in context
|
# set observer cruise id in context
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
@@ -117,7 +127,9 @@ async def initialize_observer_cruise(
|
|||||||
# create workflow and workflow run
|
# create workflow and workflow run
|
||||||
max_steps_override = 10
|
max_steps_override = 10
|
||||||
try:
|
try:
|
||||||
new_workflow = await app.WORKFLOW_SERVICE.create_empty_workflow(organization, metadata.workflow_title)
|
new_workflow = await app.WORKFLOW_SERVICE.create_empty_workflow(
|
||||||
|
organization, metadata.workflow_title, proxy_location=proxy_location
|
||||||
|
)
|
||||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||||
request_id=None,
|
request_id=None,
|
||||||
workflow_request=WorkflowRequestBody(),
|
workflow_request=WorkflowRequestBody(),
|
||||||
|
|||||||
@@ -1693,7 +1693,9 @@ class WorkflowService:
|
|||||||
|
|
||||||
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
||||||
|
|
||||||
async def create_empty_workflow(self, organization: Organization, title: str) -> Workflow:
|
async def create_empty_workflow(
|
||||||
|
self, organization: Organization, title: str, proxy_location: ProxyLocation | None = None
|
||||||
|
) -> Workflow:
|
||||||
"""
|
"""
|
||||||
Create a blank workflow with no blocks
|
Create a blank workflow with no blocks
|
||||||
"""
|
"""
|
||||||
@@ -1704,6 +1706,7 @@ class WorkflowService:
|
|||||||
parameters=[],
|
parameters=[],
|
||||||
blocks=[],
|
blocks=[],
|
||||||
),
|
),
|
||||||
|
proxy_location=proxy_location,
|
||||||
)
|
)
|
||||||
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
|
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class ActionType(StrEnum):
|
|||||||
COMPLETE = "complete"
|
COMPLETE = "complete"
|
||||||
RELOAD_PAGE = "reload_page"
|
RELOAD_PAGE = "reload_page"
|
||||||
|
|
||||||
|
EXTRACT = "extract"
|
||||||
|
|
||||||
def is_web_action(self) -> bool:
|
def is_web_action(self) -> bool:
|
||||||
return self in [
|
return self in [
|
||||||
ActionType.CLICK,
|
ActionType.CLICK,
|
||||||
@@ -248,6 +250,12 @@ class CompleteAction(DecisiveAction):
|
|||||||
data_extraction_goal: str | None = None
|
data_extraction_goal: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractAction(Action):
|
||||||
|
action_type: ActionType = ActionType.EXTRACT
|
||||||
|
data_extraction_goal: str | None = None
|
||||||
|
data_extraction_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ScrapeResult(BaseModel):
|
class ScrapeResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
Scraped response from a webpage, including:
|
Scraped response from a webpage, including:
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from skyvern.exceptions import (
|
|||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
from skyvern.forge.sdk.api.files import download_file, get_download_dir, list_files_in_directory
|
from skyvern.forge.sdk.api.files import download_file, get_download_dir, list_files_in_directory
|
||||||
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||||
from skyvern.forge.sdk.core.skyvern_context import ensure_context
|
from skyvern.forge.sdk.core.skyvern_context import ensure_context
|
||||||
@@ -309,6 +310,9 @@ class ActionHandler:
|
|||||||
action=action,
|
action=action,
|
||||||
)
|
)
|
||||||
actions_result.append(ActionFailure(e))
|
actions_result.append(ActionFailure(e))
|
||||||
|
except LLMProviderError as e:
|
||||||
|
LOG.exception("LLM error in action handler", action=action, exc_info=True)
|
||||||
|
actions_result.append(ActionFailure(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception("Unhandled exception in action handler", action=action)
|
LOG.exception("Unhandled exception in action handler", action=action)
|
||||||
actions_result.append(ActionFailure(e))
|
actions_result.append(ActionFailure(e))
|
||||||
@@ -1318,15 +1322,28 @@ async def handle_complete_action(
|
|||||||
)
|
)
|
||||||
action.verified = True
|
action.verified = True
|
||||||
|
|
||||||
|
return [ActionSuccess()]
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_extract_action(
|
||||||
|
action: actions.ExtractAction,
|
||||||
|
page: Page,
|
||||||
|
scraped_page: ScrapedPage,
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
) -> list[ActionResult]:
|
||||||
extracted_data = None
|
extracted_data = None
|
||||||
if action.data_extraction_goal:
|
if task.data_extraction_goal:
|
||||||
scrape_action_result = await extract_information_for_navigation_goal(
|
scrape_action_result = await extract_information_for_navigation_goal(
|
||||||
scraped_page=scraped_page,
|
scraped_page=scraped_page,
|
||||||
task=task,
|
task=task,
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
extracted_data = scrape_action_result.scraped_data
|
extracted_data = scrape_action_result.scraped_data
|
||||||
return [ActionSuccess(data=extracted_data)]
|
return [ActionSuccess(data=extracted_data)]
|
||||||
|
else:
|
||||||
|
LOG.warning("No data extraction goal, skipping extract action", step_id=step.step_id)
|
||||||
|
return [ActionFailure(exception=Exception("No data extraction goal"))]
|
||||||
|
|
||||||
|
|
||||||
ActionHandler.register_action_type(ActionType.SOLVE_CAPTCHA, handle_solve_captcha_action)
|
ActionHandler.register_action_type(ActionType.SOLVE_CAPTCHA, handle_solve_captcha_action)
|
||||||
@@ -1339,6 +1356,7 @@ ActionHandler.register_action_type(ActionType.SELECT_OPTION, handle_select_optio
|
|||||||
ActionHandler.register_action_type(ActionType.WAIT, handle_wait_action)
|
ActionHandler.register_action_type(ActionType.WAIT, handle_wait_action)
|
||||||
ActionHandler.register_action_type(ActionType.TERMINATE, handle_terminate_action)
|
ActionHandler.register_action_type(ActionType.TERMINATE, handle_terminate_action)
|
||||||
ActionHandler.register_action_type(ActionType.COMPLETE, handle_complete_action)
|
ActionHandler.register_action_type(ActionType.COMPLETE, handle_complete_action)
|
||||||
|
ActionHandler.register_action_type(ActionType.EXTRACT, handle_extract_action)
|
||||||
|
|
||||||
|
|
||||||
async def get_actual_value_of_parameter_if_secret(task: Task, parameter: str) -> Any:
|
async def get_actual_value_of_parameter_if_secret(task: Task, parameter: str) -> Any:
|
||||||
|
|||||||
Reference in New Issue
Block a user