add fern sdk (#1786)

This commit is contained in:
LawyZheng
2025-02-19 00:58:48 +08:00
committed by GitHub
parent e7c3e4e37a
commit a258406a86
153 changed files with 17372 additions and 255 deletions

View File

@@ -0,0 +1,3 @@
from skyvern.agent.local import Agent
__all__ = ["Agent"]

View File

@@ -1,13 +1,14 @@
import asyncio
from dotenv import load_dotenv
from skyvern.agent.parameter import TaskV1Request, TaskV2Request
from skyvern.forge import app
from skyvern.forge.sdk.core import security, skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskStatus
from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskRequest, ObserverTaskStatus
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import TaskResponse, TaskStatus
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services import observer_service
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
@@ -40,32 +41,20 @@ class Agent:
)
return organization
async def run_task_v1(self, task_request: TaskV1Request) -> TaskResponse:
organization = await self._get_organization()
async def _run_task(self, organization: Organization, task: Task) -> None:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
)
created_task = await app.agent.create_task(task_request, organization.organization_id)
skyvern_context.set(
SkyvernContext(
organization_id=organization.organization_id,
task_id=created_task.task_id,
max_steps_override=task_request.max_steps,
)
)
step = await app.DATABASE.create_step(
created_task.task_id,
task.task_id,
order=0,
retry_index=0,
organization_id=organization.organization_id,
)
updated_task = await app.DATABASE.update_task(
created_task.task_id,
task.task_id,
status=TaskStatus.running,
organization_id=organization.organization_id,
)
@@ -77,18 +66,65 @@ class Agent:
api_key=org_auth_token.token if org_auth_token else None,
)
refreshed_task = await app.DATABASE.get_task(created_task.task_id, organization.organization_id)
if refreshed_task:
updated_task = refreshed_task
async def _run_observer_task(self, organization: Organization, observer_task: ObserverTask) -> None:
# mark observer cruise as queued
await app.DATABASE.update_observer_cruise(
observer_cruise_id=observer_task.observer_cruise_id,
status=ObserverTaskStatus.queued,
organization_id=organization.organization_id,
)
assert observer_task.workflow_run_id
await app.DATABASE.update_workflow_run(
workflow_run_id=observer_task.workflow_run_id,
status=WorkflowRunStatus.queued,
)
await observer_service.run_observer_task(
organization=organization,
observer_cruise_id=observer_task.observer_cruise_id,
)
async def create_task(
self,
task_request: TaskRequest,
) -> CreateTaskResponse:
organization = await self._get_organization()
created_task = await app.agent.create_task(task_request, organization.organization_id)
skyvern_context.set(
SkyvernContext(
organization_id=organization.organization_id,
task_id=created_task.task_id,
max_steps_override=created_task.max_steps_per_run,
)
)
asyncio.create_task(self._run_task(organization, created_task))
return CreateTaskResponse(task_id=created_task.task_id)
async def get_task(
self,
task_id: str,
) -> TaskResponse | None:
organization = await self._get_organization()
task = await app.DATABASE.get_task(task_id, organization.organization_id)
if task is None:
return None
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization.organization_id)
if not latest_step:
return await app.agent.build_task_response(task=task)
failure_reason: str | None = None
if updated_task.status == TaskStatus.failed and (step.output or updated_task.failure_reason):
if task.status == TaskStatus.failed and (task.failure_reason):
failure_reason = ""
if updated_task.failure_reason:
failure_reason += updated_task.failure_reason or ""
if step.output is not None and step.output.actions_and_results is not None:
if task.failure_reason:
failure_reason += task.failure_reason or ""
if latest_step.output is not None and latest_step.output.actions_and_results is not None:
action_results_string: list[str] = []
for action, results in step.output.actions_and_results:
for action, results in latest_step.output.actions_and_results:
if len(results) == 0:
continue
if results[-1].success:
@@ -97,11 +133,27 @@ class Agent:
if len(action_results_string) > 0:
failure_reason += "(Exceptions: " + str(action_results_string) + ")"
return await app.agent.build_task_response(
task=updated_task, last_step=step, failure_reason=failure_reason, need_browser_log=True
task=task, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True
)
async def run_task_v2(self, task_request: TaskV2Request) -> ObserverTask:
async def run_task(
self,
task_request: TaskRequest,
timeout_seconds: int = 600,
) -> TaskResponse:
created_task = await self.create_task(task_request)
while True:
async with asyncio.timeout(timeout_seconds):
task_response = await self.get_task(created_task.task_id)
assert task_response is not None
if task_response.status.is_final():
return task_response
await asyncio.sleep(1)
async def observer_task_v_2(self, task_request: ObserverTaskRequest) -> ObserverTask:
organization = await self._get_organization()
observer_task = await observer_service.initialize_observer_task(
@@ -118,27 +170,22 @@ class Agent:
if not observer_task.workflow_run_id:
raise Exception("Observer cruise missing workflow run id")
# mark observer cruise as queued
await app.DATABASE.update_observer_cruise(
observer_cruise_id=observer_task.observer_cruise_id,
status=ObserverTaskStatus.queued,
organization_id=organization.organization_id,
)
await app.DATABASE.update_workflow_run(
workflow_run_id=observer_task.workflow_run_id,
status=WorkflowRunStatus.queued,
)
await observer_service.run_observer_task(
organization=organization,
observer_cruise_id=observer_task.observer_cruise_id,
max_iterations_override=task_request.max_iterations,
)
refreshed_observer_task = await app.DATABASE.get_observer_cruise(
observer_cruise_id=observer_task.observer_cruise_id, organization_id=organization.organization_id
)
if refreshed_observer_task:
return refreshed_observer_task
asyncio.create_task(self._run_observer_task(organization, observer_task))
return observer_task
async def get_observer_task_v_2(self, task_id: str) -> ObserverTask | None:
organization = await self._get_organization()
return await app.DATABASE.get_observer_cruise(task_id, organization.organization_id)
async def run_observer_task_v_2(
self, task_request: ObserverTaskRequest, timeout_seconds: int = 600
) -> ObserverTask:
observer_task = await self.observer_task_v_2(task_request)
while True:
async with asyncio.timeout(timeout_seconds):
refreshed_observer_task = await self.get_observer_task_v_2(observer_task.observer_cruise_id)
assert refreshed_observer_task is not None
if refreshed_observer_task.status.is_final():
return refreshed_observer_task
await asyncio.sleep(1)

View File

@@ -1,45 +0,0 @@
from pydantic import BaseModel, Field
from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest
from skyvern.forge.sdk.schemas.tasks import TaskRequest
class TaskV1Request(TaskRequest):
max_steps: int = 10
class TaskV2Request(ObserverTaskRequest):
max_iterations: int = 10
class RunTaskV1Schema(BaseModel):
api_key: str = Field(
description="The API key of the Skyvern API. You can get the API key from the Skyvern dashboard.",
)
endpoint: str = Field(
description="The endpoint of the Skyvern API. Don't add any path to the endpoint. Default is https://api.skyvern.com",
default="https://api.skyvern.com",
)
task: TaskV1Request
class RunTaskV2Schema(BaseModel):
api_key: str = Field(
description="The API key of the Skyvern API. You can get the API key from the Skyvern dashboard."
)
endpoint: str = Field(
description="The endpoint of the Skyvern API. Don't add any path to the endpoint. Default is https://api.skyvern.com",
default="https://api.skyvern.com",
)
task: TaskV2Request
class GetTaskSchema(BaseModel):
api_key: str = Field(
description="The API key of the Skyvern API. You can get the API key from the Skyvern dashboard."
)
endpoint: str = Field(
description="The endpoint of the Skyvern API. Don't add any path to the endpoint. Default is https://api.skyvern.com",
default="https://api.skyvern.com",
)
task_id: str

View File

@@ -1,43 +0,0 @@
import httpx
from skyvern.agent.parameter import TaskV1Request, TaskV2Request
from skyvern.forge.sdk.schemas.observers import ObserverTask
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskResponse
class RemoteAgent:
def __init__(self, api_key: str, endpoint: str = "https://api.skyvern.com"):
self.endpoint = endpoint
self.api_key = api_key
self.client = httpx.AsyncClient(
headers={
"Content-Type": "application/json",
"x-api-key": self.api_key,
}
)
async def run_task_v1(self, task: TaskV1Request) -> CreateTaskResponse:
url = f"{self.endpoint}/api/v1/tasks"
payload = task.model_dump_json()
headers = {"x_max_steps_override": str(task.max_steps)}
response = await self.client.post(url, headers=headers, data=payload)
return CreateTaskResponse.model_validate(response.json())
async def run_task_v2(self, task: TaskV2Request) -> ObserverTask:
url = f"{self.endpoint}/api/v2/tasks"
payload = task.model_dump_json()
headers = {"x_max_iterations_override": str(task.max_iterations)}
response = await self.client.post(url, headers=headers, data=payload)
return ObserverTask.model_validate(response.json())
async def get_task_v1(self, task_id: str) -> TaskResponse:
"""Get a task by id."""
url = f"{self.endpoint}/api/v1/tasks/{task_id}"
response = await self.client.get(url)
return TaskResponse.model_validate(response.json())
async def get_task_v2(self, task_id: str) -> ObserverTask:
"""Get a task by id."""
url = f"{self.endpoint}/api/v2/tasks/{task_id}"
response = await self.client.get(url)
return ObserverTask.model_validate(response.json())