Observer code open source (#1417)

This commit is contained in:
Shuchang Zheng
2024-12-19 17:26:08 -08:00
committed by GitHub
parent bd0d6a5920
commit a12776e630
12 changed files with 1071 additions and 16 deletions

View File

@@ -7,7 +7,10 @@ from skyvern.exceptions import OrganizationNotFound
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.observers import ObserverCruiseStatus
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.services import observer_service
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
LOG = structlog.get_logger()
@@ -40,12 +43,24 @@ class AsyncExecutor(abc.ABC):
) -> None:
pass
@abc.abstractmethod
async def execute_cruise(
self,
request: Request | None,
background_tasks: BackgroundTasks | None,
organization_id: str,
observer_cruise_id: str,
max_iterations_override: int | None,
**kwargs: dict,
) -> None:
pass
class BackgroundTaskExecutor(AsyncExecutor):
async def execute_task(
self,
request: Request | None,
background_tasks: BackgroundTasks,
background_tasks: BackgroundTasks | None,
task_id: str,
organization_id: str,
max_steps_override: int | None,
@@ -76,18 +91,19 @@ class BackgroundTaskExecutor(AsyncExecutor):
context.organization_id = organization_id
context.max_steps_override = max_steps_override
background_tasks.add_task(
app.agent.execute_step,
organization,
task,
step,
api_key,
)
if background_tasks:
background_tasks.add_task(
app.agent.execute_step,
organization,
task,
step,
api_key,
)
async def execute_workflow(
self,
request: Request | None,
background_tasks: BackgroundTasks,
background_tasks: BackgroundTasks | None,
organization_id: str,
workflow_id: str,
workflow_run_id: str,
@@ -104,9 +120,53 @@ class BackgroundTaskExecutor(AsyncExecutor):
if organization is None:
raise OrganizationNotFound(organization_id)
background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id,
api_key=api_key,
organization=organization,
if background_tasks:
background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id,
api_key=api_key,
organization=organization,
)
async def execute_cruise(
self,
request: Request | None,
background_tasks: BackgroundTasks | None,
organization_id: str,
observer_cruise_id: str,
max_iterations_override: int | None,
**kwargs: dict,
) -> None:
LOG.info(
"Executing cruise using background task executor",
observer_cruise_id=observer_cruise_id,
)
organization = await app.DATABASE.get_organization(organization_id)
if organization is None:
raise OrganizationNotFound(organization_id)
observer_cruise = await app.DATABASE.get_observer_cruise(
observer_cruise_id=observer_cruise_id, organization_id=organization_id
)
if not observer_cruise or not observer_cruise.workflow_run_id:
raise ValueError("No observer cruise or no workflow run associated with observer cruise")
# mark observer cruise as queued
await app.DATABASE.update_observer_cruise(
observer_cruise_id,
status=ObserverCruiseStatus.queued,
organization_id=organization_id,
)
await app.DATABASE.update_workflow_run(
workflow_run_id=observer_cruise.workflow_run_id,
status=WorkflowRunStatus.queued,
)
if background_tasks:
background_tasks.add_task(
observer_service.run_observer_cruise,
organization=organization,
observer_cruise_id=observer_cruise_id,
max_iterations_override=max_iterations_override,
)