Move the code over from private repository (#3)

This commit is contained in:
Kerem Yilmaz
2024-03-01 10:09:30 -08:00
committed by GitHub
parent 32dd6d92a5
commit 9eddb3d812
93 changed files with 16798 additions and 0 deletions

View File

View File

@@ -0,0 +1,900 @@
from datetime import datetime
from typing import Any
import structlog
from sqlalchemy import and_, create_engine, delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ArtifactModel,
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.db.utils import (
_custom_json_serializer,
convert_to_artifact,
convert_to_aws_secret_parameter,
convert_to_organization,
convert_to_organization_auth_token,
convert_to_step,
convert_to_task,
convert_to_workflow,
convert_to_workflow_parameter,
convert_to_workflow_run,
convert_to_workflow_run_parameter,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
class AgentDB:
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
super().__init__()
self.debug_enabled = debug_enabled
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
self.Session = sessionmaker(bind=self.engine)
async def create_task(
self,
url: str,
navigation_goal: str | None,
data_extraction_goal: str | None,
navigation_payload: dict[str, Any] | list | str | None,
webhook_callback_url: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
workflow_run_id: str | None = None,
order: int | None = None,
retry: int | None = None,
) -> Task:
try:
with self.Session() as session:
new_task = TaskModel(
status="created",
url=url,
webhook_callback_url=webhook_callback_url,
navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=organization_id,
proxy_location=proxy_location,
extracted_information_schema=extracted_information_schema,
workflow_run_id=workflow_run_id,
order=order,
retry=retry,
)
session.add(new_task)
session.commit()
session.refresh(new_task)
return convert_to_task(new_task, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_step(
self,
task_id: str,
order: int,
retry_index: int,
organization_id: str | None = None,
) -> Step:
try:
with self.Session() as session:
new_step = StepModel(
task_id=task_id,
order=order,
retry_index=retry_index,
status="created",
organization_id=organization_id,
)
session.add(new_step)
session.commit()
session.refresh(new_step)
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_artifact(
self,
artifact_id: str,
step_id: str,
task_id: str,
artifact_type: str,
uri: str,
organization_id: str | None = None,
) -> Artifact:
try:
with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
task_id=task_id,
step_id=step_id,
artifact_type=artifact_type,
uri=uri,
organization_id=organization_id,
)
session.add(new_artifact)
session.commit()
session.refresh(new_artifact)
return convert_to_artifact(new_artifact, self.debug_enabled)
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id"""
try:
with self.Session() as session:
if task_obj := (
session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
try:
with self.Session() as session:
if (
steps := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
):
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
else:
return []
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
try:
with self.Session() as session:
return (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
.all()
)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
with self.Session() as session:
if step := (
session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order.desc())
.first()
):
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_step(
self,
task_id: str,
step_id: str,
status: StepStatus | None = None,
output: AgentStepOutput | None = None,
is_last: bool | None = None,
retry_index: int | None = None,
organization_id: str | None = None,
chat_completion_price: ChatCompletionPrice | None = None,
) -> Step:
try:
with self.Session() as session:
if (
step := session.query(StepModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.first()
):
if status is not None:
step.status = status
if output is not None:
step.output = output.model_dump()
if is_last is not None:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
if chat_completion_price is not None:
if step.input_token_count is None:
step.input_token_count = 0
if step.output_token_count is None:
step.output_token_count = 0
step.input_token_count += chat_completion_price.input_token_count
step.output_token_count += chat_completion_price.output_token_count
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
step.input_token_count, step.output_token_count
)
session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id)
if not updated_step:
raise NotFoundError("Step not found")
return updated_step
else:
raise NotFoundError("Step not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_task(
self,
task_id: str,
status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None,
organization_id: str | None = None,
) -> Task:
try:
with self.Session() as session:
if (
task := session.query(TaskModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.first()
):
task.status = status
if extracted_information is not None:
task.extracted_information = extracted_information
if failure_reason is not None:
task.failure_reason = failure_reason
session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_tasks(self, page: int = 1, page_size: int = 10, organization_id: str | None = None) -> list[Task]:
"""
Get all tasks.
:param page: Starts at 1
:param page_size:
:return:
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
try:
with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
tasks = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.order_by(TaskModel.created_at.desc())
.limit(page_size)
.offset(db_page * page_size)
.all()
)
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_organization(self, organization_id: str) -> Organization | None:
try:
with self.Session() as session:
if organization := (
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
):
return convert_to_organization(organization)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_organization(
self,
organization_name: str,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
) -> Organization:
with self.Session() as session:
org = OrganizationModel(
organization_name=organization_name,
webhook_callback_url=webhook_callback_url,
max_steps_per_run=max_steps_per_run,
)
session.add(org)
session.commit()
session.refresh(org)
return convert_to_organization(org)
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
if token := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.first()
):
return convert_to_organization_auth_token(token)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def validate_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken | None:
try:
with self.Session() as session:
if token_obj := (
session.query(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(token=token)
.filter_by(valid=True)
.first()
):
return convert_to_organization_auth_token(token_obj)
else:
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
) -> OrganizationAuthToken:
with self.Session() as session:
token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=token,
)
session.add(token)
session.commit()
session.refresh(token)
return convert_to_organization_auth_token(token)
async def get_artifacts_for_task_step(
self,
task_id: str,
step_id: str,
organization_id: str | None = None,
) -> list[Artifact]:
try:
with self.Session() as session:
if artifacts := (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.all()
):
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
else:
return []
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_artifact_by_id(
self,
artifact_id: str,
organization_id: str,
) -> Artifact | None:
try:
with self.Session() as session:
if artifact := (
session.query(ArtifactModel)
.filter_by(artifact_id=artifact_id)
.filter_by(organization_id=organization_id)
.first()
):
return convert_to_artifact(artifact, self.debug_enabled)
else:
return None
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_artifact(
self,
task_id: str,
step_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.filter_by(artifact_type=artifact_type)
.order_by(ArtifactModel.created_at.desc())
.first()
)
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_artifact_for_workflow_run(
self,
workflow_run_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact = (
session.query(ArtifactModel)
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
.filter(TaskModel.workflow_run_id == workflow_run_id)
.filter(ArtifactModel.artifact_type == artifact_type)
.filter(ArtifactModel.organization_id == organization_id)
.order_by(ArtifactModel.created_at.desc())
.first()
)
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_latest_artifact(
self,
task_id: str,
step_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
organization_id: str | None = None,
) -> Artifact | None:
try:
with self.Session() as session:
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
if step_id:
artifact_query = artifact_query.filter_by(step_id=step_id)
if organization_id:
artifact_query = artifact_query.filter_by(organization_id=organization_id)
if artifact_types:
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.exception("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.exception("UnexpectedError", exc_info=True)
raise
async def get_latest_task_by_workflow_id(
self,
organization_id: str,
workflow_id: str,
before: datetime | None = None,
) -> Task | None:
try:
with self.Session() as session:
query = (
session.query(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_id=workflow_id)
)
if before:
query = query.filter(TaskModel.created_at < before)
task = query.order_by(TaskModel.created_at.desc()).first()
if task:
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: dict[str, Any],
description: str | None = None,
) -> Workflow:
with self.Session() as session:
workflow = WorkflowModel(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition,
)
session.add(workflow)
session.commit()
session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
async def get_workflow(self, workflow_id: str) -> Workflow | None:
try:
with self.Session() as session:
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow(
self,
workflow_id: str,
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
) -> Workflow | None:
with self.Session() as session:
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
if workflow:
if title:
workflow.title = title
if description:
workflow.description = description
if workflow_definition:
workflow.workflow_definition = workflow_definition
session.commit()
session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
return None
async def create_workflow_run(
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
) -> WorkflowRun:
try:
with self.Session() as session:
workflow_run = WorkflowRunModel(
workflow_id=workflow_id,
proxy_location=proxy_location,
status="created",
webhook_callback_url=webhook_callback_url,
)
session.add(workflow_run)
session.commit()
session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
with self.Session() as session:
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
if workflow_run:
workflow_run.status = status
session.commit()
session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
return None
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
try:
with self.Session() as session:
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
return convert_to_workflow_run(workflow_run)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
try:
with self.Session() as session:
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_workflow_run(run) for run in workflow_runs]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_parameter(
self,
workflow_id: str,
workflow_parameter_type: WorkflowParameterType,
key: str,
default_value: Any,
description: str | None = None,
) -> WorkflowParameter:
try:
with self.Session() as session:
workflow_parameter = WorkflowParameterModel(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
default_value=default_value,
description=description,
)
session.add(workflow_parameter)
session.commit()
session.refresh(workflow_parameter)
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_aws_secret_parameter(
self,
workflow_id: str,
key: str,
aws_key: str,
description: str | None = None,
) -> AWSSecretParameter:
with self.Session() as session:
aws_secret_parameter = AWSSecretParameterModel(
workflow_id=workflow_id,
key=key,
aws_key=aws_key,
description=description,
)
session.add(aws_secret_parameter)
session.commit()
session.refresh(aws_secret_parameter)
return convert_to_aws_secret_parameter(aws_secret_parameter)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
try:
with self.Session() as session:
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
try:
with self.Session() as session:
if workflow_parameter := (
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
):
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def create_workflow_run_parameter(
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
) -> WorkflowRunParameter:
try:
with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=value,
)
session.add(workflow_run_parameter)
session.commit()
session.refresh(workflow_run_parameter)
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflow_run_parameters(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
try:
with self.Session() as session:
workflow_run_parameters = (
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
)
results = []
for workflow_run_parameter in workflow_run_parameters:
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(
workflow_parameter_id=workflow_run_parameter.workflow_parameter_id
)
results.append(
(
workflow_parameter,
convert_to_workflow_run_parameter(
workflow_run_parameter, workflow_parameter, self.debug_enabled
),
)
)
return results
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
try:
with self.Session() as session:
if task := (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at.desc())
.first()
):
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
try:
with self.Session() as session:
tasks = (
session.query(TaskModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(TaskModel.created_at)
.all()
)
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.organization_id == organization_id,
ArtifactModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(StepModel).where(
and_(
StepModel.organization_id == organization_id,
StepModel.task_id == task_id,
)
)
session.execute(stmt)
session.commit()

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class OrganizationAuthTokenType(StrEnum):
api = "api"
class ScheduleRuleUnit(StrEnum):
# No support for scheduling every second
minute = "minute"
hour = "hour"
day = "day"
week = "week"
month = "month"
year = "year"

View File

@@ -0,0 +1,2 @@
class NotFoundError(Exception):
pass

136
skyvern/forge/sdk/db/id.py Normal file
View File

@@ -0,0 +1,136 @@
import hashlib
import itertools
import os
import platform
import random
import time
# 6/20/2022 12AM
BASE_EPOCH = 1655683200
VERSION = 0
# Number of bits
TIMESTAMP_BITS = 32
WORKER_ID_BITS = 21
SEQUENCE_BITS = 10
VERSION_BITS = 1
# Bit shits (left)
TIMESTAMP_SHIFT = 32
WORKER_ID_SHIFT = 11
SEQUENCE_SHIFT = 1
VERSION_SHIFT = 0
SEQUENCE_MAX = (2**SEQUENCE_BITS) - 1
_sequence_start = None
SEQUENCE_COUNTER = itertools.count()
_worker_hash = None
# prefix
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
ORG_PREFIX = "o"
TASK_PREFIX = "tsk"
USER_PREFIX = "u"
STEP_PREFIX = "stp"
ARTIFACT_PREFIX = "a"
WORKFLOW_PREFIX = "w"
WORKFLOW_RUN_PREFIX = "wr"
WORKFLOW_PARAMETER_PREFIX = "wp"
AWS_SECRET_PARAMETER_PREFIX = "asp"
def generate_workflow_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_PREFIX}_{int_id}"
def generate_workflow_run_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_RUN_PREFIX}_{int_id}"
def generate_aws_secret_parameter_id() -> str:
int_id = generate_id()
return f"{AWS_SECRET_PARAMETER_PREFIX}_{int_id}"
def generate_workflow_parameter_id() -> str:
int_id = generate_id()
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
def generate_organization_auth_token_id() -> str:
int_id = generate_id()
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
def generate_org_id() -> str:
int_id = generate_id()
return f"{ORG_PREFIX}_{int_id}"
def generate_task_id() -> str:
int_id = generate_id()
return f"{TASK_PREFIX}_{int_id}"
def generate_step_id() -> str:
int_id = generate_id()
return f"{STEP_PREFIX}_{int_id}"
def generate_artifact_id() -> str:
int_id = generate_id()
return f"{ARTIFACT_PREFIX}_{int_id}"
def generate_user_id() -> str:
int_id = generate_id()
return f"{USER_PREFIX}_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID
"""
create_at = current_time() - BASE_EPOCH
sequence = _increment_and_get_sequence()
time_part = _mask_shift(create_at, TIMESTAMP_BITS, TIMESTAMP_SHIFT)
worker_part = _mask_shift(_get_worker_hash(), WORKER_ID_BITS, WORKER_ID_SHIFT)
sequence_part = _mask_shift(sequence, SEQUENCE_BITS, SEQUENCE_SHIFT)
version_part = _mask_shift(VERSION, VERSION_BITS, VERSION_SHIFT)
return time_part | worker_part | sequence_part | version_part
def _increment_and_get_sequence() -> int:
global _sequence_start
if _sequence_start is None:
_sequence_start = random.randint(0, SEQUENCE_MAX)
return (_sequence_start + next(SEQUENCE_COUNTER)) % SEQUENCE_MAX
def current_time() -> int:
return int(time.time())
def current_time_ms() -> int:
return int(time.time() * 1000)
def _mask_shift(value: int, mask_bits: int, shift_bits: int) -> int:
return (value & ((2**mask_bits) - 1)) << shift_bits
def _get_worker_hash() -> int:
global _worker_hash
if _worker_hash is None:
_worker_hash = _generate_worker_hash()
return _worker_hash
def _generate_worker_hash() -> int:
worker_identity = f"{platform.node()}:{os.getpid()}"
return int(hashlib.md5(worker_identity.encode()).hexdigest()[-15:], 16)

View File

@@ -0,0 +1,172 @@
import datetime
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.id import (
generate_artifact_id,
generate_aws_secret_parameter_id,
generate_org_id,
generate_organization_auth_token_id,
generate_step_id,
generate_task_id,
generate_workflow_id,
generate_workflow_parameter_id,
generate_workflow_run_id,
)
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
class Base(DeclarativeBase):
pass
class TaskModel(Base):
__tablename__ = "tasks"
task_id = Column(String, primary_key=True, index=True, default=generate_task_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
status = Column(String)
webhook_callback_url = Column(String)
url = Column(String)
navigation_goal = Column(String)
data_extraction_goal = Column(String)
navigation_payload = Column(JSON)
extracted_information = Column(JSON)
failure_reason = Column(String)
proxy_location = Column(Enum(ProxyLocation))
extracted_information_schema = Column(JSON)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
order = Column(Integer, nullable=True)
retry = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class StepModel(Base):
__tablename__ = "steps"
step_id = Column(String, primary_key=True, index=True, default=generate_step_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
status = Column(String)
output = Column(JSON)
order = Column(Integer)
is_last = Column(Boolean, default=False)
retry_index = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
input_token_count = Column(Integer, default=0)
output_token_count = Column(Integer, default=0)
step_cost = Column(Numeric, default=0)
class OrganizationModel(Base):
__tablename__ = "organizations"
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
organization_name = Column(String, nullable=False)
webhook_callback_url = Column(UnicodeText)
max_steps_per_run = Column(Integer)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
class OrganizationAuthTokenModel(Base):
__tablename__ = "organization_auth_tokens"
id = Column(
String,
primary_key=True,
index=True,
default=generate_organization_auth_token_id,
)
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
token_type = Column(Enum(OrganizationAuthTokenType), nullable=False)
token = Column(String, index=True, nullable=False)
valid = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class ArtifactModel(Base):
__tablename__ = "artifacts"
artifact_id = Column(String, primary_key=True, index=True, default=generate_artifact_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
step_id = Column(String, ForeignKey("steps.step_id"))
artifact_type = Column(String)
uri = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class WorkflowModel(Base):
__tablename__ = "workflows"
workflow_id = Column(String, primary_key=True, index=True, default=generate_workflow_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
title = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_definition = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunModel(Base):
__tablename__ = "workflow_runs"
workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False)
status = Column(String, nullable=False)
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class WorkflowParameterModel(Base):
__tablename__ = "workflow_parameters"
workflow_parameter_id = Column(String, primary_key=True, index=True, default=generate_workflow_parameter_id)
workflow_parameter_type = Column(String, nullable=False)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
default_value = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class AWSSecretParameterModel(Base):
__tablename__ = "aws_secret_parameters"
aws_secret_parameter_id = Column(String, primary_key=True, index=True, default=generate_aws_secret_parameter_id)
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
key = Column(String, nullable=False)
description = Column(String, nullable=True)
aws_key = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
deleted_at = Column(DateTime, nullable=True)
class WorkflowRunParameterModel(Base):
__tablename__ = "workflow_run_parameters"
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
workflow_parameter_id = Column(
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
)
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
value = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -0,0 +1,220 @@
import json
import typing
import pydantic.json
import structlog
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.models import (
ArtifactModel,
AWSSecretParameterModel,
OrganizationAuthTokenModel,
OrganizationModel,
StepModel,
TaskModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowDefinition,
WorkflowRun,
WorkflowRunParameter,
WorkflowRunStatus,
)
LOG = structlog.get_logger()
@typing.no_type_check
def _custom_json_serializer(*args, **kwargs) -> str:
"""
Encodes json in the same way that pydantic does.
"""
return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs)
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
if debug_enabled:
LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id)
task = Task(
task_id=task_obj.task_id,
status=TaskStatus(task_obj.status),
created_at=task_obj.created_at,
modified_at=task_obj.modified_at,
url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url,
navigation_goal=task_obj.navigation_goal,
data_extraction_goal=task_obj.data_extraction_goal,
navigation_payload=task_obj.navigation_payload,
extracted_information=task_obj.extracted_information,
failure_reason=task_obj.failure_reason,
organization_id=task_obj.organization_id,
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
extracted_information_schema=task_obj.extracted_information_schema,
workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order,
retry=task_obj.retry,
)
return task
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
if debug_enabled:
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
return Step(
task_id=step_model.task_id,
step_id=step_model.step_id,
created_at=step_model.created_at,
modified_at=step_model.modified_at,
status=StepStatus(step_model.status),
output=step_model.output,
order=step_model.order,
is_last=step_model.is_last,
retry_index=step_model.retry_index,
organization_id=step_model.organization_id,
input_token_count=step_model.input_token_count,
output_token_count=step_model.output_token_count,
step_cost=step_model.step_cost,
)
def convert_to_organization(org_model: OrganizationModel) -> Organization:
return Organization(
organization_id=org_model.organization_id,
organization_name=org_model.organization_name,
webhook_callback_url=org_model.webhook_callback_url,
max_steps_per_run=org_model.max_steps_per_run,
created_at=org_model.created_at,
modified_at=org_model.modified_at,
)
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=org_auth_token.token,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,
)
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
if debug_enabled:
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
return Artifact(
artifact_id=artifact_model.artifact_id,
artifact_type=ArtifactType[artifact_model.artifact_type.upper()],
uri=artifact_model.uri,
task_id=artifact_model.task_id,
step_id=artifact_model.step_id,
created_at=artifact_model.created_at,
modified_at=artifact_model.modified_at,
organization_id=artifact_model.organization_id,
)
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
if debug_enabled:
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
return Workflow(
workflow_id=workflow_model.workflow_id,
organization_id=workflow_model.organization_id,
title=workflow_model.title,
description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
created_at=workflow_model.created_at,
modified_at=workflow_model.modified_at,
deleted_at=workflow_model.deleted_at,
)
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
if debug_enabled:
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
return WorkflowRun(
workflow_run_id=workflow_run_model.workflow_run_id,
workflow_id=workflow_run_model.workflow_id,
status=WorkflowRunStatus[workflow_run_model.status],
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
webhook_callback_url=workflow_run_model.webhook_callback_url,
created_at=workflow_run_model.created_at,
modified_at=workflow_run_model.modified_at,
)
def convert_to_workflow_parameter(
workflow_parameter_model: WorkflowParameterModel, debug_enabled: bool = False
) -> WorkflowParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowParameterModel to WorkflowParameter",
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
)
workflow_parameter_type = WorkflowParameterType[workflow_parameter_model.workflow_parameter_type.upper()]
return WorkflowParameter(
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
workflow_parameter_type=workflow_parameter_type,
workflow_id=workflow_parameter_model.workflow_id,
default_value=workflow_parameter_type.convert_value(workflow_parameter_model.default_value),
key=workflow_parameter_model.key,
description=workflow_parameter_model.description,
created_at=workflow_parameter_model.created_at,
modified_at=workflow_parameter_model.modified_at,
deleted_at=workflow_parameter_model.deleted_at,
)
def convert_to_aws_secret_parameter(
aws_secret_parameter_model: AWSSecretParameterModel, debug_enabled: bool = False
) -> AWSSecretParameter:
if debug_enabled:
LOG.debug(
"Converting AWSSecretParameterModel to AWSSecretParameter",
aws_secret_parameter_id=aws_secret_parameter_model.id,
)
return AWSSecretParameter(
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
workflow_id=aws_secret_parameter_model.workflow_id,
key=aws_secret_parameter_model.key,
description=aws_secret_parameter_model.description,
aws_key=aws_secret_parameter_model.aws_key,
created_at=aws_secret_parameter_model.created_at,
modified_at=aws_secret_parameter_model.modified_at,
deleted_at=aws_secret_parameter_model.deleted_at,
)
def convert_to_workflow_run_parameter(
workflow_run_parameter_model: WorkflowRunParameterModel,
workflow_parameter: WorkflowParameter,
debug_enabled: bool = False,
) -> WorkflowRunParameter:
if debug_enabled:
LOG.debug(
"Converting WorkflowRunParameterModel to WorkflowRunParameter",
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
)
return WorkflowRunParameter(
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
value=workflow_parameter.workflow_parameter_type.convert_value(workflow_run_parameter_model.value),
created_at=workflow_run_parameter_model.created_at,
)