migrate observer to task v2 (#1564)
This commit is contained in:
@@ -37,7 +37,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestionBase, AISuggestionRequest
|
||||
from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
GetOrganizationAPIKeysResponse,
|
||||
GetOrganizationsResponse,
|
||||
@@ -73,6 +73,7 @@ from skyvern.webeye.actions.actions import Action
|
||||
from skyvern.webeye.schemas import BrowserSessionResponse
|
||||
|
||||
base_router = APIRouter()
|
||||
v2_router = APIRouter()
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -711,18 +712,16 @@ async def get_workflow_runs_for_workflow_permanent_id(
|
||||
|
||||
@base_router.get(
|
||||
"/workflows/{workflow_id}/runs/{workflow_run_id}",
|
||||
response_model=WorkflowRunStatusResponse,
|
||||
)
|
||||
@base_router.get(
|
||||
"/workflows/{workflow_id}/runs/{workflow_run_id}/",
|
||||
response_model=WorkflowRunStatusResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_workflow_run(
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> WorkflowRunStatusResponse:
|
||||
) -> dict[str, Any]:
|
||||
analytics.capture("skyvern-oss-agent-workflow-run-get")
|
||||
workflow_run_status_response = await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
|
||||
workflow_permanent_id=workflow_id,
|
||||
@@ -730,12 +729,13 @@ async def get_workflow_run(
|
||||
organization_id=current_org.organization_id,
|
||||
include_cost=True,
|
||||
)
|
||||
return_dict = workflow_run_status_response.model_dump()
|
||||
observer_cruise = await app.DATABASE.get_observer_cruise_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
if observer_cruise:
|
||||
workflow_run_status_response.observer_cruise = observer_cruise
|
||||
return_dict["observer_task"] = observer_cruise.model_dump(by_alias=True)
|
||||
return workflow_run_status_response
|
||||
|
||||
|
||||
@@ -1115,20 +1115,20 @@ async def upload_file(
|
||||
)
|
||||
|
||||
|
||||
@base_router.post("/cruise")
|
||||
@base_router.post("/cruise/", include_in_schema=False)
|
||||
async def observer_cruise(
|
||||
@v2_router.post("/tasks")
|
||||
@v2_router.post("/tasks/", include_in_schema=False)
|
||||
async def observer_task(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
data: CruiseRequest,
|
||||
data: ObserverTaskRequest,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_max_iterations_override: Annotated[int | None, Header()] = None,
|
||||
) -> ObserverCruise:
|
||||
) -> dict[str, Any]:
|
||||
if x_max_iterations_override:
|
||||
LOG.info("Overriding max iterations for observer", max_iterations_override=x_max_iterations_override)
|
||||
|
||||
try:
|
||||
observer_cruise = await observer_service.initialize_observer_cruise(
|
||||
observer_task = await observer_service.initialize_observer_cruise(
|
||||
organization=organization,
|
||||
user_prompt=data.user_prompt,
|
||||
user_url=str(data.url) if data.url else None,
|
||||
@@ -1138,28 +1138,28 @@ async def observer_cruise(
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Skyvern LLM failure to initialize observer cruise. Please try again later."
|
||||
)
|
||||
analytics.capture("skyvern-oss-agent-observer-cruise", data={"url": observer_cruise.url})
|
||||
analytics.capture("skyvern-oss-agent-observer-cruise", data={"url": observer_task.url})
|
||||
await AsyncExecutorFactory.get_executor().execute_cruise(
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
organization_id=organization.organization_id,
|
||||
observer_cruise_id=observer_cruise.observer_cruise_id,
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
max_iterations_override=x_max_iterations_override,
|
||||
browser_session_id=data.browser_session_id,
|
||||
)
|
||||
return observer_cruise
|
||||
return observer_task.model_dump(by_alias=True)
|
||||
|
||||
|
||||
@base_router.get("/cruise/{observer_cruise_id}")
|
||||
@base_router.get("/cruise/{observer_cruise_id}/", include_in_schema=False)
|
||||
async def get_observer_cruise(
|
||||
observer_cruise_id: str,
|
||||
@v2_router.get("/tasks/{task_id}")
|
||||
@v2_router.get("/tasks/{task_id}/", include_in_schema=False)
|
||||
async def get_observer_task(
|
||||
task_id: str,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> ObserverCruise:
|
||||
observer_cruise = await observer_service.get_observer_cruise(observer_cruise_id, organization.organization_id)
|
||||
if not observer_cruise:
|
||||
raise HTTPException(status_code=404, detail=f"Observer cruise {observer_cruise_id} not found")
|
||||
return observer_cruise
|
||||
) -> dict[str, Any]:
|
||||
observer_task = await observer_service.get_observer_cruise(task_id, organization.organization_id)
|
||||
if not observer_task:
|
||||
raise HTTPException(status_code=404, detail=f"Observer task {task_id} not found")
|
||||
return observer_task.model_dump(by_alias=True)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
|
||||
Reference in New Issue
Block a user