[Backend] Fix - Task V2 conversion changes to make new proxy city/state feature work (#4153)
This commit is contained in:
@@ -66,6 +66,7 @@ from skyvern.forge.sdk.db.utils import (
|
|||||||
convert_to_script_file,
|
convert_to_script_file,
|
||||||
convert_to_step,
|
convert_to_step,
|
||||||
convert_to_task,
|
convert_to_task,
|
||||||
|
convert_to_task_v2,
|
||||||
convert_to_workflow,
|
convert_to_workflow,
|
||||||
convert_to_workflow_parameter,
|
convert_to_workflow_parameter,
|
||||||
convert_to_workflow_run,
|
convert_to_workflow_run,
|
||||||
@@ -3524,7 +3525,7 @@ class AgentDB:
|
|||||||
.filter_by(organization_id=organization_id)
|
.filter_by(organization_id=organization_id)
|
||||||
)
|
)
|
||||||
).first():
|
).first():
|
||||||
return TaskV2.model_validate(task_v2)
|
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None:
|
async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None:
|
||||||
@@ -3551,7 +3552,7 @@ class AgentDB:
|
|||||||
.filter_by(workflow_run_id=workflow_run_id)
|
.filter_by(workflow_run_id=workflow_run_id)
|
||||||
)
|
)
|
||||||
).first():
|
).first():
|
||||||
return TaskV2.model_validate(task_v2)
|
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None:
|
async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None:
|
||||||
@@ -3628,7 +3629,7 @@ class AgentDB:
|
|||||||
session.add(new_task_v2)
|
session.add(new_task_v2)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(new_task_v2)
|
await session.refresh(new_task_v2)
|
||||||
return TaskV2.model_validate(new_task_v2)
|
return convert_to_task_v2(new_task_v2, debug_enabled=self.debug_enabled)
|
||||||
|
|
||||||
async def create_thought(
|
async def create_thought(
|
||||||
self,
|
self,
|
||||||
@@ -3784,7 +3785,7 @@ class AgentDB:
|
|||||||
task_v2.webhook_failure_reason = webhook_failure_reason
|
task_v2.webhook_failure_reason = webhook_failure_reason
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(task_v2)
|
await session.refresh(task_v2)
|
||||||
return TaskV2.model_validate(task_v2)
|
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||||
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
|
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
|
||||||
|
|
||||||
async def create_workflow_run_block(
|
async def create_workflow_run_block(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from skyvern.forge.sdk.db.models import (
|
|||||||
ScriptModel,
|
ScriptModel,
|
||||||
StepModel,
|
StepModel,
|
||||||
TaskModel,
|
TaskModel,
|
||||||
|
TaskV2Model,
|
||||||
WorkflowModel,
|
WorkflowModel,
|
||||||
WorkflowParameterModel,
|
WorkflowParameterModel,
|
||||||
WorkflowRunBlockModel,
|
WorkflowRunBlockModel,
|
||||||
@@ -36,6 +37,7 @@ from skyvern.forge.sdk.schemas.organizations import (
|
|||||||
Organization,
|
Organization,
|
||||||
OrganizationAuthToken,
|
OrganizationAuthToken,
|
||||||
)
|
)
|
||||||
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||||
@@ -204,6 +206,15 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
|
|||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_task_v2(task_v2_model: TaskV2Model, debug_enabled: bool = False) -> TaskV2:
|
||||||
|
if debug_enabled:
|
||||||
|
LOG.debug("Converting TaskV2Model to TaskV2", observer_cruise_id=task_v2_model.observer_cruise_id)
|
||||||
|
task_v2_data = {column.name: getattr(task_v2_model, column.name) for column in TaskV2Model.__table__.columns}
|
||||||
|
# Deserialize proxy_location FIRST (string → GeoTarget), otherwise model_validate will fail for city/state proxy selections
|
||||||
|
task_v2_data["proxy_location"] = _deserialize_proxy_location(task_v2_model.proxy_location)
|
||||||
|
return TaskV2.model_validate(task_v2_data)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
|
||||||
if debug_enabled:
|
if debug_enabled:
|
||||||
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
|
LOG.debug("Converting StepModel to Step", step_id=step_model.step_id)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -5,7 +6,7 @@ from typing import Any
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.schemas.runs import ProxyLocationInput
|
from skyvern.schemas.runs import GeoTarget, ProxyLocation, ProxyLocationInput
|
||||||
from skyvern.utils.url_validators import validate_url
|
from skyvern.utils.url_validators import validate_url
|
||||||
|
|
||||||
DEFAULT_WORKFLOW_TITLE = "New Workflow"
|
DEFAULT_WORKFLOW_TITLE = "New Workflow"
|
||||||
@@ -57,6 +58,32 @@ class TaskV2(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_proxy_location(proxy_location: ProxyLocationInput | str) -> ProxyLocationInput:
|
||||||
|
"""Handle JSON strings that were persisted to the DB."""
|
||||||
|
if proxy_location is None or isinstance(proxy_location, (ProxyLocation, GeoTarget, dict)):
|
||||||
|
return proxy_location
|
||||||
|
|
||||||
|
if isinstance(proxy_location, str):
|
||||||
|
stripped = proxy_location.strip()
|
||||||
|
if not stripped:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if stripped.startswith("{"):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(stripped)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return GeoTarget.model_validate(parsed)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ProxyLocation(stripped)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return proxy_location
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llm_key(self) -> str | None:
|
def llm_key(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
@@ -83,6 +110,11 @@ class TaskV2(BaseModel):
|
|||||||
|
|
||||||
return validate_url(url)
|
return validate_url(url)
|
||||||
|
|
||||||
|
@field_validator("proxy_location", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def deserialize_proxy_location(cls, proxy_location: ProxyLocationInput | str) -> ProxyLocationInput:
|
||||||
|
return cls._parse_proxy_location(proxy_location)
|
||||||
|
|
||||||
|
|
||||||
class ThoughtType(StrEnum):
|
class ThoughtType(StrEnum):
|
||||||
plan = "plan"
|
plan = "plan"
|
||||||
@@ -166,3 +198,8 @@ class TaskV2Request(BaseModel):
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
return validate_url(url)
|
return validate_url(url)
|
||||||
|
|
||||||
|
@field_validator("proxy_location", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def deserialize_proxy_location(cls, proxy_location: ProxyLocationInput | str) -> ProxyLocationInput:
|
||||||
|
return TaskV2._parse_proxy_location(proxy_location)
|
||||||
|
|||||||
Reference in New Issue
Block a user