svg conversion (#717)

This commit is contained in:
LawyZheng
2024-08-23 11:17:01 +08:00
committed by GitHub
parent e5b0d734b8
commit 76ee91ecdd
6 changed files with 150 additions and 1 deletions

View File

@@ -1,22 +1,106 @@
from typing import Awaitable, Callable import copy
import hashlib
from typing import Awaitable, Callable, Dict, List
import structlog
from playwright.async_api import Page from playwright.async_api import Page
from skyvern.constants import SKYVERN_ID_ATTR
from skyvern.exceptions import StepUnableToExecuteError from skyvern.exceptions import StepUnableToExecuteError
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.async_operations import AsyncOperation from skyvern.forge.async_operations import AsyncOperation
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.models import Organization, Step, StepStatus from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ELEMENT_NODE_ATTRIBUTES, json_to_html
CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]] CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]]
LOG = structlog.get_logger()
def _remove_rect(element: dict) -> None: def _remove_rect(element: dict) -> None:
if "rect" in element: if "rect" in element:
del element["rect"] del element["rect"]
def _get_svg_cache_key(hash: str) -> str:
return f"skyvern:svg:{hash}"
def _remove_skyvern_attributes(element: Dict) -> Dict:
"""
To get the original HTML element without skyvern attributes
"""
element_copied = copy.deepcopy(element)
for attr in ELEMENT_NODE_ATTRIBUTES:
if element_copied.get(attr):
del element_copied[attr]
if element_copied.get("attributes") and SKYVERN_ID_ATTR in element_copied.get("attributes", {}):
del element_copied["attributes"][SKYVERN_ID_ATTR]
children: List[Dict] | None = element_copied.get("children", None)
if children is None:
return element_copied
trimmed_children = []
for child in children:
trimmed_children.append(_remove_skyvern_attributes(child))
element_copied["children"] = trimmed_children
return element_copied
async def _convert_svg_to_string(task: Task, step: Step, organization: Organization | None, element: Dict) -> None:
if element.get("tagName") != "svg":
return
element_id = element.get("id", "")
svg_element = _remove_skyvern_attributes(element)
svg_html = json_to_html(svg_element)
hash_object = hashlib.sha256()
hash_object.update(svg_html.encode("utf-8"))
svg_hash = hash_object.hexdigest()
svg_key = _get_svg_cache_key(svg_hash)
svg_shape: str | None = None
try:
svg_shape = await app.CACHE.get(svg_key)
except Exception:
LOG.warning(
"Failed to loaded SVG cache",
exc_info=True,
key=svg_key,
)
if svg_shape:
LOG.debug("SVG loaded from cache", element_id=element_id, shape=svg_shape)
else:
LOG.debug("call LLM to convert SVG to string shape", element_id=element_id)
svg_convert_prompt = prompt_engine.load_prompt("svg-convert", svg_element=svg_html)
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step)
svg_shape = json_response.get("shape", "")
if not svg_shape:
raise Exception("Empty SVG shape replied by secondary llm")
LOG.info("SVG converted by LLM", element_id=element_id, shape=svg_shape)
await app.CACHE.set(svg_key, svg_shape)
except Exception:
LOG.exception(
"Failed to convert SVG to string shape by secondary llm",
element=element,
svg_html=svg_html,
)
return
element["attributes"] = dict()
element["attributes"]["alt"] = svg_shape
del element["children"]
return
class AgentFunction: class AgentFunction:
async def validate_step_execution( async def validate_step_execution(
self, self,
@@ -87,6 +171,7 @@ class AgentFunction:
while queue: while queue:
queue_ele = queue.pop(0) queue_ele = queue.pop(0)
_remove_rect(queue_ele) _remove_rect(queue_ele)
await _convert_svg_to_string(task, step, organization, queue_ele)
# TODO: we can come back to test removing the unique_id # TODO: we can come back to test removing the unique_id
# from element attributes to make sure this won't increase hallucination # from element attributes to make sure this won't increase hallucination
# _remove_unique_id(queue_ele) # _remove_unique_id(queue_ele)

View File

@@ -8,6 +8,7 @@ from skyvern.forge.agent_functions import AgentFunction
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.manager import ArtifactManager from skyvern.forge.sdk.artifact.manager import ArtifactManager
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
from skyvern.forge.sdk.cache.factory import CacheFactory
from skyvern.forge.sdk.db.client import AgentDB from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
from skyvern.forge.sdk.models import Organization from skyvern.forge.sdk.models import Organization
@@ -22,6 +23,7 @@ DATABASE = AgentDB(
debug_enabled=SettingsManager.get_settings().DEBUG_MODE, debug_enabled=SettingsManager.get_settings().DEBUG_MODE,
) )
STORAGE = StorageFactory.get_storage() STORAGE = StorageFactory.get_storage()
CACHE = CacheFactory.get_cache()
ARTIFACT_MANAGER = ArtifactManager() ARTIFACT_MANAGER = ArtifactManager()
BROWSER_MANAGER = BrowserManager() BROWSER_MANAGER = BrowserManager()
EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider() EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider()

View File

@@ -0,0 +1,12 @@
You are given a svg element. You need to figure out what its shape means.
SVG Element:
```
{{svg_element}}
```
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
Reply in JSON format with the following keys:
{
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"shape": string, // A short description of the shape of SVG and its meaning
}

16
skyvern/forge/sdk/cache/base.py vendored Normal file
View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any
CACHE_EXPIRE_TIME = timedelta(weeks=1)
MAX_CACHE_ITEM = 1000
class BaseCache(ABC):
@abstractmethod
async def set(self, key: str, value: Any) -> None:
pass
@abstractmethod
async def get(self, key: str) -> Any:
pass

14
skyvern/forge/sdk/cache/factory.py vendored Normal file
View File

@@ -0,0 +1,14 @@
from skyvern.forge.sdk.cache.base import BaseCache
from skyvern.forge.sdk.cache.local import LocalCache
class CacheFactory:
__cache: BaseCache = LocalCache()
@staticmethod
def set_cache(cache: BaseCache) -> None:
CacheFactory.__cache = cache
@staticmethod
def get_cache() -> BaseCache:
return CacheFactory.__cache

20
skyvern/forge/sdk/cache/local.py vendored Normal file
View File

@@ -0,0 +1,20 @@
from typing import Any
from cachetools import TTLCache
from skyvern.forge.sdk.cache.base import CACHE_EXPIRE_TIME, MAX_CACHE_ITEM, BaseCache
class LocalCache(BaseCache):
def __init__(self) -> None:
self.cache: TTLCache = TTLCache(maxsize=MAX_CACHE_ITEM, ttl=CACHE_EXPIRE_TIME.total_seconds())
async def get(self, key: str) -> Any:
if key not in self.cache:
return None
value = self.cache[key]
await self.set(key, value)
return value
async def set(self, key: str, value: Any) -> None:
self.cache[key] = value