add actions db model and caching V0 (#980)

This commit is contained in:
Shuchang Zheng
2024-10-15 12:06:50 -07:00
committed by GitHub
parent e7583ac878
commit 9048cdfa73
19 changed files with 731 additions and 90 deletions

View File

@@ -13,6 +13,7 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ActionModel,
ArtifactModel,
AWSSecretParameterModel,
BitwardenCreditCardDataParameterModel,
@@ -68,6 +69,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunParameter,
WorkflowRunStatus,
)
from skyvern.webeye.actions.actions import Action
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
@@ -1571,3 +1573,59 @@ class AgentDB:
)
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
async def create_action(self, action: Action) -> Action:
async with self.Session() as session:
new_action = ActionModel(
action_type=action.action_type,
source_action_id=action.source_action_id,
organization_id=action.organization_id,
workflow_run_id=action.workflow_run_id,
task_id=action.task_id,
step_id=action.step_id,
step_order=action.step_order,
action_order=action.action_order,
status=action.status,
reasoning=action.reasoning,
intention=action.intention,
response=action.response,
element_id=action.element_id,
skyvern_element_hash=action.skyvern_element_hash,
skyvern_element_data=action.skyvern_element_data,
action_json=action.model_dump(),
)
session.add(new_action)
await session.commit()
await session.refresh(new_action)
return Action.model_validate(new_action)
async def retrieve_action_plan(self, task: Task) -> list[Action]:
async with self.Session() as session:
subquery = (
select(TaskModel.task_id)
.filter(TaskModel.url == task.url)
.filter(TaskModel.navigation_goal == task.navigation_goal)
.filter(TaskModel.status == TaskStatus.completed)
.order_by(TaskModel.created_at.desc())
.limit(1)
.subquery()
)
query = (
select(ActionModel)
.filter(ActionModel.task_id == subquery.c.task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]
async def get_previous_actions_for_task(self, task_id: str) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter_by(task_id=task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]