add fern sdk (#1786)
This commit is contained in:
3
skyvern/agent/__init__.py
Normal file
3
skyvern/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from skyvern.agent.local import Agent
|
||||
|
||||
__all__ = ["Agent"]
|
||||
@@ -1,13 +1,14 @@
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from skyvern.agent.parameter import TaskV1Request, TaskV2Request
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import security, skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskStatus
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTask, ObserverTaskRequest, ObserverTaskStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.services import observer_service
|
||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
@@ -40,32 +41,20 @@ class Agent:
|
||||
)
|
||||
return organization
|
||||
|
||||
async def run_task_v1(self, task_request: TaskV1Request) -> TaskResponse:
|
||||
organization = await self._get_organization()
|
||||
|
||||
async def _run_task(self, organization: Organization, task: Task) -> None:
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
)
|
||||
|
||||
created_task = await app.agent.create_task(task_request, organization.organization_id)
|
||||
|
||||
skyvern_context.set(
|
||||
SkyvernContext(
|
||||
organization_id=organization.organization_id,
|
||||
task_id=created_task.task_id,
|
||||
max_steps_override=task_request.max_steps,
|
||||
)
|
||||
)
|
||||
|
||||
step = await app.DATABASE.create_step(
|
||||
created_task.task_id,
|
||||
task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
updated_task = await app.DATABASE.update_task(
|
||||
created_task.task_id,
|
||||
task.task_id,
|
||||
status=TaskStatus.running,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
@@ -77,18 +66,65 @@ class Agent:
|
||||
api_key=org_auth_token.token if org_auth_token else None,
|
||||
)
|
||||
|
||||
refreshed_task = await app.DATABASE.get_task(created_task.task_id, organization.organization_id)
|
||||
if refreshed_task:
|
||||
updated_task = refreshed_task
|
||||
async def _run_observer_task(self, organization: Organization, observer_task: ObserverTask) -> None:
|
||||
# mark observer cruise as queued
|
||||
await app.DATABASE.update_observer_cruise(
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
status=ObserverTaskStatus.queued,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
assert observer_task.workflow_run_id
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=observer_task.workflow_run_id,
|
||||
status=WorkflowRunStatus.queued,
|
||||
)
|
||||
|
||||
await observer_service.run_observer_task(
|
||||
organization=organization,
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
)
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
task_request: TaskRequest,
|
||||
) -> CreateTaskResponse:
|
||||
organization = await self._get_organization()
|
||||
|
||||
created_task = await app.agent.create_task(task_request, organization.organization_id)
|
||||
skyvern_context.set(
|
||||
SkyvernContext(
|
||||
organization_id=organization.organization_id,
|
||||
task_id=created_task.task_id,
|
||||
max_steps_override=created_task.max_steps_per_run,
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.create_task(self._run_task(organization, created_task))
|
||||
return CreateTaskResponse(task_id=created_task.task_id)
|
||||
|
||||
async def get_task(
|
||||
self,
|
||||
task_id: str,
|
||||
) -> TaskResponse | None:
|
||||
organization = await self._get_organization()
|
||||
task = await app.DATABASE.get_task(task_id, organization.organization_id)
|
||||
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization.organization_id)
|
||||
if not latest_step:
|
||||
return await app.agent.build_task_response(task=task)
|
||||
|
||||
failure_reason: str | None = None
|
||||
if updated_task.status == TaskStatus.failed and (step.output or updated_task.failure_reason):
|
||||
if task.status == TaskStatus.failed and (task.failure_reason):
|
||||
failure_reason = ""
|
||||
if updated_task.failure_reason:
|
||||
failure_reason += updated_task.failure_reason or ""
|
||||
if step.output is not None and step.output.actions_and_results is not None:
|
||||
if task.failure_reason:
|
||||
failure_reason += task.failure_reason or ""
|
||||
if latest_step.output is not None and latest_step.output.actions_and_results is not None:
|
||||
action_results_string: list[str] = []
|
||||
for action, results in step.output.actions_and_results:
|
||||
for action, results in latest_step.output.actions_and_results:
|
||||
if len(results) == 0:
|
||||
continue
|
||||
if results[-1].success:
|
||||
@@ -97,11 +133,27 @@ class Agent:
|
||||
|
||||
if len(action_results_string) > 0:
|
||||
failure_reason += "(Exceptions: " + str(action_results_string) + ")"
|
||||
|
||||
return await app.agent.build_task_response(
|
||||
task=updated_task, last_step=step, failure_reason=failure_reason, need_browser_log=True
|
||||
task=task, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True
|
||||
)
|
||||
|
||||
async def run_task_v2(self, task_request: TaskV2Request) -> ObserverTask:
|
||||
async def run_task(
|
||||
self,
|
||||
task_request: TaskRequest,
|
||||
timeout_seconds: int = 600,
|
||||
) -> TaskResponse:
|
||||
created_task = await self.create_task(task_request)
|
||||
|
||||
while True:
|
||||
async with asyncio.timeout(timeout_seconds):
|
||||
task_response = await self.get_task(created_task.task_id)
|
||||
assert task_response is not None
|
||||
if task_response.status.is_final():
|
||||
return task_response
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def observer_task_v_2(self, task_request: ObserverTaskRequest) -> ObserverTask:
|
||||
organization = await self._get_organization()
|
||||
|
||||
observer_task = await observer_service.initialize_observer_task(
|
||||
@@ -118,27 +170,22 @@ class Agent:
|
||||
if not observer_task.workflow_run_id:
|
||||
raise Exception("Observer cruise missing workflow run id")
|
||||
|
||||
# mark observer cruise as queued
|
||||
await app.DATABASE.update_observer_cruise(
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
status=ObserverTaskStatus.queued,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=observer_task.workflow_run_id,
|
||||
status=WorkflowRunStatus.queued,
|
||||
)
|
||||
|
||||
await observer_service.run_observer_task(
|
||||
organization=organization,
|
||||
observer_cruise_id=observer_task.observer_cruise_id,
|
||||
max_iterations_override=task_request.max_iterations,
|
||||
)
|
||||
|
||||
refreshed_observer_task = await app.DATABASE.get_observer_cruise(
|
||||
observer_cruise_id=observer_task.observer_cruise_id, organization_id=organization.organization_id
|
||||
)
|
||||
if refreshed_observer_task:
|
||||
return refreshed_observer_task
|
||||
|
||||
asyncio.create_task(self._run_observer_task(organization, observer_task))
|
||||
return observer_task
|
||||
|
||||
async def get_observer_task_v_2(self, task_id: str) -> ObserverTask | None:
|
||||
organization = await self._get_organization()
|
||||
return await app.DATABASE.get_observer_cruise(task_id, organization.organization_id)
|
||||
|
||||
async def run_observer_task_v_2(
|
||||
self, task_request: ObserverTaskRequest, timeout_seconds: int = 600
|
||||
) -> ObserverTask:
|
||||
observer_task = await self.observer_task_v_2(task_request)
|
||||
|
||||
while True:
|
||||
async with asyncio.timeout(timeout_seconds):
|
||||
refreshed_observer_task = await self.get_observer_task_v_2(observer_task.observer_cruise_id)
|
||||
assert refreshed_observer_task is not None
|
||||
if refreshed_observer_task.status.is_final():
|
||||
return refreshed_observer_task
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTaskRequest
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskRequest
|
||||
|
||||
|
||||
class TaskV1Request(TaskRequest):
|
||||
max_steps: int = 10
|
||||
|
||||
|
||||
class TaskV2Request(ObserverTaskRequest):
|
||||
max_iterations: int = 10
|
||||
|
||||
|
||||
class RunTaskV1Schema(BaseModel):
|
||||
api_key: str = Field(
|
||||
description="The API key of the Skyvern API. You can get the API key from the Skyvern dashboard.",
|
||||
)
|
||||
endpoint: str = Field(
|
||||
description="The endpoint of the Skyvern API. Don't add any path to the endpoint. Default is https://api.skyvern.com",
|
||||
default="https://api.skyvern.com",
|
||||
)
|
||||
task: TaskV1Request
|
||||
|
||||
|
||||
class RunTaskV2Schema(BaseModel):
|
||||
api_key: str = Field(
|
||||
description="The API key of the Skyvern API. You can get the API key from the Skyvern dashboard."
|
||||
)
|
||||
endpoint: str = Field(
|
||||
description="The endpoint of the Skyvern API. Don't add any path to the endpoint. Default is https://api.skyvern.com",
|
||||
default="https://api.skyvern.com",
|
||||
)
|
||||
task: TaskV2Request
|
||||
|
||||
|
||||
class GetTaskSchema(BaseModel):
|
||||
api_key: str = Field(
|
||||
description="The API key of the Skyvern API. You can get the API key from the Skyvern dashboard."
|
||||
)
|
||||
endpoint: str = Field(
|
||||
description="The endpoint of the Skyvern API. Don't add any path to the endpoint. Default is https://api.skyvern.com",
|
||||
default="https://api.skyvern.com",
|
||||
)
|
||||
task_id: str
|
||||
@@ -1,43 +0,0 @@
|
||||
import httpx
|
||||
|
||||
from skyvern.agent.parameter import TaskV1Request, TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverTask
|
||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, TaskResponse
|
||||
|
||||
|
||||
class RemoteAgent:
|
||||
def __init__(self, api_key: str, endpoint: str = "https://api.skyvern.com"):
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.client = httpx.AsyncClient(
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
)
|
||||
|
||||
async def run_task_v1(self, task: TaskV1Request) -> CreateTaskResponse:
|
||||
url = f"{self.endpoint}/api/v1/tasks"
|
||||
payload = task.model_dump_json()
|
||||
headers = {"x_max_steps_override": str(task.max_steps)}
|
||||
response = await self.client.post(url, headers=headers, data=payload)
|
||||
return CreateTaskResponse.model_validate(response.json())
|
||||
|
||||
async def run_task_v2(self, task: TaskV2Request) -> ObserverTask:
|
||||
url = f"{self.endpoint}/api/v2/tasks"
|
||||
payload = task.model_dump_json()
|
||||
headers = {"x_max_iterations_override": str(task.max_iterations)}
|
||||
response = await self.client.post(url, headers=headers, data=payload)
|
||||
return ObserverTask.model_validate(response.json())
|
||||
|
||||
async def get_task_v1(self, task_id: str) -> TaskResponse:
|
||||
"""Get a task by id."""
|
||||
url = f"{self.endpoint}/api/v1/tasks/{task_id}"
|
||||
response = await self.client.get(url)
|
||||
return TaskResponse.model_validate(response.json())
|
||||
|
||||
async def get_task_v2(self, task_id: str) -> ObserverTask:
|
||||
"""Get a task by id."""
|
||||
url = f"{self.endpoint}/api/v2/tasks/{task_id}"
|
||||
response = await self.client.get(url)
|
||||
return ObserverTask.model_validate(response.json())
|
||||
Reference in New Issue
Block a user