Sqlalchemy AsyncSession (#122)

This commit is contained in:
Kerem Yilmaz
2024-03-24 12:47:47 -07:00
committed by GitHub
parent 8b9db3a295
commit cf4749c1d5
4 changed files with 295 additions and 277 deletions

View File

@@ -2,9 +2,9 @@ from datetime import datetime
from typing import Any
import structlog
from sqlalchemy import and_, create_engine, delete
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
@@ -63,8 +63,8 @@ class AgentDB:
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
super().__init__()
self.debug_enabled = debug_enabled
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
self.Session = sessionmaker(bind=self.engine)
self.engine = create_async_engine(database_string, json_serializer=_custom_json_serializer)
self.Session = async_sessionmaker(bind=self.engine)
async def create_task(
self,
@@ -83,7 +83,7 @@ class AgentDB:
error_code_mapping: dict[str, str] | None = None,
) -> Task:
try:
with self.Session() as session:
async with self.Session() as session:
new_task = TaskModel(
status="created",
url=url,
@@ -101,8 +101,8 @@ class AgentDB:
error_code_mapping=error_code_mapping,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
await session.commit()
await session.refresh(new_task)
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -119,7 +119,7 @@ class AgentDB:
organization_id: str | None = None,
) -> Step:
try:
with self.Session() as session:
async with self.Session() as session:
new_step = StepModel(
task_id=task_id,
order=order,
@@ -128,8 +128,8 @@ class AgentDB:
organization_id=organization_id,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
await session.commit()
await session.refresh(new_step)
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -148,7 +148,7 @@ class AgentDB:
organization_id: str | None = None,
) -> Artifact:
try:
with self.Session() as session:
async with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
task_id=task_id,
@@ -158,8 +158,8 @@ class AgentDB:
organization_id=organization_id,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
await session.commit()
await session.refresh(new_artifact)
return convert_to_artifact(new_artifact, self.debug_enabled)
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
@@ -171,13 +171,12 @@ class AgentDB:
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id"""
try:
with self.Session() as session:
async with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
@@ -191,13 +190,12 @@ class AgentDB:
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
async with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
await session.scalars(
select(StepModel).filter_by(step_id=step_id).filter_by(organization_id=organization_id)
)
).first():
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
@@ -211,15 +209,16 @@ class AgentDB:
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
try:
with self.Session() as session:
if (
steps := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
):
async with self.Session() as session:
if steps := (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
)
).all():
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
else:
return []
@@ -232,15 +231,16 @@ class AgentDB:
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
try:
with self.Session() as session:
async with self.Session() as session:
return (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
)
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
)
).all()
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@@ -250,14 +250,15 @@ class AgentDB:
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
async with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order.desc())
.first()
):
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order.desc())
)
).first():
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
@@ -281,14 +282,15 @@ class AgentDB:
incremental_cost: float | None = None,
) -> Step:
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
async with self.Session() as session:
if step := (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
)
).first():
if status is not None:
step.status = status
if output is not None:
@@ -300,7 +302,7 @@ class AgentDB:
if incremental_cost is not None:
step.step_cost = incremental_cost + float(step.step_cost or 0)
session.commit()
await session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id)
if not updated_step:
raise NotFoundError("Step not found")
@@ -331,13 +333,12 @@ class AgentDB:
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
)
try:
with self.Session() as session:
if (
task := session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
if status is not None:
task.status = status
if extracted_information is not None:
@@ -346,7 +347,7 @@ class AgentDB:
task.failure_reason = failure_reason
if errors is not None:
task.errors = errors
session.commit()
await session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
@@ -374,16 +375,17 @@ class AgentDB:
raise ValueError(f"Page must be greater than 0, got {page}")
try:
with self.Session() as session:
async with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
tasks = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.order_by(TaskModel.created_at.desc())
.limit(page_size)
.offset(db_page * page_size)
.all()
)
await session.scalars(
select(TaskModel)
.filter_by(organization_id=organization_id)
.order_by(TaskModel.created_at.desc())
.limit(page_size)
.offset(db_page * page_size)
)
).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -394,10 +396,10 @@ class AgentDB:
async def get_organization(self, organization_id: str) -> Organization | None:
try:
with self.Session() as session:
async with self.Session() as session:
if organization := (
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
):
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
).first():
return convert_to_organization(organization)
else:
return None
@@ -414,15 +416,15 @@ class AgentDB:
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
) -> Organization:
with self.Session() as session:
async with self.Session() as session:
org = OrganizationModel(
organization_name=organization_name,
webhook_callback_url=webhook_callback_url,
max_steps_per_run=max_steps_per_run,
)
session.add(org)
session.commit()
session.refresh(org)
await session.commit()
await session.refresh(org)
return convert_to_organization(org)
@@ -432,14 +434,15 @@ class AgentDB:
token_type: OrganizationAuthTokenType,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
async with self.Session() as session:
if token := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.first()
):
await session.scalars(
select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
)
).first():
return convert_to_organization_auth_token(token)
else:
return None
@@ -457,15 +460,16 @@ class AgentDB:
token: str,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
async with self.Session() as session:
if token_obj := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(token=token)
.filter_by(valid=True)
.first()
):
await session.scalars(
select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(token=token)
.filter_by(valid=True)
)
).first():
return convert_to_organization_auth_token(token_obj)
else:
return None
@@ -482,15 +486,15 @@ class AgentDB:
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken:
with self.Session() as session:
async with self.Session() as session:
token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=token,
)
session.add(token)
session.commit()
session.refresh(token)
await session.commit()
await session.refresh(token)
return convert_to_organization_auth_token(token)
@@ -501,14 +505,15 @@ class AgentDB:
organization_id: str | None = None,
) -> list[Artifact]:
try:
with self.Session() as session:
async with self.Session() as session:
if artifacts := (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.all()
):
await session.scalars(
select(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
)
).all():
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
else:
return []
@@ -525,13 +530,14 @@ class AgentDB:
organization_id: str,
) -> Artifact | None:
try:
with self.Session() as session:
async with self.Session() as session:
if artifact := (
session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.filter_by(organization_id=organization_id)
.first()
):
await session.scalars(
select(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.filter_by(organization_id=organization_id)
)
).first():
return convert_to_artifact(artifact, self.debug_enabled)
else:
return None
@@ -550,16 +556,17 @@ class AgentDB:
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
async with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.filter_by(artifact_type=artifact_type)
.order_by(ArtifactModel.created_at.desc())
.first()
)
await session.scalars(
select(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.filter_by(artifact_type=artifact_type)
.order_by(ArtifactModel.created_at.desc())
)
).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
@@ -577,16 +584,17 @@ class AgentDB:
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
async with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
.filter(TaskModel.workflow_run_id == workflow_run_id)
.filter(ArtifactModel.artifact_type == artifact_type)
.filter(ArtifactModel.organization_id == organization_id)
.order_by(ArtifactModel.created_at.desc())
.first()
)
await session.scalars(
select(ArtifactModel)
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
.filter(TaskModel.workflow_run_id == workflow_run_id)
.filter(ArtifactModel.artifact_type == artifact_type)
.filter(ArtifactModel.organization_id == organization_id)
.order_by(ArtifactModel.created_at.desc())
)
).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
@@ -605,8 +613,8 @@ class AgentDB:
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
async with self.Session() as session:
artifact_query = select(ArtifactModel).filter_by(task_id=task_id)
if step_id:
artifact_query = artifact_query.filter_by(step_id=step_id)
if organization_id:
@@ -614,7 +622,7 @@ class AgentDB:
if artifact_types:
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
artifact = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
@@ -632,15 +640,11 @@ class AgentDB:
before: datetime | None = None,
) -> Task | None:
try:
with self.Session() as session:
query = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_id=workflow_id)
)
async with self.Session() as session:
query = select(TaskModel).filter_by(organization_id=organization_id).filter_by(workflow_id=workflow_id)
if before:
query = query.filter(TaskModel.created_at < before)
task = query.order_by(TaskModel.created_at.desc()).first()
task = (await session.scalars(query.order_by(TaskModel.created_at.desc()))).first()
if task:
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
@@ -655,7 +659,7 @@ class AgentDB:
workflow_definition: dict[str, Any],
description: str | None = None,
) -> Workflow:
with self.Session() as session:
async with self.Session() as session:
workflow = WorkflowModel(
organization_id=organization_id,
title=title,
@@ -663,14 +667,16 @@ class AgentDB:
workflow_definition=workflow_definition,
)
session.add(workflow)
session.commit()
session.refresh(workflow)
await session.commit()
await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
async def get_workflow(self, workflow_id: str) -> Workflow | None:
try:
with self.Session() as session:
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
async with self.Session() as session:
if workflow := (
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
@@ -684,8 +690,8 @@ class AgentDB:
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
) -> Workflow | None:
with self.Session() as session:
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
async with self.Session() as session:
workflow = (await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))).first()
if workflow:
if title:
workflow.title = title
@@ -693,8 +699,8 @@ class AgentDB:
workflow.description = description
if workflow_definition:
workflow.workflow_definition = workflow_definition
session.commit()
session.refresh(workflow)
await session.commit()
await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
return None
@@ -703,7 +709,7 @@ class AgentDB:
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
) -> WorkflowRun:
try:
with self.Session() as session:
async with self.Session() as session:
workflow_run = WorkflowRunModel(
workflow_id=workflow_id,
proxy_location=proxy_location,
@@ -711,28 +717,32 @@ class AgentDB:
webhook_callback_url=webhook_callback_url,
)
session.add(workflow_run)
session.commit()
session.refresh(workflow_run)
await session.commit()
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
with self.Session() as session:
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
async with self.Session() as session:
workflow_run = (
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
).first()
if workflow_run:
workflow_run.status = status
session.commit()
session.refresh(workflow_run)
await session.commit()
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
return None
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
try:
with self.Session() as session:
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
async with self.Session() as session:
if workflow_run := (
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
).first():
return convert_to_workflow_run(workflow_run)
return None
except SQLAlchemyError:
@@ -741,8 +751,10 @@ class AgentDB:
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
try:
with self.Session() as session:
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
async with self.Session() as session:
workflow_runs = (
await session.scalars(select(WorkflowRunModel).filter_by(workflow_id=workflow_id))
).all()
return [convert_to_workflow_run(run) for run in workflow_runs]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -757,7 +769,7 @@ class AgentDB:
description: str | None = None,
) -> WorkflowParameter:
try:
with self.Session() as session:
async with self.Session() as session:
workflow_parameter = WorkflowParameterModel(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
@@ -766,8 +778,8 @@ class AgentDB:
description=description,
)
session.add(workflow_parameter)
session.commit()
session.refresh(workflow_parameter)
await session.commit()
await session.refresh(workflow_parameter)
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -780,7 +792,7 @@ class AgentDB:
aws_key: str,
description: str | None = None,
) -> AWSSecretParameter:
with self.Session() as session:
async with self.Session() as session:
aws_secret_parameter = AWSSecretParameterModel(
workflow_id=workflow_id,
key=key,
@@ -788,8 +800,8 @@ class AgentDB:
description=description,
)
session.add(aws_secret_parameter)
session.commit()
session.refresh(aws_secret_parameter)
await session.commit()
await session.refresh(aws_secret_parameter)
return convert_to_aws_secret_parameter(aws_secret_parameter)
async def create_output_parameter(
@@ -798,21 +810,23 @@ class AgentDB:
key: str,
description: str | None = None,
) -> OutputParameter:
with self.Session() as session:
async with self.Session() as session:
output_parameter = OutputParameterModel(
key=key,
description=description,
workflow_id=workflow_id,
)
session.add(output_parameter)
session.commit()
session.refresh(output_parameter)
await session.commit()
await session.refresh(output_parameter)
return convert_to_output_parameter(output_parameter)
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
try:
with self.Session() as session:
output_parameters = session.query(OutputParameterModel).filter_by(workflow_id=workflow_id).all()
async with self.Session() as session:
output_parameters = (
await session.scalars(select(OutputParameterModel).filter_by(workflow_id=workflow_id))
).all()
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -820,13 +834,14 @@ class AgentDB:
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
try:
with self.Session() as session:
async with self.Session() as session:
workflow_run_output_parameters = (
session.query(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(WorkflowRunOutputParameterModel.created_at)
.all()
)
await session.scalars(
select(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(WorkflowRunOutputParameterModel.created_at)
)
).all()
return [
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
for parameter in workflow_run_output_parameters
@@ -839,15 +854,15 @@ class AgentDB:
self, workflow_run_id: str, output_parameter_id: str, value: dict[str, Any] | list | str | None
) -> WorkflowRunOutputParameter:
try:
with self.Session() as session:
async with self.Session() as session:
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
workflow_run_id=workflow_run_id,
output_parameter_id=output_parameter_id,
value=value,
)
session.add(workflow_run_output_parameter)
session.commit()
session.refresh(workflow_run_output_parameter)
await session.commit()
await session.refresh(workflow_run_output_parameter)
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
except SQLAlchemyError:
@@ -856,8 +871,10 @@ class AgentDB:
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
try:
with self.Session() as session:
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
async with self.Session() as session:
workflow_parameters = (
await session.scalars(select(WorkflowParameterModel).filter_by(workflow_id=workflow_id))
).all()
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
@@ -865,10 +882,12 @@ class AgentDB:
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
try:
with self.Session() as session:
async with self.Session() as session:
if workflow_parameter := (
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
):
await session.scalars(
select(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id)
)
).first():
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
return None
except SQLAlchemyError:
@@ -879,15 +898,15 @@ class AgentDB:
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
) -> WorkflowRunParameter:
try:
with self.Session() as session:
async with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=value,
)
session.add(workflow_run_parameter)
session.commit()
session.refresh(workflow_run_parameter)
await session.commit()
await session.refresh(workflow_run_parameter)
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
@@ -900,10 +919,10 @@ class AgentDB:
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
try:
with self.Session() as session:
async with self.Session() as session:
workflow_run_parameters = (
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
)
await session.scalars(select(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id))
).all()
results = []
for workflow_run_parameter in workflow_run_parameters:
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
@@ -926,13 +945,14 @@ class AgentDB:
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
try:
with self.Session() as session:
async with self.Session() as session:
if task := (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at.desc())
.first()
):
await session.scalars(
select(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at.desc())
)
).first():
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
@@ -941,20 +961,19 @@ class AgentDB:
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
try:
with self.Session() as session:
async with self.Session() as session:
tasks = (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at)
.all()
)
await session.scalars(
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
)
).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
async with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(ArtifactModel).where(
and_(
@@ -962,11 +981,11 @@ class AgentDB:
ArtifactModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()
await session.execute(stmt)
await session.commit()
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
async with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(StepModel).where(
and_(
@@ -974,5 +993,5 @@ class AgentDB:
StepModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()
await session.execute(stmt)
await session.commit()

View File

@@ -1,6 +1,7 @@
import datetime
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@@ -19,7 +20,7 @@ from skyvern.forge.sdk.db.id import (
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
class Base(DeclarativeBase):
class Base(AsyncAttrs, DeclarativeBase):
pass