Move the code over from private repository (#3)
This commit is contained in:
0
skyvern/forge/sdk/db/__init__.py
Normal file
0
skyvern/forge/sdk/db/__init__.py
Normal file
900
skyvern/forge/sdk/db/client.py
Normal file
900
skyvern/forge/sdk/db/client.py
Normal file
@@ -0,0 +1,900 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, create_engine, delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ArtifactModel,
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
_custom_json_serializer,
|
||||
convert_to_artifact,
|
||||
convert_to_aws_secret_parameter,
|
||||
convert_to_organization,
|
||||
convert_to_organization_auth_token,
|
||||
convert_to_step,
|
||||
convert_to_task,
|
||||
convert_to_workflow,
|
||||
convert_to_workflow_parameter,
|
||||
convert_to_workflow_run,
|
||||
convert_to_workflow_run_parameter,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunParameter, WorkflowRunStatus
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AgentDB:
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.debug_enabled = debug_enabled
|
||||
self.engine = create_engine(database_string, json_serializer=_custom_json_serializer)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
url: str,
|
||||
navigation_goal: str | None,
|
||||
data_extraction_goal: str | None,
|
||||
navigation_payload: dict[str, Any] | list | str | None,
|
||||
webhook_callback_url: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocation | None = None,
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
order: int | None = None,
|
||||
retry: int | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status="created",
|
||||
url=url,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
navigation_goal=navigation_goal,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
organization_id=organization_id,
|
||||
proxy_location=proxy_location,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
workflow_run_id=workflow_run_id,
|
||||
order=order,
|
||||
retry=retry,
|
||||
)
|
||||
session.add(new_task)
|
||||
session.commit()
|
||||
session.refresh(new_task)
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
order: int,
|
||||
retry_index: int,
|
||||
organization_id: str | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
order=order,
|
||||
retry_index=retry_index,
|
||||
status="created",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_step)
|
||||
session.commit()
|
||||
session.refresh(new_step)
|
||||
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
step_id: str,
|
||||
task_id: str,
|
||||
artifact_type: str,
|
||||
uri: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=artifact_id,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
session.commit()
|
||||
session.refresh(new_artifact)
|
||||
return convert_to_artifact(new_artifact, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task_obj := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Task not found", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_steps(self, task_id: str, organization_id: str | None = None) -> list[Step]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
steps := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
):
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
else:
|
||||
return []
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> list[StepModel]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
return (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
.all()
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if step := (
|
||||
session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.desc())
|
||||
.first()
|
||||
):
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info("Latest step not found", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: StepStatus | None = None,
|
||||
output: AgentStepOutput | None = None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
chat_completion_price: ChatCompletionPrice | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
step := session.query(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
if status is not None:
|
||||
step.status = status
|
||||
if output is not None:
|
||||
step.output = output.model_dump()
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
if chat_completion_price is not None:
|
||||
if step.input_token_count is None:
|
||||
step.input_token_count = 0
|
||||
|
||||
if step.output_token_count is None:
|
||||
step.output_token_count = 0
|
||||
|
||||
step.input_token_count += chat_completion_price.input_token_count
|
||||
step.output_token_count += chat_completion_price.output_token_count
|
||||
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
|
||||
step.input_token_count, step.output_token_count
|
||||
)
|
||||
|
||||
session.commit()
|
||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||
if not updated_step:
|
||||
raise NotFoundError("Step not found")
|
||||
return updated_step
|
||||
else:
|
||||
raise NotFoundError("Step not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if (
|
||||
task := session.query(TaskModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
task.status = status
|
||||
if extracted_information is not None:
|
||||
task.extracted_information = extracted_information
|
||||
if failure_reason is not None:
|
||||
task.failure_reason = failure_reason
|
||||
session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_tasks(self, page: int = 1, page_size: int = 10, organization_id: str | None = None) -> list[Task]:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param page: Starts at 1
|
||||
:param page_size:
|
||||
:return:
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
|
||||
try:
|
||||
with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
.all()
|
||||
)
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization(self, organization_id: str) -> Organization | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if organization := (
|
||||
session.query(OrganizationModel).filter_by(organization_id=organization_id).first()
|
||||
):
|
||||
return convert_to_organization(organization)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_organization(
|
||||
self,
|
||||
organization_name: str,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
) -> Organization:
|
||||
with self.Session() as session:
|
||||
org = OrganizationModel(
|
||||
organization_name=organization_name,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
)
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
|
||||
return convert_to_organization(org)
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if token := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
return convert_to_organization_auth_token(token)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def validate_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if token_obj := (
|
||||
session.query(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
.filter_by(valid=True)
|
||||
.first()
|
||||
):
|
||||
return convert_to_organization_auth_token(token_obj)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
) -> OrganizationAuthToken:
|
||||
with self.Session() as session:
|
||||
token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=token,
|
||||
)
|
||||
session.add(token)
|
||||
session.commit()
|
||||
session.refresh(token)
|
||||
|
||||
return convert_to_organization_auth_token(token)
|
||||
|
||||
async def get_artifacts_for_task_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[Artifact]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if artifacts := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.all()
|
||||
):
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact_by_id(
|
||||
self,
|
||||
artifact_id: str,
|
||||
organization_id: str,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if artifact := (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(artifact_id=artifact_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.first()
|
||||
):
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_artifact_for_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact = (
|
||||
session.query(ArtifactModel)
|
||||
.join(TaskModel, TaskModel.task_id == ArtifactModel.task_id)
|
||||
.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
artifact_query = session.query(ArtifactModel).filter_by(task_id=task_id)
|
||||
if step_id:
|
||||
artifact_query = artifact_query.filter_by(step_id=step_id)
|
||||
if organization_id:
|
||||
artifact_query = artifact_query.filter_by(organization_id=organization_id)
|
||||
if artifact_types:
|
||||
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
artifact = artifact_query.order_by(ArtifactModel.created_at.desc()).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.exception("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_latest_task_by_workflow_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
before: datetime | None = None,
|
||||
) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
query = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_id=workflow_id)
|
||||
)
|
||||
if before:
|
||||
query = query.filter(TaskModel.created_at < before)
|
||||
task = query.order_by(TaskModel.created_at.desc()).first()
|
||||
if task:
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
workflow_definition: dict[str, Any],
|
||||
description: str | None = None,
|
||||
) -> Workflow:
|
||||
with self.Session() as session:
|
||||
workflow = WorkflowModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition,
|
||||
)
|
||||
session.add(workflow)
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
|
||||
async def get_workflow(self, workflow_id: str) -> Workflow | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow := session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first():
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: dict[str, Any] | None = None,
|
||||
) -> Workflow | None:
|
||||
with self.Session() as session:
|
||||
workflow = session.query(WorkflowModel).filter_by(workflow_id=workflow_id).first()
|
||||
if workflow:
|
||||
if title:
|
||||
workflow.title = title
|
||||
if description:
|
||||
workflow.description = description
|
||||
if workflow_definition:
|
||||
workflow.workflow_definition = workflow_definition
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
LOG.error("Workflow not found, nothing to update", workflow_id=workflow_id)
|
||||
return None
|
||||
|
||||
async def create_workflow_run(
|
||||
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
|
||||
) -> WorkflowRun:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run = WorkflowRunModel(
|
||||
workflow_id=workflow_id,
|
||||
proxy_location=proxy_location,
|
||||
status="created",
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_workflow_run(self, workflow_run_id: str, status: WorkflowRunStatus) -> WorkflowRun | None:
|
||||
with self.Session() as session:
|
||||
workflow_run = session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first()
|
||||
if workflow_run:
|
||||
workflow_run.status = status
|
||||
session.commit()
|
||||
session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
LOG.error("WorkflowRun not found, nothing to update", workflow_run_id=workflow_run_id)
|
||||
return None
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_run := session.query(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_runs(self, workflow_id: str) -> list[WorkflowRun]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_runs = session.query(WorkflowRunModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: Any,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameter = WorkflowParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_parameter)
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_aws_secret_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
aws_key: str,
|
||||
description: str | None = None,
|
||||
) -> AWSSecretParameter:
|
||||
with self.Session() as session:
|
||||
aws_secret_parameter = AWSSecretParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
key=key,
|
||||
aws_key=aws_key,
|
||||
description=description,
|
||||
)
|
||||
session.add(aws_secret_parameter)
|
||||
session.commit()
|
||||
session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_parameters = session.query(WorkflowParameterModel).filter_by(workflow_id=workflow_id).all()
|
||||
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if workflow_parameter := (
|
||||
session.query(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id).first()
|
||||
):
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self, workflow_run_id: str, workflow_parameter_id: str, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_parameter)
|
||||
session.commit()
|
||||
session.refresh(workflow_run_parameter)
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id)
|
||||
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_workflow_run_parameters(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
workflow_run_parameters = (
|
||||
session.query(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id).all()
|
||||
)
|
||||
results = []
|
||||
for workflow_run_parameter in workflow_run_parameters:
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(
|
||||
workflow_parameter_id=workflow_run_parameter.workflow_parameter_id
|
||||
)
|
||||
results.append(
|
||||
(
|
||||
workflow_parameter,
|
||||
convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter, workflow_parameter, self.debug_enabled
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
if task := (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.first()
|
||||
):
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
tasks = (
|
||||
session.query(TaskModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(TaskModel.created_at)
|
||||
.all()
|
||||
)
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
ArtifactModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
|
||||
with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(StepModel).where(
|
||||
and_(
|
||||
StepModel.organization_id == organization_id,
|
||||
StepModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
15
skyvern/forge/sdk/db/enums.py
Normal file
15
skyvern/forge/sdk/db/enums.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
|
||||
|
||||
class ScheduleRuleUnit(StrEnum):
|
||||
# No support for scheduling every second
|
||||
minute = "minute"
|
||||
hour = "hour"
|
||||
day = "day"
|
||||
week = "week"
|
||||
month = "month"
|
||||
year = "year"
|
||||
2
skyvern/forge/sdk/db/exceptions.py
Normal file
2
skyvern/forge/sdk/db/exceptions.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class NotFoundError(Exception):
|
||||
pass
|
||||
136
skyvern/forge/sdk/db/id.py
Normal file
136
skyvern/forge/sdk/db/id.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import hashlib
|
||||
import itertools
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
|
||||
# 6/20/2022 12AM
|
||||
BASE_EPOCH = 1655683200
|
||||
VERSION = 0
|
||||
|
||||
# Number of bits
|
||||
TIMESTAMP_BITS = 32
|
||||
WORKER_ID_BITS = 21
|
||||
SEQUENCE_BITS = 10
|
||||
VERSION_BITS = 1
|
||||
|
||||
# Bit shits (left)
|
||||
TIMESTAMP_SHIFT = 32
|
||||
WORKER_ID_SHIFT = 11
|
||||
SEQUENCE_SHIFT = 1
|
||||
VERSION_SHIFT = 0
|
||||
|
||||
SEQUENCE_MAX = (2**SEQUENCE_BITS) - 1
|
||||
_sequence_start = None
|
||||
SEQUENCE_COUNTER = itertools.count()
|
||||
_worker_hash = None
|
||||
|
||||
# prefix
|
||||
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
|
||||
ORG_PREFIX = "o"
|
||||
TASK_PREFIX = "tsk"
|
||||
USER_PREFIX = "u"
|
||||
STEP_PREFIX = "stp"
|
||||
ARTIFACT_PREFIX = "a"
|
||||
WORKFLOW_PREFIX = "w"
|
||||
WORKFLOW_RUN_PREFIX = "wr"
|
||||
WORKFLOW_PARAMETER_PREFIX = "wp"
|
||||
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
||||
|
||||
|
||||
def generate_workflow_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_workflow_run_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_RUN_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_aws_secret_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{AWS_SECRET_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_workflow_parameter_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{WORKFLOW_PARAMETER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_organization_auth_token_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORGANIZATION_AUTH_TOKEN_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_org_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ORG_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{TASK_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_step_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{STEP_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_artifact_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{ARTIFACT_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_user_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{USER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_id() -> int:
|
||||
"""
|
||||
generate a 64-bit int ID
|
||||
"""
|
||||
create_at = current_time() - BASE_EPOCH
|
||||
sequence = _increment_and_get_sequence()
|
||||
|
||||
time_part = _mask_shift(create_at, TIMESTAMP_BITS, TIMESTAMP_SHIFT)
|
||||
worker_part = _mask_shift(_get_worker_hash(), WORKER_ID_BITS, WORKER_ID_SHIFT)
|
||||
sequence_part = _mask_shift(sequence, SEQUENCE_BITS, SEQUENCE_SHIFT)
|
||||
version_part = _mask_shift(VERSION, VERSION_BITS, VERSION_SHIFT)
|
||||
|
||||
return time_part | worker_part | sequence_part | version_part
|
||||
|
||||
|
||||
def _increment_and_get_sequence() -> int:
|
||||
global _sequence_start
|
||||
if _sequence_start is None:
|
||||
_sequence_start = random.randint(0, SEQUENCE_MAX)
|
||||
|
||||
return (_sequence_start + next(SEQUENCE_COUNTER)) % SEQUENCE_MAX
|
||||
|
||||
|
||||
def current_time() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def current_time_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _mask_shift(value: int, mask_bits: int, shift_bits: int) -> int:
|
||||
return (value & ((2**mask_bits) - 1)) << shift_bits
|
||||
|
||||
|
||||
def _get_worker_hash() -> int:
|
||||
global _worker_hash
|
||||
if _worker_hash is None:
|
||||
_worker_hash = _generate_worker_hash()
|
||||
return _worker_hash
|
||||
|
||||
|
||||
def _generate_worker_hash() -> int:
|
||||
worker_identity = f"{platform.node()}:{os.getpid()}"
|
||||
return int(hashlib.md5(worker_identity.encode()).hexdigest()[-15:], 16)
|
||||
172
skyvern/forge/sdk/db/models.py
Normal file
172
skyvern/forge/sdk/db/models.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, UnicodeText
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.id import (
|
||||
generate_artifact_id,
|
||||
generate_aws_secret_parameter_id,
|
||||
generate_org_id,
|
||||
generate_organization_auth_token_id,
|
||||
generate_step_id,
|
||||
generate_task_id,
|
||||
generate_workflow_id,
|
||||
generate_workflow_parameter_id,
|
||||
generate_workflow_run_id,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class TaskModel(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id = Column(String, primary_key=True, index=True, default=generate_task_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
status = Column(String)
|
||||
webhook_callback_url = Column(String)
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
data_extraction_goal = Column(String)
|
||||
navigation_payload = Column(JSON)
|
||||
extracted_information = Column(JSON)
|
||||
failure_reason = Column(String)
|
||||
proxy_location = Column(Enum(ProxyLocation))
|
||||
extracted_information_schema = Column(JSON)
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
|
||||
order = Column(Integer, nullable=True)
|
||||
retry = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
__tablename__ = "steps"
|
||||
|
||||
step_id = Column(String, primary_key=True, index=True, default=generate_step_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
status = Column(String)
|
||||
output = Column(JSON)
|
||||
order = Column(Integer)
|
||||
is_last = Column(Boolean, default=False)
|
||||
retry_index = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
input_token_count = Column(Integer, default=0)
|
||||
output_token_count = Column(Integer, default=0)
|
||||
step_cost = Column(Numeric, default=0)
|
||||
|
||||
|
||||
class OrganizationModel(Base):
|
||||
__tablename__ = "organizations"
|
||||
|
||||
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
|
||||
organization_name = Column(String, nullable=False)
|
||||
webhook_callback_url = Column(UnicodeText)
|
||||
max_steps_per_run = Column(Integer)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||
|
||||
|
||||
class OrganizationAuthTokenModel(Base):
|
||||
__tablename__ = "organization_auth_tokens"
|
||||
|
||||
id = Column(
|
||||
String,
|
||||
primary_key=True,
|
||||
index=True,
|
||||
default=generate_organization_auth_token_id,
|
||||
)
|
||||
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
|
||||
token_type = Column(Enum(OrganizationAuthTokenType), nullable=False)
|
||||
token = Column(String, index=True, nullable=False)
|
||||
valid = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class ArtifactModel(Base):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
artifact_id = Column(String, primary_key=True, index=True, default=generate_artifact_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
task_id = Column(String, ForeignKey("tasks.task_id"))
|
||||
step_id = Column(String, ForeignKey("steps.step_id"))
|
||||
artifact_type = Column(String)
|
||||
uri = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowModel(Base):
|
||||
__tablename__ = "workflows"
|
||||
|
||||
workflow_id = Column(String, primary_key=True, index=True, default=generate_workflow_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"))
|
||||
title = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_definition = Column(JSON, nullable=False)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class WorkflowRunModel(Base):
|
||||
__tablename__ = "workflow_runs"
|
||||
|
||||
workflow_run_id = Column(String, primary_key=True, index=True, default=generate_workflow_run_id)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=False)
|
||||
status = Column(String, nullable=False)
|
||||
proxy_location = Column(Enum(ProxyLocation))
|
||||
webhook_callback_url = Column(String)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class WorkflowParameterModel(Base):
|
||||
__tablename__ = "workflow_parameters"
|
||||
|
||||
workflow_parameter_id = Column(String, primary_key=True, index=True, default=generate_workflow_parameter_id)
|
||||
workflow_parameter_type = Column(String, nullable=False)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
default_value = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class AWSSecretParameterModel(Base):
|
||||
__tablename__ = "aws_secret_parameters"
|
||||
|
||||
aws_secret_parameter_id = Column(String, primary_key=True, index=True, default=generate_aws_secret_parameter_id)
|
||||
workflow_id = Column(String, ForeignKey("workflows.workflow_id"), index=True, nullable=False)
|
||||
key = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
aws_key = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class WorkflowRunParameterModel(Base):
|
||||
__tablename__ = "workflow_run_parameters"
|
||||
|
||||
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), primary_key=True, index=True)
|
||||
workflow_parameter_id = Column(
|
||||
String, ForeignKey("workflow_parameters.workflow_parameter_id"), primary_key=True, index=True
|
||||
)
|
||||
# Can be bool | int | float | str | dict | list depending on the workflow parameter type
|
||||
value = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
220
skyvern/forge/sdk/db/utils.py
Normal file
220
skyvern/forge/sdk/db/utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import json
|
||||
import typing
|
||||
|
||||
import pydantic.json
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ArtifactModel,
|
||||
AWSSecretParameterModel,
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDefinition,
|
||||
WorkflowRun,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _custom_json_serializer(*args, **kwargs) -> str:
|
||||
"""
|
||||
Encodes json in the same way that pydantic does.
|
||||
"""
|
||||
return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs)
|
||||
|
||||
|
||||
def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id)
|
||||
task = Task(
|
||||
task_id=task_obj.task_id,
|
||||
status=TaskStatus(task_obj.status),
|
||||
created_at=task_obj.created_at,
|
||||
modified_at=task_obj.modified_at,
|
||||
url=task_obj.url,
|
||||
webhook_callback_url=task_obj.webhook_callback_url,
|
||||
navigation_goal=task_obj.navigation_goal,
|
||||
data_extraction_goal=task_obj.data_extraction_goal,
|
||||
navigation_payload=task_obj.navigation_payload,
|
||||
extracted_information=task_obj.extracted_information,
|
||||
failure_reason=task_obj.failure_reason,
|
||||
organization_id=task_obj.organization_id,
|
||||
proxy_location=ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None,
|
||||
extracted_information_schema=task_obj.extracted_information_schema,
|
||||
workflow_run_id=task_obj.workflow_run_id,
|
||||
order=task_obj.order,
|
||||
retry=task_obj.retry,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
|
||||
return Step(
|
||||
task_id=step_model.task_id,
|
||||
step_id=step_model.step_id,
|
||||
created_at=step_model.created_at,
|
||||
modified_at=step_model.modified_at,
|
||||
status=StepStatus(step_model.status),
|
||||
output=step_model.output,
|
||||
order=step_model.order,
|
||||
is_last=step_model.is_last,
|
||||
retry_index=step_model.retry_index,
|
||||
organization_id=step_model.organization_id,
|
||||
input_token_count=step_model.input_token_count,
|
||||
output_token_count=step_model.output_token_count,
|
||||
step_cost=step_model.step_cost,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||
return Organization(
|
||||
organization_id=org_model.organization_id,
|
||||
organization_name=org_model.organization_name,
|
||||
webhook_callback_url=org_model.webhook_callback_url,
|
||||
max_steps_per_run=org_model.max_steps_per_run,
|
||||
created_at=org_model.created_at,
|
||||
modified_at=org_model.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_organization_auth_token(org_auth_token: OrganizationAuthTokenModel) -> OrganizationAuthToken:
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=org_auth_token.token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting ArtifactModel to Artifact", artifact_id=artifact_model.artifact_id)
|
||||
|
||||
return Artifact(
|
||||
artifact_id=artifact_model.artifact_id,
|
||||
artifact_type=ArtifactType[artifact_model.artifact_type.upper()],
|
||||
uri=artifact_model.uri,
|
||||
task_id=artifact_model.task_id,
|
||||
step_id=artifact_model.step_id,
|
||||
created_at=artifact_model.created_at,
|
||||
modified_at=artifact_model.modified_at,
|
||||
organization_id=artifact_model.organization_id,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = False) -> Workflow:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting WorkflowModel to Workflow", workflow_id=workflow_model.workflow_id)
|
||||
|
||||
return Workflow(
|
||||
workflow_id=workflow_model.workflow_id,
|
||||
organization_id=workflow_model.organization_id,
|
||||
title=workflow_model.title,
|
||||
description=workflow_model.description,
|
||||
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
|
||||
created_at=workflow_model.created_at,
|
||||
modified_at=workflow_model.modified_at,
|
||||
deleted_at=workflow_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled: bool = False) -> WorkflowRun:
|
||||
if debug_enabled:
|
||||
LOG.debug("Converting WorkflowRunModel to WorkflowRun", workflow_run_id=workflow_run_model.workflow_run_id)
|
||||
|
||||
return WorkflowRun(
|
||||
workflow_run_id=workflow_run_model.workflow_run_id,
|
||||
workflow_id=workflow_run_model.workflow_id,
|
||||
status=WorkflowRunStatus[workflow_run_model.status],
|
||||
proxy_location=ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None,
|
||||
webhook_callback_url=workflow_run_model.webhook_callback_url,
|
||||
created_at=workflow_run_model.created_at,
|
||||
modified_at=workflow_run_model.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_parameter(
|
||||
workflow_parameter_model: WorkflowParameterModel, debug_enabled: bool = False
|
||||
) -> WorkflowParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowParameterModel to WorkflowParameter",
|
||||
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
|
||||
)
|
||||
|
||||
workflow_parameter_type = WorkflowParameterType[workflow_parameter_model.workflow_parameter_type.upper()]
|
||||
|
||||
return WorkflowParameter(
|
||||
workflow_parameter_id=workflow_parameter_model.workflow_parameter_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
workflow_id=workflow_parameter_model.workflow_id,
|
||||
default_value=workflow_parameter_type.convert_value(workflow_parameter_model.default_value),
|
||||
key=workflow_parameter_model.key,
|
||||
description=workflow_parameter_model.description,
|
||||
created_at=workflow_parameter_model.created_at,
|
||||
modified_at=workflow_parameter_model.modified_at,
|
||||
deleted_at=workflow_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_aws_secret_parameter(
|
||||
aws_secret_parameter_model: AWSSecretParameterModel, debug_enabled: bool = False
|
||||
) -> AWSSecretParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting AWSSecretParameterModel to AWSSecretParameter",
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.id,
|
||||
)
|
||||
|
||||
return AWSSecretParameter(
|
||||
aws_secret_parameter_id=aws_secret_parameter_model.aws_secret_parameter_id,
|
||||
workflow_id=aws_secret_parameter_model.workflow_id,
|
||||
key=aws_secret_parameter_model.key,
|
||||
description=aws_secret_parameter_model.description,
|
||||
aws_key=aws_secret_parameter_model.aws_key,
|
||||
created_at=aws_secret_parameter_model.created_at,
|
||||
modified_at=aws_secret_parameter_model.modified_at,
|
||||
deleted_at=aws_secret_parameter_model.deleted_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter_model: WorkflowRunParameterModel,
|
||||
workflow_parameter: WorkflowParameter,
|
||||
debug_enabled: bool = False,
|
||||
) -> WorkflowRunParameter:
|
||||
if debug_enabled:
|
||||
LOG.debug(
|
||||
"Converting WorkflowRunParameterModel to WorkflowRunParameter",
|
||||
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
|
||||
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
|
||||
)
|
||||
|
||||
return WorkflowRunParameter(
|
||||
workflow_run_id=workflow_run_parameter_model.workflow_run_id,
|
||||
workflow_parameter_id=workflow_run_parameter_model.workflow_parameter_id,
|
||||
value=workflow_parameter.workflow_parameter_type.convert_value(workflow_run_parameter_model.value),
|
||||
created_at=workflow_run_parameter_model.created_at,
|
||||
)
|
||||
Reference in New Issue
Block a user