From bda119027eb45f7087ee88369f3e095b685d4daf Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 8 Dec 2024 12:43:59 -0800 Subject: [PATCH] add observer cruise creation and completion (#1354) --- ...add_prompt_and_url_to_observer_cruises_.py | 33 ++++++ skyvern/exceptions.py | 2 +- skyvern/forge/sdk/core/validators.py | 13 +-- skyvern/forge/sdk/db/client.py | 103 ++++++++++++++++++ skyvern/forge/sdk/db/models.py | 2 + skyvern/forge/sdk/schemas/observers.py | 6 +- 6 files changed, 145 insertions(+), 14 deletions(-) create mode 100644 alembic/versions/2024_12_08_0532-dc2a8facf0d7_add_prompt_and_url_to_observer_cruises_.py diff --git a/alembic/versions/2024_12_08_0532-dc2a8facf0d7_add_prompt_and_url_to_observer_cruises_.py b/alembic/versions/2024_12_08_0532-dc2a8facf0d7_add_prompt_and_url_to_observer_cruises_.py new file mode 100644 index 00000000..781d72d5 --- /dev/null +++ b/alembic/versions/2024_12_08_0532-dc2a8facf0d7_add_prompt_and_url_to_observer_cruises_.py @@ -0,0 +1,33 @@ +"""add prompt and url to observer_cruises table + +Revision ID: dc2a8facf0d7 +Revises: 8069e38dc1b4 +Create Date: 2024-12-08 05:32:21.240122+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "dc2a8facf0d7" +down_revision: Union[str, None] = "8069e38dc1b4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("observer_cruises", sa.Column("prompt", sa.UnicodeText(), nullable=True)) + op.add_column("observer_cruises", sa.Column("url", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("observer_cruises", "url") + op.drop_column("observer_cruises", "prompt") + # ### end Alembic commands ### diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 96e7f800..522209b7 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -504,7 +504,7 @@ class CachedActionPlanError(SkyvernException): super().__init__(message) -class InvalidUrl(SkyvernException): +class InvalidUrl(SkyvernHTTPException): def __init__(self, url: str) -> None: super().__init__(f"Invalid URL: {url}. Skyvern supports HTTP and HTTPS urls with max 2083 character length.") diff --git a/skyvern/forge/sdk/core/validators.py b/skyvern/forge/sdk/core/validators.py index 7da7c990..054455a3 100644 --- a/skyvern/forge/sdk/core/validators.py +++ b/skyvern/forge/sdk/core/validators.py @@ -1,7 +1,7 @@ import ipaddress from urllib.parse import urlparse -from pydantic import HttpUrl, ValidationError, parse_obj_as +from pydantic import HttpUrl, ValidationError from skyvern.config import settings from skyvern.exceptions import InvalidUrl @@ -27,17 +27,6 @@ def prepend_scheme_and_validate_url(url: str) -> str: return url -def validate_url(url: str) -> str: - try: - if url: - # Use parse_obj_as to validate the string as an HttpUrl - parse_obj_as(HttpUrl, url) - return url - except ValidationError: - # Handle the validation error - raise InvalidUrl(url=url) - - def is_blocked_host(host: str) -> bool: try: ip = ipaddress.ip_address(host) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 8008988b..bf7286ae 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -19,6 +19,8 @@ from skyvern.forge.sdk.db.models import ( BitwardenCreditCardDataParameterModel, BitwardenLoginCredentialParameterModel, BitwardenSensitiveInformationParameterModel, + ObserverCruiseModel, + ObserverThoughtModel, OrganizationAuthTokenModel, OrganizationModel, OutputParameterModel, @@ -50,6 +52,7 @@ from skyvern.forge.sdk.db.utils import ( convert_to_workflow_run_parameter, ) from skyvern.forge.sdk.models import Step, StepStatus +from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverThought from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus @@ -1729,3 +1732,103 @@ class AgentDB: ) await session.execute(stmt) await session.commit() + + async def get_observer_cruise( + self, observer_cruise_id: str, organization_id: str | None = None + ) -> ObserverCruise | None: + async with self.Session() as session: + if observer_cruise := ( + await session.scalars( + select(ObserverCruiseModel) + .filter_by(observer_cruise_id=observer_cruise_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + return ObserverCruise.model_validate(observer_cruise) + return None + + async def get_observer_thought( + self, observer_thought_id: str, organization_id: str | None = None + ) -> ObserverThought | None: + async with self.Session() as session: + if observer_thought := ( + await session.scalars( + select(ObserverThoughtModel) + .filter_by(observer_thought_id=observer_thought_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + return ObserverThought.model_validate(observer_thought) + return None + + async def create_observer_cruise( + self, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + prompt: str | None = None, + url: str | None = None, + organization_id: str | None = None, + ) -> ObserverCruise: + async with self.Session() as session: + new_observer_cruise = ObserverCruiseModel( + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + prompt=prompt, + url=url, + organization_id=organization_id, + ) + session.add(new_observer_cruise) + await session.commit() + await session.refresh(new_observer_cruise) + return ObserverCruise.model_validate(new_observer_cruise) + + async def create_observer_thought( + self, + observer_cruise_id: str, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_run_block_id: str | None = None, + user_input: str | None = None, + observation: str | None = None, + thought: str | None = None, + answer: str | None = None, + organization_id: str | None = None, + ) -> ObserverThought: + async with self.Session() as session: + new_observer_thought = ObserverThoughtModel( + observer_cruise_id=observer_cruise_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_run_block_id=workflow_run_block_id, + user_input=user_input, + observation=observation, + thought=thought, + answer=answer, + organization_id=organization_id, + ) + session.add(new_observer_thought) + await session.commit() + await session.refresh(new_observer_thought) + return ObserverThought.model_validate(new_observer_thought) + + async def update_observer_cruise( + self, + observer_cruise_id: str, + status: ObserverCruiseStatus | None = None, + organization_id: str | None = None, + ) -> ObserverCruise: + async with self.Session() as session: + observer_cruise = ( + await session.scalars( + select(ObserverCruiseModel) + .filter_by(observer_cruise_id=observer_cruise_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if observer_cruise: + if status: + observer_cruise.status = status + await session.commit() + await session.refresh(observer_cruise) + return ObserverCruise.model_validate(observer_cruise) + raise NotFoundError(f"ObserverCruise {observer_cruise_id} not found") diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 6f97c330..19e1ba01 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -515,6 +515,8 @@ class ObserverCruiseModel(Base): organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True) workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True) workflow_id = Column(String, ForeignKey("workflows.workflow_id"), nullable=True) + prompt = Column(UnicodeText, nullable=True) + url = Column(String, nullable=True) class ObserverThoughtModel(Base): diff --git a/skyvern/forge/sdk/schemas/observers.py b/skyvern/forge/sdk/schemas/observers.py index 1fa9f073..a8ada135 100644 --- a/skyvern/forge/sdk/schemas/observers.py +++ b/skyvern/forge/sdk/schemas/observers.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import StrEnum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, HttpUrl class ObserverCruiseStatus(StrEnum): @@ -23,12 +23,16 @@ class ObserverCruise(BaseModel): organization_id: str | None = None workflow_run_id: str | None = None workflow_id: str | None = None + prompt: str | None = None + url: HttpUrl | None = None created_at: datetime modified_at: datetime class ObserverThought(BaseModel): + model_config = ConfigDict(from_attributes=True) + observer_thought_id: str observer_cruise_id: str organization_id: str | None = None