svg conversion (#717)
This commit is contained in:
@@ -1,22 +1,106 @@
|
|||||||
from typing import Awaitable, Callable
|
import copy
|
||||||
|
import hashlib
|
||||||
|
from typing import Awaitable, Callable, Dict, List
|
||||||
|
|
||||||
|
import structlog
|
||||||
from playwright.async_api import Page
|
from playwright.async_api import Page
|
||||||
|
|
||||||
|
from skyvern.constants import SKYVERN_ID_ATTR
|
||||||
from skyvern.exceptions import StepUnableToExecuteError
|
from skyvern.exceptions import StepUnableToExecuteError
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.async_operations import AsyncOperation
|
from skyvern.forge.async_operations import AsyncOperation
|
||||||
|
from skyvern.forge.prompts import prompt_engine
|
||||||
from skyvern.forge.sdk.models import Organization, Step, StepStatus
|
from skyvern.forge.sdk.models import Organization, Step, StepStatus
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
from skyvern.webeye.browser_factory import BrowserState
|
from skyvern.webeye.browser_factory import BrowserState
|
||||||
|
from skyvern.webeye.scraper.scraper import ELEMENT_NODE_ATTRIBUTES, json_to_html
|
||||||
|
|
||||||
CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]]
|
CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]]
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
def _remove_rect(element: dict) -> None:
|
def _remove_rect(element: dict) -> None:
|
||||||
if "rect" in element:
|
if "rect" in element:
|
||||||
del element["rect"]
|
del element["rect"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_svg_cache_key(hash: str) -> str:
|
||||||
|
return f"skyvern:svg:{hash}"
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_skyvern_attributes(element: Dict) -> Dict:
|
||||||
|
"""
|
||||||
|
To get the original HTML element without skyvern attributes
|
||||||
|
"""
|
||||||
|
element_copied = copy.deepcopy(element)
|
||||||
|
for attr in ELEMENT_NODE_ATTRIBUTES:
|
||||||
|
if element_copied.get(attr):
|
||||||
|
del element_copied[attr]
|
||||||
|
|
||||||
|
if element_copied.get("attributes") and SKYVERN_ID_ATTR in element_copied.get("attributes", {}):
|
||||||
|
del element_copied["attributes"][SKYVERN_ID_ATTR]
|
||||||
|
|
||||||
|
children: List[Dict] | None = element_copied.get("children", None)
|
||||||
|
if children is None:
|
||||||
|
return element_copied
|
||||||
|
|
||||||
|
trimmed_children = []
|
||||||
|
for child in children:
|
||||||
|
trimmed_children.append(_remove_skyvern_attributes(child))
|
||||||
|
|
||||||
|
element_copied["children"] = trimmed_children
|
||||||
|
return element_copied
|
||||||
|
|
||||||
|
|
||||||
|
async def _convert_svg_to_string(task: Task, step: Step, organization: Organization | None, element: Dict) -> None:
|
||||||
|
if element.get("tagName") != "svg":
|
||||||
|
return
|
||||||
|
|
||||||
|
element_id = element.get("id", "")
|
||||||
|
svg_element = _remove_skyvern_attributes(element)
|
||||||
|
svg_html = json_to_html(svg_element)
|
||||||
|
hash_object = hashlib.sha256()
|
||||||
|
hash_object.update(svg_html.encode("utf-8"))
|
||||||
|
svg_hash = hash_object.hexdigest()
|
||||||
|
svg_key = _get_svg_cache_key(svg_hash)
|
||||||
|
|
||||||
|
svg_shape: str | None = None
|
||||||
|
try:
|
||||||
|
svg_shape = await app.CACHE.get(svg_key)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to loaded SVG cache",
|
||||||
|
exc_info=True,
|
||||||
|
key=svg_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if svg_shape:
|
||||||
|
LOG.debug("SVG loaded from cache", element_id=element_id, shape=svg_shape)
|
||||||
|
else:
|
||||||
|
LOG.debug("call LLM to convert SVG to string shape", element_id=element_id)
|
||||||
|
svg_convert_prompt = prompt_engine.load_prompt("svg-convert", svg_element=svg_html)
|
||||||
|
try:
|
||||||
|
json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step)
|
||||||
|
svg_shape = json_response.get("shape", "")
|
||||||
|
if not svg_shape:
|
||||||
|
raise Exception("Empty SVG shape replied by secondary llm")
|
||||||
|
LOG.info("SVG converted by LLM", element_id=element_id, shape=svg_shape)
|
||||||
|
await app.CACHE.set(svg_key, svg_shape)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"Failed to convert SVG to string shape by secondary llm",
|
||||||
|
element=element,
|
||||||
|
svg_html=svg_html,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
element["attributes"] = dict()
|
||||||
|
element["attributes"]["alt"] = svg_shape
|
||||||
|
del element["children"]
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class AgentFunction:
|
class AgentFunction:
|
||||||
async def validate_step_execution(
|
async def validate_step_execution(
|
||||||
self,
|
self,
|
||||||
@@ -87,6 +171,7 @@ class AgentFunction:
|
|||||||
while queue:
|
while queue:
|
||||||
queue_ele = queue.pop(0)
|
queue_ele = queue.pop(0)
|
||||||
_remove_rect(queue_ele)
|
_remove_rect(queue_ele)
|
||||||
|
await _convert_svg_to_string(task, step, organization, queue_ele)
|
||||||
# TODO: we can come back to test removing the unique_id
|
# TODO: we can come back to test removing the unique_id
|
||||||
# from element attributes to make sure this won't increase hallucination
|
# from element attributes to make sure this won't increase hallucination
|
||||||
# _remove_unique_id(queue_ele)
|
# _remove_unique_id(queue_ele)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from skyvern.forge.agent_functions import AgentFunction
|
|||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||||
|
from skyvern.forge.sdk.cache.factory import CacheFactory
|
||||||
from skyvern.forge.sdk.db.client import AgentDB
|
from skyvern.forge.sdk.db.client import AgentDB
|
||||||
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
||||||
from skyvern.forge.sdk.models import Organization
|
from skyvern.forge.sdk.models import Organization
|
||||||
@@ -22,6 +23,7 @@ DATABASE = AgentDB(
|
|||||||
debug_enabled=SettingsManager.get_settings().DEBUG_MODE,
|
debug_enabled=SettingsManager.get_settings().DEBUG_MODE,
|
||||||
)
|
)
|
||||||
STORAGE = StorageFactory.get_storage()
|
STORAGE = StorageFactory.get_storage()
|
||||||
|
CACHE = CacheFactory.get_cache()
|
||||||
ARTIFACT_MANAGER = ArtifactManager()
|
ARTIFACT_MANAGER = ArtifactManager()
|
||||||
BROWSER_MANAGER = BrowserManager()
|
BROWSER_MANAGER = BrowserManager()
|
||||||
EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider()
|
EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider()
|
||||||
|
|||||||
12
skyvern/forge/prompts/skyvern/svg-convert.j2
Normal file
12
skyvern/forge/prompts/skyvern/svg-convert.j2
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
You are given a svg element. You need to figure out what its shape means.
|
||||||
|
SVG Element:
|
||||||
|
```
|
||||||
|
{{svg_element}}
|
||||||
|
```
|
||||||
|
|
||||||
|
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
|
||||||
|
Reply in JSON format with the following keys:
|
||||||
|
{
|
||||||
|
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
|
||||||
|
"shape": string, // A short description of the shape of SVG and its meaning
|
||||||
|
}
|
||||||
16
skyvern/forge/sdk/cache/base.py
vendored
Normal file
16
skyvern/forge/sdk/cache/base.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
CACHE_EXPIRE_TIME = timedelta(weeks=1)
|
||||||
|
MAX_CACHE_ITEM = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCache(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def set(self, key: str, value: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get(self, key: str) -> Any:
|
||||||
|
pass
|
||||||
14
skyvern/forge/sdk/cache/factory.py
vendored
Normal file
14
skyvern/forge/sdk/cache/factory.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from skyvern.forge.sdk.cache.base import BaseCache
|
||||||
|
from skyvern.forge.sdk.cache.local import LocalCache
|
||||||
|
|
||||||
|
|
||||||
|
class CacheFactory:
|
||||||
|
__cache: BaseCache = LocalCache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_cache(cache: BaseCache) -> None:
|
||||||
|
CacheFactory.__cache = cache
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cache() -> BaseCache:
|
||||||
|
return CacheFactory.__cache
|
||||||
20
skyvern/forge/sdk/cache/local.py
vendored
Normal file
20
skyvern/forge/sdk/cache/local.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.cache.base import CACHE_EXPIRE_TIME, MAX_CACHE_ITEM, BaseCache
|
||||||
|
|
||||||
|
|
||||||
|
class LocalCache(BaseCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.cache: TTLCache = TTLCache(maxsize=MAX_CACHE_ITEM, ttl=CACHE_EXPIRE_TIME.total_seconds())
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Any:
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
value = self.cache[key]
|
||||||
|
await self.set(key, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def set(self, key: str, value: Any) -> None:
|
||||||
|
self.cache[key] = value
|
||||||
Reference in New Issue
Block a user