From 374b2326c402421bac1c066d037a71c47c18f98e Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Fri, 10 Jan 2025 14:59:53 -0800 Subject: [PATCH] observer summary (#1532) --- ...a947c379c02_observer_summary_and_output.py | 33 +++++++ .../forge/prompts/skyvern/observer_summary.j2 | 24 +++++ skyvern/forge/sdk/db/client.py | 6 ++ skyvern/forge/sdk/db/models.py | 2 + skyvern/forge/sdk/schemas/observers.py | 3 + .../forge/sdk/services/observer_service.py | 92 +++++++++++++------ 6 files changed, 133 insertions(+), 27 deletions(-) create mode 100644 alembic/versions/2025_01_10_2246-6a947c379c02_observer_summary_and_output.py create mode 100644 skyvern/forge/prompts/skyvern/observer_summary.j2 diff --git a/alembic/versions/2025_01_10_2246-6a947c379c02_observer_summary_and_output.py b/alembic/versions/2025_01_10_2246-6a947c379c02_observer_summary_and_output.py new file mode 100644 index 00000000..3651b0f8 --- /dev/null +++ b/alembic/versions/2025_01_10_2246-6a947c379c02_observer_summary_and_output.py @@ -0,0 +1,33 @@ +"""observer summary and output + +Revision ID: 6a947c379c02 +Revises: d5640aa644b9 +Create Date: 2025-01-10 22:46:41.757862+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "6a947c379c02" +down_revision: Union[str, None] = "d5640aa644b9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("observer_cruises", sa.Column("summary", sa.String(), nullable=True)) + op.add_column("observer_cruises", sa.Column("output", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("observer_cruises", "output") + op.drop_column("observer_cruises", "summary") + # ### end Alembic commands ### diff --git a/skyvern/forge/prompts/skyvern/observer_summary.j2 b/skyvern/forge/prompts/skyvern/observer_summary.j2 new file mode 100644 index 00000000..5c71f447 --- /dev/null +++ b/skyvern/forge/prompts/skyvern/observer_summary.j2 @@ -0,0 +1,24 @@ +The AI assistant has helped the user achieve the user goal in the web. +Given the user goal, the latest screenshot of the page and the mini tasks that have been completed by the user along the way, summarize what has been achieved and output structured data related to the user goal in json format. +You want to present the response in a clear way so that user can clearly understand what has been achieved. + +Reply in JSON format with the following keys: +{ + "description": str, // Summarize what has been achieved and describe the information extracted related to the user goal if any. Be precise and concise. + "output": json, // Structured data related to the user goal if any. +} + +User goal: +``` +{{ user_goal }} +``` + +Task history (the earliest task is the first in the list and the latest is the last in the list): +``` +{{ task_history }} +``` + +Current datetime, ISO format: +``` +{{ local_datetime }} +``` diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index adbc4984..2bc9ed25 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -2093,6 +2093,8 @@ class AgentDB: workflow_permanent_id: str | None = None, url: str | None = None, prompt: str | None = None, + summary: str | None = None, + output: dict[str, Any] | None = None, organization_id: str | None = None, ) -> ObserverCruise: async with self.Session() as session: @@ -2116,6 +2118,10 @@ class AgentDB: observer_cruise.url = url if prompt: observer_cruise.prompt = prompt + if summary: + observer_cruise.summary = summary + if output: + observer_cruise.output = output await session.commit() await session.refresh(observer_cruise) return ObserverCruise.model_validate(observer_cruise) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 7ad5b2c7..5f2c3efa 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -548,6 +548,8 @@ class ObserverCruiseModel(Base): workflow_permanent_id = Column(String, nullable=True) prompt = Column(UnicodeText, nullable=True) url = Column(String, nullable=True) + summary = Column(String, nullable=True) + output = Column(JSON, nullable=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/schemas/observers.py b/skyvern/forge/sdk/schemas/observers.py index 1b1593a3..9f039665 100644 --- a/skyvern/forge/sdk/schemas/observers.py +++ b/skyvern/forge/sdk/schemas/observers.py @@ -31,6 +31,8 @@ class ObserverCruise(BaseModel): workflow_permanent_id: str | None = None prompt: str | None = None url: HttpUrl | None = None + summary: str | None = None + output: dict[str, Any] | list | str | None = None created_at: datetime modified_at: datetime @@ -46,6 +48,7 @@ class ObserverThoughtType(StrEnum): class ObserverThoughtScenario(StrEnum): generate_plan = "generate_plan" user_goal_check = "user_goal_check" + summarization = "summarization" generate_metadata = "generate_metadata" extract_loop_values = "extract_loop_values" generate_task_in_loop = "generate_task_in_loop" diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index 35bf1845..c90dd3f1 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -17,7 +17,6 @@ from skyvern.forge.sdk.schemas.observers import ( ObserverCruise, ObserverCruiseStatus, ObserverMetadata, - ObserverThought, ObserverThoughtScenario, ObserverThoughtType, ) @@ -209,13 +208,13 @@ async def run_observer_cruise( organization_id=organization_id, ) return - except Exception: + except Exception as e: LOG.error("Failed to run observer cruise", exc_info=True) + failure_reason = f"Failed to run observer cruise: {str(e)}" await mark_observer_cruise_as_failed( observer_cruise_id, workflow_run_id=observer_cruise.workflow_run_id, - # TODO: add better failure reason - failure_reason="Failed to run observer cruise", + failure_reason=failure_reason, organization_id=organization_id, ) return @@ -395,7 +394,12 @@ async def run_observer_cruise_helper( iteration=i, workflow_run_id=workflow_run_id, ) - await app.WORKFLOW_SERVICE.mark_workflow_run_as_completed(workflow_run_id=workflow_run_id) + await _summarize_observer_cruise( + observer_cruise=observer_cruise, + task_history=task_history, + context=context, + screenshots=scraped_page.screenshots, + ) break # parse observer repsonse and run the next task @@ -573,10 +577,11 @@ async def run_observer_cruise_helper( workflow_run_id=workflow_run_id, completion_resp=completion_resp, ) - await mark_observer_cruise_as_completed( - observer_cruise_id=observer_cruise_id, - workflow_run_id=workflow_run_id, - organization_id=organization_id, + await _summarize_observer_cruise( + observer_cruise=observer_cruise, + task_history=task_history, + context=context, + screenshots=completion_screenshots, ) break else: @@ -1039,24 +1044,6 @@ async def get_observer_thought_timelines( ] -async def _record_thought_screenshot(observer_thought: ObserverThought, workflow_run_id: str) -> None: - # get the browser state for the workflow run - browser_state = app.BROWSER_MANAGER.get_for_workflow_run(workflow_run_id=workflow_run_id) - if not browser_state: - LOG.warning("No browser state found for the workflow run", workflow_run_id=workflow_run_id) - return - # get the screenshot for the workflow run - try: - screenshot = await browser_state.take_screenshot(full_page=True) - await app.ARTIFACT_MANAGER.create_observer_thought_artifact( - observer_thought=observer_thought, - artifact_type=ArtifactType.SCREENSHOT_LLM, - data=screenshot, - ) - except Exception: - LOG.warning("Failed to take screenshot for the observer thought", observer_thought=observer_thought) - - async def get_observer_cruise(observer_cruise_id: str, organization_id: str | None = None) -> ObserverCruise | None: return await app.DATABASE.get_observer_cruise(observer_cruise_id, organization_id=organization_id) @@ -1080,11 +1067,15 @@ async def mark_observer_cruise_as_completed( observer_cruise_id: str, workflow_run_id: str | None = None, organization_id: str | None = None, + summary: str | None = None, + output: dict[str, Any] | None = None, ) -> None: await app.DATABASE.update_observer_cruise( observer_cruise_id, organization_id=organization_id, status=ObserverCruiseStatus.completed, + summary=summary, + output=output, ) if workflow_run_id: await app.WORKFLOW_SERVICE.mark_workflow_run_as_completed(workflow_run_id) @@ -1157,3 +1148,50 @@ def _get_extracted_data_from_block_result( loop_output_overall.append(inner_loop_output_overall) return loop_output_overall if loop_output_overall else None return None + + +async def _summarize_observer_cruise( + observer_cruise: ObserverCruise, + task_history: list[dict], + context: SkyvernContext, + screenshots: list[bytes] | None = None, +) -> None: + observer_thought = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise.observer_cruise_id, + organization_id=observer_cruise.organization_id, + workflow_run_id=observer_cruise.workflow_run_id, + workflow_id=observer_cruise.workflow_id, + workflow_permanent_id=observer_cruise.workflow_permanent_id, + observer_thought_type=ObserverThoughtType.user_goal_check, + observer_thought_scenario=ObserverThoughtScenario.summarization, + ) + # summarize the observer cruise and format the output + observer_summary_prompt = prompt_engine.load_prompt( + "observer_summary", + user_goal=observer_cruise.prompt, + task_history=task_history, + local_datetime=datetime.now(context.tz_info).isoformat(), + ) + observer_summary_resp = await app.LLM_API_HANDLER( + prompt=observer_summary_prompt, + screenshots=screenshots, + observer_thought=observer_thought, + ) + LOG.info("Observer summary response", observer_summary_resp=observer_summary_resp) + + thought = observer_summary_resp.get("description") + summarized_output = observer_summary_resp.get("output") + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought.observer_thought_id, + organization_id=observer_cruise.organization_id, + thought=thought, + output=observer_summary_resp, + ) + + await mark_observer_cruise_as_completed( + observer_cruise_id=observer_cruise.observer_cruise_id, + workflow_run_id=observer_cruise.workflow_run_id, + organization_id=observer_cruise.organization_id, + summary=thought, + output=summarized_output, + )