user defined browser header (#2752)

Co-authored-by: lawyzheng <lawyzheng1106@gmail.com>
This commit is contained in:
Shuchang Zheng
2025-06-19 00:42:34 -07:00
committed by GitHub
parent 2776475ca3
commit df5f40bdb9
15 changed files with 132 additions and 10 deletions

View File

@@ -150,6 +150,7 @@ class AgentDB:
include_action_history_in_verification: bool | None = None,
model: dict[str, Any] | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
) -> Task:
try:
async with self.Session() as session:
@@ -178,6 +179,7 @@ class AgentDB:
include_action_history_in_verification=include_action_history_in_verification,
model=model,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
)
session.add(new_task)
await session.commit()
@@ -1300,6 +1302,7 @@ class AgentDB:
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
persist_browser_session: bool = False,
@@ -1320,6 +1323,7 @@ class AgentDB:
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
persist_browser_session=persist_browser_session,
model=model,
is_saved_task=is_saved_task,
@@ -1564,6 +1568,7 @@ class AgentDB:
totp_identifier: str | None = None,
parent_workflow_run_id: str | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
) -> WorkflowRun:
try:
async with self.Session() as session:
@@ -1578,6 +1583,7 @@ class AgentDB:
totp_identifier=totp_identifier,
parent_workflow_run_id=parent_workflow_run_id,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
)
session.add(workflow_run)
await session.commit()
@@ -2523,6 +2529,7 @@ class AgentDB:
error_code_mapping: dict | None = None,
model: dict[str, Any] | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
) -> TaskV2:
async with self.Session() as session:
new_task_v2 = TaskV2Model(
@@ -2540,6 +2547,7 @@ class AgentDB:
organization_id=organization_id,
model=model,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
)
session.add(new_task_v2)
await session.commit()

View File

@@ -77,6 +77,7 @@ class TaskModel(Base):
failure_reason = Column(String)
proxy_location = Column(String)
extracted_information_schema = Column(JSON)
extra_http_headers = Column(JSON, nullable=True)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), index=True)
order = Column(Integer, nullable=True)
retry = Column(Integer, nullable=True)
@@ -220,6 +221,7 @@ class WorkflowModel(Base):
proxy_location = Column(String)
webhook_callback_url = Column(String)
max_screenshot_scrolling_times = Column(Integer, nullable=True)
extra_http_headers = Column(JSON, nullable=True)
totp_verification_url = Column(String)
totp_identifier = Column(String)
persist_browser_session = Column(Boolean, default=False, nullable=False)
@@ -257,6 +259,7 @@ class WorkflowRunModel(Base):
totp_verification_url = Column(String)
totp_identifier = Column(String)
max_screenshot_scrolling_times = Column(Integer, nullable=True)
extra_http_headers = Column(JSON, nullable=True)
queued_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)
@@ -626,6 +629,7 @@ class TaskV2Model(Base):
error_code_mapping = Column(JSON, nullable=True)
max_steps = Column(Integer, nullable=True)
max_screenshot_scrolling_times = Column(Integer, nullable=True)
extra_http_headers = Column(JSON, nullable=True)
queued_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)

View File

@@ -130,6 +130,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
organization_id=task_obj.organization_id,
proxy_location=(ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None),
extracted_information_schema=task_obj.extracted_information_schema,
extra_http_headers=task_obj.extra_http_headers,
workflow_run_id=task_obj.workflow_run_id,
workflow_permanent_id=workflow_permanent_id,
order=task_obj.order,
@@ -248,6 +249,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
modified_at=workflow_model.modified_at,
deleted_at=workflow_model.deleted_at,
status=WorkflowStatus(workflow_model.status),
extra_http_headers=workflow_model.extra_http_headers,
)
@@ -281,6 +283,7 @@ def convert_to_workflow_run(
modified_at=workflow_run_model.modified_at,
workflow_title=workflow_title,
max_screenshot_scrolling_times=workflow_run_model.max_screenshot_scrolling_times,
extra_http_headers=workflow_run_model.extra_http_headers,
)

View File

