TaskRun pydantic model gets renamed to Run (#2042)

This commit is contained in:
Shuchang Zheng
2025-03-30 18:41:24 -07:00
committed by GitHub
parent 05e28931bc
commit 12ef2100b5
3 changed files with 13 additions and 13 deletions

View File

@@ -66,7 +66,7 @@ from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
from skyvern.forge.sdk.schemas.runs import TaskRun
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
@@ -2789,7 +2789,7 @@ class AgentDB:
title: str | None = None,
url: str | None = None,
url_hash: str | None = None,
) -> TaskRun:
) -> Run:
async with self.Session() as session:
task_run = TaskRunModel(
task_run_type=task_run_type,
@@ -2802,7 +2802,7 @@ class AgentDB:
session.add(task_run)
await session.commit()
await session.refresh(task_run)
return TaskRun.model_validate(task_run)
return Run.model_validate(task_run)
async def create_credential(
self,
@@ -2916,7 +2916,7 @@ class AgentDB:
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
return None
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> TaskRun:
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
async with self.Session() as session:
task_run = (
await session.scalars(
@@ -2927,12 +2927,12 @@ class AgentDB:
task_run.cached = True
await session.commit()
await session.refresh(task_run)
return TaskRun.model_validate(task_run)
raise NotFoundError(f"TaskRun {run_id} not found")
return Run.model_validate(task_run)
raise NotFoundError(f"Run {run_id} not found")
async def get_cached_task_run(
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
) -> TaskRun | None:
) -> Run | None:
async with self.Session() as session:
query = select(TaskRunModel)
if task_run_type:
@@ -2943,16 +2943,16 @@ class AgentDB:
query = query.filter_by(organization_id=organization_id)
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
task_run = (await session.scalars(query)).first()
return TaskRun.model_validate(task_run) if task_run else None
return Run.model_validate(task_run) if task_run else None
async def get_task_run(
self,
run_id: str,
organization_id: str | None = None,
) -> TaskRun | None:
) -> Run | None:
async with self.Session() as session:
query = select(TaskRunModel).filter_by(run_id=run_id)
if organization_id:
query = query.filter_by(organization_id=organization_id)
task_run = (await session.scalars(query)).first()
return TaskRun.model_validate(task_run) if task_run else None
return Run.model_validate(task_run) if task_run else None

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict
from skyvern.schemas.runs import RunType
class TaskRun(BaseModel):
class Run(BaseModel):
model_config = ConfigDict(from_attributes=True)
task_run_id: str

View File

@@ -1,9 +1,9 @@
from skyvern.forge import app
from skyvern.forge.sdk.schemas.runs import TaskRun
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.schemas.runs import RunEngine, RunResponse, RunType
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
async def get_task_run(run_id: str, organization_id: str | None = None) -> Run | None:
return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)