task_v2 refactor Part 1 - observer_service -> task_v2_service (#1812)

This commit is contained in:
Shuchang Zheng
2025-02-22 01:36:35 -08:00
committed by GitHub
parent 9a07c0bc6f
commit 1e7318d004
5 changed files with 12 additions and 12 deletions

View File

@@ -9,7 +9,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskRequest, ObserverTaskStatus from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskRequest, ObserverTaskStatus
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services import observer_service from skyvern.forge.sdk.services import task_v2_service
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.utils import migrate_db from skyvern.utils import migrate_db
@@ -80,7 +80,7 @@ class Agent:
status=WorkflowRunStatus.queued, status=WorkflowRunStatus.queued,
) )
await observer_service.run_observer_task( await task_v2_service.run_observer_task(
organization=organization, organization=organization,
observer_cruise_id=observer_task.observer_cruise_id, observer_cruise_id=observer_task.observer_cruise_id,
) )
@@ -156,7 +156,7 @@ class Agent:
async def observer_task_v_2(self, task_request: ObserverTaskRequest) -> ObserverTask: async def observer_task_v_2(self, task_request: ObserverTaskRequest) -> ObserverTask:
organization = await self._get_organization() organization = await self._get_organization()
observer_task = await observer_service.initialize_observer_task( observer_task = await task_v2_service.initialize_observer_task(
organization=organization, organization=organization,
user_prompt=task_request.user_prompt, user_prompt=task_request.user_prompt,
user_url=str(task_request.url) if task_request.url else None, user_url=str(task_request.url) if task_request.url else None,

View File

@@ -9,7 +9,7 @@ from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.observers import ObserverTaskStatus from skyvern.forge.sdk.schemas.observers import ObserverTaskStatus
from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.services import observer_service from skyvern.forge.sdk.services import task_v2_service
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -176,7 +176,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
if background_tasks: if background_tasks:
background_tasks.add_task( background_tasks.add_task(
observer_service.run_observer_task, task_v2_service.run_observer_task,
organization=organization, organization=organization,
observer_cruise_id=observer_cruise_id, observer_cruise_id=observer_cruise_id,
max_iterations_override=max_iterations_override, max_iterations_override=max_iterations_override,

View File

@@ -58,7 +58,7 @@ from skyvern.forge.sdk.schemas.tasks import (
TaskStatus, TaskStatus,
) )
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline
from skyvern.forge.sdk.services import observer_service, org_auth_service from skyvern.forge.sdk.services import org_auth_service, task_v2_service
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow, FailedToCreateWorkflow,
FailedToUpdateWorkflow, FailedToUpdateWorkflow,
@@ -1234,7 +1234,7 @@ async def observer_task(
LOG.info("Overriding max iterations for observer", max_iterations_override=x_max_iterations_override) LOG.info("Overriding max iterations for observer", max_iterations_override=x_max_iterations_override)
try: try:
observer_task = await observer_service.initialize_observer_task( observer_task = await task_v2_service.initialize_observer_task(
organization=organization, organization=organization,
user_prompt=data.user_prompt, user_prompt=data.user_prompt,
user_url=str(data.url) if data.url else None, user_url=str(data.url) if data.url else None,
@@ -1268,7 +1268,7 @@ async def get_observer_task(
task_id: str, task_id: str,
organization: Organization = Depends(org_auth_service.get_current_org), organization: Organization = Depends(org_auth_service.get_current_org),
) -> dict[str, Any]: ) -> dict[str, Any]:
observer_task = await observer_service.get_observer_cruise(task_id, organization.organization_id) observer_task = await task_v2_service.get_observer_cruise(task_id, organization.organization_id)
if not observer_task: if not observer_task:
raise HTTPException(status_code=404, detail=f"Observer task {task_id} not found") raise HTTPException(status_code=404, detail=f"Observer task {task_id} not found")
return observer_task.model_dump(by_alias=True) return observer_task.model_dump(by_alias=True)
@@ -1408,7 +1408,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
final_workflow_run_block_timeline.extend(workflow_blocks) final_workflow_run_block_timeline.extend(workflow_blocks)
if observer_task_obj and observer_task_obj.observer_cruise_id: if observer_task_obj and observer_task_obj.observer_cruise_id:
observer_thought_timeline = await observer_service.get_observer_thought_timelines( observer_thought_timeline = await task_v2_service.get_observer_thought_timelines(
observer_cruise_id=observer_task_obj.observer_cruise_id, observer_cruise_id=observer_task_obj.observer_cruise_id,
organization_id=organization_id, organization_id=organization_id,
) )

View File

@@ -2138,7 +2138,7 @@ class TaskV2Block(Block):
browser_session_id: str | None = None, browser_session_id: str | None = None,
**kwargs: dict, **kwargs: dict,
) -> BlockResult: ) -> BlockResult:
from skyvern.forge.sdk.services import observer_service from skyvern.forge.sdk.services import task_v2_service
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
if not organization_id: if not organization_id:
@@ -2150,7 +2150,7 @@ class TaskV2Block(Block):
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id) workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id)
if not workflow_run: if not workflow_run:
raise ValueError(f"WorkflowRun not found {workflow_run_id} when running TaskV2Block") raise ValueError(f"WorkflowRun not found {workflow_run_id} when running TaskV2Block")
observer_task = await observer_service.initialize_observer_task( observer_task = await task_v2_service.initialize_observer_task(
organization, organization,
user_prompt=self.prompt, user_prompt=self.prompt,
user_url=self.url, user_url=self.url,
@@ -2171,7 +2171,7 @@ class TaskV2Block(Block):
block_workflow_run_id=observer_task.workflow_run_id, block_workflow_run_id=observer_task.workflow_run_id,
) )
observer_task = await observer_service.run_observer_task( observer_task = await task_v2_service.run_observer_task(
organization=organization, organization=organization,
observer_cruise_id=observer_task.observer_cruise_id, observer_cruise_id=observer_task.observer_cruise_id,
request_id=None, request_id=None,