add get_run endpoint (#1944)
This commit is contained in:
@@ -2905,3 +2905,15 @@ class AgentDB:
|
|||||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||||
task_run = (await session.scalars(query)).first()
|
task_run = (await session.scalars(query)).first()
|
||||||
return TaskRun.model_validate(task_run) if task_run else None
|
return TaskRun.model_validate(task_run) if task_run else None
|
||||||
|
|
||||||
|
async def get_task_run(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
organization_id: str | None = None,
|
||||||
|
) -> TaskRun | None:
|
||||||
|
async with self.Session() as session:
|
||||||
|
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||||
|
if organization_id:
|
||||||
|
query = query.filter_by(organization_id=organization_id)
|
||||||
|
task_run = (await session.scalars(query)).first()
|
||||||
|
return TaskRun.model_validate(task_run) if task_run else None
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ BITWARDEN_SENSITIVE_INFORMATION_PARAMETER_PREFIX = "bsi"
|
|||||||
CREDENTIAL_PARAMETER_PREFIX = "cp"
|
CREDENTIAL_PARAMETER_PREFIX = "cp"
|
||||||
CREDENTIAL_PREFIX = "cred"
|
CREDENTIAL_PREFIX = "cred"
|
||||||
ORGANIZATION_BITWARDEN_COLLECTION_PREFIX = "obc"
|
ORGANIZATION_BITWARDEN_COLLECTION_PREFIX = "obc"
|
||||||
TASK_V2_ID = "oc"
|
TASK_V2_ID = "tsk_v2"
|
||||||
THOUGHT_ID = "ot"
|
THOUGHT_ID = "ot"
|
||||||
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
|
ORGANIZATION_AUTH_TOKEN_PREFIX = "oat"
|
||||||
ORG_PREFIX = "o"
|
ORG_PREFIX = "o"
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from skyvern.forge.sdk.schemas.organizations import (
|
|||||||
OrganizationUpdate,
|
OrganizationUpdate,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
||||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType
|
||||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||||
from skyvern.forge.sdk.schemas.tasks import (
|
from skyvern.forge.sdk.schemas.tasks import (
|
||||||
CreateTaskResponse,
|
CreateTaskResponse,
|
||||||
@@ -57,7 +57,7 @@ from skyvern.forge.sdk.schemas.tasks import (
|
|||||||
TaskStatus,
|
TaskStatus,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline
|
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline
|
||||||
from skyvern.forge.sdk.services import org_auth_service, task_v2_service
|
from skyvern.forge.sdk.services import org_auth_service, task_run_service, task_v2_service
|
||||||
from skyvern.forge.sdk.workflow.exceptions import (
|
from skyvern.forge.sdk.workflow.exceptions import (
|
||||||
FailedToCreateWorkflow,
|
FailedToCreateWorkflow,
|
||||||
FailedToUpdateWorkflow,
|
FailedToUpdateWorkflow,
|
||||||
@@ -391,6 +391,23 @@ async def get_runs(
|
|||||||
return ORJSONResponse([run.model_dump() for run in runs])
|
return ORJSONResponse([run.model_dump() for run in runs])
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.get("/runs/{run_id}", response_model=TaskRunResponse)
|
||||||
|
@base_router.get("/runs/{run_id}/", response_model=TaskRunResponse, include_in_schema=False)
|
||||||
|
async def get_run(
|
||||||
|
run_id: str,
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
) -> TaskRunResponse:
|
||||||
|
task_run_response = await task_run_service.get_task_run_response(
|
||||||
|
run_id, organization_id=current_org.organization_id
|
||||||
|
)
|
||||||
|
if not task_run_response:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Task run not found {run_id}",
|
||||||
|
)
|
||||||
|
return task_run_response
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
|
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
|
||||||
@base_router.get(
|
@base_router.get(
|
||||||
"/tasks/{task_id}/steps/",
|
"/tasks/{task_id}/steps/",
|
||||||
|
|||||||
@@ -3,6 +3,24 @@ from enum import StrEnum
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRunStatus(StrEnum):
|
||||||
|
created = "created"
|
||||||
|
queued = "queued"
|
||||||
|
running = "running"
|
||||||
|
timed_out = "timed_out"
|
||||||
|
failed = "failed"
|
||||||
|
terminated = "terminated"
|
||||||
|
completed = "completed"
|
||||||
|
canceled = "canceled"
|
||||||
|
|
||||||
|
|
||||||
|
class RunEngine(StrEnum):
|
||||||
|
skyvern_v1 = "skyvern-1.0"
|
||||||
|
skyvern_v2 = "skyvern-2.0"
|
||||||
|
|
||||||
|
|
||||||
class TaskRunType(StrEnum):
|
class TaskRunType(StrEnum):
|
||||||
task_v1 = "task_v1"
|
task_v1 = "task_v1"
|
||||||
@@ -22,3 +40,22 @@ class TaskRun(BaseModel):
|
|||||||
cached: bool = False
|
cached: bool = False
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRunResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
engine: RunEngine = RunEngine.skyvern_v1
|
||||||
|
status: TaskRunStatus
|
||||||
|
goal: str | None = None
|
||||||
|
url: str | None = None
|
||||||
|
output: dict | list | str | None = None
|
||||||
|
failure_reason: str | None = None
|
||||||
|
webhook_url: str | None = None
|
||||||
|
totp_identifier: str | None = None
|
||||||
|
totp_url: str | None = None
|
||||||
|
proxy_location: ProxyLocation | None = None
|
||||||
|
error_code_mapping: dict[str, str] | None = None
|
||||||
|
title: str | None = None
|
||||||
|
max_steps: int | None = None
|
||||||
|
created_at: datetime
|
||||||
|
modified_at: datetime
|
||||||
|
|||||||
53
skyvern/forge/sdk/services/task_run_service.py
Normal file
53
skyvern/forge/sdk/services/task_run_service.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
|
||||||
|
return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_run_response(run_id: str, organization_id: str | None = None) -> TaskRunResponse | None:
|
||||||
|
task_run = await get_task_run(run_id, organization_id=organization_id)
|
||||||
|
if not task_run:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if task_run.task_run_type == TaskRunType.task_v1:
|
||||||
|
# fetch task v1 from db and transform to task run response
|
||||||
|
task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id)
|
||||||
|
if not task_v1:
|
||||||
|
return None
|
||||||
|
return TaskRunResponse(
|
||||||
|
run_id=task_run.run_id,
|
||||||
|
engine=RunEngine.skyvern_v1,
|
||||||
|
status=task_v1.status,
|
||||||
|
goal=task_v1.navigation_goal,
|
||||||
|
url=task_v1.url,
|
||||||
|
output=task_v1.extracted_information,
|
||||||
|
failure_reason=task_v1.failure_reason,
|
||||||
|
webhook_url=task_v1.webhook_callback_url,
|
||||||
|
totp_identifier=task_v1.totp_identifier,
|
||||||
|
totp_url=task_v1.totp_verification_url,
|
||||||
|
proxy_location=task_v1.proxy_location,
|
||||||
|
created_at=task_v1.created_at,
|
||||||
|
modified_at=task_v1.modified_at,
|
||||||
|
)
|
||||||
|
elif task_run.task_run_type == TaskRunType.task_v2:
|
||||||
|
task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id)
|
||||||
|
if not task_v2:
|
||||||
|
return None
|
||||||
|
return TaskRunResponse(
|
||||||
|
run_id=task_run.run_id,
|
||||||
|
engine=RunEngine.skyvern_v2,
|
||||||
|
status=task_v2.status,
|
||||||
|
goal=task_v2.prompt,
|
||||||
|
url=task_v2.url,
|
||||||
|
output=task_v2.output,
|
||||||
|
failure_reason=task_v2.failure_reason,
|
||||||
|
webhook_url=task_v2.webhook_callback_url,
|
||||||
|
totp_identifier=task_v2.totp_identifier,
|
||||||
|
totp_url=task_v2.totp_verification_url,
|
||||||
|
proxy_location=task_v2.proxy_location,
|
||||||
|
created_at=task_v2.created_at,
|
||||||
|
modified_at=task_v2.modified_at,
|
||||||
|
)
|
||||||
|
raise ValueError(f"Invalid task run type: {task_run.task_run_type}")
|
||||||
Reference in New Issue
Block a user