fix loop_values in observer (#1522)

This commit is contained in:
Shuchang Zheng
2025-01-08 23:53:21 -08:00
committed by GitHub
parent 5796de73d1
commit d4ffcdbfda

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from typing import Any
import structlog
from pydantic import BaseModel
from sqlalchemy.exc import OperationalError
from skyvern.exceptions import UrlGenerationFailure
@@ -56,29 +55,26 @@ DEFAULT_WORKFLOW_TITLE = "New Workflow"
RANDOM_STRING_POOL = string.ascii_letters + string.digits
DEFAULT_MAX_ITERATIONS = 10
DATA_EXTRACTION_SCHEMA_FOR_LOOP = {
"type": "object",
"properties": {
"loop_values": {
"type": "array",
"description": 'User will later iterate through this array of values to achieve their "big goal" in the web. In each iteration, the user will try to take the same actions in the web but with a different value of its own. If the value is a url link, make sure it is a full url with http/https protocol, domain and path if any, based on the current url. For examples: \n1. When the goal is "Open up to 10 links from an ecomm search result page, and extract information like the price of each product.", user will iterate through an array of product links or URLs. In each iteration, the user will go to the linked page and extrat price information of the product. As a result, the array consists of 10 product urls scraped from the search result page.\n2. When the goal is "download 10 documents found on a page", user will iterate through an array of document names. In each iteration, the user will use a different value variant to start from the same page (the existing page) and take actions based on the variant. As a result, the array consists of up to 10 document names scraped from the page that the user wants to download.',
"items": {"type": "string", "description": "The relevant value"},
},
"is_loop_value_link": {
"type": "boolean",
"description": "true if the loop_values is an array of urls to be visited for each task. false if the loop_values is an array of non-link values to be used in each task (for each task they start from the same page / link).",
},
},
}
MINI_GOAL_TEMPLATE = """Achieve the following mini goal and once it's achieved, complete: {mini_goal}
This mini goal is part of the big goal the user wants to achieve and use the big goal as context to achieve the mini goal: {main_goal}"""
class LoopExtractionOutput(BaseModel):
loop_values: list[str]
is_loop_value_link: bool
def _generate_data_extraction_schema_for_loop(loop_values_key: str) -> dict:
return {
"type": "object",
"properties": {
loop_values_key: {
"type": "array",
"description": 'User will later iterate through this array of values to achieve their "big goal" in the web. In each iteration, the user will try to take the same actions in the web but with a different value of its own. If the value is a url link, make sure it is a full url with http/https protocol, domain and path if any, based on the current url. For examples: \n1. When the goal is "Open up to 10 links from an ecomm search result page, and extract information like the price of each product.", user will iterate through an array of product links or URLs. In each iteration, the user will go to the linked page and extrat price information of the product. As a result, the array consists of 10 product urls scraped from the search result page.\n2. When the goal is "download 10 documents found on a page", user will iterate through an array of document names. In each iteration, the user will use a different value variant to start from the same page (the existing page) and take actions based on the variant. As a result, the array consists of up to 10 document names scraped from the page that the user wants to download.',
"items": {"type": "string", "description": "The relevant value"},
},
"is_loop_value_link": {
"type": "boolean",
"description": "true if the loop_values is an array of urls to be visited for each task. false if the loop_values is an array of non-link values to be used in each task (for each task they start from the same page / link).",
},
},
}
async def initialize_observer_cruise(
@@ -442,7 +438,7 @@ async def run_observer_cruise_helper(
task_history_record = {
"type": task_type,
"task": plan,
"loop_over_values": extraction_obj.loop_values,
"loop_over_values": extraction_obj.get("loop_values"),
"task_inside_the_loop": inner_task,
}
except Exception:
@@ -668,7 +664,7 @@ async def _generate_loop_task(
browser_state: BrowserState,
original_url: str,
scraped_page: ScrapedPage,
) -> tuple[ForLoopBlock, list[BLOCK_YAML_TYPES], list[PARAMETER_YAML_TYPES], LoopExtractionOutput, dict[str, Any]]:
) -> tuple[ForLoopBlock, list[BLOCK_YAML_TYPES], list[PARAMETER_YAML_TYPES], dict[str, Any], dict[str, Any]]:
for_loop_parameter_yaml_list: list[PARAMETER_YAML_TYPES] = []
loop_value_extraction_goal = prompt_engine.load_prompt(
"observer_loop_task_extraction_goal",
@@ -693,11 +689,13 @@ async def _generate_loop_task(
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
label = f"extraction_task_for_loop_{_generate_random_string()}"
loop_random_string = _generate_random_string()
label = f"extraction_task_for_loop_{loop_random_string}"
loop_values_key = f"loop_values_{loop_random_string}"
extraction_block_yaml = ExtractionBlockYAML(
label=label,
data_extraction_goal=loop_value_extraction_goal,
data_schema=DATA_EXTRACTION_SCHEMA_FOR_LOOP,
data_schema=_generate_data_extraction_schema_for_loop(loop_values_key),
)
loop_value_extraction_output_parameter = await app.WORKFLOW_SERVICE.create_output_parameter_for_block(
workflow_id=workflow_id,
@@ -706,7 +704,7 @@ async def _generate_loop_task(
extraction_block_for_loop = ExtractionBlock(
label=label,
data_extraction_goal=loop_value_extraction_goal,
data_schema=DATA_EXTRACTION_SCHEMA_FOR_LOOP,
data_schema=_generate_data_extraction_schema_for_loop(loop_values_key),
output_parameter=loop_value_extraction_output_parameter,
)
@@ -729,9 +727,15 @@ async def _generate_loop_task(
raise Exception("extraction_block failed")
# validate output parameter
try:
output_value_obj = LoopExtractionOutput.model_validate(
extraction_block_result.output_parameter_value.get("extracted_information") # type: ignore
)
output_value_obj: dict[str, Any] = extraction_block_result.output_parameter_value.get("extracted_information") # type: ignore
if not output_value_obj or not isinstance(output_value_obj, dict):
raise Exception("Invalid output parameter of the extraction block for the loop task")
if loop_values_key not in output_value_obj:
raise Exception("loop_values_key not found in the output parameter of the extraction block")
if "is_loop_value_link" not in output_value_obj:
raise Exception("is_loop_value_link not found in the output parameter of the extraction block")
loop_values = output_value_obj.get(loop_values_key, [])
is_loop_value_link = output_value_obj.get("is_loop_value_link")
except Exception:
LOG.error(
"Failed to validate the output parameter of the extraction block for the loop task",
@@ -747,12 +751,12 @@ async def _generate_loop_task(
await app.DATABASE.update_observer_thought(
observer_thought_id=observer_thought.observer_thought_id,
organization_id=observer_cruise.organization_id,
output=output_value_obj.model_dump(),
output=output_value_obj,
)
# create ContextParameter for the loop over pointer that ForLoopBlock needs.
loop_for_context_parameter = ContextParameter(
key=f"loop_values_{_generate_random_string()}",
key=loop_values_key,
source=loop_value_extraction_output_parameter,
)
for_loop_parameter_yaml_list.append(
@@ -769,11 +773,11 @@ async def _generate_loop_task(
value=extraction_block_result.output_parameter_value,
)
task_parameters: list[PARAMETER_TYPE] = []
if output_value_obj.is_loop_value_link:
LOG.info("Loop values are links", loop_values=output_value_obj.loop_values)
context_parameter_key = url = f"task_in_loop_url_{_generate_random_string()}"
if is_loop_value_link is True:
LOG.info("Loop values are links", loop_values=loop_values)
context_parameter_key = url = f"task_in_loop_url_{loop_random_string}"
else:
LOG.info("Loop values are not links", loop_values=output_value_obj.loop_values)
LOG.info("Loop values are not links", loop_values=loop_values)
page = await browser_state.get_working_page()
url = str(
await SkyvernFrame.evaluate(frame=page, expression="() => document.location.href") if page else original_url
@@ -801,8 +805,8 @@ async def _generate_loop_task(
"observer_generate_task_block",
plan=plan,
local_datetime=datetime.now(context.tz_info).isoformat(),
is_link=output_value_obj.is_loop_value_link,
loop_values=output_value_obj.loop_values,
is_link=is_loop_value_link,
loop_values=loop_values,
)
observer_thought_task_in_loop = await app.DATABASE.create_observer_thought(
observer_cruise_id=observer_cruise.observer_cruise_id,
@@ -1108,8 +1112,18 @@ def _get_extracted_data_from_block_result(
)
continue
output_value = inner_output.get("output_value", {})
if "extracted_information" in output_value and output_value["extracted_information"]:
inner_loop_output_overall.append(output_value["extracted_information"])
if not isinstance(output_value, dict):
LOG.warning(
"output_value is not a dict",
output_value=output_value,
observer_cruise_id=observer_cruise_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=block_result.workflow_run_block_id,
)
continue
else:
if "extracted_information" in output_value and output_value["extracted_information"]:
inner_loop_output_overall.append(output_value["extracted_information"])
loop_output_overall.append(inner_loop_output_overall)
return loop_output_overall if loop_output_overall else None
return None