diff --git a/alembic/versions/2024_06_07_2257-312d305c6b18_add_task_generations_table.py b/alembic/versions/2024_06_07_2257-312d305c6b18_add_task_generations_table.py new file mode 100644 index 00000000..b00d08dd --- /dev/null +++ b/alembic/versions/2024_06_07_2257-312d305c6b18_add_task_generations_table.py @@ -0,0 +1,53 @@ +"""add task_generations table + +Revision ID: 312d305c6b18 +Revises: 04bf06540db6 +Create Date: 2024-06-07 22:57:18.228793+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "312d305c6b18" +down_revision: Union[str, None] = "04bf06540db6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "task_generations", + sa.Column("task_generation_id", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("user_prompt", sa.String(), nullable=False), + sa.Column("url", sa.String(), nullable=True), + sa.Column("navigation_goal", sa.String(), nullable=True), + sa.Column("navigation_payload", sa.JSON(), nullable=True), + sa.Column("data_extraction_goal", sa.String(), nullable=True), + sa.Column("extracted_information_schema", sa.JSON(), nullable=True), + sa.Column("llm", sa.String(), nullable=True), + sa.Column("llm_prompt", sa.String(), nullable=True), + sa.Column("llm_response", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("modified_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.organization_id"], + ), + sa.PrimaryKeyConstraint("task_generation_id"), + ) + op.create_index(op.f("ix_task_generations_user_prompt"), "task_generations", ["user_prompt"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_task_generations_user_prompt"), table_name="task_generations") + op.drop_table("task_generations") + # ### end Alembic commands ### diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 75a6847a..0a9e0df8 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -210,6 +210,7 @@ class ForgeAgent: ) # Check some conditions before executing the step, throw an exception if the step can't be executed await app.AGENT_FUNCTION.validate_step_execution(task, step) + ( step, browser_state, @@ -337,27 +338,31 @@ class ForgeAgent: ) raise except StepTerminationError as e: - LOG.error( + LOG.warning( "Step cannot be executed. Task failed.", task_id=task.task_id, step_id=step.step_id, + exc_info=True, ) await self.update_step( step=step, status=StepStatus.failed, + force_update=True, ) task = await self.update_task( task, status=TaskStatus.failed, failure_reason=e.message, + force_update=True, ) await self.send_task_response( task=task, last_step=step, api_key=api_key, close_browser_on_completion=close_browser_on_completion, + skip_cleanup=True, ) - return step, detailed_output, next_step + return step, detailed_output, None except FailedToSendWebhook: LOG.exception( "Failed to send webhook", @@ -939,6 +944,7 @@ class ForgeAgent: api_key: str | None = None, close_browser_on_completion: bool = True, skip_artifacts: bool = False, + skip_cleanup: bool = False, ) -> None: """ send the task response to the webhook callback url @@ -957,6 +963,10 @@ class ForgeAgent: ) raise TaskNotFound(task_id=task.task_id) from e task = refreshed_task + if skip_cleanup: + await self.execute_task_webhook(task=task, last_step=last_step, api_key=api_key) + return + # log the task status as an event analytics.capture("skyvern-oss-agent-task-status", {"status": task.status}) # We skip the artifacts and send the webhook response directly only when there is an issue with the browser @@ -1165,8 +1175,10 @@ class ForgeAgent: output: AgentStepOutput | None = None, is_last: bool | None = None, retry_index: int | None = None, + force_update: bool = False, ) -> Step: - step.validate_update(status, output, is_last) + if not force_update: + step.validate_update(status, output, is_last) updates: dict[str, Any] = {} if status is not None: updates["status"] = status @@ -1200,8 +1212,10 @@ class ForgeAgent: status: TaskStatus, extracted_information: dict[str, Any] | list | str | None = None, failure_reason: str | None = None, + force_update: bool = False, ) -> Task: - task.validate_update(status, extracted_information, failure_reason) + if not force_update: + task.validate_update(status, extracted_information, failure_reason) updates: dict[str, Any] = {} if status is not None: updates["status"] = status diff --git a/skyvern/forge/prompts/skyvern/generate-task.j2 b/skyvern/forge/prompts/skyvern/generate-task.j2 new file mode 100644 index 00000000..b6f8ca94 --- /dev/null +++ b/skyvern/forge/prompts/skyvern/generate-task.j2 @@ -0,0 +1,16 @@ +We are building an AI agent that can automate browser tasks. The task creation schema is a JSON object with the following fields: + +url: required field, the starting URL for the task. This will be the first page the agent visits in order to achieve its goals. +navigation_goal: optional. The value should be a string that we can use as an input to a Large Language Modal. It needs to tell the agent the goal in terms of navigating the website. It needs to define a single goal. You can include explicit completion and failure criteria. You can define guardrails that could help the agent from taking certain actions or getting derailed. +data_extraction_goal: optional. The value should be a string that we can use as an input to a Large Language Modal. It needs to tell the agent the goal in terms of extracting data. It needs to be a single goal. +navigation_payload: optional. The value should be JSON. Use this field if there is any information for the agent to be able to complete the task such as values that can help fill a form, parameters for queries and so on. +extracted_information_schema: optional. The exact schema of the data to be extracted. + +At least one of navigation goal or data extraction goal should be provided. The agent can't proceed without any goals. + +If a field is not required to achieve a task, provide the value `null`. + +Respond with only JSON output that follows the task creation schema for the following prompt: +``` +{{ user_prompt }} +``` diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index ad42ca61..98f96950 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -18,6 +18,7 @@ from skyvern.forge.sdk.db.models import ( OrganizationModel, OutputParameterModel, StepModel, + TaskGenerationModel, TaskModel, WorkflowModel, WorkflowParameterModel, @@ -42,6 +43,7 @@ from skyvern.forge.sdk.db.utils import ( convert_to_workflow_run_parameter, ) from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus +from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus from skyvern.forge.sdk.workflow.models.parameter import ( AWSSecretParameter, @@ -1236,3 +1238,34 @@ class AgentDB: ) await session.execute(stmt) await session.commit() + + async def create_task_generation( + self, + organization_id: str, + user_prompt: str, + url: str | None = None, + navigation_goal: str | None = None, + navigation_payload: dict[str, Any] | None = None, + data_extraction_goal: str | None = None, + extracted_information_schema: dict[str, Any] | None = None, + llm: str | None = None, + llm_prompt: str | None = None, + llm_response: str | None = None, + ) -> TaskGeneration: + async with self.Session() as session: + new_task_generation = TaskGenerationModel( + organization_id=organization_id, + user_prompt=user_prompt, + url=url, + navigation_goal=navigation_goal, + navigation_payload=navigation_payload, + data_extraction_goal=data_extraction_goal, + extracted_information_schema=extracted_information_schema, + llm=llm, + llm_prompt=llm_prompt, + llm_response=llm_response, + ) + session.add(new_task_generation) + await session.commit() + await session.refresh(new_task_generation) + return TaskGeneration.model_validate(new_task_generation) diff --git a/skyvern/forge/sdk/db/id.py b/skyvern/forge/sdk/db/id.py index 613f12c9..540bf257 100644 --- a/skyvern/forge/sdk/db/id.py +++ b/skyvern/forge/sdk/db/id.py @@ -40,6 +40,7 @@ WORKFLOW_PARAMETER_PREFIX = "wp" AWS_SECRET_PARAMETER_PREFIX = "asp" OUTPUT_PARAMETER_PREFIX = "op" BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc" +TASK_GENERATION_PREFIX = "tg" def generate_workflow_id() -> str: @@ -107,6 +108,11 @@ def generate_user_id() -> str: return f"{USER_PREFIX}_{int_id}" +def generate_task_generation_id() -> str: + int_id = generate_id() + return f"{TASK_GENERATION_PREFIX}_{int_id}" + + def generate_id() -> int: """ generate a 64-bit int ID diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index fcae00f2..bd8042f8 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -26,6 +26,7 @@ from skyvern.forge.sdk.db.id import ( generate_organization_auth_token_id, generate_output_parameter_id, generate_step_id, + generate_task_generation_id, generate_task_id, generate_workflow_id, generate_workflow_parameter_id, @@ -325,3 +326,27 @@ class WorkflowRunOutputParameterModel(Base): ) value = Column(JSON, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) + + +class TaskGenerationModel(Base): + """ + Generate a task based on the prompt (natural language description of the task) from the user + """ + + __tablename__ = "task_generations" + + task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id) + organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False) + user_prompt = Column(String, nullable=False, index=True) # The prompt from the user + url = Column(String) + navigation_goal = Column(String) + navigation_payload = Column(JSON) + data_extraction_goal = Column(String) + extracted_information_schema = Column(JSON) + + llm = Column(String) # language model to use + llm_prompt = Column(String) # The prompt sent to the language model + llm_response = Column(String) # The response from the language model + + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) + modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 1bb97f55..1cc939a4 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -9,12 +9,15 @@ from pydantic import BaseModel from skyvern import analytics from skyvern.exceptions import StepNotFound from skyvern.forge import app +from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.models import Organization, Step +from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, ProxyLocation, @@ -660,3 +663,33 @@ async def get_workflow( organization_id=current_org.organization_id, version=version, ) + + +@base_router.post("/generate/task", include_in_schema=False) +@base_router.post("/generate/task/") +async def generate_task( + data: GenerateTaskRequest, + current_org: Organization = Depends(org_auth_service.get_current_org_with_authentication), +) -> TaskGeneration: + llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt) + try: + llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt) + parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response) + + # generate a TaskGenerationModel + task_generation = await app.DATABASE.create_task_generation( + organization_id=current_org.organization_id, + user_prompt=data.prompt, + url=parsed_task_generation_obj.url, + navigation_goal=parsed_task_generation_obj.navigation_goal, + navigation_payload=parsed_task_generation_obj.navigation_payload, + data_extraction_goal=parsed_task_generation_obj.data_extraction_goal, + extracted_information_schema=parsed_task_generation_obj.extracted_information_schema, + llm=SettingsManager.get_settings().LLM_KEY, + llm_prompt=llm_prompt, + llm_response=str(llm_response), + ) + return task_generation + except LLMProviderError: + LOG.error("Failed to generate task", exc_info=True) + raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.") diff --git a/skyvern/forge/sdk/schemas/task_generations.py b/skyvern/forge/sdk/schemas/task_generations.py new file mode 100644 index 00000000..7e1ff4a8 --- /dev/null +++ b/skyvern/forge/sdk/schemas/task_generations.py @@ -0,0 +1,42 @@ +from datetime import datetime +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class LLMType(StrEnum): + OPENAI_GPT4O = "OPENAI_GPT4O" + + +class TaskGenerationBase(BaseModel): + model_config = ConfigDict(from_attributes=True) + + organization_id: str | None = None + user_prompt: str | None = None + url: str | None = None + navigation_goal: str | None = None + navigation_payload: dict[str, Any] | None = None + data_extraction_goal: str | None = None + extracted_information_schema: dict[str, Any] | None = None + llm: LLMType | None = None + llm_prompt: str | None = None + llm_response: str | None = None + + +class TaskGenerationCreate(TaskGenerationBase): + organization_id: str + user_prompt: str + + +class TaskGeneration(TaskGenerationBase): + task_generation_id: str + organization_id: str + user_prompt: str + + created_at: datetime + modified_at: datetime + + +class GenerateTaskRequest(BaseModel): + prompt: str