Sqlalchemy AsyncSession (#122)
This commit is contained in:
@@ -2,9 +2,9 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, create_engine, delete
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
@@ -63,8 +63,8 @@ class AgentDB:
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.debug_enabled = debug_enabled
|
||||
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
self.engine = create_async_engine(database_string, json_serializer=_custom_json_serializer)
|
||||
self.Session = async_sessionmaker(bind=self.engine)
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
@@ -83,7 +83,7 @@ class AgentDB:
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
url=url,
|
||||
@@ -101,8 +101,8 @@ class AgentDB:
|
||||
error_code_mapping=error_code_mapping,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
await session.commit()
|
||||
await session.refresh(new_task)
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -119,7 +119,7 @@ class AgentDB:
|
||||
organization_id: str | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
order=order,
|
||||
@@ -128,8 +128,8 @@ class AgentDB:
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
await session.commit()
|
||||
await session.refresh(new_step)
|
||||
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -148,7 +148,7 @@ class AgentDB:
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=artifact_id,
|
||||
task_id=task_id,
|
||||
@@ -158,8 +158,8 @@ class AgentDB:
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
await session.commit()
|
||||
await session.refresh(new_artifact)
|
||||
return convert_to_artifact(new_artifact, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
@@ -171,13 +171,12 @@ class AgentDB:
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
|
||||
@@ -191,13 +190,12 @@ class AgentDB:
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(StepModel).filter_by(step_id=step_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
|
||||
else:
|
||||
@@ -211,15 +209,16 @@ class AgentDB:
|
||||
|
||||
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
steps := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
):
|
||||
async with self.Session() as session:
|
||||
if steps := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
)
|
||||
).all():
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
else:
|
||||
return []
|
||||
@@ -232,15 +231,16 @@ class AgentDB:
|
||||
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
return (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
)
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
)
|
||||
).all()
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
@@ -250,14 +250,15 @@ class AgentDB:
|
||||
|
||||
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.desc())
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
|
||||
@@ -281,14 +282,15 @@ class AgentDB:
|
||||
incremental_cost: float | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if output is not None:
|
||||
@@ -300,7 +302,7 @@ class AgentDB:
|
||||
if incremental_cost is not None:
|
||||
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||
if not updated_step:
|
||||
raise NotFoundError("Step not found")
|
||||
@@ -331,13 +333,12 @@ class AgentDB:
|
||||
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
|
||||
)
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
task := session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if extracted_information is not None:
|
||||
@@ -346,7 +347,7 @@ class AgentDB:
|
||||
task.failure_reason = failure_reason
|
||||
if errors is not None:
|
||||
task.errors = errors
|
||||
session.commit()
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
@@ -374,16 +375,17 @@ class AgentDB:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
.all()
|
||||
)
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -394,10 +396,10 @@ class AgentDB:
|
||||
|
||||
async def get_organization(self, organization_id: str) -> Organization | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if organization := (
|
||||
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
|
||||
):
|
||||
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
||||
).first():
|
||||
return convert_to_organization(organization)
|
||||
else:
|
||||
return None
|
||||
@@ -414,15 +416,15 @@ class AgentDB:
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
) -> Organization:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
org = OrganizationModel(
|
||||
organization_name=organization_name,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
)
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
|
||||
return convert_to_organization(org)
|
||||
|
||||
@@ -432,14 +434,15 @@ class AgentDB:
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
)
|
||||
).first():
|
||||
return convert_to_organization_auth_token(token)
|
||||
else:
|
||||
return None
|
||||
@@ -457,15 +460,16 @@ class AgentDB:
|
||||
token: str,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if token_obj := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
.filter_by(valid=True)
|
||||
)
|
||||
).first():
|
||||
return convert_to_organization_auth_token(token_obj)
|
||||
else:
|
||||
return None
|
||||
@@ -482,15 +486,15 @@ class AgentDB:
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=token,
|
||||
)
|
||||
session.add(token)
|
||||
session.commit()
|
||||
session.refresh(token)
|
||||
await session.commit()
|
||||
await session.refresh(token)
|
||||
|
||||
return convert_to_organization_auth_token(token)
|
||||
|
||||
@@ -501,14 +505,15 @@ class AgentDB:
|
||||
organization_id: str | None = None,
|
||||
) -> list[Artifact]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if artifacts := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.all()
|
||||
):
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all():
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
@@ -525,13 +530,14 @@ class AgentDB:
|
||||
organization_id: str,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if artifact := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
@@ -550,16 +556,17 @@ class AgentDB:
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
@@ -577,16 +584,17 @@ class AgentDB:
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
|
||||
.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
|
||||
.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
@@ -605,8 +613,8 @@ class AgentDB:
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
|
||||
async with self.Session() as session:
|
||||
artifact_query = select(ArtifactModel).filter_by(task_id=task_id)
|
||||
if step_id:
|
||||
artifact_query = artifact_query.filter_by(step_id=step_id)
|
||||
if organization_id:
|
||||
@@ -614,7 +622,7 @@ class AgentDB:
|
||||
if artifact_types:
|
||||
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
|
||||
artifact = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
@@ -632,15 +640,11 @@ class AgentDB:
|
||||
before: datetime | None = None,
|
||||
) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
query = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_id=workflow_id)
|
||||
)
|
||||
async with self.Session() as session:
|
||||
query = select(TaskModel).filter_by(organization_id=organization_id).filter_by(workflow_id=workflow_id)
|
||||
if before:
|
||||
query = query.filter(TaskModel.created_at < before)
|
||||
task = query.order_by(TaskModel.created_at.desc()).first()
|
||||
task = (await session.scalars(query.order_by(TaskModel.created_at.desc()))).first()
|
||||
if task:
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
@@ -655,7 +659,7 @@ class AgentDB:
|
||||
workflow_definition: dict[str, Any],
|
||||
description: str | None = None,
|
||||
) -> Workflow:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow = WorkflowModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
@@ -663,14 +667,16 @@ class AgentDB:
|
||||
workflow_definition=workflow_definition,
|
||||
)
|
||||
session.add(workflow)
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
|
||||
async def get_workflow(self, workflow_id: str) -> Workflow | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
|
||||
async with self.Session() as session:
|
||||
if workflow := (
|
||||
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
|
||||
).first():
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -684,8 +690,8 @@ class AgentDB:
|
||||
description: str | None = None,
|
||||
workflow_definition: dict[str, Any] | None = None,
|
||||
) -> Workflow | None:
|
||||
with self.Session() as session:
|
||||
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
|
||||
async with self.Session() as session:
|
||||
workflow = (await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))).first()
|
||||
if workflow:
|
||||
if title:
|
||||
workflow.title = title
|
||||
@@ -693,8 +699,8 @@ class AgentDB:
|
||||
workflow.description = description
|
||||
if workflow_definition:
|
||||
workflow.workflow_definition = workflow_definition
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
|
||||
return None
|
||||
@@ -703,7 +709,7 @@ class AgentDB:
|
||||
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
|
||||
) -> WorkflowRun:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow_run = WorkflowRunModel(
|
||||
workflow_id=workflow_id,
|
||||
proxy_location=proxy_location,
|
||||
@@ -711,28 +717,32 @@ class AgentDB:
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
|
||||
with self.Session() as session:
|
||||
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).first()
|
||||
if workflow_run:
|
||||
workflow_run.status = status
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
|
||||
return None
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
|
||||
async with self.Session() as session:
|
||||
if workflow_run := (
|
||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -741,8 +751,10 @@ class AgentDB:
|
||||
|
||||
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
|
||||
async with self.Session() as session:
|
||||
workflow_runs = (
|
||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -757,7 +769,7 @@ class AgentDB:
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow_parameter = WorkflowParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
@@ -766,8 +778,8 @@ class AgentDB:
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_parameter)
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -780,7 +792,7 @@ class AgentDB:
|
||||
aws_key: str,
|
||||
description: str | None = None,
|
||||
) -> AWSSecretParameter:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
aws_secret_parameter = AWSSecretParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
key=key,
|
||||
@@ -788,8 +800,8 @@ class AgentDB:
|
||||
description=description,
|
||||
)
|
||||
session.add(aws_secret_parameter)
|
||||
session.commit()
|
||||
session.refresh(aws_secret_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
async def create_output_parameter(
|
||||
@@ -798,21 +810,23 @@ class AgentDB:
|
||||
key: str,
|
||||
description: str | None = None,
|
||||
) -> OutputParameter:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
output_parameter = OutputParameterModel(
|
||||
key=key,
|
||||
description=description,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(output_parameter)
|
||||
session.commit()
|
||||
session.refresh(output_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(output_parameter)
|
||||
return convert_to_output_parameter(output_parameter)
|
||||
|
||||
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
output_parameters = session.query(OutputParameterModel).filter_by(workflow_id=workflow_id).all()
|
||||
async with self.Session() as session:
|
||||
output_parameters = (
|
||||
await session.scalars(select(OutputParameterModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -820,13 +834,14 @@ class AgentDB:
|
||||
|
||||
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow_run_output_parameters = (
|
||||
session.query(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
.all()
|
||||
)
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
)
|
||||
).all()
|
||||
return [
|
||||
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
|
||||
for parameter in workflow_run_output_parameters
|
||||
@@ -839,15 +854,15 @@ class AgentDB:
|
||||
self, workflow_run_id: str, output_parameter_id: str, value: dict[str, Any] | list | str | None
|
||||
) -> WorkflowRunOutputParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=output_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_output_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_run_output_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
except SQLAlchemyError:
|
||||
@@ -856,8 +871,10 @@ class AgentDB:
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
|
||||
async with self.Session() as session:
|
||||
workflow_parameters = (
|
||||
await session.scalars(select(WorkflowParameterModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
@@ -865,10 +882,12 @@ class AgentDB:
|
||||
|
||||
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if workflow_parameter := (
|
||||
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -879,15 +898,15 @@ class AgentDB:
|
||||
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_run_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_parameter)
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id)
|
||||
@@ -900,10 +919,10 @@ class AgentDB:
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameters = (
|
||||
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
|
||||
)
|
||||
await session.scalars(select(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).all()
|
||||
results = []
|
||||
for workflow_run_parameter in workflow_run_parameters:
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
|
||||
@@ -926,13 +945,14 @@ class AgentDB:
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.first()
|
||||
):
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -941,20 +961,19 @@ class AgentDB:
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at)
|
||||
.all()
|
||||
)
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
@@ -962,11 +981,11 @@ class AgentDB:
|
||||
ArtifactModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
async with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(StepModel).where(
|
||||
and_(
|
||||
@@ -974,5 +993,5 @@ class AgentDB:
|
||||
StepModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
@@ -19,7 +20,7 @@ from skyvern.forge.sdk.db.id import (
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user