new observer thoughts (#1442)

This commit is contained in:
Shuchang Zheng
2024-12-27 09:04:09 -08:00
committed by GitHub
parent d03957d590
commit 9e6c2362bf
6 changed files with 198 additions and 17 deletions

View File

@@ -55,7 +55,12 @@ from skyvern.forge.sdk.db.utils import (
)
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverThought
from skyvern.forge.sdk.schemas.observers import (
ObserverCruise,
ObserverCruiseStatus,
ObserverThought,
ObserverThoughtType,
)
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus
@@ -1924,17 +1929,19 @@ class AgentDB:
async def get_observer_thoughts(
self,
observer_cruise_id: str,
observer_thought_types: list[ObserverThoughtType] | None = None,
organization_id: str | None = None,
) -> list[ObserverThought]:
async with self.Session() as session:
observer_thoughts = (
await session.scalars(
select(ObserverThoughtModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(organization_id=organization_id)
.order_by(ObserverThoughtModel.created_at)
)
).all()
query = (
select(ObserverThoughtModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(organization_id=organization_id)
.order_by(ObserverThoughtModel.created_at)
)
if observer_thought_types:
query = query.filter(ObserverThoughtModel.observer_thought_type.in_(observer_thought_types))
observer_thoughts = (await session.scalars(query)).all()
return [ObserverThought.model_validate(thought) for thought in observer_thoughts]
async def create_observer_cruise(
@@ -1971,6 +1978,9 @@ class AgentDB:
observation: str | None = None,
thought: str | None = None,
answer: str | None = None,
observer_thought_scenario: str | None = None,
observer_thought_type: str = ObserverThoughtType.plan,
output: dict[str, Any] | None = None,
organization_id: str | None = None,
) -> ObserverThought:
async with self.Session() as session:
@@ -1984,6 +1994,9 @@ class AgentDB:
observation=observation,
thought=thought,
answer=answer,
observer_thought_scenario=observer_thought_scenario,
observer_thought_type=observer_thought_type,
output=output,
organization_id=organization_id,
)
session.add(new_observer_thought)
@@ -1995,9 +2008,13 @@ class AgentDB:
self,
observer_thought_id: str,
workflow_run_block_id: str | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
observation: str | None = None,
thought: str | None = None,
answer: str | None = None,
output: dict[str, Any] | None = None,
organization_id: str | None = None,
) -> ObserverThought:
async with self.Session() as session:
@@ -2011,12 +2028,20 @@ class AgentDB:
if observer_thought:
if workflow_run_block_id:
observer_thought.workflow_run_block_id = workflow_run_block_id
if workflow_run_id:
observer_thought.workflow_run_id = workflow_run_id
if workflow_id:
observer_thought.workflow_id = workflow_id
if workflow_permanent_id:
observer_thought.workflow_permanent_id = workflow_permanent_id
if observation:
observer_thought.observation = observation
if thought:
observer_thought.thought = thought
if answer:
observer_thought.answer = answer
if output:
observer_thought.output = output
await session.commit()
await session.refresh(observer_thought)
return ObserverThought.model_validate(observer_thought)