diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d76b242b..80602648 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -105,8 +105,7 @@ jobs: AWS_REGION: "us-east-1" ENABLE_BEDROCK: "true" - - name: Run the alembic-check pre-commit hook - uses: pre-commit/action@v3.0.0 + - name: Run alembic check env: ENABLE_OPENAI: "true" OPENAI_API_KEY: "sk-dummy" @@ -117,8 +116,7 @@ jobs: AZURE_GPT4O_MINI_API_VERSION: "dummy" AWS_REGION: "us-east-1" ENABLE_BEDROCK: "true" - with: - args: "run --hook-stage manual alembic-check" + run: poetry run ./run_alembic_check.sh - name: trigger tests env: ENABLE_OPENAI: "true" diff --git a/alembic/versions/2025_05_31_0300-babaa7307e8a_add_model_column_to_task_v1_v2_tables.py b/alembic/versions/2025_05_31_0300-babaa7307e8a_add_model_column_to_task_v1_v2_tables.py new file mode 100644 index 00000000..c1464e6f --- /dev/null +++ b/alembic/versions/2025_05_31_0300-babaa7307e8a_add_model_column_to_task_v1_v2_tables.py @@ -0,0 +1,33 @@ +"""add model column to task v1 & v2 tables + +Revision ID: babaa7307e8a +Revises: af49ca791fc7 +Create Date: 2025-05-31 03:00:17.128919+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "babaa7307e8a" +down_revision: Union[str, None] = "af49ca791fc7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("observer_cruises", sa.Column("model", sa.JSON(), nullable=True)) + op.add_column("tasks", sa.Column("model", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tasks", "model") + op.drop_column("observer_cruises", "model") + # ### end Alembic commands ### diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 231998a6..a5d0876f 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -229,6 +229,7 @@ class ForgeAgent: error_code_mapping=task_request.error_code_mapping, application=task_request.application, include_action_history_in_verification=task_request.include_action_history_in_verification, + model=task_request.model, ) LOG.info( "Created new task", diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index abc749fa..5b693f6a 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -140,6 +140,7 @@ class AgentDB: task_type: str = TaskType.general, application: str | None = None, include_action_history_in_verification: bool | None = None, + model: dict[str, Any] | None = None, ) -> Task: try: async with self.Session() as session: @@ -166,6 +167,7 @@ class AgentDB: error_code_mapping=error_code_mapping, application=application, include_action_history_in_verification=include_action_history_in_verification, + model=model, ) session.add(new_task) await session.commit() @@ -2363,6 +2365,7 @@ class AgentDB: webhook_callback_url: str | None = None, extracted_information_schema: dict | list | str | None = None, error_code_mapping: dict | None = None, + model: dict[str, Any] | None = None, ) -> TaskV2: async with self.Session() as session: new_task_v2 = TaskV2Model( @@ -2378,6 +2381,7 @@ class AgentDB: extracted_information_schema=extracted_information_schema, error_code_mapping=error_code_mapping, organization_id=organization_id, + model=model, ) session.add(new_task_v2) await session.commit() diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 9e8a2747..e4d09f0c 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -92,6 +92,7 @@ class TaskModel(Base): nullable=False, index=True, ) + model = Column(JSON, nullable=True) class StepModel(Base): @@ -593,6 +594,7 @@ class TaskV2Model(Base): created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + model = Column(JSON, nullable=True) class ThoughtModel(Base): diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 605cb428..4e5d654f 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -91,6 +91,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p error_code_mapping=task_obj.error_code_mapping, errors=task_obj.errors, application=task_obj.application, + model=task_obj.model, ) return task diff --git a/skyvern/forge/sdk/executor/async_executor.py b/skyvern/forge/sdk/executor/async_executor.py index 2d1b0f30..2adb64e2 100644 --- a/skyvern/forge/sdk/executor/async_executor.py +++ b/skyvern/forge/sdk/executor/async_executor.py @@ -5,6 +5,7 @@ from fastapi import BackgroundTasks, Request from skyvern.exceptions import OrganizationNotFound from skyvern.forge import app +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.schemas.organizations import Organization @@ -105,6 +106,9 @@ class BackgroundTaskExecutor(AsyncExecutor): context.organization_id = organization_id context.max_steps_override = max_steps_override + llm_key = task.llm_key + llm_caller = LLMCaller(llm_key) if llm_key else None + if background_tasks: background_tasks.add_task( app.agent.execute_step, @@ -115,6 +119,7 @@ class BackgroundTaskExecutor(AsyncExecutor): close_browser_on_completion=close_browser_on_completion, browser_session_id=browser_session_id, engine=engine, + llm_caller=llm_caller, ) async def execute_workflow( diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index caa7a62f..ff5e9826 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -168,6 +168,7 @@ async def run_task( totp_verification_url=run_request.totp_url, totp_identifier=run_request.totp_identifier, include_action_history_in_verification=run_request.include_action_history_in_verification, + model=run_request.model, ) task_v1_response = await task_v1_service.run_task( task=task_v1_request, @@ -222,6 +223,7 @@ async def run_task( extracted_information_schema=run_request.data_extraction_schema, error_code_mapping=run_request.error_code_mapping, create_task_run=True, + model=run_request.model, ) except LLMProviderError: LOG.error("LLM failure to initialize task v2", exc_info=True) diff --git a/skyvern/forge/sdk/schemas/task_v2.py b/skyvern/forge/sdk/schemas/task_v2.py index a45387fb..57c8c516 100644 --- a/skyvern/forge/sdk/schemas/task_v2.py +++ b/skyvern/forge/sdk/schemas/task_v2.py @@ -4,6 +4,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_validator +from skyvern.config import settings from skyvern.schemas.runs import ProxyLocation from skyvern.utils.url_validators import validate_url @@ -43,10 +44,28 @@ class TaskV2(BaseModel): webhook_callback_url: str | None = None extracted_information_schema: dict | list | str | None = None error_code_mapping: dict | None = None - + model: dict[str, Any] | None = None created_at: datetime modified_at: datetime + @property + def llm_key(self) -> str | None: + """ + If the `TaskV2` has a `model` defined, then return the mapped llm_key for it. + + Otherwise return `None`. + """ + + if self.model: + model_name = self.model.get("model_name") + if model_name: + mapping = settings.get_model_name_to_llm_key() + llm_key = mapping.get(model_name) + if llm_key: + return llm_key + + return None + @field_validator("url", "webhook_callback_url", "totp_verification_url") @classmethod def validate_urls(cls, url: str | None) -> str | None: diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 4cdefc23..47647984 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -8,6 +8,7 @@ from fastapi import status from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self +from skyvern.config import settings from skyvern.exceptions import ( InvalidTaskStatusTransition, SkyvernHTTPException, @@ -110,6 +111,7 @@ class TaskRequest(TaskBase): ) totp_verification_url: str | None = None browser_session_id: str | None = None + model: dict[str, Any] | None = None @model_validator(mode="after") def validate_url(self) -> Self: @@ -236,6 +238,22 @@ class Task(TaskBase): retry: int | None = None max_steps_per_run: int | None = None errors: list[dict[str, Any]] = [] + model: dict[str, Any] | None = None + + @property + def llm_key(self) -> str | None: + """ + If the `Task` has a `model` defined, then return the mapped llm_key for it. + + Otherwise return `None`. + """ + if self.model: + model_name = self.model.get("model_name") + if model_name: + mapping = settings.get_model_name_to_llm_key() + return mapping.get(model_name) + + return None def validate_update( self, diff --git a/skyvern/schemas/docs/doc_strings.py b/skyvern/schemas/docs/doc_strings.py index 64366c39..9a740653 100644 --- a/skyvern/schemas/docs/doc_strings.py +++ b/skyvern/schemas/docs/doc_strings.py @@ -61,3 +61,7 @@ URL that serves TOTP/2FA/MFA codes for Skyvern to use during the workflow run. R BROWSER_SESSION_ID_DOC_STRING = """ Run the task or workflow in the specific Skyvern browser session. Having a browser session can persist the real-time state of the browser, so that the next run can continue from where the previous run left off. """ + +MODEL_CONFIG = """ +Optional model configuration. +""" diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index 987870d2..7e283ea7 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -21,6 +21,7 @@ from skyvern.schemas.docs.doc_strings import ( DATA_EXTRACTION_SCHEMA_DOC_STRING, ERROR_CODE_MAPPING_DOC_STRING, MAX_STEPS_DOC_STRING, + MODEL_CONFIG, PROXY_LOCATION_DOC_STRING, TASK_ENGINE_DOC_STRING, TASK_PROMPT_DOC_STRING, @@ -263,6 +264,12 @@ class TaskRunRequest(BaseModel): description=BROWSER_SESSION_ID_DOC_STRING, examples=BROWSER_SESSION_ID_EXAMPLES, ) + model: dict[str, Any] | None = Field( + default=None, + description=MODEL_CONFIG, + examples=None, + ) + publish_workflow: bool = Field( default=False, description="Whether to publish this task as a reusable workflow. Only available for skyvern-2.0.", diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index 06484f18..1947d240 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -163,6 +163,7 @@ async def initialize_task_v2( extracted_information_schema: dict | list | str | None = None, error_code_mapping: dict | None = None, create_task_run: bool = False, + model: dict[str, Any] | None = None, ) -> TaskV2: task_v2 = await app.DATABASE.create_task_v2( prompt=user_prompt, @@ -173,6 +174,7 @@ async def initialize_task_v2( proxy_location=proxy_location, extracted_information_schema=extracted_information_schema, error_code_mapping=error_code_mapping, + model=model, ) # set task_v2_id in context context = skyvern_context.current() @@ -620,6 +622,7 @@ async def run_task_v2_helper( screenshots=scraped_page.screenshots, thought=thought, prompt_name="task_v2", + llm_key_override=task_v2.llm_key, ) LOG.info( "Task v2 response", diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 300016b9..13e673d7 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -3368,6 +3368,7 @@ async def extract_information_for_navigation_goal( step=step, screenshots=scraped_page.screenshots, prompt_name="extract-information", + llm_key_override=task.llm_key, ) return ScrapeResult(