diff --git a/skyvern/forge/agent_functions.py b/skyvern/forge/agent_functions.py index d3776606..152bd8d4 100644 --- a/skyvern/forge/agent_functions.py +++ b/skyvern/forge/agent_functions.py @@ -1,22 +1,106 @@ -from typing import Awaitable, Callable +import copy +import hashlib +from typing import Awaitable, Callable, Dict, List +import structlog from playwright.async_api import Page +from skyvern.constants import SKYVERN_ID_ATTR from skyvern.exceptions import StepUnableToExecuteError from skyvern.forge import app from skyvern.forge.async_operations import AsyncOperation +from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.models import Organization, Step, StepStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.webeye.browser_factory import BrowserState +from skyvern.webeye.scraper.scraper import ELEMENT_NODE_ATTRIBUTES, json_to_html CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]] +LOG = structlog.get_logger() + def _remove_rect(element: dict) -> None: if "rect" in element: del element["rect"] +def _get_svg_cache_key(hash: str) -> str: + return f"skyvern:svg:{hash}" + + +def _remove_skyvern_attributes(element: Dict) -> Dict: + """ + To get the original HTML element without skyvern attributes + """ + element_copied = copy.deepcopy(element) + for attr in ELEMENT_NODE_ATTRIBUTES: + if element_copied.get(attr): + del element_copied[attr] + + if element_copied.get("attributes") and SKYVERN_ID_ATTR in element_copied.get("attributes", {}): + del element_copied["attributes"][SKYVERN_ID_ATTR] + + children: List[Dict] | None = element_copied.get("children", None) + if children is None: + return element_copied + + trimmed_children = [] + for child in children: + trimmed_children.append(_remove_skyvern_attributes(child)) + + element_copied["children"] = trimmed_children + return element_copied + + +async def _convert_svg_to_string(task: Task, step: Step, organization: Organization | None, element: Dict) -> None: + if element.get("tagName") != "svg": + return + + element_id = element.get("id", "") + svg_element = _remove_skyvern_attributes(element) + svg_html = json_to_html(svg_element) + hash_object = hashlib.sha256() + hash_object.update(svg_html.encode("utf-8")) + svg_hash = hash_object.hexdigest() + svg_key = _get_svg_cache_key(svg_hash) + + svg_shape: str | None = None + try: + svg_shape = await app.CACHE.get(svg_key) + except Exception: + LOG.warning( + "Failed to loaded SVG cache", + exc_info=True, + key=svg_key, + ) + + if svg_shape: + LOG.debug("SVG loaded from cache", element_id=element_id, shape=svg_shape) + else: + LOG.debug("call LLM to convert SVG to string shape", element_id=element_id) + svg_convert_prompt = prompt_engine.load_prompt("svg-convert", svg_element=svg_html) + try: + json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step) + svg_shape = json_response.get("shape", "") + if not svg_shape: + raise Exception("Empty SVG shape replied by secondary llm") + LOG.info("SVG converted by LLM", element_id=element_id, shape=svg_shape) + await app.CACHE.set(svg_key, svg_shape) + except Exception: + LOG.exception( + "Failed to convert SVG to string shape by secondary llm", + element=element, + svg_html=svg_html, + ) + return + + element["attributes"] = dict() + element["attributes"]["alt"] = svg_shape + del element["children"] + return + + class AgentFunction: async def validate_step_execution( self, @@ -87,6 +171,7 @@ class AgentFunction: while queue: queue_ele = queue.pop(0) _remove_rect(queue_ele) + await _convert_svg_to_string(task, step, organization, queue_ele) # TODO: we can come back to test removing the unique_id # from element attributes to make sure this won't increase hallucination # _remove_unique_id(queue_ele) diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py index b1480156..f6d17f97 100644 --- a/skyvern/forge/app.py +++ b/skyvern/forge/app.py @@ -8,6 +8,7 @@ from skyvern.forge.agent_functions import AgentFunction from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.artifact.manager import ArtifactManager from skyvern.forge.sdk.artifact.storage.factory import StorageFactory +from skyvern.forge.sdk.cache.factory import CacheFactory from skyvern.forge.sdk.db.client import AgentDB from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider from skyvern.forge.sdk.models import Organization @@ -22,6 +23,7 @@ DATABASE = AgentDB( debug_enabled=SettingsManager.get_settings().DEBUG_MODE, ) STORAGE = StorageFactory.get_storage() +CACHE = CacheFactory.get_cache() ARTIFACT_MANAGER = ArtifactManager() BROWSER_MANAGER = BrowserManager() EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider() diff --git a/skyvern/forge/prompts/skyvern/svg-convert.j2 b/skyvern/forge/prompts/skyvern/svg-convert.j2 new file mode 100644 index 00000000..5d32271c --- /dev/null +++ b/skyvern/forge/prompts/skyvern/svg-convert.j2 @@ -0,0 +1,12 @@ +You are given a svg element. You need to figure out what its shape means. +SVG Element: +``` +{{svg_element}} +``` + +MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc. +Reply in JSON format with the following keys: +{ + "confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence + "shape": string, // A short description of the shape of SVG and its meaning +} \ No newline at end of file diff --git a/skyvern/forge/sdk/cache/base.py b/skyvern/forge/sdk/cache/base.py new file mode 100644 index 00000000..396146c4 --- /dev/null +++ b/skyvern/forge/sdk/cache/base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from datetime import timedelta +from typing import Any + +CACHE_EXPIRE_TIME = timedelta(weeks=1) +MAX_CACHE_ITEM = 1000 + + +class BaseCache(ABC): + @abstractmethod + async def set(self, key: str, value: Any) -> None: + pass + + @abstractmethod + async def get(self, key: str) -> Any: + pass diff --git a/skyvern/forge/sdk/cache/factory.py b/skyvern/forge/sdk/cache/factory.py new file mode 100644 index 00000000..88a513e2 --- /dev/null +++ b/skyvern/forge/sdk/cache/factory.py @@ -0,0 +1,14 @@ +from skyvern.forge.sdk.cache.base import BaseCache +from skyvern.forge.sdk.cache.local import LocalCache + + +class CacheFactory: + __cache: BaseCache = LocalCache() + + @staticmethod + def set_cache(cache: BaseCache) -> None: + CacheFactory.__cache = cache + + @staticmethod + def get_cache() -> BaseCache: + return CacheFactory.__cache diff --git a/skyvern/forge/sdk/cache/local.py b/skyvern/forge/sdk/cache/local.py new file mode 100644 index 00000000..f95e5039 --- /dev/null +++ b/skyvern/forge/sdk/cache/local.py @@ -0,0 +1,20 @@ +from typing import Any + +from cachetools import TTLCache + +from skyvern.forge.sdk.cache.base import CACHE_EXPIRE_TIME, MAX_CACHE_ITEM, BaseCache + + +class LocalCache(BaseCache): + def __init__(self) -> None: + self.cache: TTLCache = TTLCache(maxsize=MAX_CACHE_ITEM, ttl=CACHE_EXPIRE_TIME.total_seconds()) + + async def get(self, key: str) -> Any: + if key not in self.cache: + return None + value = self.cache[key] + await self.set(key, value) + return value + + async def set(self, key: str, value: Any) -> None: + self.cache[key] = value