task v2 refactor part 6 - observer_cruise_id -> task_v2_id (#1817)

This commit is contained in:
Shuchang Zheng
2025-02-23 16:03:49 -08:00
committed by GitHub
parent 2d24055c36
commit ffbc95e1b4
22 changed files with 238 additions and 250 deletions

View File

@@ -210,7 +210,7 @@ class AgentDB:
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_cruise_id: str | None = None,
task_v2_id: str | None = None,
observer_thought_id: str | None = None,
ai_suggestion_id: str | None = None,
organization_id: str | None = None,
@@ -225,7 +225,7 @@ class AgentDB:
step_id=step_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_cruise_id=observer_cruise_id,
observer_cruise_id=task_v2_id,
observer_thought_id=observer_thought_id,
ai_suggestion_id=ai_suggestion_id,
organization_id=organization_id,
@@ -807,9 +807,9 @@ class AgentDB:
return convert_to_organization_auth_token(auth_token)
async def get_artifacts_for_observer_cruise(
async def get_artifacts_for_task_v2(
self,
observer_cruise_id: str,
task_v2_id: str,
organization_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
) -> list[Artifact]:
@@ -817,7 +817,7 @@ class AgentDB:
async with self.Session() as session:
query = (
select(ArtifactModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
if artifact_types:
@@ -894,7 +894,7 @@ class AgentDB:
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_thought_id: str | None = None,
observer_cruise_id: str | None = None,
task_v2_id: str | None = None,
organization_id: str | None = None,
) -> list[Artifact]:
try:
@@ -913,8 +913,8 @@ class AgentDB:
query = query.filter_by(workflow_run_block_id=workflow_run_block_id)
if observer_thought_id is not None:
query = query.filter_by(observer_thought_id=observer_thought_id)
if observer_cruise_id is not None:
query = query.filter_by(observer_cruise_id=observer_cruise_id)
if task_v2_id is not None:
query = query.filter_by(observer_cruise_id=task_v2_id)
if organization_id is not None:
query = query.filter_by(organization_id=organization_id)
@@ -938,7 +938,7 @@ class AgentDB:
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_thought_id: str | None = None,
observer_cruise_id: str | None = None,
task_v2_id: str | None = None,
organization_id: str | None = None,
) -> Artifact | None:
artifacts = await self.get_artifacts_by_entity_id(
@@ -948,7 +948,7 @@ class AgentDB:
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_thought_id=observer_thought_id,
observer_cruise_id=observer_cruise_id,
task_v2_id=task_v2_id,
organization_id=organization_id,
)
return artifacts[0] if artifacts else None
@@ -1915,13 +1915,11 @@ class AgentDB:
await session.execute(stmt)
await session.commit()
async def delete_observer_cruise_artifacts(
self, observer_cruise_id: str, organization_id: str | None = None
) -> None:
async def delete_task_v2_artifacts(self, task_v2_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.observer_cruise_id == observer_cruise_id,
ArtifactModel.observer_cruise_id == task_v2_id,
ArtifactModel.organization_id == organization_id,
)
)
@@ -2130,47 +2128,43 @@ class AgentDB:
await session.execute(stmt)
await session.commit()
async def get_observer_cruise(
self, observer_cruise_id: str, organization_id: str | None = None
) -> ObserverTask | None:
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> ObserverTask | None:
async with self.Session() as session:
if observer_cruise := (
if task_v2 := (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
).first():
return ObserverTask.model_validate(observer_cruise)
return ObserverTask.model_validate(task_v2)
return None
async def delete_observer_thoughts_for_cruise(
self, observer_cruise_id: str, organization_id: str | None = None
) -> None:
async def delete_observer_thoughts_for_cruise(self, task_v2_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(ObserverThoughtModel).where(
and_(
ObserverThoughtModel.observer_cruise_id == observer_cruise_id,
ObserverThoughtModel.observer_cruise_id == task_v2_id,
ObserverThoughtModel.organization_id == organization_id,
)
)
await session.execute(stmt)
await session.commit()
async def get_observer_cruise_by_workflow_run_id(
async def get_task_v2_by_workflow_run_id(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> ObserverTask | None:
async with self.Session() as session:
if observer_cruise := (
if task_v2 := (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_run_id=workflow_run_id)
)
).first():
return ObserverTask.model_validate(observer_cruise)
return ObserverTask.model_validate(task_v2)
return None
async def get_observer_thought(
@@ -2189,14 +2183,14 @@ class AgentDB:
async def get_observer_thoughts(
self,
observer_cruise_id: str,
task_v2_id: str,
observer_thought_types: list[ObserverThoughtType] | None = None,
organization_id: str | None = None,
) -> list[ObserverThought]:
async with self.Session() as session:
query = (
select(ObserverThoughtModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
.order_by(ObserverThoughtModel.created_at)
)
@@ -2205,7 +2199,7 @@ class AgentDB:
observer_thoughts = (await session.scalars(query)).all()
return [ObserverThought.model_validate(thought) for thought in observer_thoughts]
async def create_observer_cruise(
async def create_task_v2(
self,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
@@ -2219,7 +2213,7 @@ class AgentDB:
webhook_callback_url: str | None = None,
) -> ObserverTask:
async with self.Session() as session:
new_observer_cruise = ObserverCruiseModel(
new_task_v2 = ObserverCruiseModel(
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
@@ -2231,14 +2225,14 @@ class AgentDB:
webhook_callback_url=webhook_callback_url,
organization_id=organization_id,
)
session.add(new_observer_cruise)
session.add(new_task_v2)
await session.commit()
await session.refresh(new_observer_cruise)
return ObserverTask.model_validate(new_observer_cruise)
await session.refresh(new_task_v2)
return ObserverTask.model_validate(new_task_v2)
async def create_observer_thought(
self,
observer_cruise_id: str,
task_v2_id: str,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
@@ -2257,7 +2251,7 @@ class AgentDB:
) -> ObserverThought:
async with self.Session() as session:
new_observer_thought = ObserverThoughtModel(
observer_cruise_id=observer_cruise_id,
observer_cruise_id=task_v2_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
@@ -2331,9 +2325,9 @@ class AgentDB:
return ObserverThought.model_validate(observer_thought)
raise NotFoundError(f"ObserverThought {observer_thought_id}")
async def update_observer_cruise(
async def update_task_v2(
self,
observer_cruise_id: str,
task_v2_id: str,
status: ObserverTaskStatus | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
@@ -2345,34 +2339,34 @@ class AgentDB:
organization_id: str | None = None,
) -> ObserverTask:
async with self.Session() as session:
observer_cruise = (
task_v2 = (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
).first()
if observer_cruise:
if task_v2:
if status:
observer_cruise.status = status
task_v2.status = status
if workflow_run_id:
observer_cruise.workflow_run_id = workflow_run_id
task_v2.workflow_run_id = workflow_run_id
if workflow_id:
observer_cruise.workflow_id = workflow_id
task_v2.workflow_id = workflow_id
if workflow_permanent_id:
observer_cruise.workflow_permanent_id = workflow_permanent_id
task_v2.workflow_permanent_id = workflow_permanent_id
if url:
observer_cruise.url = url
task_v2.url = url
if prompt:
observer_cruise.prompt = prompt
task_v2.prompt = prompt
if summary:
observer_cruise.summary = summary
task_v2.summary = summary
if output:
observer_cruise.output = output
task_v2.output = output
await session.commit()
await session.refresh(observer_cruise)
return ObserverTask.model_validate(observer_cruise)
raise NotFoundError(f"ObserverTask {observer_cruise_id} not found")
await session.refresh(task_v2)
return ObserverTask.model_validate(task_v2)
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
async def create_workflow_run_block(
self,

View File

@@ -37,7 +37,7 @@ BITWARDEN_SENSITIVE_INFORMATION_PARAMETER_PREFIX = "bsi"
CREDENTIAL_PARAMETER_PREFIX = "cp"
CREDENTIAL_PREFIX = "cred"
ORGANIZATION_BITWARDEN_COLLECTION_PREFIX = "obc"
OBSERVER_CRUISE_ID = "oc"
TASK_V2_ID = "oc"
OBSERVER_THOUGHT_ID = "ot"
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
ORG_PREFIX = "o"
@@ -156,9 +156,9 @@ def generate_action_id() -> str:
return f"{ACTION_PREFIX}_{int_id}"
def generate_observer_cruise_id() -> str:
def generate_task_v2_id() -> str:
int_id = generate_id()
return f"{OBSERVER_CRUISE_ID}_{int_id}"
return f"{TASK_V2_ID}_{int_id}"
def generate_observer_thought_id() -> str:

View File

@@ -29,7 +29,6 @@ from skyvern.forge.sdk.db.id import (
generate_bitwarden_sensitive_information_parameter_id,
generate_credential_id,
generate_credential_parameter_id,
generate_observer_cruise_id,
generate_observer_thought_id,
generate_org_id,
generate_organization_auth_token_id,
@@ -40,6 +39,7 @@ from skyvern.forge.sdk.db.id import (
generate_task_generation_id,
generate_task_id,
generate_task_run_id,
generate_task_v2_id,
generate_totp_code_id,
generate_workflow_id,
generate_workflow_parameter_id,
@@ -569,7 +569,8 @@ class ObserverCruiseModel(Base):
__tablename__ = "observer_cruises"
__table_args__ = (Index("oc_org_wfr_index", "organization_id", "workflow_run_id"),)
observer_cruise_id = Column(String, primary_key=True, default=generate_observer_cruise_id)
# observer_cruise_id is the task_id for task v2
observer_cruise_id = Column(String, primary_key=True, default=generate_task_v2_id)
status = Column(String, nullable=False, default="created")
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True)