make observer thought artifacts work (#1423)

This commit is contained in:
Shuchang Zheng
2024-12-22 23:01:02 -08:00
committed by GitHub
parent 94a3779bd7
commit 2029c0c41f
4 changed files with 46 additions and 6 deletions

View File

@@ -84,3 +84,4 @@ class LogEntityType(StrEnum):
TASK = "task"
WORKFLOW_RUN = "workflow_run"
WORKFLOW_RUN_BLOCK = "workflow_run_block"
OBSERVER = "observer"

View File

@@ -1975,6 +1975,37 @@ class AgentDB:
await session.refresh(new_observer_thought)
return ObserverThought.model_validate(new_observer_thought)
async def update_observer_thought(
self,
observer_thought_id: str,
workflow_run_block_id: str | None = None,
observation: str | None = None,
thought: str | None = None,
answer: str | None = None,
organization_id: str | None = None,
) -> ObserverThought:
async with self.Session() as session:
observer_thought = (
await session.scalars(
select(ObserverThoughtModel)
.filter_by(observer_thought_id=observer_thought_id)
.filter_by(organization_id=organization_id)
)
).first()
if observer_thought:
if workflow_run_block_id:
observer_thought.workflow_run_block_id = workflow_run_block_id
if observation:
observer_thought.observation = observation
if thought:
observer_thought.thought = thought
if answer:
observer_thought.answer = answer
await session.commit()
await session.refresh(observer_thought)
return ObserverThought.model_validate(observer_thought)
raise NotFoundError(f"ObserverThought {observer_thought_id}")
async def update_observer_cruise(
self,
observer_cruise_id: str,

View File

@@ -478,6 +478,7 @@ class EntityType(str, Enum):
TASK = "task"
WORKFLOW_RUN = "workflow_run"
WORKFLOW_RUN_BLOCK = "workflow_run_block"
OBSERVER_THOUGHT = "observer_thought"
entity_type_to_param = {
@@ -485,6 +486,7 @@ entity_type_to_param = {
EntityType.TASK: "task_id",
EntityType.WORKFLOW_RUN: "workflow_run_id",
EntityType.WORKFLOW_RUN_BLOCK: "workflow_run_block_id",
EntityType.OBSERVER_THOUGHT: "observer_thought_id",
}

View File

@@ -231,8 +231,17 @@ async def run_observer_cruise(
task_history=task_history,
local_datetime=datetime.now(context.tz_info).isoformat(),
)
observer_thought = await app.DATABASE.create_observer_thought(
observer_cruise_id=observer_cruise_id,
organization_id=organization_id,
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
workflow_permanent_id=workflow.workflow_permanent_id,
)
observer_response = await app.LLM_API_HANDLER(
prompt=observer_prompt, screenshots=scraped_page.screenshots, observer_cruise=observer_cruise
prompt=observer_prompt,
screenshots=scraped_page.screenshots,
observer_thought=observer_thought,
)
LOG.info(
"Observer response",
@@ -247,12 +256,9 @@ async def run_observer_cruise(
thoughts: str = observer_response.get("thoughts", "")
plan: str = observer_response.get("plan", "")
# Create and save observer thought
await app.DATABASE.create_observer_thought(
observer_cruise_id=observer_cruise_id,
await app.DATABASE.update_observer_thought(
observer_thought_id=observer_thought.observer_thought_id,
organization_id=organization_id,
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
workflow_permanent_id=workflow.workflow_permanent_id,
thought=thoughts,
observation=observation,
answer=plan,