add generate_script to the task v2 table (#3332)
This commit is contained in:
@@ -0,0 +1,33 @@
|
|||||||
|
"""add generate_script to task v2 table
|
||||||
|
|
||||||
|
Revision ID: 8de03b8cb83a
|
||||||
|
Revises: ee2f523ea454
|
||||||
|
Create Date: 2025-08-31 06:18:37.032634+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "8de03b8cb83a"
|
||||||
|
down_revision: Union[str, None] = "ee2f523ea454"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column(
|
||||||
|
"observer_cruises", sa.Column("generate_script", sa.Boolean(), nullable=False, server_default=sa.false())
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("observer_cruises", "generate_script")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -2639,6 +2639,7 @@ class AgentDB:
|
|||||||
max_screenshot_scrolling_times: int | None = None,
|
max_screenshot_scrolling_times: int | None = None,
|
||||||
extra_http_headers: dict[str, str] | None = None,
|
extra_http_headers: dict[str, str] | None = None,
|
||||||
browser_address: str | None = None,
|
browser_address: str | None = None,
|
||||||
|
generate_script: bool = False,
|
||||||
) -> TaskV2:
|
) -> TaskV2:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
new_task_v2 = TaskV2Model(
|
new_task_v2 = TaskV2Model(
|
||||||
@@ -2658,6 +2659,7 @@ class AgentDB:
|
|||||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||||
extra_http_headers=extra_http_headers,
|
extra_http_headers=extra_http_headers,
|
||||||
browser_address=browser_address,
|
browser_address=browser_address,
|
||||||
|
generate_script=generate_script,
|
||||||
)
|
)
|
||||||
session.add(new_task_v2)
|
session.add(new_task_v2)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -664,6 +664,7 @@ class TaskV2Model(Base):
|
|||||||
max_screenshot_scrolling_times = Column(Integer, nullable=True)
|
max_screenshot_scrolling_times = Column(Integer, nullable=True)
|
||||||
extra_http_headers = Column(JSON, nullable=True)
|
extra_http_headers = Column(JSON, nullable=True)
|
||||||
browser_address = Column(String, nullable=True)
|
browser_address = Column(String, nullable=True)
|
||||||
|
generate_script = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
queued_at = Column(DateTime, nullable=True)
|
queued_at = Column(DateTime, nullable=True)
|
||||||
started_at = Column(DateTime, nullable=True)
|
started_at = Column(DateTime, nullable=True)
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class TaskV2(BaseModel):
|
|||||||
max_screenshot_scrolls: int | None = Field(default=None, alias="max_screenshot_scrolling_times")
|
max_screenshot_scrolls: int | None = Field(default=None, alias="max_screenshot_scrolling_times")
|
||||||
extra_http_headers: dict[str, str] | None = None
|
extra_http_headers: dict[str, str] | None = None
|
||||||
browser_address: str | None = None
|
browser_address: str | None = None
|
||||||
|
generate_script: bool = False
|
||||||
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
@@ -155,6 +156,7 @@ class TaskV2Request(BaseModel):
|
|||||||
max_screenshot_scrolls: int | None = None
|
max_screenshot_scrolls: int | None = None
|
||||||
extra_http_headers: dict[str, str] | None = None
|
extra_http_headers: dict[str, str] | None = None
|
||||||
browser_address: str | None = None
|
browser_address: str | None = None
|
||||||
|
generate_script: bool = False
|
||||||
|
|
||||||
@field_validator("url", "webhook_callback_url", "totp_verification_url")
|
@field_validator("url", "webhook_callback_url", "totp_verification_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -2273,6 +2273,7 @@ class WorkflowService:
|
|||||||
proxy_location: ProxyLocation | None = None,
|
proxy_location: ProxyLocation | None = None,
|
||||||
max_screenshot_scrolling_times: int | None = None,
|
max_screenshot_scrolling_times: int | None = None,
|
||||||
extra_http_headers: dict[str, str] | None = None,
|
extra_http_headers: dict[str, str] | None = None,
|
||||||
|
generate_script: bool = False,
|
||||||
status: WorkflowStatus = WorkflowStatus.published,
|
status: WorkflowStatus = WorkflowStatus.published,
|
||||||
) -> Workflow:
|
) -> Workflow:
|
||||||
"""
|
"""
|
||||||
@@ -2289,6 +2290,7 @@ class WorkflowService:
|
|||||||
status=status,
|
status=status,
|
||||||
max_screenshot_scrolls=max_screenshot_scrolling_times,
|
max_screenshot_scrolls=max_screenshot_scrolling_times,
|
||||||
extra_http_headers=extra_http_headers,
|
extra_http_headers=extra_http_headers,
|
||||||
|
generate_script=generate_script,
|
||||||
)
|
)
|
||||||
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
|
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
|
|||||||
@@ -164,7 +164,11 @@ async def initialize_task_v2(
|
|||||||
browser_session_id: str | None = None,
|
browser_session_id: str | None = None,
|
||||||
extra_http_headers: dict[str, str] | None = None,
|
extra_http_headers: dict[str, str] | None = None,
|
||||||
browser_address: str | None = None,
|
browser_address: str | None = None,
|
||||||
|
generate_script: bool = False,
|
||||||
) -> TaskV2:
|
) -> TaskV2:
|
||||||
|
if generate_script:
|
||||||
|
publish_workflow = True
|
||||||
|
|
||||||
task_v2 = await app.DATABASE.create_task_v2(
|
task_v2 = await app.DATABASE.create_task_v2(
|
||||||
prompt=user_prompt,
|
prompt=user_prompt,
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
@@ -178,6 +182,7 @@ async def initialize_task_v2(
|
|||||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||||
extra_http_headers=extra_http_headers,
|
extra_http_headers=extra_http_headers,
|
||||||
browser_address=browser_address,
|
browser_address=browser_address,
|
||||||
|
generate_script=generate_script,
|
||||||
)
|
)
|
||||||
# set task_v2_id in context
|
# set task_v2_id in context
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
@@ -224,6 +229,7 @@ async def initialize_task_v2(
|
|||||||
status=workflow_status,
|
status=workflow_status,
|
||||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||||
extra_http_headers=extra_http_headers,
|
extra_http_headers=extra_http_headers,
|
||||||
|
generate_script=generate_script,
|
||||||
)
|
)
|
||||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||||
request_id=None,
|
request_id=None,
|
||||||
@@ -886,6 +892,11 @@ async def run_task_v2_helper(
|
|||||||
context=context,
|
context=context,
|
||||||
screenshots=completion_screenshots,
|
screenshots=completion_screenshots,
|
||||||
)
|
)
|
||||||
|
if task_v2.generate_script:
|
||||||
|
await app.WORKFLOW_SERVICE.generate_script_if_needed(
|
||||||
|
workflow=workflow,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# total step number validation
|
# total step number validation
|
||||||
|
|||||||
Reference in New Issue
Block a user