add model to Task and TaskV2; expose it to run_task endpoint; thread … (#2540)

This commit is contained in:
Shuchang Zheng
2025-05-30 20:07:12 -07:00
committed by GitHub
parent aee129a0a8
commit 2ed14f42e7
14 changed files with 103 additions and 5 deletions

View File

@@ -140,6 +140,7 @@ class AgentDB:
task_type: str = TaskType.general,
application: str | None = None,
include_action_history_in_verification: bool | None = None,
model: dict[str, Any] | None = None,
) -> Task:
try:
async with self.Session() as session:
@@ -166,6 +167,7 @@ class AgentDB:
error_code_mapping=error_code_mapping,
application=application,
include_action_history_in_verification=include_action_history_in_verification,
model=model,
)
session.add(new_task)
await session.commit()
@@ -2363,6 +2365,7 @@ class AgentDB:
webhook_callback_url: str | None = None,
extracted_information_schema: dict | list | str | None = None,
error_code_mapping: dict | None = None,
model: dict[str, Any] | None = None,
) -> TaskV2:
async with self.Session() as session:
new_task_v2 = TaskV2Model(
@@ -2378,6 +2381,7 @@ class AgentDB:
extracted_information_schema=extracted_information_schema,
error_code_mapping=error_code_mapping,
organization_id=organization_id,
model=model,
)
session.add(new_task_v2)
await session.commit()

View File

@@ -92,6 +92,7 @@ class TaskModel(Base):
nullable=False,
index=True,
)
model = Column(JSON, nullable=True)
class StepModel(Base):
@@ -593,6 +594,7 @@ class TaskV2Model(Base):
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
model = Column(JSON, nullable=True)
class ThoughtModel(Base):

View File

@@ -91,6 +91,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
error_code_mapping=task_obj.error_code_mapping,
errors=task_obj.errors,
application=task_obj.application,
model=task_obj.model,
)
return task

View File

@@ -5,6 +5,7 @@ from fastapi import BackgroundTasks, Request
from skyvern.exceptions import OrganizationNotFound
from skyvern.forge import app
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.organizations import Organization
@@ -105,6 +106,9 @@ class BackgroundTaskExecutor(AsyncExecutor):
context.organization_id = organization_id
context.max_steps_override = max_steps_override
llm_key = task.llm_key
llm_caller = LLMCaller(llm_key) if llm_key else None
if background_tasks:
background_tasks.add_task(
app.agent.execute_step,
@@ -115,6 +119,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
close_browser_on_completion=close_browser_on_completion,
browser_session_id=browser_session_id,
engine=engine,
llm_caller=llm_caller,
)
async def execute_workflow(

View File

@@ -168,6 +168,7 @@ async def run_task(
totp_verification_url=run_request.totp_url,
totp_identifier=run_request.totp_identifier,
include_action_history_in_verification=run_request.include_action_history_in_verification,
model=run_request.model,
)
task_v1_response = await task_v1_service.run_task(
task=task_v1_request,
@@ -222,6 +223,7 @@ async def run_task(
extracted_information_schema=run_request.data_extraction_schema,
error_code_mapping=run_request.error_code_mapping,
create_task_run=True,
model=run_request.model,
)
except LLMProviderError:
LOG.error("LLM failure to initialize task v2", exc_info=True)

View File

@@ -4,6 +4,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
from skyvern.config import settings
from skyvern.schemas.runs import ProxyLocation
from skyvern.utils.url_validators import validate_url
@@ -43,10 +44,28 @@ class TaskV2(BaseModel):
webhook_callback_url: str | None = None
extracted_information_schema: dict | list | str | None = None
error_code_mapping: dict | None = None
model: dict[str, Any] | None = None
created_at: datetime
modified_at: datetime
@property
def llm_key(self) -> str | None:
"""
If the `TaskV2` has a `model` defined, then return the mapped llm_key for it.
Otherwise return `None`.
"""
if self.model:
model_name = self.model.get("model_name")
if model_name:
mapping = settings.get_model_name_to_llm_key()
llm_key = mapping.get(model_name)
if llm_key:
return llm_key
return None
@field_validator("url", "webhook_callback_url", "totp_verification_url")
@classmethod
def validate_urls(cls, url: str | None) -> str | None:

View File

@@ -8,6 +8,7 @@ from fastapi import status
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from skyvern.config import settings
from skyvern.exceptions import (
InvalidTaskStatusTransition,
SkyvernHTTPException,
@@ -110,6 +111,7 @@ class TaskRequest(TaskBase):
)
totp_verification_url: str | None = None
browser_session_id: str | None = None
model: dict[str, Any] | None = None
@model_validator(mode="after")
def validate_url(self) -> Self:
@@ -236,6 +238,22 @@ class Task(TaskBase):
retry: int | None = None
max_steps_per_run: int | None = None
errors: list[dict[str, Any]] = []
model: dict[str, Any] | None = None
@property
def llm_key(self) -> str | None:
"""
If the `Task` has a `model` defined, then return the mapped llm_key for it.
Otherwise return `None`.
"""
if self.model:
model_name = self.model.get("model_name")
if model_name:
mapping = settings.get_model_name_to_llm_key()
return mapping.get(model_name)
return None
def validate_update(
self,