From acce1c869d79e3337423a9b2f4c25ce50fa5e0b7 Mon Sep 17 00:00:00 2001 From: Marc Kelechava Date: Mon, 1 Dec 2025 16:08:36 -0800 Subject: [PATCH] [Backend] Fix - Task V2 conversion changes to make new proxy city/state feature work (#4153) --- skyvern/forge/sdk/db/client.py | 9 ++++--- skyvern/forge/sdk/db/utils.py | 11 ++++++++ skyvern/forge/sdk/schemas/task_v2.py | 39 +++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index e5bfbcb7..dd1bc3c8 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -66,6 +66,7 @@ from skyvern.forge.sdk.db.utils import ( convert_to_script_file, convert_to_step, convert_to_task, + convert_to_task_v2, convert_to_workflow, convert_to_workflow_parameter, convert_to_workflow_run, @@ -3524,7 +3525,7 @@ class AgentDB: .filter_by(organization_id=organization_id) ) ).first(): - return TaskV2.model_validate(task_v2) + return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled) return None async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None: @@ -3551,7 +3552,7 @@ class AgentDB: .filter_by(workflow_run_id=workflow_run_id) ) ).first(): - return TaskV2.model_validate(task_v2) + return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled) return None async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None: @@ -3628,7 +3629,7 @@ class AgentDB: session.add(new_task_v2) await session.commit() await session.refresh(new_task_v2) - return TaskV2.model_validate(new_task_v2) + return convert_to_task_v2(new_task_v2, debug_enabled=self.debug_enabled) async def create_thought( self, @@ -3784,7 +3785,7 @@ class AgentDB: task_v2.webhook_failure_reason = webhook_failure_reason await session.commit() await session.refresh(task_v2) - return TaskV2.model_validate(task_v2) + return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled) raise NotFoundError(f"TaskV2 {task_v2_id} not found") async def create_workflow_run_block( diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index a7106717..24260dd3 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -20,6 +20,7 @@ from skyvern.forge.sdk.db.models import ( ScriptModel, StepModel, TaskModel, + TaskV2Model, WorkflowModel, WorkflowParameterModel, WorkflowRunBlockModel, @@ -36,6 +37,7 @@ from skyvern.forge.sdk.schemas.organizations import ( Organization, OrganizationAuthToken, ) +from skyvern.forge.sdk.schemas.task_v2 import TaskV2 from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock from skyvern.forge.sdk.workflow.models.parameter import ( @@ -204,6 +206,15 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p return task +def convert_to_task_v2(task_v2_model: TaskV2Model, debug_enabled: bool = False) -> TaskV2: + if debug_enabled: + LOG.debug("Converting TaskV2Model to TaskV2", observer_cruise_id=task_v2_model.observer_cruise_id) + task_v2_data = {column.name: getattr(task_v2_model, column.name) for column in TaskV2Model.__table__.columns} + # Deserialize proxy_location FIRST (string → GeoTarget), otherwise model_validate will fail for city/state proxy selections + task_v2_data["proxy_location"] = _deserialize_proxy_location(task_v2_model.proxy_location) + return TaskV2.model_validate(task_v2_data) + + def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step: if debug_enabled: LOG.debug("Converting StepModel to Step", step_id=step_model.step_id) diff --git a/skyvern/forge/sdk/schemas/task_v2.py b/skyvern/forge/sdk/schemas/task_v2.py index afaacaaf..d0393de8 100644 --- a/skyvern/forge/sdk/schemas/task_v2.py +++ b/skyvern/forge/sdk/schemas/task_v2.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from enum import StrEnum from typing import Any @@ -5,7 +6,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_validator from skyvern.config import settings -from skyvern.schemas.runs import ProxyLocationInput +from skyvern.schemas.runs import GeoTarget, ProxyLocation, ProxyLocationInput from skyvern.utils.url_validators import validate_url DEFAULT_WORKFLOW_TITLE = "New Workflow" @@ -57,6 +58,32 @@ class TaskV2(BaseModel): created_at: datetime modified_at: datetime + @staticmethod + def _parse_proxy_location(proxy_location: ProxyLocationInput | str) -> ProxyLocationInput: + """Handle JSON strings that were persisted to the DB.""" + if proxy_location is None or isinstance(proxy_location, (ProxyLocation, GeoTarget, dict)): + return proxy_location + + if isinstance(proxy_location, str): + stripped = proxy_location.strip() + if not stripped: + return None + + if stripped.startswith("{"): + try: + parsed = json.loads(stripped) + if isinstance(parsed, dict): + return GeoTarget.model_validate(parsed) + except (json.JSONDecodeError, ValueError): + pass + + try: + return ProxyLocation(stripped) + except ValueError: + return None + + return proxy_location + @property def llm_key(self) -> str | None: """ @@ -83,6 +110,11 @@ class TaskV2(BaseModel): return validate_url(url) + @field_validator("proxy_location", mode="before") + @classmethod + def deserialize_proxy_location(cls, proxy_location: ProxyLocationInput | str) -> ProxyLocationInput: + return cls._parse_proxy_location(proxy_location) + class ThoughtType(StrEnum): plan = "plan" @@ -166,3 +198,8 @@ class TaskV2Request(BaseModel): return url return validate_url(url) + + @field_validator("proxy_location", mode="before") + @classmethod + def deserialize_proxy_location(cls, proxy_location: ProxyLocationInput | str) -> ProxyLocationInput: + return TaskV2._parse_proxy_location(proxy_location)