add model to Task and TaskV2; expose it to run_task endpoint; thread … (#2540)
This commit is contained in:
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@@ -105,8 +105,7 @@ jobs:
|
||||
AWS_REGION: "us-east-1"
|
||||
ENABLE_BEDROCK: "true"
|
||||
|
||||
- name: Run the alembic-check pre-commit hook
|
||||
uses: pre-commit/action@v3.0.0
|
||||
- name: Run alembic check
|
||||
env:
|
||||
ENABLE_OPENAI: "true"
|
||||
OPENAI_API_KEY: "sk-dummy"
|
||||
@@ -117,8 +116,7 @@ jobs:
|
||||
AZURE_GPT4O_MINI_API_VERSION: "dummy"
|
||||
AWS_REGION: "us-east-1"
|
||||
ENABLE_BEDROCK: "true"
|
||||
with:
|
||||
args: "run --hook-stage manual alembic-check"
|
||||
run: poetry run ./run_alembic_check.sh
|
||||
- name: trigger tests
|
||||
env:
|
||||
ENABLE_OPENAI: "true"
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add model column to task v1 & v2 tables
|
||||
|
||||
Revision ID: babaa7307e8a
|
||||
Revises: af49ca791fc7
|
||||
Create Date: 2025-05-31 03:00:17.128919+00:00
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "babaa7307e8a"
|
||||
down_revision: Union[str, None] = "af49ca791fc7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("observer_cruises", sa.Column("model", sa.JSON(), nullable=True))
|
||||
op.add_column("tasks", sa.Column("model", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("tasks", "model")
|
||||
op.drop_column("observer_cruises", "model")
|
||||
# ### end Alembic commands ###
|
||||
@@ -229,6 +229,7 @@ class ForgeAgent:
|
||||
error_code_mapping=task_request.error_code_mapping,
|
||||
application=task_request.application,
|
||||
include_action_history_in_verification=task_request.include_action_history_in_verification,
|
||||
model=task_request.model,
|
||||
)
|
||||
LOG.info(
|
||||
"Created new task",
|
||||
|
||||
@@ -140,6 +140,7 @@ class AgentDB:
|
||||
task_type: str = TaskType.general,
|
||||
application: str | None = None,
|
||||
include_action_history_in_verification: bool | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> Task:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
@@ -166,6 +167,7 @@ class AgentDB:
|
||||
error_code_mapping=error_code_mapping,
|
||||
application=application,
|
||||
include_action_history_in_verification=include_action_history_in_verification,
|
||||
model=model,
|
||||
)
|
||||
session.add(new_task)
|
||||
await session.commit()
|
||||
@@ -2363,6 +2365,7 @@ class AgentDB:
|
||||
webhook_callback_url: str | None = None,
|
||||
extracted_information_schema: dict | list | str | None = None,
|
||||
error_code_mapping: dict | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> TaskV2:
|
||||
async with self.Session() as session:
|
||||
new_task_v2 = TaskV2Model(
|
||||
@@ -2378,6 +2381,7 @@ class AgentDB:
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
error_code_mapping=error_code_mapping,
|
||||
organization_id=organization_id,
|
||||
model=model,
|
||||
)
|
||||
session.add(new_task_v2)
|
||||
await session.commit()
|
||||
|
||||
@@ -92,6 +92,7 @@ class TaskModel(Base):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model = Column(JSON, nullable=True)
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
@@ -593,6 +594,7 @@ class TaskV2Model(Base):
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
model = Column(JSON, nullable=True)
|
||||
|
||||
|
||||
class ThoughtModel(Base):
|
||||
|
||||
@@ -91,6 +91,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
|
||||
error_code_mapping=task_obj.error_code_mapping,
|
||||
errors=task_obj.errors,
|
||||
application=task_obj.application,
|
||||
model=task_obj.model,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import BackgroundTasks, Request
|
||||
|
||||
from skyvern.exceptions import OrganizationNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
@@ -105,6 +106,9 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
context.organization_id = organization_id
|
||||
context.max_steps_override = max_steps_override
|
||||
|
||||
llm_key = task.llm_key
|
||||
llm_caller = LLMCaller(llm_key) if llm_key else None
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
app.agent.execute_step,
|
||||
@@ -115,6 +119,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
close_browser_on_completion=close_browser_on_completion,
|
||||
browser_session_id=browser_session_id,
|
||||
engine=engine,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
|
||||
async def execute_workflow(
|
||||
|
||||
@@ -168,6 +168,7 @@ async def run_task(
|
||||
totp_verification_url=run_request.totp_url,
|
||||
totp_identifier=run_request.totp_identifier,
|
||||
include_action_history_in_verification=run_request.include_action_history_in_verification,
|
||||
model=run_request.model,
|
||||
)
|
||||
task_v1_response = await task_v1_service.run_task(
|
||||
task=task_v1_request,
|
||||
@@ -222,6 +223,7 @@ async def run_task(
|
||||
extracted_information_schema=run_request.data_extraction_schema,
|
||||
error_code_mapping=run_request.error_code_mapping,
|
||||
create_task_run=True,
|
||||
model=run_request.model,
|
||||
)
|
||||
except LLMProviderError:
|
||||
LOG.error("LLM failure to initialize task v2", exc_info=True)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
from skyvern.utils.url_validators import validate_url
|
||||
|
||||
@@ -43,10 +44,28 @@ class TaskV2(BaseModel):
|
||||
webhook_callback_url: str | None = None
|
||||
extracted_information_schema: dict | list | str | None = None
|
||||
error_code_mapping: dict | None = None
|
||||
|
||||
model: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
@property
|
||||
def llm_key(self) -> str | None:
|
||||
"""
|
||||
If the `TaskV2` has a `model` defined, then return the mapped llm_key for it.
|
||||
|
||||
Otherwise return `None`.
|
||||
"""
|
||||
|
||||
if self.model:
|
||||
model_name = self.model.get("model_name")
|
||||
if model_name:
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
llm_key = mapping.get(model_name)
|
||||
if llm_key:
|
||||
return llm_key
|
||||
|
||||
return None
|
||||
|
||||
@field_validator("url", "webhook_callback_url", "totp_verification_url")
|
||||
@classmethod
|
||||
def validate_urls(cls, url: str | None) -> str | None:
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi import status
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import (
|
||||
InvalidTaskStatusTransition,
|
||||
SkyvernHTTPException,
|
||||
@@ -110,6 +111,7 @@ class TaskRequest(TaskBase):
|
||||
)
|
||||
totp_verification_url: str | None = None
|
||||
browser_session_id: str | None = None
|
||||
model: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_url(self) -> Self:
|
||||
@@ -236,6 +238,22 @@ class Task(TaskBase):
|
||||
retry: int | None = None
|
||||
max_steps_per_run: int | None = None
|
||||
errors: list[dict[str, Any]] = []
|
||||
model: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def llm_key(self) -> str | None:
|
||||
"""
|
||||
If the `Task` has a `model` defined, then return the mapped llm_key for it.
|
||||
|
||||
Otherwise return `None`.
|
||||
"""
|
||||
if self.model:
|
||||
model_name = self.model.get("model_name")
|
||||
if model_name:
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
return mapping.get(model_name)
|
||||
|
||||
return None
|
||||
|
||||
def validate_update(
|
||||
self,
|
||||
|
||||
@@ -61,3 +61,7 @@ URL that serves TOTP/2FA/MFA codes for Skyvern to use during the workflow run. R
|
||||
BROWSER_SESSION_ID_DOC_STRING = """
|
||||
Run the task or workflow in the specific Skyvern browser session. Having a browser session can persist the real-time state of the browser, so that the next run can continue from where the previous run left off.
|
||||
"""
|
||||
|
||||
MODEL_CONFIG = """
|
||||
Optional model configuration.
|
||||
"""
|
||||
|
||||
@@ -21,6 +21,7 @@ from skyvern.schemas.docs.doc_strings import (
|
||||
DATA_EXTRACTION_SCHEMA_DOC_STRING,
|
||||
ERROR_CODE_MAPPING_DOC_STRING,
|
||||
MAX_STEPS_DOC_STRING,
|
||||
MODEL_CONFIG,
|
||||
PROXY_LOCATION_DOC_STRING,
|
||||
TASK_ENGINE_DOC_STRING,
|
||||
TASK_PROMPT_DOC_STRING,
|
||||
@@ -263,6 +264,12 @@ class TaskRunRequest(BaseModel):
|
||||
description=BROWSER_SESSION_ID_DOC_STRING,
|
||||
examples=BROWSER_SESSION_ID_EXAMPLES,
|
||||
)
|
||||
model: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=MODEL_CONFIG,
|
||||
examples=None,
|
||||
)
|
||||
|
||||
publish_workflow: bool = Field(
|
||||
default=False,
|
||||
description="Whether to publish this task as a reusable workflow. Only available for skyvern-2.0.",
|
||||
|
||||
@@ -163,6 +163,7 @@ async def initialize_task_v2(
|
||||
extracted_information_schema: dict | list | str | None = None,
|
||||
error_code_mapping: dict | None = None,
|
||||
create_task_run: bool = False,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> TaskV2:
|
||||
task_v2 = await app.DATABASE.create_task_v2(
|
||||
prompt=user_prompt,
|
||||
@@ -173,6 +174,7 @@ async def initialize_task_v2(
|
||||
proxy_location=proxy_location,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
error_code_mapping=error_code_mapping,
|
||||
model=model,
|
||||
)
|
||||
# set task_v2_id in context
|
||||
context = skyvern_context.current()
|
||||
@@ -620,6 +622,7 @@ async def run_task_v2_helper(
|
||||
screenshots=scraped_page.screenshots,
|
||||
thought=thought,
|
||||
prompt_name="task_v2",
|
||||
llm_key_override=task_v2.llm_key,
|
||||
)
|
||||
LOG.info(
|
||||
"Task v2 response",
|
||||
|
||||
@@ -3368,6 +3368,7 @@ async def extract_information_for_navigation_goal(
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
prompt_name="extract-information",
|
||||
llm_key_override=task.llm_key,
|
||||
)
|
||||
|
||||
return ScrapeResult(
|
||||
|
||||
Reference in New Issue
Block a user