add model to Task and TaskV2; expose it to run_task endpoint; thread … (#2540)
This commit is contained in:
@@ -8,6 +8,7 @@ from fastapi import status
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import (
|
||||
InvalidTaskStatusTransition,
|
||||
SkyvernHTTPException,
|
||||
@@ -110,6 +111,7 @@ class TaskRequest(TaskBase):
|
||||
)
|
||||
totp_verification_url: str | None = None
|
||||
browser_session_id: str | None = None
|
||||
model: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_url(self) -> Self:
|
||||
@@ -236,6 +238,22 @@ class Task(TaskBase):
|
||||
retry: int | None = None
|
||||
max_steps_per_run: int | None = None
|
||||
errors: list[dict[str, Any]] = []
|
||||
model: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def llm_key(self) -> str | None:
|
||||
"""
|
||||
If the `Task` has a `model` defined, then return the mapped llm_key for it.
|
||||
|
||||
Otherwise return `None`.
|
||||
"""
|
||||
if self.model:
|
||||
model_name = self.model.get("model_name")
|
||||
if model_name:
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
return mapping.get(model_name)
|
||||
|
||||
return None
|
||||
|
||||
def validate_update(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user