fix skyvern agent local (#2050)

This commit is contained in:
Shuchang Zheng
2025-03-31 10:16:33 -07:00
committed by GitHub
parent 3bcd7db2bb
commit a75d5c947d
3 changed files with 11 additions and 11 deletions

View File

@@ -1,5 +1,4 @@
import asyncio
import os
import subprocess
from typing import Any, cast
@@ -32,7 +31,7 @@ class SkyvernAgent:
browser_path: str | None = None,
browser_type: str | None = None,
) -> None:
self.skyvern_client: SkyvernClient | None = None
self.client: SkyvernClient | None = None
if base_url is None and api_key is None:
# TODO: run at the root wherever the code is initiated
load_dotenv(".env")
@@ -61,9 +60,9 @@ class SkyvernAgent:
)
elif base_url is None and api_key is None:
if not browser_type:
if "BROWSER_TYPE" not in os.environ:
raise Exception("browser type is missing")
browser_type = os.environ["BROWSER_TYPE"]
# if "BROWSER_TYPE" not in os.environ:
# raise Exception("browser type is missing")
browser_type = "chromium-headful"
settings.BROWSER_TYPE = browser_type
elif base_url and api_key:
@@ -253,7 +252,7 @@ class SkyvernAgent:
async def run_task(
self,
prompt: str,
engine: RunEngine = RunEngine.skyvern_v1,
engine: RunEngine = RunEngine.skyvern_v2,
url: str | None = None,
webhook_url: str | None = None,
totp_identifier: str | None = None,

View File

@@ -122,7 +122,7 @@ class TaskRunRequest(BaseModel):
engine: RunEngine = Field(
default=RunEngine.skyvern_v2, description="The Skyvern engine version to use for this task"
)
proxy_location: ProxyLocation = Field(
proxy_location: ProxyLocation | None = Field(
default=ProxyLocation.RESIDENTIAL, description="Geographic Proxy location to route the browser traffic through"
)
data_extraction_schema: dict | list | str | None = Field(

View File

@@ -9,7 +9,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
if run.task_run_type == RunType.task_v1:
# fetch task v1 from db and transform to task run response
task_v1 = await app.DATABASE.get_task(run.task_v1_id, organization_id=organization_id)
task_v1 = await app.DATABASE.get_task(run.run_id, organization_id=organization_id)
if not task_v1:
return None
return TaskRunResponse(
@@ -34,7 +34,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
),
)
elif run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(run.task_v2_id, organization_id=organization_id)
task_v2 = await app.DATABASE.get_task_v2(run.run_id, organization_id=organization_id)
if not task_v2:
return None
return TaskRunResponse(
@@ -42,7 +42,8 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
run_type=run.task_run_type,
status=task_v2.status,
output=task_v2.output,
failure_reason=task_v2.failure_reason,
# TODO: add failure reason
# failure_reason=task_v2.failure_reason,
created_at=task_v2.created_at,
modified_at=task_v2.modified_at,
run_request=TaskRunRequest(
@@ -53,7 +54,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
totp_identifier=task_v2.totp_identifier,
totp_url=task_v2.totp_verification_url,
proxy_location=task_v2.proxy_location,
data_extraction_schema=task_v2.data_extraction_schema,
data_extraction_schema=task_v2.extracted_information_schema,
error_code_mapping=task_v2.error_code_mapping,
),
)