migrate observer to task v2 (#1564)

This commit is contained in:
Shuchang Zheng
2025-01-15 09:59:18 -08:00
committed by GitHub
parent 997b0adea7
commit c158ad3f21
13 changed files with 79 additions and 83 deletions

View File

@@ -37,7 +37,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestionBase, AISuggestionRequest
from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise
from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest
from skyvern.forge.sdk.schemas.organizations import (
GetOrganizationAPIKeysResponse,
GetOrganizationsResponse,
@@ -73,6 +73,7 @@ from skyvern.webeye.actions.actions import Action
from skyvern.webeye.schemas import BrowserSessionResponse
base_router = APIRouter()
v2_router = APIRouter()
LOG = structlog.get_logger()
@@ -711,18 +712,16 @@ async def get_workflow_runs_for_workflow_permanent_id(
@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}",
response_model=WorkflowRunStatusResponse,
)
@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/",
response_model=WorkflowRunStatusResponse,
include_in_schema=False,
)
async def get_workflow_run(
workflow_id: str,
workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowRunStatusResponse:
) -> dict[str, Any]:
analytics.capture("skyvern-oss-agent-workflow-run-get")
workflow_run_status_response = await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
workflow_permanent_id=workflow_id,
@@ -730,12 +729,13 @@ async def get_workflow_run(
organization_id=current_org.organization_id,
include_cost=True,
)
return_dict = workflow_run_status_response.model_dump()
observer_cruise = await app.DATABASE.get_observer_cruise_by_workflow_run_id(
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
if observer_cruise:
workflow_run_status_response.observer_cruise = observer_cruise
return_dict["observer_task"] = observer_cruise.model_dump(by_alias=True)
return workflow_run_status_response
@@ -1115,20 +1115,20 @@ async def upload_file(
)
@base_router.post("/cruise")
@base_router.post("/cruise/", include_in_schema=False)
async def observer_cruise(
@v2_router.post("/tasks")
@v2_router.post("/tasks/", include_in_schema=False)
async def observer_task(
request: Request,
background_tasks: BackgroundTasks,
data: CruiseRequest,
data: ObserverTaskRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
x_max_iterations_override: Annotated[int | None, Header()] = None,
) -> ObserverCruise:
) -> dict[str, Any]:
if x_max_iterations_override:
LOG.info("Overriding max iterations for observer", max_iterations_override=x_max_iterations_override)
try:
observer_cruise = await observer_service.initialize_observer_cruise(
observer_task = await observer_service.initialize_observer_cruise(
organization=organization,
user_prompt=data.user_prompt,
user_url=str(data.url) if data.url else None,
@@ -1138,28 +1138,28 @@ async def observer_cruise(
raise HTTPException(
status_code=500, detail="Skyvern LLM failure to initialize observer cruise. Please try again later."
)
analytics.capture("skyvern-oss-agent-observer-cruise", data={"url": observer_cruise.url})
analytics.capture("skyvern-oss-agent-observer-cruise", data={"url": observer_task.url})
await AsyncExecutorFactory.get_executor().execute_cruise(
request=request,
background_tasks=background_tasks,
organization_id=organization.organization_id,
observer_cruise_id=observer_cruise.observer_cruise_id,
observer_cruise_id=observer_task.observer_cruise_id,
max_iterations_override=x_max_iterations_override,
browser_session_id=data.browser_session_id,
)
return observer_cruise
return observer_task.model_dump(by_alias=True)
@base_router.get("/cruise/{observer_cruise_id}")
@base_router.get("/cruise/{observer_cruise_id}/", include_in_schema=False)
async def get_observer_cruise(
observer_cruise_id: str,
@v2_router.get("/tasks/{task_id}")
@v2_router.get("/tasks/{task_id}/", include_in_schema=False)
async def get_observer_task(
task_id: str,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> ObserverCruise:
observer_cruise = await observer_service.get_observer_cruise(observer_cruise_id, organization.organization_id)
if not observer_cruise:
raise HTTPException(status_code=404, detail=f"Observer cruise {observer_cruise_id} not found")
return observer_cruise
) -> dict[str, Any]:
observer_task = await observer_service.get_observer_cruise(task_id, organization.organization_id)
if not observer_task:
raise HTTPException(status_code=404, detail=f"Observer task {task_id} not found")
return observer_task.model_dump(by_alias=True)
@base_router.get(