task generation (#450)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-06-07 15:59:53 -07:00
committed by GitHub
parent 12b83e009e
commit d18fc5b59c
8 changed files with 226 additions and 4 deletions

View File

@@ -0,0 +1,53 @@
"""add task_generations table
Revision ID: 312d305c6b18
Revises: 04bf06540db6
Create Date: 2024-06-07 22:57:18.228793+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "312d305c6b18"
down_revision: Union[str, None] = "04bf06540db6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"task_generations",
sa.Column("task_generation_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("user_prompt", sa.String(), nullable=False),
sa.Column("url", sa.String(), nullable=True),
sa.Column("navigation_goal", sa.String(), nullable=True),
sa.Column("navigation_payload", sa.JSON(), nullable=True),
sa.Column("data_extraction_goal", sa.String(), nullable=True),
sa.Column("extracted_information_schema", sa.JSON(), nullable=True),
sa.Column("llm", sa.String(), nullable=True),
sa.Column("llm_prompt", sa.String(), nullable=True),
sa.Column("llm_response", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.PrimaryKeyConstraint("task_generation_id"),
)
op.create_index(op.f("ix_task_generations_user_prompt"), "task_generations", ["user_prompt"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_task_generations_user_prompt"), table_name="task_generations")
op.drop_table("task_generations")
# ### end Alembic commands ###

View File

@@ -210,6 +210,7 @@ class ForgeAgent:
) )
# Check some conditions before executing the step, throw an exception if the step can't be executed # Check some conditions before executing the step, throw an exception if the step can't be executed
await app.AGENT_FUNCTION.validate_step_execution(task, step) await app.AGENT_FUNCTION.validate_step_execution(task, step)
( (
step, step,
browser_state, browser_state,
@@ -337,27 +338,31 @@ class ForgeAgent:
) )
raise raise
except StepTerminationError as e: except StepTerminationError as e:
LOG.error( LOG.warning(
"Step cannot be executed. Task failed.", "Step cannot be executed. Task failed.",
task_id=task.task_id, task_id=task.task_id,
step_id=step.step_id, step_id=step.step_id,
exc_info=True,
) )
await self.update_step( await self.update_step(
step=step, step=step,
status=StepStatus.failed, status=StepStatus.failed,
force_update=True,
) )
task = await self.update_task( task = await self.update_task(
task, task,
status=TaskStatus.failed, status=TaskStatus.failed,
failure_reason=e.message, failure_reason=e.message,
force_update=True,
) )
await self.send_task_response( await self.send_task_response(
task=task, task=task,
last_step=step, last_step=step,
api_key=api_key, api_key=api_key,
close_browser_on_completion=close_browser_on_completion, close_browser_on_completion=close_browser_on_completion,
skip_cleanup=True,
) )
return step, detailed_output, next_step return step, detailed_output, None
except FailedToSendWebhook: except FailedToSendWebhook:
LOG.exception( LOG.exception(
"Failed to send webhook", "Failed to send webhook",
@@ -939,6 +944,7 @@ class ForgeAgent:
api_key: str | None = None, api_key: str | None = None,
close_browser_on_completion: bool = True, close_browser_on_completion: bool = True,
skip_artifacts: bool = False, skip_artifacts: bool = False,
skip_cleanup: bool = False,
) -> None: ) -> None:
""" """
send the task response to the webhook callback url send the task response to the webhook callback url
@@ -957,6 +963,10 @@ class ForgeAgent:
) )
raise TaskNotFound(task_id=task.task_id) from e raise TaskNotFound(task_id=task.task_id) from e
task = refreshed_task task = refreshed_task
if skip_cleanup:
await self.execute_task_webhook(task=task, last_step=last_step, api_key=api_key)
return
# log the task status as an event # log the task status as an event
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status}) analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
# We skip the artifacts and send the webhook response directly only when there is an issue with the browser # We skip the artifacts and send the webhook response directly only when there is an issue with the browser
@@ -1165,8 +1175,10 @@ class ForgeAgent:
output: AgentStepOutput | None = None, output: AgentStepOutput | None = None,
is_last: bool | None = None, is_last: bool | None = None,
retry_index: int | None = None, retry_index: int | None = None,
force_update: bool = False,
) -> Step: ) -> Step:
step.validate_update(status, output, is_last) if not force_update:
step.validate_update(status, output, is_last)
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if status is not None: if status is not None:
updates["status"] = status updates["status"] = status
@@ -1200,8 +1212,10 @@ class ForgeAgent:
status: TaskStatus, status: TaskStatus,
extracted_information: dict[str, Any] | list | str | None = None, extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None, failure_reason: str | None = None,
force_update: bool = False,
) -> Task: ) -> Task:
task.validate_update(status, extracted_information, failure_reason) if not force_update:
task.validate_update(status, extracted_information, failure_reason)
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if status is not None: if status is not None:
updates["status"] = status updates["status"] = status

View File

@@ -0,0 +1,16 @@
We are building an AI agent that can automate browser tasks. The task creation schema is a JSON object with the following fields:
url: required field, the starting URL for the task. This will be the first page the agent visits in order to achieve its goals.
navigation_goal: optional. The value should be a string that we can use as an input to a Large Language Modal. It needs to tell the agent the goal in terms of navigating the website. It needs to define a single goal. You can include explicit completion and failure criteria. You can define guardrails that could help the agent from taking certain actions or getting derailed.
data_extraction_goal: optional. The value should be a string that we can use as an input to a Large Language Modal. It needs to tell the agent the goal in terms of extracting data. It needs to be a single goal.
navigation_payload: optional. The value should be JSON. Use this field if there is any information for the agent to be able to complete the task such as values that can help fill a form, parameters for queries and so on.
extracted_information_schema: optional. The exact schema of the data to be extracted.
At least one of navigation goal or data extraction goal should be provided. The agent can't proceed without any goals.
If a field is not required to achieve a task, provide the value `null`.
Respond with only JSON output that follows the task creation schema for the following prompt:
```
{{ user_prompt }}
```

View File

@@ -18,6 +18,7 @@ from skyvern.forge.sdk.db.models import (
OrganizationModel, OrganizationModel,
OutputParameterModel, OutputParameterModel,
StepModel, StepModel,
TaskGenerationModel,
TaskModel, TaskModel,
WorkflowModel, WorkflowModel,
WorkflowParameterModel, WorkflowParameterModel,
@@ -42,6 +43,7 @@ from skyvern.forge.sdk.db.utils import (
convert_to_workflow_run_parameter, convert_to_workflow_run_parameter,
) )
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import ( from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter, AWSSecretParameter,
@@ -1236,3 +1238,34 @@ class AgentDB:
) )
await session.execute(stmt) await session.execute(stmt)
await session.commit() await session.commit()
async def create_task_generation(
self,
organization_id: str,
user_prompt: str,
url: str | None = None,
navigation_goal: str | None = None,
navigation_payload: dict[str, Any] | None = None,
data_extraction_goal: str | None = None,
extracted_information_schema: dict[str, Any] | None = None,
llm: str | None = None,
llm_prompt: str | None = None,
llm_response: str | None = None,
) -> TaskGeneration:
async with self.Session() as session:
new_task_generation = TaskGenerationModel(
organization_id=organization_id,
user_prompt=user_prompt,
url=url,
navigation_goal=navigation_goal,
navigation_payload=navigation_payload,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=extracted_information_schema,
llm=llm,
llm_prompt=llm_prompt,
llm_response=llm_response,
)
session.add(new_task_generation)
await session.commit()
await session.refresh(new_task_generation)
return TaskGeneration.model_validate(new_task_generation)

View File

@@ -40,6 +40,7 @@ WORKFLOW_PARAMETER_PREFIX = "wp"
AWS_SECRET_PARAMETER_PREFIX = "asp" AWS_SECRET_PARAMETER_PREFIX = "asp"
OUTPUT_PARAMETER_PREFIX = "op" OUTPUT_PARAMETER_PREFIX = "op"
BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc" BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc"
TASK_GENERATION_PREFIX = "tg"
def generate_workflow_id() -> str: def generate_workflow_id() -> str:
@@ -107,6 +108,11 @@ def generate_user_id() -> str:
return f"{USER_PREFIX}_{int_id}" return f"{USER_PREFIX}_{int_id}"
def generate_task_generation_id() -> str:
int_id = generate_id()
return f"{TASK_GENERATION_PREFIX}_{int_id}"
def generate_id() -> int: def generate_id() -> int:
""" """
generate a 64-bit int ID generate a 64-bit int ID

View File

@@ -26,6 +26,7 @@ from skyvern.forge.sdk.db.id import (
generate_organization_auth_token_id, generate_organization_auth_token_id,
generate_output_parameter_id, generate_output_parameter_id,
generate_step_id, generate_step_id,
generate_task_generation_id,
generate_task_id, generate_task_id,
generate_workflow_id, generate_workflow_id,
generate_workflow_parameter_id, generate_workflow_parameter_id,
@@ -325,3 +326,27 @@ class WorkflowRunOutputParameterModel(Base):
) )
value = Column(JSON, nullable=False) value = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
class TaskGenerationModel(Base):
"""
Generate a task based on the prompt (natural language description of the task) from the user
"""
__tablename__ = "task_generations"
task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
user_prompt = Column(String, nullable=False, index=True) # The prompt from the user
url = Column(String)
navigation_goal = Column(String)
navigation_payload = Column(JSON)
data_extraction_goal = Column(String)
extracted_information_schema = Column(JSON)
llm = Column(String) # language model to use
llm_prompt = Column(String) # The prompt sent to the language model
llm_response = Column(String) # The response from the language model
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -9,12 +9,15 @@ from pydantic import BaseModel
from skyvern import analytics from skyvern import analytics
from skyvern.exceptions import StepNotFound from skyvern.exceptions import StepNotFound
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Organization, Step from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.tasks import ( from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse, CreateTaskResponse,
ProxyLocation, ProxyLocation,
@@ -660,3 +663,33 @@ async def get_workflow(
organization_id=current_org.organization_id, organization_id=current_org.organization_id,
version=version, version=version,
) )
@base_router.post("/generate/task", include_in_schema=False)
@base_router.post("/generate/task/")
async def generate_task(
data: GenerateTaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org_with_authentication),
) -> TaskGeneration:
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
# generate a TaskGenerationModel
task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload,
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
llm=SettingsManager.get_settings().LLM_KEY,
llm_prompt=llm_prompt,
llm_response=str(llm_response),
)
return task_generation
except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict
class LLMType(StrEnum):
OPENAI_GPT4O = "OPENAI_GPT4O"
class TaskGenerationBase(BaseModel):
model_config = ConfigDict(from_attributes=True)
organization_id: str | None = None
user_prompt: str | None = None
url: str | None = None
navigation_goal: str | None = None
navigation_payload: dict[str, Any] | None = None
data_extraction_goal: str | None = None
extracted_information_schema: dict[str, Any] | None = None
llm: LLMType | None = None
llm_prompt: str | None = None
llm_response: str | None = None
class TaskGenerationCreate(TaskGenerationBase):
organization_id: str
user_prompt: str
class TaskGeneration(TaskGenerationBase):
task_generation_id: str
organization_id: str
user_prompt: str
created_at: datetime
modified_at: datetime
class GenerateTaskRequest(BaseModel):
prompt: str