diff --git a/alembic/versions/2025_02_09_1226-60d0743274c9_add_task_runs.py b/alembic/versions/2025_02_09_1226-60d0743274c9_add_task_runs.py new file mode 100644 index 00000000..21235c86 --- /dev/null +++ b/alembic/versions/2025_02_09_1226-60d0743274c9_add_task_runs.py @@ -0,0 +1,46 @@ +"""add task_runs + +Revision ID: 60d0743274c9 +Revises: ebf093461132 +Create Date: 2025-02-09 12:26:55.725475+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "60d0743274c9" +down_revision: Union[str, None] = "ebf093461132" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "task_runs", + sa.Column("task_run_id", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("task_run_type", sa.String(), nullable=False), + sa.Column("run_id", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=True), + sa.Column("url", sa.String(), nullable=True), + sa.Column("url_hash", sa.String(), nullable=True), + sa.Column("cached", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("modified_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("task_run_id"), + ) + op.create_index("task_run_org_url_index", "task_runs", ["organization_id", "url_hash", "cached"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("task_run_org_url_index", table_name="task_runs") + op.drop_table("task_runs") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/core/hashing.py b/skyvern/forge/sdk/core/hashing.py new file mode 100644 index 00000000..4323ca2f --- /dev/null +++ b/skyvern/forge/sdk/core/hashing.py @@ -0,0 +1,5 @@ +import hashlib + + +def generate_url_hash(url: str) -> str: + return hashlib.sha256(url.encode()).hexdigest() diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 0f209516..4368be7c 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -29,6 +29,7 @@ from skyvern.forge.sdk.db.models import ( StepModel, TaskGenerationModel, TaskModel, + TaskRunModel, TOTPCodeModel, WorkflowModel, WorkflowParameterModel, @@ -62,6 +63,7 @@ from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskStatus from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession from skyvern.forge.sdk.schemas.task_generations import TaskGeneration +from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus from skyvern.forge.sdk.schemas.totp_codes import TOTPCode from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock @@ -2647,3 +2649,26 @@ class AgentDB: except Exception: LOG.error("UnexpectedError", exc_info=True) raise + + async def create_task_run( + self, + task_run_type: TaskRunType, + organization_id: str, + run_id: str, + title: str | None = None, + url: str | None = None, + url_hash: str | None = None, + ) -> TaskRun: + async with self.Session() as session: + task_run = TaskRunModel( + task_run_type=task_run_type, + organization_id=organization_id, + run_id=run_id, + title=title, + url=url, + url_hash=url_hash, + ) + session.add(task_run) + await session.commit() + await session.refresh(task_run) + return TaskRun.model_validate(task_run) diff --git a/skyvern/forge/sdk/db/id.py b/skyvern/forge/sdk/db/id.py index 54c66447..05c46d6c 100644 --- a/skyvern/forge/sdk/db/id.py +++ b/skyvern/forge/sdk/db/id.py @@ -43,6 +43,7 @@ PERSISTENT_BROWSER_SESSION_ID = "pbs" STEP_PREFIX = "stp" TASK_GENERATION_PREFIX = "tg" TASK_PREFIX = "tsk" +TASK_RUN_PREFIX = "tr" TOTP_CODE_PREFIX = "totp" USER_PREFIX = "u" WORKFLOW_PARAMETER_PREFIX = "wp" @@ -167,6 +168,11 @@ def generate_persistent_browser_session_id() -> str: return f"{PERSISTENT_BROWSER_SESSION_ID}_{int_id}" +def generate_task_run_id() -> str: + int_id = generate_id() + return f"{TASK_RUN_PREFIX}_{int_id}" + + def generate_id() -> int: """ generate a 64-bit int ID diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index ffbb267a..74dda787 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -36,6 +36,7 @@ from skyvern.forge.sdk.db.id import ( generate_step_id, generate_task_generation_id, generate_task_id, + generate_task_run_id, generate_totp_code_id, generate_workflow_id, generate_workflow_parameter_id, @@ -609,3 +610,19 @@ class PersistentBrowserSessionModel(Base): created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) deleted_at = Column(DateTime, nullable=True) + + +class TaskRunModel(Base): + __tablename__ = "task_runs" + __table_args__ = (Index("task_run_org_url_index", "organization_id", "url_hash", "cached"),) + + task_run_id = Column(String, primary_key=True, default=generate_task_run_id) + organization_id = Column(String, nullable=False) + task_run_type = Column(String, nullable=False) + run_id = Column(String, nullable=False) + title = Column(String, nullable=True) + url = Column(String, nullable=True) + url_hash = Column(String, nullable=True) + cached = Column(Boolean, nullable=False, default=False) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) + modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 0588c257..b5a5dbf4 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -32,6 +32,7 @@ from skyvern.forge.sdk.api.aws import aws_client from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.artifact.models import Artifact from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.core.hashing import generate_url_hash from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType @@ -46,6 +47,7 @@ from skyvern.forge.sdk.schemas.organizations import ( OrganizationUpdate, ) from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase +from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, OrderBy, @@ -149,6 +151,15 @@ async def create_agent_task( await PermissionCheckerFactory.get_instance().check(current_org) created_task = await app.agent.create_task(task, current_org.organization_id) + url_hash = generate_url_hash(task.url) + await app.DATABASE.create_task_run( + task_run_type=TaskRunType.task_v1, + organization_id=current_org.organization_id, + run_id=created_task.task_id, + title=task.title, + url=task.url, + url_hash=url_hash, + ) if x_max_steps_override: LOG.info( "Overriding max steps per run", @@ -676,6 +687,17 @@ async def execute_workflow( max_steps_override=x_max_steps_override, is_template_workflow=template, ) + workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_id, + organization_id=current_org.organization_id, + version=version, + ) + await app.DATABASE.create_task_run( + task_run_type=TaskRunType.workflow_run, + organization_id=current_org.organization_id, + run_id=workflow_run.workflow_run_id, + title=workflow.title, + ) if x_max_steps_override: LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) await AsyncExecutorFactory.get_executor().execute_workflow( @@ -1208,6 +1230,7 @@ async def observer_task( webhook_callback_url=data.webhook_callback_url, proxy_location=data.proxy_location, publish_workflow=data.publish_workflow, + create_task_run=True, ) except LLMProviderError: LOG.error("LLM failure to initialize observer cruise", exc_info=True) diff --git a/skyvern/forge/sdk/schemas/task_runs.py b/skyvern/forge/sdk/schemas/task_runs.py new file mode 100644 index 00000000..bbb344af --- /dev/null +++ b/skyvern/forge/sdk/schemas/task_runs.py @@ -0,0 +1,24 @@ +from datetime import datetime +from enum import StrEnum + +from pydantic import BaseModel, ConfigDict + + +class TaskRunType(StrEnum): + task_v1 = "task_v1" + task_v2 = "task_v2" + workflow_run = "workflow_run" + + +class TaskRun(BaseModel): + model_config = ConfigDict(from_attributes=True) + + task_run_id: str + task_run_type: TaskRunType + run_id: str + organization_id: str | None = None + title: str | None = None + url: str | None = None + cached: bool = False + created_at: datetime + modified_at: datetime diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index 10b7cfd4..a75e9f48 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -13,6 +13,7 @@ from skyvern.forge import app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.core.hashing import generate_url_hash from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType @@ -24,6 +25,7 @@ from skyvern.forge.sdk.schemas.observers import ( ObserverThoughtType, ) from skyvern.forge.sdk.schemas.organizations import Organization +from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.tasks import ProxyLocation from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType from skyvern.forge.sdk.workflow.models.block import ( @@ -97,6 +99,7 @@ async def initialize_observer_task( webhook_callback_url: str | None = None, publish_workflow: bool = False, parent_workflow_run_id: str | None = None, + create_task_run: bool = False, ) -> ObserverTask: observer_task = await app.DATABASE.create_observer_cruise( prompt=user_prompt, @@ -189,6 +192,15 @@ async def initialize_observer_task( url=url, organization_id=organization.organization_id, ) + if create_task_run: + await app.DATABASE.create_task_run( + task_run_type=TaskRunType.task_v2, + organization_id=organization.organization_id, + run_id=observer_task.observer_cruise_id, + title=new_workflow.title, + url=url, + url_hash=generate_url_hash(url), + ) except Exception: LOG.warning("Failed to update task 2.0", exc_info=True) # fail the workflow run