task generation (#450)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -0,0 +1,53 @@
|
|||||||
|
"""add task_generations table
|
||||||
|
|
||||||
|
Revision ID: 312d305c6b18
|
||||||
|
Revises: 04bf06540db6
|
||||||
|
Create Date: 2024-06-07 22:57:18.228793+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "312d305c6b18"
|
||||||
|
down_revision: Union[str, None] = "04bf06540db6"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"task_generations",
|
||||||
|
sa.Column("task_generation_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("organization_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_prompt", sa.String(), nullable=False),
|
||||||
|
sa.Column("url", sa.String(), nullable=True),
|
||||||
|
sa.Column("navigation_goal", sa.String(), nullable=True),
|
||||||
|
sa.Column("navigation_payload", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("data_extraction_goal", sa.String(), nullable=True),
|
||||||
|
sa.Column("extracted_information_schema", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("llm", sa.String(), nullable=True),
|
||||||
|
sa.Column("llm_prompt", sa.String(), nullable=True),
|
||||||
|
sa.Column("llm_response", sa.String(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.Column("modified_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["organization_id"],
|
||||||
|
["organizations.organization_id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("task_generation_id"),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_task_generations_user_prompt"), "task_generations", ["user_prompt"], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f("ix_task_generations_user_prompt"), table_name="task_generations")
|
||||||
|
op.drop_table("task_generations")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -210,6 +210,7 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
||||||
await app.AGENT_FUNCTION.validate_step_execution(task, step)
|
await app.AGENT_FUNCTION.validate_step_execution(task, step)
|
||||||
|
|
||||||
(
|
(
|
||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
@@ -337,27 +338,31 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
except StepTerminationError as e:
|
except StepTerminationError as e:
|
||||||
LOG.error(
|
LOG.warning(
|
||||||
"Step cannot be executed. Task failed.",
|
"Step cannot be executed. Task failed.",
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
step_id=step.step_id,
|
step_id=step.step_id,
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
await self.update_step(
|
await self.update_step(
|
||||||
step=step,
|
step=step,
|
||||||
status=StepStatus.failed,
|
status=StepStatus.failed,
|
||||||
|
force_update=True,
|
||||||
)
|
)
|
||||||
task = await self.update_task(
|
task = await self.update_task(
|
||||||
task,
|
task,
|
||||||
status=TaskStatus.failed,
|
status=TaskStatus.failed,
|
||||||
failure_reason=e.message,
|
failure_reason=e.message,
|
||||||
|
force_update=True,
|
||||||
)
|
)
|
||||||
await self.send_task_response(
|
await self.send_task_response(
|
||||||
task=task,
|
task=task,
|
||||||
last_step=step,
|
last_step=step,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
close_browser_on_completion=close_browser_on_completion,
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
|
skip_cleanup=True,
|
||||||
)
|
)
|
||||||
return step, detailed_output, next_step
|
return step, detailed_output, None
|
||||||
except FailedToSendWebhook:
|
except FailedToSendWebhook:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Failed to send webhook",
|
"Failed to send webhook",
|
||||||
@@ -939,6 +944,7 @@ class ForgeAgent:
|
|||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
close_browser_on_completion: bool = True,
|
close_browser_on_completion: bool = True,
|
||||||
skip_artifacts: bool = False,
|
skip_artifacts: bool = False,
|
||||||
|
skip_cleanup: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
send the task response to the webhook callback url
|
send the task response to the webhook callback url
|
||||||
@@ -957,6 +963,10 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
raise TaskNotFound(task_id=task.task_id) from e
|
raise TaskNotFound(task_id=task.task_id) from e
|
||||||
task = refreshed_task
|
task = refreshed_task
|
||||||
|
if skip_cleanup:
|
||||||
|
await self.execute_task_webhook(task=task, last_step=last_step, api_key=api_key)
|
||||||
|
return
|
||||||
|
|
||||||
# log the task status as an event
|
# log the task status as an event
|
||||||
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
|
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
|
||||||
# We skip the artifacts and send the webhook response directly only when there is an issue with the browser
|
# We skip the artifacts and send the webhook response directly only when there is an issue with the browser
|
||||||
@@ -1165,8 +1175,10 @@ class ForgeAgent:
|
|||||||
output: AgentStepOutput | None = None,
|
output: AgentStepOutput | None = None,
|
||||||
is_last: bool | None = None,
|
is_last: bool | None = None,
|
||||||
retry_index: int | None = None,
|
retry_index: int | None = None,
|
||||||
|
force_update: bool = False,
|
||||||
) -> Step:
|
) -> Step:
|
||||||
step.validate_update(status, output, is_last)
|
if not force_update:
|
||||||
|
step.validate_update(status, output, is_last)
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if status is not None:
|
if status is not None:
|
||||||
updates["status"] = status
|
updates["status"] = status
|
||||||
@@ -1200,8 +1212,10 @@ class ForgeAgent:
|
|||||||
status: TaskStatus,
|
status: TaskStatus,
|
||||||
extracted_information: dict[str, Any] | list | str | None = None,
|
extracted_information: dict[str, Any] | list | str | None = None,
|
||||||
failure_reason: str | None = None,
|
failure_reason: str | None = None,
|
||||||
|
force_update: bool = False,
|
||||||
) -> Task:
|
) -> Task:
|
||||||
task.validate_update(status, extracted_information, failure_reason)
|
if not force_update:
|
||||||
|
task.validate_update(status, extracted_information, failure_reason)
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if status is not None:
|
if status is not None:
|
||||||
updates["status"] = status
|
updates["status"] = status
|
||||||
|
|||||||
16
skyvern/forge/prompts/skyvern/generate-task.j2
Normal file
16
skyvern/forge/prompts/skyvern/generate-task.j2
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
We are building an AI agent that can automate browser tasks. The task creation schema is a JSON object with the following fields:
|
||||||
|
|
||||||
|
url: required field, the starting URL for the task. This will be the first page the agent visits in order to achieve its goals.
|
||||||
|
navigation_goal: optional. The value should be a string that we can use as an input to a Large Language Modal. It needs to tell the agent the goal in terms of navigating the website. It needs to define a single goal. You can include explicit completion and failure criteria. You can define guardrails that could help the agent from taking certain actions or getting derailed.
|
||||||
|
data_extraction_goal: optional. The value should be a string that we can use as an input to a Large Language Modal. It needs to tell the agent the goal in terms of extracting data. It needs to be a single goal.
|
||||||
|
navigation_payload: optional. The value should be JSON. Use this field if there is any information for the agent to be able to complete the task such as values that can help fill a form, parameters for queries and so on.
|
||||||
|
extracted_information_schema: optional. The exact schema of the data to be extracted.
|
||||||
|
|
||||||
|
At least one of navigation goal or data extraction goal should be provided. The agent can't proceed without any goals.
|
||||||
|
|
||||||
|
If a field is not required to achieve a task, provide the value `null`.
|
||||||
|
|
||||||
|
Respond with only JSON output that follows the task creation schema for the following prompt:
|
||||||
|
```
|
||||||
|
{{ user_prompt }}
|
||||||
|
```
|
||||||
@@ -18,6 +18,7 @@ from skyvern.forge.sdk.db.models import (
|
|||||||
OrganizationModel,
|
OrganizationModel,
|
||||||
OutputParameterModel,
|
OutputParameterModel,
|
||||||
StepModel,
|
StepModel,
|
||||||
|
TaskGenerationModel,
|
||||||
TaskModel,
|
TaskModel,
|
||||||
WorkflowModel,
|
WorkflowModel,
|
||||||
WorkflowParameterModel,
|
WorkflowParameterModel,
|
||||||
@@ -42,6 +43,7 @@ from skyvern.forge.sdk.db.utils import (
|
|||||||
convert_to_workflow_run_parameter,
|
convert_to_workflow_run_parameter,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||||
|
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||||
AWSSecretParameter,
|
AWSSecretParameter,
|
||||||
@@ -1236,3 +1238,34 @@ class AgentDB:
|
|||||||
)
|
)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def create_task_generation(
|
||||||
|
self,
|
||||||
|
organization_id: str,
|
||||||
|
user_prompt: str,
|
||||||
|
url: str | None = None,
|
||||||
|
navigation_goal: str | None = None,
|
||||||
|
navigation_payload: dict[str, Any] | None = None,
|
||||||
|
data_extraction_goal: str | None = None,
|
||||||
|
extracted_information_schema: dict[str, Any] | None = None,
|
||||||
|
llm: str | None = None,
|
||||||
|
llm_prompt: str | None = None,
|
||||||
|
llm_response: str | None = None,
|
||||||
|
) -> TaskGeneration:
|
||||||
|
async with self.Session() as session:
|
||||||
|
new_task_generation = TaskGenerationModel(
|
||||||
|
organization_id=organization_id,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
url=url,
|
||||||
|
navigation_goal=navigation_goal,
|
||||||
|
navigation_payload=navigation_payload,
|
||||||
|
data_extraction_goal=data_extraction_goal,
|
||||||
|
extracted_information_schema=extracted_information_schema,
|
||||||
|
llm=llm,
|
||||||
|
llm_prompt=llm_prompt,
|
||||||
|
llm_response=llm_response,
|
||||||
|
)
|
||||||
|
session.add(new_task_generation)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(new_task_generation)
|
||||||
|
return TaskGeneration.model_validate(new_task_generation)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ WORKFLOW_PARAMETER_PREFIX = "wp"
|
|||||||
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
||||||
OUTPUT_PARAMETER_PREFIX = "op"
|
OUTPUT_PARAMETER_PREFIX = "op"
|
||||||
BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc"
|
BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc"
|
||||||
|
TASK_GENERATION_PREFIX = "tg"
|
||||||
|
|
||||||
|
|
||||||
def generate_workflow_id() -> str:
|
def generate_workflow_id() -> str:
|
||||||
@@ -107,6 +108,11 @@ def generate_user_id() -> str:
|
|||||||
return f"{USER_PREFIX}_{int_id}"
|
return f"{USER_PREFIX}_{int_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_task_generation_id() -> str:
|
||||||
|
int_id = generate_id()
|
||||||
|
return f"{TASK_GENERATION_PREFIX}_{int_id}"
|
||||||
|
|
||||||
|
|
||||||
def generate_id() -> int:
|
def generate_id() -> int:
|
||||||
"""
|
"""
|
||||||
generate a 64-bit int ID
|
generate a 64-bit int ID
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from skyvern.forge.sdk.db.id import (
|
|||||||
generate_organization_auth_token_id,
|
generate_organization_auth_token_id,
|
||||||
generate_output_parameter_id,
|
generate_output_parameter_id,
|
||||||
generate_step_id,
|
generate_step_id,
|
||||||
|
generate_task_generation_id,
|
||||||
generate_task_id,
|
generate_task_id,
|
||||||
generate_workflow_id,
|
generate_workflow_id,
|
||||||
generate_workflow_parameter_id,
|
generate_workflow_parameter_id,
|
||||||
@@ -325,3 +326,27 @@ class WorkflowRunOutputParameterModel(Base):
|
|||||||
)
|
)
|
||||||
value = Column(JSON, nullable=False)
|
value = Column(JSON, nullable=False)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGenerationModel(Base):
|
||||||
|
"""
|
||||||
|
Generate a task based on the prompt (natural language description of the task) from the user
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "task_generations"
|
||||||
|
|
||||||
|
task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
|
||||||
|
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
|
||||||
|
user_prompt = Column(String, nullable=False, index=True) # The prompt from the user
|
||||||
|
url = Column(String)
|
||||||
|
navigation_goal = Column(String)
|
||||||
|
navigation_payload = Column(JSON)
|
||||||
|
data_extraction_goal = Column(String)
|
||||||
|
extracted_information_schema = Column(JSON)
|
||||||
|
|
||||||
|
llm = Column(String) # language model to use
|
||||||
|
llm_prompt = Column(String) # The prompt sent to the language model
|
||||||
|
llm_response = Column(String) # The response from the language model
|
||||||
|
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
|
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||||
|
|||||||
@@ -9,12 +9,15 @@ from pydantic import BaseModel
|
|||||||
from skyvern import analytics
|
from skyvern import analytics
|
||||||
from skyvern.exceptions import StepNotFound
|
from skyvern.exceptions import StepNotFound
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.prompts import prompt_engine
|
||||||
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||||
from skyvern.forge.sdk.models import Organization, Step
|
from skyvern.forge.sdk.models import Organization, Step
|
||||||
|
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
||||||
from skyvern.forge.sdk.schemas.tasks import (
|
from skyvern.forge.sdk.schemas.tasks import (
|
||||||
CreateTaskResponse,
|
CreateTaskResponse,
|
||||||
ProxyLocation,
|
ProxyLocation,
|
||||||
@@ -660,3 +663,33 @@ async def get_workflow(
|
|||||||
organization_id=current_org.organization_id,
|
organization_id=current_org.organization_id,
|
||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.post("/generate/task", include_in_schema=False)
|
||||||
|
@base_router.post("/generate/task/")
|
||||||
|
async def generate_task(
|
||||||
|
data: GenerateTaskRequest,
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org_with_authentication),
|
||||||
|
) -> TaskGeneration:
|
||||||
|
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
||||||
|
try:
|
||||||
|
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
|
||||||
|
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||||
|
|
||||||
|
# generate a TaskGenerationModel
|
||||||
|
task_generation = await app.DATABASE.create_task_generation(
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
user_prompt=data.prompt,
|
||||||
|
url=parsed_task_generation_obj.url,
|
||||||
|
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||||
|
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||||
|
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
||||||
|
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
||||||
|
llm=SettingsManager.get_settings().LLM_KEY,
|
||||||
|
llm_prompt=llm_prompt,
|
||||||
|
llm_response=str(llm_response),
|
||||||
|
)
|
||||||
|
return task_generation
|
||||||
|
except LLMProviderError:
|
||||||
|
LOG.error("Failed to generate task", exc_info=True)
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||||
|
|||||||
42
skyvern/forge/sdk/schemas/task_generations.py
Normal file
42
skyvern/forge/sdk/schemas/task_generations.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class LLMType(StrEnum):
|
||||||
|
OPENAI_GPT4O = "OPENAI_GPT4O"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGenerationBase(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
organization_id: str | None = None
|
||||||
|
user_prompt: str | None = None
|
||||||
|
url: str | None = None
|
||||||
|
navigation_goal: str | None = None
|
||||||
|
navigation_payload: dict[str, Any] | None = None
|
||||||
|
data_extraction_goal: str | None = None
|
||||||
|
extracted_information_schema: dict[str, Any] | None = None
|
||||||
|
llm: LLMType | None = None
|
||||||
|
llm_prompt: str | None = None
|
||||||
|
llm_response: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGenerationCreate(TaskGenerationBase):
|
||||||
|
organization_id: str
|
||||||
|
user_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGeneration(TaskGenerationBase):
|
||||||
|
task_generation_id: str
|
||||||
|
organization_id: str
|
||||||
|
user_prompt: str
|
||||||
|
|
||||||
|
created_at: datetime
|
||||||
|
modified_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateTaskRequest(BaseModel):
|
||||||
|
prompt: str
|
||||||
Reference in New Issue
Block a user