From 69c458bd7c1cc662fdf301e4fb091c75a50ec67e Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Tue, 2 Apr 2024 14:43:29 -0700 Subject: [PATCH] Implement get_latest_screenshots, add action_screenshots to TaskResponse (#148) --- skyvern/config.py | 1 + skyvern/forge/agent.py | 27 +++++++++++++++++++++++- skyvern/forge/sdk/db/client.py | 34 +++++++++++++++++++++++++++--- skyvern/forge/sdk/schemas/tasks.py | 8 ++++++- 4 files changed, 65 insertions(+), 5 deletions(-) diff --git a/skyvern/config.py b/skyvern/config.py index 4a8615ec..f9127730 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -23,6 +23,7 @@ class Settings(BaseSettings): DEBUG_MODE: bool = False DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern" PROMPT_ACTION_HISTORY_WINDOW: int = 5 + TASK_RESPONSE_ACTION_SCREENSHOT_COUNT: int = 3 ENV: str = "local" EXECUTE_ALL_STEPS: bool = True diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index c125b743..8dcb5cdc 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -787,6 +787,27 @@ class ForgeAgent(Agent): if recording_artifact: recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact) + # get the artifact of the last screenshot and get the screenshot_url + latest_action_screenshot_artifacts = await app.DATABASE.get_latest_n_artifacts( + task_id=task.task_id, + organization_id=task.organization_id, + artifact_types=[ArtifactType.SCREENSHOT_ACTION], + n=SettingsManager.get_settings().TASK_RESPONSE_ACTION_SCREENSHOT_COUNT, + ) + latest_action_screenshot_urls = [] + if latest_action_screenshot_artifacts: + for artifact in latest_action_screenshot_artifacts: + screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(artifact) + if screenshot_url: + latest_action_screenshot_urls.append(screenshot_url) + else: + LOG.error( + "Failed to get share link for action screenshot", + artifact_id=artifact.artifact_id, + ) + else: + LOG.error("Failed to get latest action screenshots") + # get the latest task from the db to get the latest status, extracted_information, and failure_reason task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id) if not task_from_db: @@ -798,7 +819,11 @@ class ForgeAgent(Agent): LOG.info("Task has no webhook callback url. Not sending task response") return - task_response = task.to_task_response(screenshot_url=screenshot_url, recording_url=recording_url) + task_response = task.to_task_response( + action_screenshot_urls=latest_action_screenshot_urls, + screenshot_url=screenshot_url, + recording_url=recording_url, + ) # send task_response to the webhook callback url # TODO: use async requests (httpx) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 645004f3..83efd152 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -612,6 +612,32 @@ class AgentDB: artifact_types: list[ArtifactType] | None = None, organization_id: str | None = None, ) -> Artifact | None: + try: + artifacts = await self.get_latest_n_artifacts( + task_id=task_id, + step_id=step_id, + artifact_types=artifact_types, + organization_id=organization_id, + n=1, + ) + if artifacts: + return artifacts[0] + return None + except SQLAlchemyError: + LOG.exception("SQLAlchemyError", exc_info=True) + raise + except Exception: + LOG.exception("UnexpectedError", exc_info=True) + raise + + async def get_latest_n_artifacts( + self, + task_id: str, + step_id: str | None = None, + artifact_types: list[ArtifactType] | None = None, + organization_id: str | None = None, + n: int = 1, + ) -> list[Artifact] | None: try: async with self.Session() as session: artifact_query = select(ArtifactModel).filter_by(task_id=task_id) @@ -622,9 +648,11 @@ class AgentDB: if artifact_types: artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types)) - artifact = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).first() - if artifact: - return convert_to_artifact(artifact, self.debug_enabled) + artifacts = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).fetchmany( + n + ) + if artifacts: + return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts] return None except SQLAlchemyError: LOG.exception("SQLAlchemyError", exc_info=True) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 962421d8..f468f1c1 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -166,7 +166,11 @@ class Task(TaskRequest): raise ValueError(f"cant_override_failure_reason({self.task_id})") def to_task_response( - self, screenshot_url: str | None = None, recording_url: str | None = None, failure_reason: str | None = None + self, + action_screenshot_urls: list[str] | None = None, + screenshot_url: str | None = None, + recording_url: str | None = None, + failure_reason: str | None = None, ) -> TaskResponse: return TaskResponse( request=self, @@ -176,6 +180,7 @@ class Task(TaskRequest): modified_at=self.modified_at, extracted_information=self.extracted_information, failure_reason=failure_reason or self.failure_reason, + action_screenshot_urls=action_screenshot_urls, screenshot_url=screenshot_url, recording_url=recording_url, errors=self.errors, @@ -189,6 +194,7 @@ class TaskResponse(BaseModel): created_at: datetime modified_at: datetime extracted_information: list | dict[str, Any] | str | None = None + action_screenshot_urls: list[str] | None = None screenshot_url: str | None = None recording_url: str | None = None failure_reason: str | None = None