add actions db model and caching V0 (#980)

This commit is contained in:
Shuchang Zheng
2024-10-15 12:06:50 -07:00
committed by GitHub
parent e7583ac878
commit 9048cdfa73
19 changed files with 731 additions and 90 deletions

View File

@@ -0,0 +1,8 @@
import hashlib
def calculate_sha256(data: str) -> str:
"""Helper function to calculate SHA256 hash of a string."""
sha256_hash = hashlib.sha256()
sha256_hash.update(data.encode())
return sha256_hash.hexdigest()

View File

@@ -113,7 +113,7 @@ def rename_file(file_path: str, new_file_name: str) -> str:
return file_path
def calculate_sha256(file_path: str) -> str:
def calculate_sha256_for_file(file_path: str) -> str:
"""Helper function to calculate SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:

View File

@@ -13,6 +13,7 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ActionModel,
ArtifactModel,
AWSSecretParameterModel,
BitwardenCreditCardDataParameterModel,
@@ -68,6 +69,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunParameter,
WorkflowRunStatus,
)
from skyvern.webeye.actions.actions import Action
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
@@ -1571,3 +1573,59 @@ class AgentDB:
)
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
async def create_action(self, action: Action) -> Action:
async with self.Session() as session:
new_action = ActionModel(
action_type=action.action_type,
source_action_id=action.source_action_id,
organization_id=action.organization_id,
workflow_run_id=action.workflow_run_id,
task_id=action.task_id,
step_id=action.step_id,
step_order=action.step_order,
action_order=action.action_order,
status=action.status,
reasoning=action.reasoning,
intention=action.intention,
response=action.response,
element_id=action.element_id,
skyvern_element_hash=action.skyvern_element_hash,
skyvern_element_data=action.skyvern_element_data,
action_json=action.model_dump(),
)
session.add(new_action)
await session.commit()
await session.refresh(new_action)
return Action.model_validate(new_action)
async def retrieve_action_plan(self, task: Task) -> list[Action]:
async with self.Session() as session:
subquery = (
select(TaskModel.task_id)
.filter(TaskModel.url == task.url)
.filter(TaskModel.navigation_goal == task.navigation_goal)
.filter(TaskModel.status == TaskStatus.completed)
.order_by(TaskModel.created_at.desc())
.limit(1)
.subquery()
)
query = (
select(ActionModel)
.filter(ActionModel.task_id == subquery.c.task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]
async def get_previous_actions_for_task(self, task_id: str) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter_by(task_id=task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]

View File

@@ -130,6 +130,11 @@ def generate_totp_code_id() -> str:
return f"totp_{int_id}"
def generate_action_id() -> str:
int_id = generate_id()
return f"a_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID

View File

@@ -19,6 +19,7 @@ from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.id import (
generate_action_id,
generate_artifact_id,
generate_aws_secret_parameter_id,
generate_bitwarden_credit_card_data_parameter_id,
@@ -437,3 +438,29 @@ class TOTPCodeModel(Base):
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
expired_at = Column(DateTime, index=True)
class ActionModel(Base):
__tablename__ = "actions"
__table_args__ = (Index("action_org_task_step_index", "organization_id", "task_id", "step_id"),)
action_id = Column(String, primary_key=True, index=True, default=generate_action_id)
action_type = Column(String, nullable=False)
source_action_id = Column(String, ForeignKey("actions.action_id"), nullable=True, index=True)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True)
task_id = Column(String, ForeignKey("tasks.task_id"), nullable=False, index=True)
step_id = Column(String, ForeignKey("steps.step_id"), nullable=False)
step_order = Column(Integer, nullable=False)
action_order = Column(Integer, nullable=False)
status = Column(String, nullable=False)
reasoning = Column(String, nullable=True)
intention = Column(String, nullable=True)
response = Column(String, nullable=True)
element_id = Column(String, nullable=True)
skyvern_element_hash = Column(String, nullable=True)
skyvern_element_data = Column(JSON, nullable=True)
action_json = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -32,7 +32,7 @@ from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.files import (
calculate_sha256,
calculate_sha256_for_file,
download_file,
download_from_s3,
get_path_for_workflow_download_directory,
@@ -181,6 +181,7 @@ class TaskBlock(Block):
download_suffix: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
cache_actions: bool = False
def get_all_parameters(
self,
@@ -1057,7 +1058,7 @@ class SendEmailBlock(Block):
subtype=subtype,
filename=attachment_filename,
)
file_hash = calculate_sha256(path)
file_hash = calculate_sha256_for_file(path)
file_names_by_hash[file_hash].append(path)
finally:
if path:

View File

@@ -129,6 +129,7 @@ class TaskBlockYAML(BlockYAML):
download_suffix: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
cache_actions: bool = False
class ForLoopBlockYAML(BlockYAML):

View File

@@ -985,7 +985,8 @@ class WorkflowService:
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
bitwarden_collection_id=parameter.bitwarden_collection_id,
# TODO: remove "# type: ignore" after ensuring bitwarden_collection_id is always set
bitwarden_collection_id=parameter.bitwarden_collection_id, # type: ignore
bitwarden_item_id=parameter.bitwarden_item_id,
key=parameter.key,
description=parameter.description,
@@ -1128,6 +1129,7 @@ class WorkflowService:
continue_on_failure=block_yaml.continue_on_failure,
totp_verification_url=block_yaml.totp_verification_url,
totp_identifier=block_yaml.totp_identifier,
cache_actions=block_yaml.cache_actions,
)
elif block_yaml.block_type == BlockType.FOR_LOOP:
loop_blocks = [