add webhook support for observer (#1546)

This commit is contained in:
Shuchang Zheng
2025-01-14 08:59:53 -08:00
committed by GitHub
parent 950a4a54f3
commit 0392763998
8 changed files with 160 additions and 13 deletions

View File

@@ -0,0 +1,59 @@
"""observer webhook_callback_url, totp_verification_url, totp_identifier, proxy_location
Revision ID: 46e38fc53f64
Revises: 6a947c379c02
Create Date: 2025-01-14 16:41:46.037751+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "46e38fc53f64"
down_revision: Union[str, None] = "6a947c379c02"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("observer_cruises", sa.Column("webhook_callback_url", sa.String(), nullable=True))
op.add_column("observer_cruises", sa.Column("totp_verification_url", sa.String(), nullable=True))
op.add_column("observer_cruises", sa.Column("totp_identifier", sa.String(), nullable=True))
op.add_column(
"observer_cruises",
sa.Column(
"proxy_location",
sa.Enum(
"US_CA",
"US_NY",
"US_TX",
"US_FL",
"US_WA",
"RESIDENTIAL",
"RESIDENTIAL_ES",
"RESIDENTIAL_IE",
"RESIDENTIAL_GB",
"RESIDENTIAL_IN",
"RESIDENTIAL_JP",
"RESIDENTIAL_FR",
"NONE",
name="proxylocation",
),
nullable=True,
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("observer_cruises", "proxy_location")
op.drop_column("observer_cruises", "totp_identifier")
op.drop_column("observer_cruises", "totp_verification_url")
op.drop_column("observer_cruises", "webhook_callback_url")
# ### end Alembic commands ###

View File

@@ -29,11 +29,13 @@ class FailedToSendWebhook(SkyvernException):
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
observer_cruise_id: str | None = None,
):
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""
workflow_str = f"workflow_id={workflow_id}" if workflow_id else ""
task_str = f"task_id={task_id}" if task_id else ""
super().__init__(f"Failed to send webhook. {workflow_run_str} {workflow_str} {task_str}")
observer_cruise_str = f"observer_cruise_id={observer_cruise_id}" if observer_cruise_id else ""
super().__init__(f"Failed to send webhook. {workflow_run_str} {workflow_str} {task_str} {observer_cruise_str}")
class ProxyLocationNotSupportedError(SkyvernException):

View File

@@ -1613,9 +1613,10 @@ class ForgeAgent:
headers=headers,
)
try:
resp = await httpx.AsyncClient().post(
task.webhook_callback_url, data=payload, headers=headers, timeout=httpx.Timeout(30.0)
)
async with httpx.AsyncClient() as client:
resp = await client.post(
task.webhook_callback_url, data=payload, headers=headers, timeout=httpx.Timeout(30.0)
)
if resp.status_code == 200:
LOG.info(
"Webhook sent successfully",

View File

@@ -550,6 +550,10 @@ class ObserverCruiseModel(Base):
url = Column(String, nullable=True)
summary = Column(String, nullable=True)
output = Column(JSON, nullable=True)
webhook_callback_url = Column(String, nullable=True)
totp_verification_url = Column(String, nullable=True)
totp_identifier = Column(String, nullable=True)
proxy_location = Column(Enum(ProxyLocation), nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -91,5 +91,16 @@ class ObserverMetadata(BaseModel):
class CruiseRequest(BaseModel):
user_prompt: str
url: HttpUrl | None = None
url: str | None = None
browser_session_id: str | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
@field_validator("url", "webhook_callback_url", "totp_verification_url")
@classmethod
def validate_urls(cls, url: str | None) -> str | None:
if url is None:
return None
return validate_url(url)

View File

@@ -4,15 +4,18 @@ import string
from datetime import datetime
from typing import Any
import httpx
import structlog
from sqlalchemy.exc import OperationalError
from skyvern.exceptions import UrlGenerationFailure
from skyvern.exceptions import FailedToSendWebhook, UrlGenerationFailure
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.observers import (
ObserverCruise,
ObserverCruiseStatus,
@@ -1061,6 +1064,9 @@ async def mark_observer_cruise_as_failed(
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
workflow_run_id, failure_reason=failure_reason or "Observer cruise failed"
)
observer_cruise = await get_observer_cruise(observer_cruise_id, organization_id=organization_id)
if observer_cruise:
await send_observer_cruise_webhook(observer_cruise)
async def mark_observer_cruise_as_completed(
@@ -1080,6 +1086,10 @@ async def mark_observer_cruise_as_completed(
if workflow_run_id:
await app.WORKFLOW_SERVICE.mark_workflow_run_as_completed(workflow_run_id)
observer_cruise = await get_observer_cruise(observer_cruise_id, organization_id=organization_id)
if observer_cruise:
await send_observer_cruise_webhook(observer_cruise)
def _get_extracted_data_from_block_result(
block_result: BlockResult,
@@ -1195,3 +1205,52 @@ async def _summarize_observer_cruise(
summary=thought,
output=summarized_output,
)
async def send_observer_cruise_webhook(observer_cruise: ObserverCruise) -> None:
if not observer_cruise.webhook_callback_url:
return
organization_id = observer_cruise.organization_id
if not organization_id:
return
api_key = await app.DATABASE.get_valid_org_auth_token(
organization_id,
OrganizationAuthTokenType.api,
)
if not api_key:
LOG.warning(
"No valid API key found for the organization of observer cruise",
observer_cruise_id=observer_cruise.observer_cruise_id,
)
return
# build the observer cruise response
payload = observer_cruise.model_dump_json()
headers = generate_skyvern_webhook_headers(payload=payload, api_key=api_key.token)
LOG.info(
"Sending observer cruise response to webhook callback url",
observer_cruise_id=observer_cruise.observer_cruise_id,
webhook_callback_url=observer_cruise.webhook_callback_url,
payload=payload,
headers=headers,
)
try:
resp = await httpx.AsyncClient().post(
observer_cruise.webhook_callback_url, data=payload, headers=headers, timeout=httpx.Timeout(30.0)
)
if resp.status_code == 200:
LOG.info(
"Observer cruise webhook sent successfully",
observer_cruise_id=observer_cruise.observer_cruise_id,
resp_code=resp.status_code,
resp_text=resp.text,
)
else:
LOG.info(
"Observer cruise webhook failed",
observer_cruise_id=observer_cruise.observer_cruise_id,
resp=resp,
resp_code=resp.status_code,
resp_text=resp.text,
)
except Exception as e:
raise FailedToSendWebhook(observer_cruise_id=observer_cruise.observer_cruise_id) from e

View File

@@ -987,9 +987,10 @@ class WorkflowService:
headers=headers,
)
try:
resp = await httpx.AsyncClient().post(
url=workflow_run.webhook_callback_url, data=payload, headers=headers, timeout=httpx.Timeout(30.0)
)
async with httpx.AsyncClient() as client:
resp = await client.post(
url=workflow_run.webhook_callback_url, data=payload, headers=headers, timeout=httpx.Timeout(30.0)
)
if resp.status_code == 200:
LOG.info(
"Webhook sent successfully",

View File

@@ -2762,6 +2762,8 @@ async def poll_verification_code(
task_id: str,
organization_id: str,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> str | None:
@@ -2793,10 +2795,18 @@ async def poll_verification_code(
await asyncio.sleep(10)
async def _get_verification_code_from_url(task_id: str, url: str, api_key: str) -> str | None:
request_data = {
"task_id": task_id,
}
async def _get_verification_code_from_url(
task_id: str,
url: str,
api_key: str,
workflow_run_id: str | None = None,
workflow_permanent_id: str | None = None,
) -> str | None:
request_data = {"task_id": task_id}
if workflow_run_id:
request_data["workflow_run_id"] = workflow_run_id
if workflow_permanent_id:
request_data["workflow_permanent_id"] = workflow_permanent_id
payload = json.dumps(request_data)
signature = generate_skyvern_signature(
payload=payload,