is_saved_task parameter for workflows (#526)

This commit is contained in:
Kerem Yilmaz
2024-06-27 12:53:08 -07:00
committed by GitHub
parent c44a3076c0
commit 63adc860ef
8 changed files with 69 additions and 10 deletions

View File

@@ -774,6 +774,7 @@ class AgentDB:
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
is_saved_task: bool = False,
) -> Workflow:
async with self.Session() as session:
workflow = WorkflowModel(
@@ -783,6 +784,7 @@ class AgentDB:
workflow_definition=workflow_definition,
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
is_saved_task=is_saved_task,
)
if workflow_permanent_id:
workflow.workflow_permanent_id = workflow_permanent_id
@@ -838,6 +840,8 @@ class AgentDB:
organization_id: str,
page: int = 1,
page_size: int = 10,
only_saved_tasks: bool = False,
only_workflows: bool = False,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
@@ -861,17 +865,18 @@ class AgentDB:
)
.subquery()
)
main_query = select(WorkflowModel).join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
if only_saved_tasks:
main_query = main_query.where(WorkflowModel.is_saved_task.is_(True))
elif only_workflows:
main_query = main_query.where(WorkflowModel.is_saved_task.is_(False))
main_query = (
select(WorkflowModel)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.order_by(WorkflowModel.created_at.desc()) # Example ordering by creation date
.limit(page_size)
.offset(db_page * page_size)
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows]

View File

@@ -189,6 +189,7 @@ class WorkflowModel(Base):
workflow_permanent_id = Column(String, nullable=False, default=generate_workflow_permanent_id, index=True)
version = Column(Integer, default=1, nullable=False)
is_saved_task = Column(Boolean, default=False, nullable=False)
class WorkflowRunModel(Base):

View File

@@ -160,6 +160,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
webhook_callback_url=workflow_model.webhook_callback_url,
proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None),
version=workflow_model.version,
is_saved_task=workflow_model.is_saved_task,
description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
created_at=workflow_model.created_at,

View File

@@ -654,16 +654,27 @@ async def delete_workflow(
async def get_workflows(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
only_saved_tasks: bool = Query(False),
only_workflows: bool = Query(False),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
analytics.capture("skyvern-oss-agent-workflows-get")
if only_saved_tasks and only_workflows:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="only_saved_tasks and only_workflows cannot be used together",
)
return await app.WORKFLOW_SERVICE.get_workflows_by_organization_id(
organization_id=current_org.organization_id,
page=page,
page_size=page_size,
only_saved_tasks=only_saved_tasks,
only_workflows=only_workflows,
)

View File

@@ -44,6 +44,7 @@ class Workflow(BaseModel):
title: str
workflow_permanent_id: str
version: int
is_saved_task: bool
description: str | None = None
workflow_definition: WorkflowDefinition
proxy_location: ProxyLocation | None = None

View File

@@ -194,3 +194,4 @@ class WorkflowCreateYAMLRequest(BaseModel):
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
workflow_definition: WorkflowDefinitionYAML
is_saved_task: bool = False

View File

@@ -281,6 +281,7 @@ class WorkflowService:
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
is_saved_task: bool = False,
) -> Workflow:
return await app.DATABASE.create_workflow(
title=title,
@@ -291,6 +292,7 @@ class WorkflowService:
webhook_callback_url=webhook_callback_url,
workflow_permanent_id=workflow_permanent_id,
version=version,
is_saved_task=is_saved_task,
)
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow:
@@ -319,6 +321,8 @@ class WorkflowService:
organization_id: str,
page: int = 1,
page_size: int = 10,
only_saved_tasks: bool = False,
only_workflows: bool = False,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
@@ -327,6 +331,8 @@ class WorkflowService:
organization_id=organization_id,
page=page,
page_size=page_size,
only_saved_tasks=only_saved_tasks,
only_workflows=only_workflows,
)
async def update_workflow(
@@ -773,6 +779,7 @@ class WorkflowService:
webhook_callback_url=request.webhook_callback_url,
workflow_permanent_id=workflow_permanent_id,
version=existing_version + 1,
is_saved_task=request.is_saved_task,
)
else:
workflow = await self.create_workflow(
@@ -782,6 +789,7 @@ class WorkflowService:
organization_id=organization_id,
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
is_saved_task=request.is_saved_task,
)
# Create parameters from the request
parameters: dict[str, PARAMETER_TYPE] = {}