TaskRunType -> RunType (#2041)
This commit is contained in:
@@ -66,8 +66,8 @@ from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.runs import TaskRun
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
|
||||
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.totp_codes import TOTPCode
|
||||
@@ -91,7 +91,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
WorkflowRunStatus,
|
||||
WorkflowStatus,
|
||||
)
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
from skyvern.schemas.runs import ProxyLocation, RunType
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
from skyvern.webeye.actions.models import AgentStepOutput
|
||||
|
||||
@@ -2783,7 +2783,7 @@ class AgentDB:
|
||||
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: TaskRunType,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
@@ -2931,7 +2931,7 @@ class AgentDB:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: TaskRunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> TaskRun | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
|
||||
@@ -41,7 +41,6 @@ from skyvern.forge.sdk.schemas.organizations import (
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import (
|
||||
CreateTaskResponse,
|
||||
@@ -71,7 +70,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
WorkflowStatus,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||
from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse
|
||||
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest
|
||||
from skyvern.services import task_v1_service, task_v2_service
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
from skyvern.webeye.schemas import BrowserSessionResponse
|
||||
@@ -445,7 +444,7 @@ async def get_runs(
|
||||
@base_router.get(
|
||||
"/runs/{run_id}",
|
||||
tags=["agent"],
|
||||
response_model=TaskRunResponse,
|
||||
response_model=RunResponse,
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
"x-fern-sdk-method-name": "get_run",
|
||||
@@ -453,13 +452,13 @@ async def get_runs(
|
||||
)
|
||||
@base_router.get(
|
||||
"/runs/{run_id}/",
|
||||
response_model=TaskRunResponse,
|
||||
response_model=RunResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_run(
|
||||
run_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> TaskRunResponse:
|
||||
) -> RunResponse:
|
||||
task_run_response = await task_run_service.get_task_run_response(
|
||||
run_id, organization_id=current_org.organization_id
|
||||
)
|
||||
@@ -683,7 +682,7 @@ async def run_workflow(
|
||||
version=version,
|
||||
)
|
||||
await app.DATABASE.create_task_run(
|
||||
task_run_type=TaskRunType.workflow_run,
|
||||
task_run_type=RunType.workflow_run,
|
||||
organization_id=current_org.organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
title=workflow.title,
|
||||
@@ -1512,7 +1511,7 @@ async def run_task(
|
||||
run_request: TaskRunRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> TaskRunResponse:
|
||||
) -> RunResponse:
|
||||
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
|
||||
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
|
||||
|
||||
@@ -1555,7 +1554,7 @@ async def run_task(
|
||||
background_tasks=background_tasks,
|
||||
)
|
||||
# build the task run response
|
||||
return TaskRunResponse(
|
||||
return RunResponse(
|
||||
run_id=task_v1_response.task_id,
|
||||
title=task_v1_response.title,
|
||||
status=str(task_v1_response.status),
|
||||
@@ -1603,7 +1602,7 @@ async def run_task(
|
||||
max_steps_override=run_request.max_steps,
|
||||
browser_session_id=run_request.browser_session_id,
|
||||
)
|
||||
return TaskRunResponse(
|
||||
return RunResponse(
|
||||
run_id=task_v2.observer_cruise_id,
|
||||
title=run_request.title,
|
||||
status=str(task_v2.status),
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class TaskRunType(StrEnum):
|
||||
task_v1 = "task_v1"
|
||||
task_v2 = "task_v2"
|
||||
workflow_run = "workflow_run"
|
||||
from skyvern.schemas.runs import RunType
|
||||
|
||||
|
||||
class TaskRun(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
task_run_id: str
|
||||
task_run_type: TaskRunType
|
||||
task_run_type: RunType
|
||||
run_id: str
|
||||
organization_id: str | None = None
|
||||
title: str | None = None
|
||||
@@ -1,23 +1,23 @@
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
|
||||
from skyvern.schemas.runs import RunEngine, TaskRunResponse
|
||||
from skyvern.forge.sdk.schemas.runs import TaskRun
|
||||
from skyvern.schemas.runs import RunEngine, RunResponse, RunType
|
||||
|
||||
|
||||
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
|
||||
return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)
|
||||
|
||||
|
||||
async def get_task_run_response(run_id: str, organization_id: str | None = None) -> TaskRunResponse | None:
|
||||
async def get_task_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
||||
task_run = await get_task_run(run_id, organization_id=organization_id)
|
||||
if not task_run:
|
||||
return None
|
||||
|
||||
if task_run.task_run_type == TaskRunType.task_v1:
|
||||
if task_run.task_run_type == RunType.task_v1:
|
||||
# fetch task v1 from db and transform to task run response
|
||||
task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id)
|
||||
if not task_v1:
|
||||
return None
|
||||
return TaskRunResponse(
|
||||
return RunResponse(
|
||||
run_id=task_run.run_id,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
status=task_v1.status,
|
||||
@@ -32,11 +32,11 @@ async def get_task_run_response(run_id: str, organization_id: str | None = None)
|
||||
created_at=task_v1.created_at,
|
||||
modified_at=task_v1.modified_at,
|
||||
)
|
||||
elif task_run.task_run_type == TaskRunType.task_v2:
|
||||
elif task_run.task_run_type == RunType.task_v2:
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
return None
|
||||
return TaskRunResponse(
|
||||
return RunResponse(
|
||||
run_id=task_run.run_id,
|
||||
engine=RunEngine.skyvern_v2,
|
||||
status=task_v2.status,
|
||||
|
||||
Reference in New Issue
Block a user