From 5796de73d1b16b70406ae4995a7a9cc75da3e19d Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 8 Jan 2025 21:45:38 -0800 Subject: [PATCH] Add AI suggestion endpoints (#1519) --- ...1-d5640aa644b9_add_ai_suggestions_table.py | 47 +++++++++++++++++++ .../prompts/skyvern/suggest-data-schema.j2 | 35 ++++++++++++++ .../forge/sdk/api/llm/api_handler_factory.py | 14 ++++++ skyvern/forge/sdk/api/llm/models.py | 2 + skyvern/forge/sdk/artifact/manager.py | 36 ++++++++++++++ skyvern/forge/sdk/artifact/models.py | 1 + skyvern/forge/sdk/artifact/storage/base.py | 7 +++ skyvern/forge/sdk/artifact/storage/local.py | 7 +++ skyvern/forge/sdk/artifact/storage/s3.py | 7 +++ skyvern/forge/sdk/db/client.py | 19 ++++++++ skyvern/forge/sdk/db/id.py | 6 +++ skyvern/forge/sdk/db/models.py | 12 +++++ skyvern/forge/sdk/routes/agent_protocol.py | 33 +++++++++++++ skyvern/forge/sdk/schemas/ai_suggestions.py | 22 +++++++++ 14 files changed, 248 insertions(+) create mode 100644 alembic/versions/2025_01_09_0541-d5640aa644b9_add_ai_suggestions_table.py create mode 100644 skyvern/forge/prompts/skyvern/suggest-data-schema.j2 create mode 100644 skyvern/forge/sdk/schemas/ai_suggestions.py diff --git a/alembic/versions/2025_01_09_0541-d5640aa644b9_add_ai_suggestions_table.py b/alembic/versions/2025_01_09_0541-d5640aa644b9_add_ai_suggestions_table.py new file mode 100644 index 00000000..2a342fa2 --- /dev/null +++ b/alembic/versions/2025_01_09_0541-d5640aa644b9_add_ai_suggestions_table.py @@ -0,0 +1,47 @@ +"""add ai_suggestions table + +Revision ID: d5640aa644b9 +Revises: d47a586d7036 +Create Date: 2025-01-09 05:41:43.872901+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d5640aa644b9" +down_revision: Union[str, None] = "d47a586d7036" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "ai_suggestions", + sa.Column("ai_suggestion_id", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=True), + sa.Column("ai_suggestion_type", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("modified_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.organization_id"], + ), + sa.PrimaryKeyConstraint("ai_suggestion_id"), + ) + op.add_column("artifacts", sa.Column("ai_suggestion_id", sa.String(), nullable=True)) + op.create_index(op.f("ix_artifacts_ai_suggestion_id"), "artifacts", ["ai_suggestion_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_artifacts_ai_suggestion_id"), table_name="artifacts") + op.drop_column("artifacts", "ai_suggestion_id") + op.drop_table("ai_suggestions") + # ### end Alembic commands ### diff --git a/skyvern/forge/prompts/skyvern/suggest-data-schema.j2 b/skyvern/forge/prompts/skyvern/suggest-data-schema.j2 new file mode 100644 index 00000000..b8ea050b --- /dev/null +++ b/skyvern/forge/prompts/skyvern/suggest-data-schema.j2 @@ -0,0 +1,35 @@ +You are given an input string from a user. This string is a data extraction goal for an AI agent. It tells the agent what to do on a web page. + +A data extraction goal describes what data to extract from the page. + +Your goal when given an input data extraction goal is to provide a JSONC schema describing a shape for the data to be extracted. + +Good data schema examples: + +Input data extraction goal: "Extract the title and link of the top post on Hacker News." +Suggested Data Schema: +```json +{ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "The title of the top post on Hacker News." + }, + "link": { + "type": "string", + "format": "uri", + "description": "The URL link to the top post on Hacker News." + } + }, + "required": [ + "title", + "link" + ] +} +``` + +Respond only with JSON output containing a single key "output" with the value of the suggested data schema given the following input data extraction goal: +``` +{{ input }} +``` diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 4750365d..b0b08026 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -22,6 +22,7 @@ from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_resp from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought LOG = structlog.get_logger() @@ -63,6 +64,7 @@ class LLMAPIHandlerFactory: step: Step | None = None, observer_cruise: ObserverCruise | None = None, observer_thought: ObserverThought | None = None, + ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, ) -> dict[str, Any]: @@ -89,6 +91,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) await app.ARTIFACT_MANAGER.create_llm_artifact( @@ -113,6 +116,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) try: response = await router.acompletion(model=main_model_group, messages=messages, **parameters) @@ -140,6 +144,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) if step: llm_cost = litellm.completion_cost(completion_response=response) @@ -160,6 +165,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) if context and len(context.hashed_href_map) > 0: @@ -172,6 +178,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) return parsed_response @@ -192,6 +199,7 @@ class LLMAPIHandlerFactory: step: Step | None = None, observer_cruise: ObserverCruise | None = None, observer_thought: ObserverThought | None = None, + ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, ) -> dict[str, Any]: @@ -211,6 +219,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) await app.ARTIFACT_MANAGER.create_llm_artifact( @@ -220,6 +229,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) if not llm_config.supports_vision: @@ -239,6 +249,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) t_llm_request = time.perf_counter() try: @@ -274,6 +285,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) if step: @@ -295,6 +307,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) if context and len(context.hashed_href_map) > 0: @@ -307,6 +320,7 @@ class LLMAPIHandlerFactory: step=step, observer_cruise=observer_cruise, observer_thought=observer_thought, + ai_suggestion=ai_suggestion, ) return parsed_response diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index c913cc0c..5595e240 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -4,6 +4,7 @@ from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict from litellm import AllowedFailsPolicy from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought from skyvern.forge.sdk.settings_manager import SettingsManager @@ -81,6 +82,7 @@ class LLMAPIHandler(Protocol): step: Step | None = None, observer_cruise: ObserverCruise | None = None, observer_thought: ObserverThought | None = None, + ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, ) -> Awaitable[dict[str, Any]]: ... diff --git a/skyvern/forge/sdk/artifact/manager.py b/skyvern/forge/sdk/artifact/manager.py index 6dc01ad9..d8cca3b6 100644 --- a/skyvern/forge/sdk/artifact/manager.py +++ b/skyvern/forge/sdk/artifact/manager.py @@ -8,6 +8,7 @@ from skyvern.forge import app from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.db.id import generate_artifact_id from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock @@ -30,6 +31,7 @@ class ArtifactManager: workflow_run_block_id: str | None = None, observer_thought_id: str | None = None, observer_cruise_id: str | None = None, + ai_suggestion_id: str | None = None, organization_id: str | None = None, data: bytes | None = None, path: str | None = None, @@ -49,6 +51,7 @@ class ArtifactManager: observer_thought_id=observer_thought_id, observer_cruise_id=observer_cruise_id, organization_id=organization_id, + ai_suggestion_id=ai_suggestion_id, ) if data: # Fire and forget @@ -173,6 +176,26 @@ class ArtifactManager: path=path, ) + async def create_ai_suggestion_artifact( + self, + ai_suggestion: AISuggestion, + artifact_type: ArtifactType, + data: bytes | None = None, + path: str | None = None, + ) -> str: + artifact_id = generate_artifact_id() + uri = app.STORAGE.build_ai_suggestion_uri(artifact_id, ai_suggestion, artifact_type) + return await self._create_artifact( + aio_task_primary_key=ai_suggestion.ai_suggestion_id, + artifact_id=artifact_id, + artifact_type=artifact_type, + uri=uri, + ai_suggestion_id=ai_suggestion.ai_suggestion_id, + organization_id=ai_suggestion.organization_id, + data=data, + path=path, + ) + async def create_llm_artifact( self, data: bytes, @@ -181,6 +204,7 @@ class ArtifactManager: step: Step | None = None, observer_thought: ObserverThought | None = None, observer_cruise: ObserverCruise | None = None, + ai_suggestion: AISuggestion | None = None, ) -> None: if step: await self.create_artifact( @@ -218,6 +242,18 @@ class ArtifactManager: artifact_type=ArtifactType.SCREENSHOT_LLM, data=screenshot, ) + elif ai_suggestion: + await self.create_ai_suggestion_artifact( + ai_suggestion=ai_suggestion, + artifact_type=artifact_type, + data=data, + ) + for screenshot in screenshots or []: + await self.create_ai_suggestion_artifact( + ai_suggestion=ai_suggestion, + artifact_type=ArtifactType.SCREENSHOT_LLM, + data=screenshot, + ) async def update_artifact_data( self, diff --git a/skyvern/forge/sdk/artifact/models.py b/skyvern/forge/sdk/artifact/models.py index 6cbe8c6e..5bdd6f87 100644 --- a/skyvern/forge/sdk/artifact/models.py +++ b/skyvern/forge/sdk/artifact/models.py @@ -75,6 +75,7 @@ class Artifact(BaseModel): workflow_run_block_id: str | None = None observer_cruise_id: str | None = None observer_thought_id: str | None = None + ai_suggestion_id: str | None = None signed_url: str | None = None organization_id: str | None = None diff --git a/skyvern/forge/sdk/artifact/storage/base.py b/skyvern/forge/sdk/artifact/storage/base.py index e00a8cfe..cbc20379 100644 --- a/skyvern/forge/sdk/artifact/storage/base.py +++ b/skyvern/forge/sdk/artifact/storage/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock @@ -61,6 +62,12 @@ class BaseStorage(ABC): ) -> str: pass + @abstractmethod + def build_ai_suggestion_uri( + self, artifact_id: str, ai_suggestion: AISuggestion, artifact_type: ArtifactType + ) -> str: + pass + @abstractmethod async def store_artifact(self, artifact: Artifact, data: bytes) -> None: pass diff --git a/skyvern/forge/sdk/artifact/storage/local.py b/skyvern/forge/sdk/artifact/storage/local.py index a12e966b..e1f2215a 100644 --- a/skyvern/forge/sdk/artifact/storage/local.py +++ b/skyvern/forge/sdk/artifact/storage/local.py @@ -11,6 +11,7 @@ from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock @@ -47,6 +48,12 @@ class LocalStorage(BaseStorage): file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"file://{self.artifact_path}/{settings.ENV}/workflow_runs/{workflow_run_block.workflow_run_id}/{workflow_run_block.workflow_run_block_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + def build_ai_suggestion_uri( + self, artifact_id: str, ai_suggestion: AISuggestion, artifact_type: ArtifactType + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"file://{self.artifact_path}/{settings.ENV}/ai_suggestions/{ai_suggestion.ai_suggestion_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + async def store_artifact(self, artifact: Artifact, data: bytes) -> None: file_path = None try: diff --git a/skyvern/forge/sdk/artifact/storage/s3.py b/skyvern/forge/sdk/artifact/storage/s3.py index e5cb5f6e..e44cff66 100644 --- a/skyvern/forge/sdk/artifact/storage/s3.py +++ b/skyvern/forge/sdk/artifact/storage/s3.py @@ -15,6 +15,7 @@ from skyvern.forge.sdk.api.files import ( from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock @@ -50,6 +51,12 @@ class S3Storage(BaseStorage): file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"s3://{self.bucket}/{settings.ENV}/workflow_runs/{workflow_run_block.workflow_run_id}/{workflow_run_block.workflow_run_block_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + def build_ai_suggestion_uri( + self, artifact_id: str, ai_suggestion: AISuggestion, artifact_type: ArtifactType + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"s3://{self.bucket}/{settings.ENV}/ai_suggestions/{ai_suggestion.ai_suggestion_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + async def store_artifact(self, artifact: Artifact, data: bytes) -> None: await self.async_client.upload_file(artifact.uri, data) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index d1468a7c..adbc4984 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -14,6 +14,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.exceptions import NotFoundError from skyvern.forge.sdk.db.models import ( ActionModel, + AISuggestionModel, ArtifactModel, AWSSecretParameterModel, BitwardenCreditCardDataParameterModel, @@ -56,6 +57,7 @@ from skyvern.forge.sdk.db.utils import ( ) from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs from skyvern.forge.sdk.models import Step, StepStatus +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.observers import ( ObserverCruise, ObserverCruiseStatus, @@ -206,6 +208,7 @@ class AgentDB: workflow_run_block_id: str | None = None, observer_cruise_id: str | None = None, observer_thought_id: str | None = None, + ai_suggestion_id: str | None = None, organization_id: str | None = None, ) -> Artifact: try: @@ -220,6 +223,7 @@ class AgentDB: workflow_run_block_id=workflow_run_block_id, observer_cruise_id=observer_cruise_id, observer_thought_id=observer_thought_id, + ai_suggestion_id=ai_suggestion_id, organization_id=organization_id, ) session.add(new_artifact) @@ -1789,6 +1793,21 @@ class AgentDB: await session.refresh(new_task_generation) return TaskGeneration.model_validate(new_task_generation) + async def create_ai_suggestion( + self, + organization_id: str, + ai_suggestion_type: str, + ) -> AISuggestion: + async with self.Session() as session: + new_ai_suggestion = AISuggestionModel( + organization_id=organization_id, + ai_suggestion_type=ai_suggestion_type, + ) + session.add(new_ai_suggestion) + await session.commit() + await session.refresh(new_ai_suggestion) + return AISuggestion.model_validate(new_ai_suggestion) + async def get_task_generation_by_prompt_hash( self, user_prompt_hash: str, diff --git a/skyvern/forge/sdk/db/id.py b/skyvern/forge/sdk/db/id.py index cd3af976..6a25ceb8 100644 --- a/skyvern/forge/sdk/db/id.py +++ b/skyvern/forge/sdk/db/id.py @@ -44,6 +44,7 @@ BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc" BITWARDEN_SENSITIVE_INFORMATION_PARAMETER_PREFIX = "bsi" BITWARDEN_CREDIT_CARD_DATA_PARAMETER_PREFIX = "bccd" TASK_GENERATION_PREFIX = "tg" +AI_SUGGESTION_PREFIX = "as" OBSERVER_CRUISE_ID = "oc" OBSERVER_THOUGHT_ID = "ot" PERSISTENT_BROWSER_SESSION_ID = "pbs" @@ -134,6 +135,11 @@ def generate_task_generation_id() -> str: return f"{TASK_GENERATION_PREFIX}_{int_id}" +def generate_ai_suggestion_id() -> str: + int_id = generate_id() + return f"{AI_SUGGESTION_PREFIX}_{int_id}" + + def generate_totp_code_id() -> str: int_id = generate_id() return f"totp_{int_id}" diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 1c501042..7ad5b2c7 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -20,6 +20,7 @@ from sqlalchemy.orm import DeclarativeBase from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.id import ( generate_action_id, + generate_ai_suggestion_id, generate_artifact_id, generate_aws_secret_parameter_id, generate_bitwarden_credit_card_data_parameter_id, @@ -169,6 +170,7 @@ class ArtifactModel(Base): workflow_run_block_id = Column(String, index=True) observer_cruise_id = Column(String, index=True) observer_thought_id = Column(String, index=True) + ai_suggestion_id = Column(String, index=True) task_id = Column(String, ForeignKey("tasks.task_id")) step_id = Column(String, ForeignKey("steps.step_id"), index=True) artifact_type = Column(String) @@ -441,6 +443,16 @@ class TaskGenerationModel(Base): modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) +class AISuggestionModel(Base): + __tablename__ = "ai_suggestions" + + ai_suggestion_id = Column(String, primary_key=True, default=generate_ai_suggestion_id) + organization_id = Column(String, ForeignKey("organizations.organization_id")) + ai_suggestion_type = Column(String) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) + modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + + class TOTPCodeModel(Base): __tablename__ = "totp_codes" diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 644fddc3..68b7e3f4 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -36,6 +36,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestionBase, AISuggestionRequest from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise from skyvern.forge.sdk.schemas.organizations import ( GetOrganizationAPIKeysResponse, @@ -938,6 +939,38 @@ async def get_workflow( ) +class AISuggestionType(str, Enum): + DATA_SCHEMA = "data_schema" + + +@base_router.post("/suggest/{ai_suggestion_type}", include_in_schema=False) +@base_router.post("/suggest/{ai_suggestion_type}/") +async def make_ai_suggestion( + ai_suggestion_type: AISuggestionType, + data: AISuggestionRequest, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> AISuggestionBase: + llm_prompt = "" + + if ai_suggestion_type == AISuggestionType.DATA_SCHEMA: + llm_prompt = prompt_engine.load_prompt("suggest-data-schema", input=data.input) + + try: + new_ai_suggestion = await app.DATABASE.create_ai_suggestion( + organization_id=current_org.organization_id, + ai_suggestion_type=ai_suggestion_type, + ) + + llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, ai_suggestion=new_ai_suggestion) + parsed_ai_suggestion = AISuggestionBase.model_validate(llm_response) + + return parsed_ai_suggestion + + except LLMProviderError: + LOG.error("Failed to suggest data schema", exc_info=True) + raise HTTPException(status_code=400, detail="Failed to suggest data schema. Please try again later.") + + @base_router.post("/generate/task", include_in_schema=False) @base_router.post("/generate/task/") async def generate_task( diff --git a/skyvern/forge/sdk/schemas/ai_suggestions.py b/skyvern/forge/sdk/schemas/ai_suggestions.py new file mode 100644 index 00000000..00608a94 --- /dev/null +++ b/skyvern/forge/sdk/schemas/ai_suggestions.py @@ -0,0 +1,22 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class AISuggestionBase(BaseModel): + output: dict[str, Any] | str | None = None + + +class AISuggestion(AISuggestionBase): + model_config = ConfigDict(from_attributes=True) + ai_suggestion_type: str + ai_suggestion_id: str + organization_id: str | None = None + + created_at: datetime + modified_at: datetime + + +class AISuggestionRequest(BaseModel): + input: str = Field(..., min_length=1)