support global workflow (#1664)

This commit is contained in:
Shuchang Zheng
2025-01-28 15:04:18 +08:00
committed by GitHub
parent 833cd8194e
commit 1b79ef9ca3
10 changed files with 154 additions and 26 deletions

View File

@@ -124,3 +124,11 @@ class FailedToFormatJinjaStyleParameter(SkyvernException):
class NoIterableValueFound(SkyvernException):
def __init__(self) -> None:
super().__init__("No iterable value found for the loop block")
class InvalidTemplateWorkflowPermanentId(SkyvernHTTPException):
def __init__(self, workflow_permanent_id: str) -> None:
super().__init__(
message=f"Invalid template workflow permanent id: {workflow_permanent_id}. Please make sure the workflow is a valid template.",
status_code=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -441,9 +441,8 @@ class BaseTaskBlock(Block):
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
workflow = await app.WORKFLOW_SERVICE.get_workflow(
workflow_id=workflow_run.workflow_id,
organization_id=organization_id,
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_run.workflow_permanent_id,
)
# if the task url is parameterized, we need to get the value from the workflow run context
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
@@ -512,12 +511,12 @@ class BaseTaskBlock(Block):
workflow_run_block = await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
task_id=task.task_id,
organization_id=workflow.organization_id,
organization_id=organization_id,
)
current_running_task = task
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
organization = await app.DATABASE.get_organization(organization_id=workflow_run.organization_id)
if not organization:
raise Exception(f"Organization is missing organization_id={workflow.organization_id}")
raise Exception(f"Organization is missing organization_id={workflow_run.organization_id}")
browser_state: BrowserState | None = None
if is_first_task:
@@ -544,7 +543,7 @@ class BaseTaskBlock(Block):
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
failure_reason=str(e),
)
raise e
@@ -569,7 +568,7 @@ class BaseTaskBlock(Block):
workflow_run_id=workflow_run_id,
task_id=task.task_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
step_id=step.step_id,
)
try:
@@ -578,7 +577,7 @@ class BaseTaskBlock(Block):
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
failure_reason=str(e),
)
raise e
@@ -597,13 +596,15 @@ class BaseTaskBlock(Block):
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
failure_reason=str(e),
)
raise e
# Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
updated_task = await app.DATABASE.get_task(
task_id=task.task_id, organization_id=workflow_run.organization_id
)
if not updated_task:
raise TaskNotFound(task.task_id)
if not updated_task.status.is_final():
@@ -624,7 +625,7 @@ class BaseTaskBlock(Block):
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)
success = updated_task.status == TaskStatus.completed
task_output = TaskOutput.from_task(updated_task)
@@ -645,7 +646,7 @@ class BaseTaskBlock(Block):
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)
return await self.build_block_result(
success=False,
@@ -662,7 +663,7 @@ class BaseTaskBlock(Block):
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)
return await self.build_block_result(
success=False,
@@ -683,7 +684,7 @@ class BaseTaskBlock(Block):
status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
current_retry=current_retry,
max_retries=self.max_retries,
task_output=task_output.model_dump_json(),

View File

@@ -93,6 +93,7 @@ class WorkflowService:
workflow_request: WorkflowRequestBody,
workflow_permanent_id: str,
organization_id: str,
is_template_workflow: bool = False,
version: int | None = None,
max_steps_override: int | None = None,
) -> WorkflowRun:
@@ -109,7 +110,7 @@ class WorkflowService:
# Validate the workflow and the organization
workflow = await self.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
organization_id=None if is_template_workflow else organization_id,
version=version,
)
if workflow is None:
@@ -125,7 +126,7 @@ class WorkflowService:
workflow_request=workflow_request,
workflow_permanent_id=workflow_permanent_id,
workflow_id=workflow_id,
organization_id=workflow.organization_id,
organization_id=organization_id,
)
LOG.info(
f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}",
@@ -202,7 +203,7 @@ class WorkflowService:
browser_session_id=browser_session_id,
)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id=workflow_run.workflow_permanent_id)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
@@ -520,6 +521,24 @@ class WorkflowService:
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version)
return workflow
async def get_workflows_by_permanent_ids(
self,
workflow_permanent_ids: list[str],
organization_id: str | None = None,
page: int = 1,
page_size: int = 10,
title: str = "",
statuses: list[WorkflowStatus] | None = None,
) -> list[Workflow]:
return await app.DATABASE.get_workflows_by_permanent_ids(
workflow_permanent_ids,
organization_id=organization_id,
page=page,
page_size=page_size,
title=title,
statuses=statuses,
)
async def get_workflows_by_organization_id(
self,
organization_id: str,
@@ -864,7 +883,7 @@ class WorkflowService:
organization_id: str,
include_cost: bool = False,
) -> WorkflowRunStatusResponse:
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id, organization_id=organization_id)
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id)
if workflow is None:
LOG.error(f"Workflow {workflow_permanent_id} not found")
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id)
@@ -903,7 +922,9 @@ class WorkflowService:
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
organization_id=workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id
organization_id=workflow_run.organization_id,
task_id=None,
workflow_run_id=workflow_run.workflow_run_id,
)
except asyncio.TimeoutError:
LOG.warning(
@@ -989,7 +1010,7 @@ class WorkflowService:
await self.persist_debug_artifacts(browser_state, tasks[-1], workflow, workflow_run)
if workflow.persist_browser_session and browser_state.browser_artifacts.browser_session_dir:
await app.STORAGE.store_browser_session(
workflow.organization_id,
workflow_run.organization_id,
workflow.workflow_permanent_id,
browser_state.browser_artifacts.browser_session_dir,
)
@@ -1000,7 +1021,7 @@ class WorkflowService:
try:
async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT):
await app.STORAGE.save_downloaded_files(
workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id
workflow_run.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id
)
except asyncio.TimeoutError:
LOG.warning(
@@ -1106,7 +1127,7 @@ class WorkflowService:
for video_artifact in video_artifacts:
await app.ARTIFACT_MANAGER.update_artifact_data(
artifact_id=video_artifact.video_artifact_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
data=video_artifact.video_data,
)