From c158ad3f216f1479a793f4526d0bee25ccc74d7d Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 15 Jan 2025 09:59:18 -0800 Subject: [PATCH] migrate observer to task v2 (#1564) --- skyvern/forge/api_app.py | 3 +- .../forge/sdk/api/llm/api_handler_factory.py | 6 +-- skyvern/forge/sdk/api/llm/models.py | 4 +- skyvern/forge/sdk/artifact/manager.py | 6 +-- skyvern/forge/sdk/artifact/storage/base.py | 4 +- skyvern/forge/sdk/artifact/storage/local.py | 4 +- skyvern/forge/sdk/artifact/storage/s3.py | 4 +- skyvern/forge/sdk/db/client.py | 27 +++++------ skyvern/forge/sdk/executor/async_executor.py | 4 +- skyvern/forge/sdk/routes/agent_protocol.py | 46 +++++++++---------- skyvern/forge/sdk/schemas/observers.py | 24 +++++----- .../forge/sdk/services/observer_service.py | 26 +++++------ skyvern/forge/sdk/workflow/models/workflow.py | 4 +- 13 files changed, 79 insertions(+), 83 deletions(-) diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py index ca2ca93c..7c9e0606 100644 --- a/skyvern/forge/api_app.py +++ b/skyvern/forge/api_app.py @@ -17,7 +17,7 @@ from skyvern.forge import app as forge_app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.exceptions import NotFoundError -from skyvern.forge.sdk.routes.agent_protocol import base_router +from skyvern.forge.sdk.routes.agent_protocol import base_router, v2_router from skyvern.forge.sdk.routes.streaming import websocket_router LOG = structlog.get_logger() @@ -47,6 +47,7 @@ def get_agent_app() -> FastAPI: ) app.include_router(base_router, prefix="/api/v1") + app.include_router(v2_router, prefix="/api/v2") app.include_router(websocket_router, prefix="/api/v1/stream") app.add_middleware( diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 17bdee7d..357e4e8f 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -23,7 +23,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverThought LOG = structlog.get_logger() @@ -62,7 +62,7 @@ class LLMAPIHandlerFactory: async def llm_api_handler_with_router_and_fallback( prompt: str, step: Step | None = None, - observer_cruise: ObserverCruise | None = None, + observer_cruise: ObserverTask | None = None, observer_thought: ObserverThought | None = None, ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, @@ -201,7 +201,7 @@ class LLMAPIHandlerFactory: async def llm_api_handler( prompt: str, step: Step | None = None, - observer_cruise: ObserverCruise | None = None, + observer_cruise: ObserverTask | None = None, observer_thought: ObserverThought | None = None, ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 5595e240..59fbedcc 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -5,7 +5,7 @@ from litellm import AllowedFailsPolicy from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverThought from skyvern.forge.sdk.settings_manager import SettingsManager @@ -80,7 +80,7 @@ class LLMAPIHandler(Protocol): self, prompt: str, step: Step | None = None, - observer_cruise: ObserverCruise | None = None, + observer_cruise: ObserverTask | None = None, observer_thought: ObserverThought | None = None, ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, diff --git a/skyvern/forge/sdk/artifact/manager.py b/skyvern/forge/sdk/artifact/manager.py index d8cca3b6..a90c0fd0 100644 --- a/skyvern/forge/sdk/artifact/manager.py +++ b/skyvern/forge/sdk/artifact/manager.py @@ -9,7 +9,7 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityT from skyvern.forge.sdk.db.id import generate_artifact_id from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock LOG = structlog.get_logger(__name__) @@ -137,7 +137,7 @@ class ArtifactManager: async def create_observer_cruise_artifact( self, - observer_cruise: ObserverCruise, + observer_cruise: ObserverTask, artifact_type: ArtifactType, data: bytes | None = None, path: str | None = None, @@ -203,7 +203,7 @@ class ArtifactManager: screenshots: list[bytes] | None = None, step: Step | None = None, observer_thought: ObserverThought | None = None, - observer_cruise: ObserverCruise | None = None, + observer_cruise: ObserverTask | None = None, ai_suggestion: AISuggestion | None = None, ) -> None: if step: diff --git a/skyvern/forge/sdk/artifact/storage/base.py b/skyvern/forge/sdk/artifact/storage/base.py index cbc20379..27a36b6d 100644 --- a/skyvern/forge/sdk/artifact/storage/base.py +++ b/skyvern/forge/sdk/artifact/storage/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock # TODO: This should be a part of the ArtifactType model @@ -52,7 +52,7 @@ class BaseStorage(ABC): @abstractmethod def build_observer_cruise_uri( - self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType + self, artifact_id: str, observer_cruise: ObserverTask, artifact_type: ArtifactType ) -> str: pass diff --git a/skyvern/forge/sdk/artifact/storage/local.py b/skyvern/forge/sdk/artifact/storage/local.py index e1f2215a..50a77812 100644 --- a/skyvern/forge/sdk/artifact/storage/local.py +++ b/skyvern/forge/sdk/artifact/storage/local.py @@ -12,7 +12,7 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityT from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock LOG = structlog.get_logger() @@ -37,7 +37,7 @@ class LocalStorage(BaseStorage): return f"file://{self.artifact_path}/{settings.ENV}/observers/{observer_thought.observer_cruise_id}/{observer_thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" def build_observer_cruise_uri( - self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType + self, artifact_id: str, observer_cruise: ObserverTask, artifact_type: ArtifactType ) -> str: file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"file://{self.artifact_path}/{settings.ENV}/observers/{observer_cruise.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" diff --git a/skyvern/forge/sdk/artifact/storage/s3.py b/skyvern/forge/sdk/artifact/storage/s3.py index e44cff66..87ba1e25 100644 --- a/skyvern/forge/sdk/artifact/storage/s3.py +++ b/skyvern/forge/sdk/artifact/storage/s3.py @@ -16,7 +16,7 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityT from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock @@ -40,7 +40,7 @@ class S3Storage(BaseStorage): return f"s3://{self.bucket}/{settings.ENV}/observers/{observer_thought.observer_cruise_id}/{observer_thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" def build_observer_cruise_uri( - self, artifact_id: str, observer_cruise: ObserverCruise, artifact_type: ArtifactType + self, artifact_id: str, observer_cruise: ObserverTask, artifact_type: ArtifactType ) -> str: file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"s3://{self.bucket}/{settings.ENV}/observers/{observer_cruise.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 33d0b953..1c03a194 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -58,12 +58,7 @@ from skyvern.forge.sdk.db.utils import ( from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.observers import ( - ObserverCruise, - ObserverCruiseStatus, - ObserverThought, - ObserverThoughtType, -) +from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskStatus, ObserverThought, ObserverThoughtType from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession from skyvern.forge.sdk.schemas.task_generations import TaskGeneration @@ -1934,7 +1929,7 @@ class AgentDB: async def get_observer_cruise( self, observer_cruise_id: str, organization_id: str | None = None - ) -> ObserverCruise | None: + ) -> ObserverTask | None: async with self.Session() as session: if observer_cruise := ( await session.scalars( @@ -1943,7 +1938,7 @@ class AgentDB: .filter_by(organization_id=organization_id) ) ).first(): - return ObserverCruise.model_validate(observer_cruise) + return ObserverTask.model_validate(observer_cruise) return None async def delete_observer_thoughts_for_cruise( @@ -1963,7 +1958,7 @@ class AgentDB: self, workflow_run_id: str, organization_id: str | None = None, - ) -> ObserverCruise | None: + ) -> ObserverTask | None: async with self.Session() as session: if observer_cruise := ( await session.scalars( @@ -1972,7 +1967,7 @@ class AgentDB: .filter_by(workflow_run_id=workflow_run_id) ) ).first(): - return ObserverCruise.model_validate(observer_cruise) + return ObserverTask.model_validate(observer_cruise) return None async def get_observer_thought( @@ -2015,7 +2010,7 @@ class AgentDB: prompt: str | None = None, url: str | None = None, organization_id: str | None = None, - ) -> ObserverCruise: + ) -> ObserverTask: async with self.Session() as session: new_observer_cruise = ObserverCruiseModel( workflow_run_id=workflow_run_id, @@ -2028,7 +2023,7 @@ class AgentDB: session.add(new_observer_cruise) await session.commit() await session.refresh(new_observer_cruise) - return ObserverCruise.model_validate(new_observer_cruise) + return ObserverTask.model_validate(new_observer_cruise) async def create_observer_thought( self, @@ -2113,7 +2108,7 @@ class AgentDB: async def update_observer_cruise( self, observer_cruise_id: str, - status: ObserverCruiseStatus | None = None, + status: ObserverTaskStatus | None = None, workflow_run_id: str | None = None, workflow_id: str | None = None, workflow_permanent_id: str | None = None, @@ -2122,7 +2117,7 @@ class AgentDB: summary: str | None = None, output: dict[str, Any] | None = None, organization_id: str | None = None, - ) -> ObserverCruise: + ) -> ObserverTask: async with self.Session() as session: observer_cruise = ( await session.scalars( @@ -2150,8 +2145,8 @@ class AgentDB: observer_cruise.output = output await session.commit() await session.refresh(observer_cruise) - return ObserverCruise.model_validate(observer_cruise) - raise NotFoundError(f"ObserverCruise {observer_cruise_id} not found") + return ObserverTask.model_validate(observer_cruise) + raise NotFoundError(f"ObserverTask {observer_cruise_id} not found") async def create_workflow_run_block( self, diff --git a/skyvern/forge/sdk/executor/async_executor.py b/skyvern/forge/sdk/executor/async_executor.py index 15c96825..342bdb80 100644 --- a/skyvern/forge/sdk/executor/async_executor.py +++ b/skyvern/forge/sdk/executor/async_executor.py @@ -7,7 +7,7 @@ from skyvern.exceptions import OrganizationNotFound from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext -from skyvern.forge.sdk.schemas.observers import ObserverCruiseStatus +from skyvern.forge.sdk.schemas.observers import ObserverTaskStatus from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.services import observer_service from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus @@ -166,7 +166,7 @@ class BackgroundTaskExecutor(AsyncExecutor): # mark observer cruise as queued await app.DATABASE.update_observer_cruise( observer_cruise_id, - status=ObserverCruiseStatus.queued, + status=ObserverTaskStatus.queued, organization_id=organization_id, ) await app.DATABASE.update_workflow_run( diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index b75e87a6..66c7a07c 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -37,7 +37,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestionBase, AISuggestionRequest -from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise +from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest from skyvern.forge.sdk.schemas.organizations import ( GetOrganizationAPIKeysResponse, GetOrganizationsResponse, @@ -73,6 +73,7 @@ from skyvern.webeye.actions.actions import Action from skyvern.webeye.schemas import BrowserSessionResponse base_router = APIRouter() +v2_router = APIRouter() LOG = structlog.get_logger() @@ -711,18 +712,16 @@ async def get_workflow_runs_for_workflow_permanent_id( @base_router.get( "/workflows/{workflow_id}/runs/{workflow_run_id}", - response_model=WorkflowRunStatusResponse, ) @base_router.get( "/workflows/{workflow_id}/runs/{workflow_run_id}/", - response_model=WorkflowRunStatusResponse, include_in_schema=False, ) async def get_workflow_run( workflow_id: str, workflow_run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), -) -> WorkflowRunStatusResponse: +) -> dict[str, Any]: analytics.capture("skyvern-oss-agent-workflow-run-get") workflow_run_status_response = await app.WORKFLOW_SERVICE.build_workflow_run_status_response( workflow_permanent_id=workflow_id, @@ -730,12 +729,13 @@ async def get_workflow_run( organization_id=current_org.organization_id, include_cost=True, ) + return_dict = workflow_run_status_response.model_dump() observer_cruise = await app.DATABASE.get_observer_cruise_by_workflow_run_id( workflow_run_id=workflow_run_id, organization_id=current_org.organization_id, ) if observer_cruise: - workflow_run_status_response.observer_cruise = observer_cruise + return_dict["observer_task"] = observer_cruise.model_dump(by_alias=True) return workflow_run_status_response @@ -1115,20 +1115,20 @@ async def upload_file( ) -@base_router.post("/cruise") -@base_router.post("/cruise/", include_in_schema=False) -async def observer_cruise( +@v2_router.post("/tasks") +@v2_router.post("/tasks/", include_in_schema=False) +async def observer_task( request: Request, background_tasks: BackgroundTasks, - data: CruiseRequest, + data: ObserverTaskRequest, organization: Organization = Depends(org_auth_service.get_current_org), x_max_iterations_override: Annotated[int | None, Header()] = None, -) -> ObserverCruise: +) -> dict[str, Any]: if x_max_iterations_override: LOG.info("Overriding max iterations for observer", max_iterations_override=x_max_iterations_override) try: - observer_cruise = await observer_service.initialize_observer_cruise( + observer_task = await observer_service.initialize_observer_cruise( organization=organization, user_prompt=data.user_prompt, user_url=str(data.url) if data.url else None, @@ -1138,28 +1138,28 @@ async def observer_cruise( raise HTTPException( status_code=500, detail="Skyvern LLM failure to initialize observer cruise. Please try again later." ) - analytics.capture("skyvern-oss-agent-observer-cruise", data={"url": observer_cruise.url}) + analytics.capture("skyvern-oss-agent-observer-cruise", data={"url": observer_task.url}) await AsyncExecutorFactory.get_executor().execute_cruise( request=request, background_tasks=background_tasks, organization_id=organization.organization_id, - observer_cruise_id=observer_cruise.observer_cruise_id, + observer_cruise_id=observer_task.observer_cruise_id, max_iterations_override=x_max_iterations_override, browser_session_id=data.browser_session_id, ) - return observer_cruise + return observer_task.model_dump(by_alias=True) -@base_router.get("/cruise/{observer_cruise_id}") -@base_router.get("/cruise/{observer_cruise_id}/", include_in_schema=False) -async def get_observer_cruise( - observer_cruise_id: str, +@v2_router.get("/tasks/{task_id}") +@v2_router.get("/tasks/{task_id}/", include_in_schema=False) +async def get_observer_task( + task_id: str, organization: Organization = Depends(org_auth_service.get_current_org), -) -> ObserverCruise: - observer_cruise = await observer_service.get_observer_cruise(observer_cruise_id, organization.organization_id) - if not observer_cruise: - raise HTTPException(status_code=404, detail=f"Observer cruise {observer_cruise_id} not found") - return observer_cruise +) -> dict[str, Any]: + observer_task = await observer_service.get_observer_cruise(task_id, organization.organization_id) + if not observer_task: + raise HTTPException(status_code=404, detail=f"Observer task {task_id} not found") + return observer_task.model_dump(by_alias=True) @base_router.get( diff --git a/skyvern/forge/sdk/schemas/observers.py b/skyvern/forge/sdk/schemas/observers.py index a49a2d40..019ff912 100644 --- a/skyvern/forge/sdk/schemas/observers.py +++ b/skyvern/forge/sdk/schemas/observers.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import StrEnum from typing import Any -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from skyvern.forge.sdk.core.validators import validate_url from skyvern.forge.sdk.schemas.tasks import ProxyLocation @@ -10,7 +10,7 @@ from skyvern.forge.sdk.schemas.tasks import ProxyLocation DEFAULT_WORKFLOW_TITLE = "New Workflow" -class ObserverCruiseStatus(StrEnum): +class ObserverTaskStatus(StrEnum): created = "created" queued = "queued" running = "running" @@ -21,11 +21,11 @@ class ObserverCruiseStatus(StrEnum): completed = "completed" -class ObserverCruise(BaseModel): - model_config = ConfigDict(from_attributes=True) +class ObserverTask(BaseModel): + model_config = ConfigDict(from_attributes=True, populate_by_name=True) - observer_cruise_id: str - status: ObserverCruiseStatus + observer_cruise_id: str = Field(alias="task_id") + status: ObserverTaskStatus organization_id: str | None = None workflow_run_id: str | None = None workflow_id: str | None = None @@ -69,10 +69,10 @@ class ObserverThoughtScenario(StrEnum): class ObserverThought(BaseModel): - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True, populate_by_name=True) - observer_thought_id: str - observer_cruise_id: str + observer_thought_id: str = Field(alias="thought_id") + observer_cruise_id: str = Field(alias="task_id") organization_id: str | None = None workflow_run_id: str | None = None workflow_run_block_id: str | None = None @@ -82,8 +82,8 @@ class ObserverThought(BaseModel): observation: str | None = None thought: str | None = None answer: str | None = None - observer_thought_type: ObserverThoughtType | None = ObserverThoughtType.plan - observer_thought_scenario: ObserverThoughtScenario | None = None + observer_thought_type: ObserverThoughtType | None = Field(alias="thought_type", default=ObserverThoughtType.plan) + observer_thought_scenario: ObserverThoughtScenario | None = Field(alias="thought_scenario", default=None) output: dict[str, Any] | None = None created_at: datetime @@ -102,7 +102,7 @@ class ObserverMetadata(BaseModel): return validate_url(v) -class CruiseRequest(BaseModel): +class ObserverTaskRequest(BaseModel): user_prompt: str url: str | None = None browser_session_id: str | None = None diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index aa5892df..ca1f8adf 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -17,9 +17,9 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.schemas.observers import ( - ObserverCruise, - ObserverCruiseStatus, ObserverMetadata, + ObserverTask, + ObserverTaskStatus, ObserverThoughtScenario, ObserverThoughtType, ) @@ -81,7 +81,7 @@ def _generate_data_extraction_schema_for_loop(loop_values_key: str) -> dict: async def initialize_observer_cruise( organization: Organization, user_prompt: str, user_url: str | None = None -) -> ObserverCruise: +) -> ObserverTask: observer_cruise = await app.DATABASE.create_observer_cruise( prompt=user_prompt, organization_id=organization.organization_id, @@ -237,14 +237,14 @@ async def run_observer_cruise( async def run_observer_cruise_helper( organization: Organization, - observer_cruise: ObserverCruise, + observer_cruise: ObserverTask, request_id: str | None = None, max_iterations_override: str | int | None = None, browser_session_id: str | None = None, ) -> tuple[Workflow, WorkflowRun] | tuple[None, None]: organization_id = organization.organization_id observer_cruise_id = observer_cruise.observer_cruise_id - if observer_cruise.status != ObserverCruiseStatus.queued: + if observer_cruise.status != ObserverTaskStatus.queued: LOG.error( "Observer cruise is not queued. Duplicate observer cruise", observer_cruise_id=observer_cruise_id, @@ -310,7 +310,7 @@ async def run_observer_cruise_helper( ) await app.DATABASE.update_observer_cruise( - observer_cruise_id=observer_cruise_id, organization_id=organization_id, status=ObserverCruiseStatus.running + observer_cruise_id=observer_cruise_id, organization_id=organization_id, status=ObserverTaskStatus.running ) await app.WORKFLOW_SERVICE.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id) await _set_up_workflow_context(workflow_id, workflow_run_id) @@ -700,7 +700,7 @@ async def _set_up_workflow_context(workflow_id: str, workflow_run_id: str) -> No async def _generate_loop_task( - observer_cruise: ObserverCruise, + observer_cruise: ObserverTask, workflow_id: str, workflow_permanent_id: str, workflow_run_id: str, @@ -938,7 +938,7 @@ async def _generate_loop_task( async def _generate_extraction_task( - observer_cruise: ObserverCruise, + observer_cruise: ObserverTask, workflow_id: str, workflow_permanent_id: str, workflow_run_id: str, @@ -1059,7 +1059,7 @@ async def get_observer_thought_timelines( ] -async def get_observer_cruise(observer_cruise_id: str, organization_id: str | None = None) -> ObserverCruise | None: +async def get_observer_cruise(observer_cruise_id: str, organization_id: str | None = None) -> ObserverTask | None: return await app.DATABASE.get_observer_cruise(observer_cruise_id, organization_id=organization_id) @@ -1070,7 +1070,7 @@ async def mark_observer_cruise_as_failed( organization_id: str | None = None, ) -> None: await app.DATABASE.update_observer_cruise( - observer_cruise_id, organization_id=organization_id, status=ObserverCruiseStatus.failed + observer_cruise_id, organization_id=organization_id, status=ObserverTaskStatus.failed ) if workflow_run_id: await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( @@ -1091,7 +1091,7 @@ async def mark_observer_cruise_as_completed( await app.DATABASE.update_observer_cruise( observer_cruise_id, organization_id=organization_id, - status=ObserverCruiseStatus.completed, + status=ObserverTaskStatus.completed, summary=summary, output=output, ) @@ -1173,7 +1173,7 @@ def _get_extracted_data_from_block_result( async def _summarize_observer_cruise( - observer_cruise: ObserverCruise, + observer_cruise: ObserverTask, task_history: list[dict], context: SkyvernContext, screenshots: list[bytes] | None = None, @@ -1219,7 +1219,7 @@ async def _summarize_observer_cruise( ) -async def send_observer_cruise_webhook(observer_cruise: ObserverCruise) -> None: +async def send_observer_cruise_webhook(observer_cruise: ObserverTask) -> None: if not observer_cruise.webhook_callback_url: return organization_id = observer_cruise.organization_id diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index c59cacc1..51e0e33c 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -5,7 +5,7 @@ from typing import Any, List from pydantic import BaseModel, field_validator from skyvern.forge.sdk.core.validators import validate_url -from skyvern.forge.sdk.schemas.observers import ObserverCruise +from skyvern.forge.sdk.schemas.observers import ObserverTask from skyvern.forge.sdk.schemas.tasks import ProxyLocation from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels from skyvern.forge.sdk.workflow.models.block import BlockTypeVar @@ -138,4 +138,4 @@ class WorkflowRunStatusResponse(BaseModel): outputs: dict[str, Any] | None = None total_steps: int | None = None total_cost: float | None = None - observer_cruise: ObserverCruise | None = None + observer_cruise: ObserverTask | None = None