diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index c416d226..6158fa9b 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -66,7 +66,7 @@ from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession -from skyvern.forge.sdk.schemas.runs import TaskRun +from skyvern.forge.sdk.schemas.runs import Run from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus @@ -2789,7 +2789,7 @@ class AgentDB: title: str | None = None, url: str | None = None, url_hash: str | None = None, - ) -> TaskRun: + ) -> Run: async with self.Session() as session: task_run = TaskRunModel( task_run_type=task_run_type, @@ -2802,7 +2802,7 @@ class AgentDB: session.add(task_run) await session.commit() await session.refresh(task_run) - return TaskRun.model_validate(task_run) + return Run.model_validate(task_run) async def create_credential( self, @@ -2916,7 +2916,7 @@ class AgentDB: return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection) return None - async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> TaskRun: + async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run: async with self.Session() as session: task_run = ( await session.scalars( @@ -2927,12 +2927,12 @@ class AgentDB: task_run.cached = True await session.commit() await session.refresh(task_run) - return TaskRun.model_validate(task_run) - raise NotFoundError(f"TaskRun {run_id} not found") + return Run.model_validate(task_run) + raise NotFoundError(f"Run {run_id} not found") async def get_cached_task_run( self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None - ) -> TaskRun | None: + ) -> Run | None: async with self.Session() as session: query = select(TaskRunModel) if task_run_type: @@ -2943,16 +2943,16 @@ class AgentDB: query = query.filter_by(organization_id=organization_id) query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc()) task_run = (await session.scalars(query)).first() - return TaskRun.model_validate(task_run) if task_run else None + return Run.model_validate(task_run) if task_run else None async def get_task_run( self, run_id: str, organization_id: str | None = None, - ) -> TaskRun | None: + ) -> Run | None: async with self.Session() as session: query = select(TaskRunModel).filter_by(run_id=run_id) if organization_id: query = query.filter_by(organization_id=organization_id) task_run = (await session.scalars(query)).first() - return TaskRun.model_validate(task_run) if task_run else None + return Run.model_validate(task_run) if task_run else None diff --git a/skyvern/forge/sdk/schemas/runs.py b/skyvern/forge/sdk/schemas/runs.py index 1361baa7..409baa78 100644 --- a/skyvern/forge/sdk/schemas/runs.py +++ b/skyvern/forge/sdk/schemas/runs.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict from skyvern.schemas.runs import RunType -class TaskRun(BaseModel): +class Run(BaseModel): model_config = ConfigDict(from_attributes=True) task_run_id: str diff --git a/skyvern/forge/sdk/services/task_run_service.py b/skyvern/forge/sdk/services/task_run_service.py index 21002270..e216c3d0 100644 --- a/skyvern/forge/sdk/services/task_run_service.py +++ b/skyvern/forge/sdk/services/task_run_service.py @@ -1,9 +1,9 @@ from skyvern.forge import app -from skyvern.forge.sdk.schemas.runs import TaskRun +from skyvern.forge.sdk.schemas.runs import Run from skyvern.schemas.runs import RunEngine, RunResponse, RunType -async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None: +async def get_task_run(run_id: str, organization_id: str | None = None) -> Run | None: return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)