From 2029c0c41fc7e72bf09d6c5294c12391f7185a26 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 22 Dec 2024 23:01:02 -0800 Subject: [PATCH] make observer thought artifacts work (#1423) --- skyvern/forge/sdk/artifact/models.py | 1 + skyvern/forge/sdk/db/client.py | 31 +++++++++++++++++++ skyvern/forge/sdk/routes/agent_protocol.py | 2 ++ .../forge/sdk/services/observer_service.py | 18 +++++++---- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/skyvern/forge/sdk/artifact/models.py b/skyvern/forge/sdk/artifact/models.py index 307c902b..c91a37a4 100644 --- a/skyvern/forge/sdk/artifact/models.py +++ b/skyvern/forge/sdk/artifact/models.py @@ -84,3 +84,4 @@ class LogEntityType(StrEnum): TASK = "task" WORKFLOW_RUN = "workflow_run" WORKFLOW_RUN_BLOCK = "workflow_run_block" + OBSERVER = "observer" diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 0460f08b..74f1f38d 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1975,6 +1975,37 @@ class AgentDB: await session.refresh(new_observer_thought) return ObserverThought.model_validate(new_observer_thought) + async def update_observer_thought( + self, + observer_thought_id: str, + workflow_run_block_id: str | None = None, + observation: str | None = None, + thought: str | None = None, + answer: str | None = None, + organization_id: str | None = None, + ) -> ObserverThought: + async with self.Session() as session: + observer_thought = ( + await session.scalars( + select(ObserverThoughtModel) + .filter_by(observer_thought_id=observer_thought_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if observer_thought: + if workflow_run_block_id: + observer_thought.workflow_run_block_id = workflow_run_block_id + if observation: + observer_thought.observation = observation + if thought: + observer_thought.thought = thought + if answer: + observer_thought.answer = answer + await session.commit() + await session.refresh(observer_thought) + return ObserverThought.model_validate(observer_thought) + raise NotFoundError(f"ObserverThought {observer_thought_id}") + async def update_observer_cruise( self, observer_cruise_id: str, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 020d9712..82e52ff4 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -478,6 +478,7 @@ class EntityType(str, Enum): TASK = "task" WORKFLOW_RUN = "workflow_run" WORKFLOW_RUN_BLOCK = "workflow_run_block" + OBSERVER_THOUGHT = "observer_thought" entity_type_to_param = { @@ -485,6 +486,7 @@ entity_type_to_param = { EntityType.TASK: "task_id", EntityType.WORKFLOW_RUN: "workflow_run_id", EntityType.WORKFLOW_RUN_BLOCK: "workflow_run_block_id", + EntityType.OBSERVER_THOUGHT: "observer_thought_id", } diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index fdbc5473..92de49b0 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -231,8 +231,17 @@ async def run_observer_cruise( task_history=task_history, local_datetime=datetime.now(context.tz_info).isoformat(), ) + observer_thought = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise_id, + organization_id=organization_id, + workflow_run_id=workflow_run.workflow_run_id, + workflow_id=workflow.workflow_id, + workflow_permanent_id=workflow.workflow_permanent_id, + ) observer_response = await app.LLM_API_HANDLER( - prompt=observer_prompt, screenshots=scraped_page.screenshots, observer_cruise=observer_cruise + prompt=observer_prompt, + screenshots=scraped_page.screenshots, + observer_thought=observer_thought, ) LOG.info( "Observer response", @@ -247,12 +256,9 @@ async def run_observer_cruise( thoughts: str = observer_response.get("thoughts", "") plan: str = observer_response.get("plan", "") # Create and save observer thought - await app.DATABASE.create_observer_thought( - observer_cruise_id=observer_cruise_id, + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought.observer_thought_id, organization_id=organization_id, - workflow_run_id=workflow_run.workflow_run_id, - workflow_id=workflow.workflow_id, - workflow_permanent_id=workflow.workflow_permanent_id, thought=thoughts, observation=observation, answer=plan,