backend changes extracted from codex/jon/SKY-5016 (#2508)
This commit is contained in:
@@ -46,7 +46,7 @@ from skyvern.forge.sdk.api.files import (
|
||||
download_from_s3,
|
||||
get_path_for_workflow_download_directory,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
@@ -126,6 +126,7 @@ class Block(BaseModel, abc.ABC):
|
||||
block_type: BlockType
|
||||
output_parameter: OutputParameter
|
||||
continue_on_failure: bool = False
|
||||
model: dict[str, Any] | None = None
|
||||
|
||||
async def record_output_parameter_value(
|
||||
self,
|
||||
@@ -618,6 +619,9 @@ class BaseTaskBlock(Block):
|
||||
try:
|
||||
current_context = skyvern_context.ensure_context()
|
||||
current_context.task_id = task.task_id
|
||||
llm_key = workflow.determine_llm_key(block=self)
|
||||
llm_caller = None if not llm_key else LLMCaller(llm_key=llm_key)
|
||||
|
||||
await app.agent.execute_step(
|
||||
organization=organization,
|
||||
task=task,
|
||||
@@ -627,6 +631,7 @@ class BaseTaskBlock(Block):
|
||||
close_browser_on_completion=browser_session_id is None,
|
||||
complete_verification=self.complete_verification,
|
||||
engine=self.engine,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
except Exception as e:
|
||||
# Make sure the task is marked as failed in the database before raising the exception
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Any, List
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2
|
||||
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||
from skyvern.forge.sdk.workflow.models.block import Block, BlockTypeVar
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
from skyvern.utils.url_validators import validate_url
|
||||
@@ -74,12 +75,41 @@ class Workflow(BaseModel):
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
persist_browser_session: bool = False
|
||||
model: dict[str, Any] | None = None
|
||||
status: WorkflowStatus = WorkflowStatus.published
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
def determine_llm_key(self, *, block: Block | None = None) -> str | None:
|
||||
"""
|
||||
Determine the LLM key override to use for a block, if it has one.
|
||||
|
||||
It has one if:
|
||||
- it defines one, or
|
||||
- the workflow it is a part of (if applicable) defines one
|
||||
"""
|
||||
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
|
||||
if block:
|
||||
model_name = (block.model or {}).get("model")
|
||||
|
||||
if model_name:
|
||||
llm_key = mapping.get(model_name)
|
||||
if llm_key:
|
||||
return llm_key
|
||||
|
||||
workflow_model_name = (self.model or {}).get("model")
|
||||
|
||||
if workflow_model_name:
|
||||
llm_key = mapping.get(workflow_model_name)
|
||||
if llm_key:
|
||||
return llm_key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowRunStatus(StrEnum):
|
||||
created = "created"
|
||||
|
||||
@@ -117,6 +117,7 @@ class BlockYAML(BaseModel, abc.ABC):
|
||||
block_type: BlockType
|
||||
label: str
|
||||
continue_on_failure: bool = False
|
||||
model: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskBlockYAML(BlockYAML):
|
||||
@@ -413,6 +414,7 @@ class WorkflowCreateYAMLRequest(BaseModel):
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
persist_browser_session: bool = False
|
||||
model: dict[str, Any] | None = None
|
||||
workflow_definition: WorkflowDefinitionYAML
|
||||
is_saved_task: bool = False
|
||||
status: WorkflowStatus = WorkflowStatus.published
|
||||
|
||||
Reference in New Issue
Block a user