shu/fix task url type (#999)

This commit is contained in:
Shuchang Zheng
2024-10-17 23:47:59 -07:00
committed by GitHub
parent dad53e1f6a
commit f69016088b
4 changed files with 43 additions and 7 deletions

View File

@@ -495,3 +495,10 @@ class IllegitComplete(SkyvernException):
class CachedActionPlanError(SkyvernException): class CachedActionPlanError(SkyvernException):
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
super().__init__(message) super().__init__(message)
class InvalidUrl(SkyvernHTTPException):
def __init__(self, url: str) -> None:
super().__init__(
f"Invalid URL: {url}. Skyvern supports HTTP and HTTPS urls.", status_code=status.HTTP_400_BAD_REQUEST
)

View File

@@ -36,6 +36,7 @@ from skyvern.forge.sdk.api.files import get_path_for_workflow_download_directory
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.validators import validate_url
from skyvern.forge.sdk.models import Organization, Step, StepStatus from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
@@ -126,6 +127,7 @@ class ForgeAgent:
task_url = working_page.url task_url = working_page.url
task_url = validate_url(task_url)
task = await app.DATABASE.create_task( task = await app.DATABASE.create_task(
url=task_url, url=task_url,
title=task_block.title, title=task_block.title,
@@ -183,10 +185,10 @@ class ForgeAgent:
async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task: async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task:
task = await app.DATABASE.create_task( task = await app.DATABASE.create_task(
url=task_request.url, url=str(task_request.url),
title=task_request.title, title=task_request.title,
webhook_callback_url=task_request.webhook_callback_url, webhook_callback_url=str(task_request.webhook_callback_url),
totp_verification_url=task_request.totp_verification_url, totp_verification_url=str(task_request.totp_verification_url),
totp_identifier=task_request.totp_identifier, totp_identifier=task_request.totp_identifier,
navigation_goal=task_request.navigation_goal, navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal, data_extraction_goal=task_request.data_extraction_goal,

View File

@@ -0,0 +1,13 @@
from pydantic import HttpUrl, ValidationError, parse_obj_as
from skyvern.exceptions import InvalidUrl
def validate_url(url: str) -> str:
try:
# Use parse_obj_as to validate the string as an HttpUrl
parse_obj_as(HttpUrl, url)
return url
except ValidationError:
# Handle the validation error
raise InvalidUrl(url=url)

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled
@@ -22,7 +22,7 @@ class ProxyLocation(StrEnum):
NONE = "NONE" NONE = "NONE"
class TaskRequest(BaseModel): class TaskBase(BaseModel):
title: str | None = Field( title: str | None = Field(
default=None, default=None,
description="The title of the task.", description="The title of the task.",
@@ -76,6 +76,20 @@ class TaskRequest(BaseModel):
) )
class TaskRequest(TaskBase):
url: HttpUrl = Field(
...,
description="Starting URL for the task.",
examples=["https://www.geico.com"],
)
webhook_callback_url: HttpUrl | None = Field(
default=None,
description="The URL to call when the task is completed.",
examples=["https://my-webhook.com"],
)
totp_verification_url: HttpUrl | None = None
class TaskStatus(StrEnum): class TaskStatus(StrEnum):
created = "created" created = "created"
queued = "queued" queued = "queued"
@@ -144,7 +158,7 @@ class TaskStatus(StrEnum):
return self in status_requires_failure_reason return self in status_requires_failure_reason
class Task(TaskRequest): class Task(TaskBase):
created_at: datetime = Field( created_at: datetime = Field(
..., ...,
description="The creation datetime of the task.", description="The creation datetime of the task.",
@@ -229,7 +243,7 @@ class Task(TaskRequest):
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
request: TaskRequest request: TaskBase
task_id: str task_id: str
status: TaskStatus status: TaskStatus
created_at: datetime created_at: datetime