@@ -167,6 +167,7 @@ async def run_task(
include_action_history_in_verification=run_request.include_action_history_in_verification,
model=run_request.model,
max_screenshot_scrolling_times=run_request.max_screenshot_scrolling_times,
extra_http_headers=run_request.extra_http_headers,
)
task_v1_response = await task_v1_service.run_task(
task=task_v1_request,
@@ -224,6 +225,7 @@ async def run_task(
create_task_run=True,
model=run_request.model,
max_screenshot_scrolling_times=run_request.max_screenshot_scrolling_times,
extra_http_headers=run_request.extra_http_headers,
)
except MissingBrowserAddressError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@@ -320,9 +322,10 @@ async def run_workflow(
proxy_location=workflow_run_request.proxy_location,
webhook_callback_url=workflow_run_request.webhook_url,
totp_identifier=workflow_run_request.totp_identifier,
totp_url=workflow_run_request.totp_url,
totp_verification_url=workflow_run_request.totp_url,
browser_session_id=workflow_run_request.browser_session_id,
max_screenshot_scrolling_times=workflow_run_request.max_screenshot_scrolling_times,
extra_http_headers=workflow_run_request.extra_http_headers,
)
try:
@@ -1822,6 +1825,7 @@ async def run_task_v2(
error_code_mapping=data.error_code_mapping,
max_screenshot_scrolling_times=data.max_screenshot_scrolling_times,
browser_session_id=data.browser_session_id,
extra_http_headers=data.extra_http_headers,
)
except MissingBrowserAddressError as e:
raise HTTPException(status_code=400, detail=str(e)) from e

View File

@@ -49,6 +49,7 @@ class TaskV2(BaseModel):
started_at: datetime | None = None
finished_at: datetime | None = None
max_screenshot_scrolling_times: int | None = None
extra_http_headers: dict[str, str] | None = None
created_at: datetime
modified_at: datetime
@@ -150,6 +151,7 @@ class TaskV2Request(BaseModel):
extracted_information_schema: dict | list | str | None = None
error_code_mapping: dict[str, str] | None = None
max_screenshot_scrolling_times: int | None = None
extra_http_headers: dict[str, str] | None = None
@field_validator("url", "webhook_callback_url", "totp_verification_url")
@classmethod

View File

@@ -73,6 +73,9 @@ class TaskBase(BaseModel):
default=None,
description="The requested schema of the extracted information.",
)
extra_http_headers: dict[str, str] | None = Field(
None, description="The extra HTTP headers for the requests in browser."
)
complete_criterion: str | None = Field(
default=None, description="Criterion to complete", examples=["Complete if 'hello world' shows up on the page"]
)

View File

@@ -23,6 +23,7 @@ class WorkflowRequestBody(BaseModel):
totp_identifier: str | None = None
browser_session_id: str | None = None
max_screenshot_scrolling_times: int | None = None
extra_http_headers: dict[str, str] | None = None
@field_validator("webhook_callback_url", "totp_verification_url")
@classmethod
@@ -78,6 +79,7 @@ class Workflow(BaseModel):
model: dict[str, Any] | None = None
status: WorkflowStatus = WorkflowStatus.published
max_screenshot_scrolling_times: int | None = None
extra_http_headers: dict[str, str] | None = None
created_at: datetime
modified_at: datetime
@@ -110,6 +112,7 @@ class WorkflowRun(BaseModel):
workflow_permanent_id: str
organization_id: str
status: WorkflowRunStatus
extra_http_headers: dict[str, str] | None = None
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None

View File

@@ -425,4 +425,5 @@ class WorkflowCreateYAMLRequest(BaseModel):
workflow_definition: WorkflowDefinitionYAML
is_saved_task: bool = False
max_screenshot_scrolling_times: int | None = None
extra_http_headers: dict[str, str] | None = None
status: WorkflowStatus = WorkflowStatus.published

View File

@@ -599,6 +599,7 @@ class WorkflowService:
version: int | None = None,
is_saved_task: bool = False,
status: WorkflowStatus = WorkflowStatus.published,
extra_http_headers: dict[str, str] | None = None,
) -> Workflow:
return await app.DATABASE.create_workflow(
title=title,
@@ -616,6 +617,7 @@ class WorkflowService:
version=version,
is_saved_task=is_saved_task,
status=status,
extra_http_headers=extra_http_headers,
)
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow:
@@ -782,6 +784,7 @@ class WorkflowService:
totp_identifier=workflow_request.totp_identifier,
parent_workflow_run_id=parent_workflow_run_id,
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolling_times,
extra_http_headers=workflow_request.extra_http_headers,
)
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> WorkflowRun:
@@ -1470,6 +1473,7 @@ class WorkflowService:
persist_browser_session=request.persist_browser_session,
model=request.model,
max_screenshot_scrolling_times=request.max_screenshot_scrolling_times,
extra_http_headers=request.extra_http_headers,
workflow_permanent_id=workflow_permanent_id,
version=existing_version + 1,
is_saved_task=request.is_saved_task,
@@ -1488,6 +1492,7 @@ class WorkflowService:
persist_browser_session=request.persist_browser_session,
model=request.model,
max_screenshot_scrolling_times=request.max_screenshot_scrolling_times,
extra_http_headers=request.extra_http_headers,
is_saved_task=request.is_saved_task,
status=request.status,
)
@@ -2069,6 +2074,8 @@ class WorkflowService:
organization: Organization,
title: str,
proxy_location: ProxyLocation | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
status: WorkflowStatus = WorkflowStatus.published,
) -> Workflow:
"""
@@ -2083,6 +2090,8 @@ class WorkflowService:
),
proxy_location=proxy_location,
status=status,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
)
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
organization=organization,