diff --git a/alembic/versions/2024_12_27_1610-d13af1e466fa_new_observer_thoughts.py b/alembic/versions/2024_12_27_1610-d13af1e466fa_new_observer_thoughts.py new file mode 100644 index 00000000..f4ed78e8 --- /dev/null +++ b/alembic/versions/2024_12_27_1610-d13af1e466fa_new_observer_thoughts.py @@ -0,0 +1,35 @@ +"""new observer thoughts + +Revision ID: d13af1e466fa +Revises: 835522a23b19 +Create Date: 2024-12-27 16:10:36.555540+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d13af1e466fa" +down_revision: Union[str, None] = "835522a23b19" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("observer_thoughts", sa.Column("observer_thought_type", sa.String(), nullable=True)) + op.add_column("observer_thoughts", sa.Column("observer_thought_scenario", sa.String(), nullable=True)) + op.add_column("observer_thoughts", sa.Column("output", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("observer_thoughts", "output") + op.drop_column("observer_thoughts", "observer_thought_scenario") + op.drop_column("observer_thoughts", "observer_thought_type") + # ### end Alembic commands ### diff --git a/skyvern/forge/prompts/skyvern/observer_loop_task_extraction_goal.j2 b/skyvern/forge/prompts/skyvern/observer_loop_task_extraction_goal.j2 index a537eb7f..5928b573 100644 --- a/skyvern/forge/prompts/skyvern/observer_loop_task_extraction_goal.j2 +++ b/skyvern/forge/prompts/skyvern/observer_loop_task_extraction_goal.j2 @@ -1,6 +1,6 @@ The user is trying to achieve a goal the web. Now they've decided to go through a list of values and take the same tasks with each variant in the list. -Help to user extract this list of values based on what they want to achieve: +Help the user extract a list of values based on what they want to achieve: ``` {{ plan }} -``` +``` \ No newline at end of file diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index bb0ee46e..bffbb268 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -55,7 +55,12 @@ from skyvern.forge.sdk.db.utils import ( ) from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs from skyvern.forge.sdk.models import Step, StepStatus -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverThought +from skyvern.forge.sdk.schemas.observers import ( + ObserverCruise, + ObserverCruiseStatus, + ObserverThought, + ObserverThoughtType, +) from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus @@ -1924,17 +1929,19 @@ class AgentDB: async def get_observer_thoughts( self, observer_cruise_id: str, + observer_thought_types: list[ObserverThoughtType] | None = None, organization_id: str | None = None, ) -> list[ObserverThought]: async with self.Session() as session: - observer_thoughts = ( - await session.scalars( - select(ObserverThoughtModel) - .filter_by(observer_cruise_id=observer_cruise_id) - .filter_by(organization_id=organization_id) - .order_by(ObserverThoughtModel.created_at) - ) - ).all() + query = ( + select(ObserverThoughtModel) + .filter_by(observer_cruise_id=observer_cruise_id) + .filter_by(organization_id=organization_id) + .order_by(ObserverThoughtModel.created_at) + ) + if observer_thought_types: + query = query.filter(ObserverThoughtModel.observer_thought_type.in_(observer_thought_types)) + observer_thoughts = (await session.scalars(query)).all() return [ObserverThought.model_validate(thought) for thought in observer_thoughts] async def create_observer_cruise( @@ -1971,6 +1978,9 @@ class AgentDB: observation: str | None = None, thought: str | None = None, answer: str | None = None, + observer_thought_scenario: str | None = None, + observer_thought_type: str = ObserverThoughtType.plan, + output: dict[str, Any] | None = None, organization_id: str | None = None, ) -> ObserverThought: async with self.Session() as session: @@ -1984,6 +1994,9 @@ class AgentDB: observation=observation, thought=thought, answer=answer, + observer_thought_scenario=observer_thought_scenario, + observer_thought_type=observer_thought_type, + output=output, organization_id=organization_id, ) session.add(new_observer_thought) @@ -1995,9 +2008,13 @@ class AgentDB: self, observer_thought_id: str, workflow_run_block_id: str | None = None, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_permanent_id: str | None = None, observation: str | None = None, thought: str | None = None, answer: str | None = None, + output: dict[str, Any] | None = None, organization_id: str | None = None, ) -> ObserverThought: async with self.Session() as session: @@ -2011,12 +2028,20 @@ class AgentDB: if observer_thought: if workflow_run_block_id: observer_thought.workflow_run_block_id = workflow_run_block_id + if workflow_run_id: + observer_thought.workflow_run_id = workflow_run_id + if workflow_id: + observer_thought.workflow_id = workflow_id + if workflow_permanent_id: + observer_thought.workflow_permanent_id = workflow_permanent_id if observation: observer_thought.observation = observation if thought: observer_thought.thought = thought if answer: observer_thought.answer = answer + if output: + observer_thought.output = output await session.commit() await session.refresh(observer_thought) return ObserverThought.model_validate(observer_thought) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 001f869b..2966243f 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -41,6 +41,7 @@ from skyvern.forge.sdk.db.id import ( generate_workflow_run_block_id, generate_workflow_run_id, ) +from skyvern.forge.sdk.schemas.observers import ObserverThoughtType from skyvern.forge.sdk.schemas.tasks import ProxyLocation @@ -559,6 +560,10 @@ class ObserverThoughtModel(Base): thought = Column(String, nullable=True) answer = Column(String, nullable=True) + observer_thought_type = Column(String, nullable=True, default=ObserverThoughtType.plan) + observer_thought_scenario = Column(String, nullable=True) + output = Column(JSON, nullable=True) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/schemas/observers.py b/skyvern/forge/sdk/schemas/observers.py index 8fc2bf7f..29d583ed 100644 --- a/skyvern/forge/sdk/schemas/observers.py +++ b/skyvern/forge/sdk/schemas/observers.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import StrEnum +from typing import Any from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator @@ -35,6 +36,22 @@ class ObserverCruise(BaseModel): modified_at: datetime +class ObserverThoughtType(StrEnum): + plan = "plan" + metadata = "metadata" + user_goal_check = "user_goal_check" + internal_plan = "internal_plan" + + +class ObserverThoughtScenario(StrEnum): + generate_plan = "generate_plan" + user_goal_check = "user_goal_check" + generate_metadata = "generate_metadata" + extract_loop_values = "extract_loop_values" + generate_task_in_loop = "generate_task_in_loop" + generate_task = "generate_general_task" + + class ObserverThought(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -49,6 +66,9 @@ class ObserverThought(BaseModel): observation: str | None = None thought: str | None = None answer: str | None = None + observer_thought_type: ObserverThoughtType | None = ObserverThoughtType.plan + observer_thought_scenario: ObserverThoughtScenario | None = None + output: dict[str, Any] | None = None created_at: datetime modified_at: datetime diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index 92de49b0..b4f3c18e 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -10,9 +10,16 @@ from pydantic import BaseModel from skyvern.exceptions import UrlGenerationFailure from skyvern.forge import app from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverMetadata +from skyvern.forge.sdk.schemas.observers import ( + ObserverCruise, + ObserverCruiseStatus, + ObserverMetadata, + ObserverThoughtScenario, + ObserverThoughtType, +) from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.tasks import ProxyLocation from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType @@ -76,6 +83,13 @@ async def initialize_observer_cruise( organization_id=organization.organization_id, ) + observer_thought = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise.observer_cruise_id, + organization_id=organization.organization_id, + observer_thought_type=ObserverThoughtType.metadata, + observer_thought_scenario=ObserverThoughtScenario.generate_metadata, + ) + metadata_prompt = prompt_engine.load_prompt("observer_generate_metadata", user_goal=user_prompt, user_url=user_url) metadata_response = await app.SECONDARY_LLM_API_HANDLER(prompt=metadata_prompt, observer_cruise=observer_cruise) # validate @@ -103,6 +117,16 @@ async def initialize_observer_cruise( version=None, max_steps_override=max_steps_override, ) + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought.observer_thought_id, + organization_id=organization.organization_id, + workflow_run_id=workflow_run.workflow_run_id, + workflow_id=new_workflow.workflow_id, + workflow_permanent_id=new_workflow.workflow_permanent_id, + thought=metadata_response.get("thoughts", ""), + output=metadata.model_dump(), + ) + # update oserver cruise observer_cruise = await app.DATABASE.update_observer_cruise( observer_cruise_id=observer_cruise.observer_cruise_id, @@ -237,6 +261,8 @@ async def run_observer_cruise( workflow_run_id=workflow_run.workflow_run_id, workflow_id=workflow.workflow_id, workflow_permanent_id=workflow.workflow_permanent_id, + observer_thought_type=ObserverThoughtType.plan, + observer_thought_scenario=ObserverThoughtScenario.generate_plan, ) observer_response = await app.LLM_API_HANDLER( prompt=observer_prompt, @@ -255,6 +281,7 @@ async def run_observer_cruise( observation = observer_response.get("page_info", "") thoughts: str = observer_response.get("thoughts", "") plan: str = observer_response.get("plan", "") + task_type: str = observer_response.get("task_type", "") # Create and save observer thought await app.DATABASE.update_observer_thought( observer_thought_id=observer_thought.observer_thought_id, @@ -262,6 +289,7 @@ async def run_observer_cruise( thought=thoughts, observation=observation, answer=plan, + output={"task_type": task_type, "user_goal_achieved": user_goal_achieved}, ) if user_goal_achieved is True: @@ -274,7 +302,6 @@ async def run_observer_cruise( break # parse observer repsonse and run the next task - task_type = observer_response.get("task_type") if not task_type: LOG.error("No task type found in observer response", observer_response=observer_response) await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( @@ -288,6 +315,8 @@ async def run_observer_cruise( block, block_yaml_list, parameter_yaml_list = await _generate_extraction_task( observer_cruise=observer_cruise, workflow_id=workflow_id, + workflow_permanent_id=workflow.workflow_permanent_id, + workflow_run_id=workflow_run_id, current_url=current_url, element_tree_in_prompt=element_tree_in_prompt, data_extraction_goal=plan, @@ -298,6 +327,8 @@ async def run_observer_cruise( original_url = url if i == 0 else None block, block_yaml_list, parameter_yaml_list = await _generate_navigation_task( workflow_id=workflow_id, + workflow_permanent_id=workflow.workflow_permanent_id, + workflow_run_id=workflow_run_id, original_url=original_url, navigation_goal=plan, ) @@ -307,6 +338,7 @@ async def run_observer_cruise( block, block_yaml_list, parameter_yaml_list, extraction_obj, inner_task = await _generate_loop_task( observer_cruise=observer_cruise, workflow_id=workflow_id, + workflow_permanent_id=workflow.workflow_permanent_id, workflow_run_id=workflow_run_id, plan=plan, browser_state=browser_state, @@ -378,8 +410,18 @@ async def run_observer_cruise( task_history=task_history, local_datetime=datetime.now(context.tz_info).isoformat(), ) + observer_thought = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise_id, + organization_id=organization_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_permanent_id=workflow.workflow_permanent_id, + observer_thought_type=ObserverThoughtType.user_goal_check, + observer_thought_scenario=ObserverThoughtScenario.user_goal_check, + ) completion_resp = await app.LLM_API_HANDLER( - prompt=observer_completion_prompt, observer_cruise=observer_cruise + prompt=observer_completion_prompt, + observer_cruise=observer_thought, ) LOG.info( "Observer completion check response", @@ -388,7 +430,15 @@ async def run_observer_cruise( workflow_run_id=workflow_run_id, task_history=task_history, ) - if completion_resp.get("user_goal_achieved", False): + user_goal_achieved = completion_resp.get("user_goal_achieved", False) + thought = completion_resp.get("thoughts", "") + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought.observer_thought_id, + organization_id=organization_id, + thought=thought, + output={"user_goal_achieved": user_goal_achieved}, + ) + if user_goal_achieved: LOG.info( "User goal achieved according to the observer completion check", iteration=i, @@ -514,6 +564,7 @@ async def _set_up_workflow_context(workflow_id: str, workflow_run_id: str) -> No async def _generate_loop_task( observer_cruise: ObserverCruise, workflow_id: str, + workflow_permanent_id: str, workflow_run_id: str, plan: str, browser_state: BrowserState, @@ -525,7 +576,25 @@ async def _generate_loop_task( "observer_loop_task_extraction_goal", plan=plan, ) - + data_extraction_thought = f"Going to generate a list of values to go through based on the plan: {plan}." + observer_thought = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise.observer_cruise_id, + organization_id=observer_cruise.organization_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_permanent_id=workflow_permanent_id, + observer_thought_type=ObserverThoughtType.plan, + observer_thought_scenario=ObserverThoughtScenario.extract_loop_values, + thought=data_extraction_thought, + ) + # generate screenshot artifact for the observer thought + if scraped_page.screenshots: + for screenshot in scraped_page.screenshots: + await app.ARTIFACT_MANAGER.create_observer_thought_artifact( + observer_thought=observer_thought, + artifact_type=ArtifactType.SCREENSHOT_LLM, + data=screenshot, + ) label = f"extraction_task_for_loop_{_generate_random_string()}" extraction_block_yaml = ExtractionBlockYAML( label=label, @@ -576,6 +645,13 @@ async def _generate_loop_task( ) raise + # update the observer thought + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought.observer_thought_id, + organization_id=observer_cruise.organization_id, + output=output_value_obj.model_dump(), + ) + # create ContextParameter for the loop over pointer that ForLoopBlock needs. loop_for_context_parameter = ContextParameter( key="loop_values", @@ -627,15 +703,31 @@ async def _generate_loop_task( is_link=output_value_obj.is_loop_value_link, loop_values=output_value_obj.loop_values, ) + observer_thought_task_in_loop = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise.observer_cruise_id, + organization_id=observer_cruise.organization_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_permanent_id=workflow_permanent_id, + observer_thought_type=ObserverThoughtType.internal_plan, + observer_thought_scenario=ObserverThoughtScenario.generate_task_in_loop, + ) task_in_loop_metadata_response = await app.LLM_API_HANDLER( task_in_loop_metadata_prompt, screenshots=scraped_page.screenshots, - observer_cruise=observer_cruise, + observer_thought=observer_thought_task_in_loop, ) LOG.info("Task in loop metadata response", task_in_loop_metadata_response=task_in_loop_metadata_response) navigation_goal = task_in_loop_metadata_response.get("navigation_goal") data_extraction_goal = task_in_loop_metadata_response.get("data_extraction_goal") data_extraction_schema = task_in_loop_metadata_response.get("data_schema") + thought = task_in_loop_metadata_response.get("thoughts") + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought_task_in_loop.observer_thought_id, + organization_id=observer_cruise.organization_id, + thought=thought, + output=task_in_loop_metadata_response, + ) if data_extraction_goal and navigation_goal: navigation_goal = ( navigation_goal @@ -699,6 +791,8 @@ async def _generate_loop_task( async def _generate_extraction_task( observer_cruise: ObserverCruise, workflow_id: str, + workflow_permanent_id: str, + workflow_run_id: str, current_url: str, element_tree_in_prompt: str, data_extraction_goal: str, @@ -753,6 +847,8 @@ async def _generate_extraction_task( async def _generate_navigation_task( workflow_id: str, + workflow_permanent_id: str, + workflow_run_id: str, navigation_goal: str, original_url: str | None = None, ) -> tuple[NavigationBlock, list[BLOCK_YAML_TYPES], list[PARAMETER_YAML_TYPES]]: