email relay (#598)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-07-11 21:34:00 -07:00
committed by GitHub
parent f6bb4981fc
commit ea1039277f
15 changed files with 191 additions and 18 deletions

View File

@@ -2,16 +2,18 @@ import asyncio
import json
import os
import uuid
from datetime import datetime, timedelta
from typing import Any, Awaitable, Callable, List
import structlog
from deprecation import deprecated
from playwright.async_api import FileChooser, Locator, Page, TimeoutError
from skyvern.constants import REPO_ROOT_DIR
from skyvern.constants import REPO_ROOT_DIR, VERIFICATION_CODE_PLACEHOLDER, VERIFICATION_CODE_POLLING_TIMEOUT_MINS
from skyvern.exceptions import (
EmptySelect,
ErrFoundSelectableElement,
FailedToFetchSecret,
FailToClick,
FailToSelectByIndex,
FailToSelectByLabel,
@@ -33,6 +35,9 @@ from skyvern.forge.sdk.api.files import (
get_number_of_files_in_directory,
get_path_for_workflow_download_directory,
)
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants
@@ -277,7 +282,9 @@ async def handle_input_text_action(
# before filling text, we need to validate if the element can be filled if it's not one of COMMON_INPUT_TAGS
tag_name = scraped_page.id_to_element_dict[action.element_id]["tagName"].lower()
text = get_actual_value_of_parameter_if_secret(task, action.text)
text = await get_actual_value_of_parameter_if_secret(task, action.text)
if text is None:
return [ActionFailure(FailedToFetchSecret())]
try:
await skyvern_element.input_clear()
@@ -312,7 +319,7 @@ async def handle_upload_file_action(
# After this point if the file_url is a secret, it will be replaced with the actual value
# In order to make sure we don't log the secret value, we log the action with the original value action.file_url
# ************************************************************************************************************** #
file_url = get_actual_value_of_parameter_if_secret(task, action.file_url)
file_url = await get_actual_value_of_parameter_if_secret(task, action.file_url)
if file_url not in str(task.navigation_payload):
LOG.warning(
"LLM might be imagining the file url, which is not in navigation payload",
@@ -665,7 +672,7 @@ ActionHandler.register_action_type(ActionType.TERMINATE, handle_terminate_action
ActionHandler.register_action_type(ActionType.COMPLETE, handle_complete_action)
def get_actual_value_of_parameter_if_secret(task: Task, parameter: str) -> Any:
async def get_actual_value_of_parameter_if_secret(task: Task, parameter: str) -> Any:
"""
Get the actual value of a parameter if it's a secret. If it's not a secret, return the parameter value as is.
@@ -673,6 +680,13 @@ def get_actual_value_of_parameter_if_secret(task: Task, parameter: str) -> Any:
This is only used for InputTextAction, UploadFileAction, and ClickAction (if it has a file_url).
"""
if task.totp_verification_url and task.organization_id and VERIFICATION_CODE_PLACEHOLDER == parameter:
# if parameter is the secret code in the navigation playload,
# fetch the real verification from totp_verification_url
# do polling every 10 seconds to fetch the verification code
verification_code = await poll_verification_code(task.task_id, task.organization_id, task.totp_verification_url)
return verification_code
if task.workflow_run_id is None:
return parameter
@@ -702,7 +716,7 @@ async def chain_click(
LOG.info("Chain click starts", action=action, locator=locator)
file: list[str] | str = []
if action.file_url:
file_url = get_actual_value_of_parameter_if_secret(task, action.file_url)
file_url = await get_actual_value_of_parameter_if_secret(task, action.file_url)
try:
file = await download_file(file_url)
except Exception:
@@ -1095,7 +1109,6 @@ async def click_listbox_option(
LOG.error(
"Failed to click on the option",
action=action,
locator=locator,
exc_info=True,
)
if "children" in child:
@@ -1108,3 +1121,38 @@ async def get_input_value(tag_name: str, locator: Locator) -> str | None:
return await locator.input_value()
# for span, div, p or other tags:
return await locator.inner_text()
async def poll_verification_code(task_id: str, organization_id: str, url: str) -> str | None:
timeout = timedelta(minutes=VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
start_datetime = datetime.utcnow()
timeout_datetime = start_datetime + timeout
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api)
if not org_token:
LOG.error("Failed to get organization token when trying to get verification code")
return None
while True:
# check timeout
if datetime.utcnow() > timeout_datetime:
return None
request_data = {
"task_id": task_id,
}
payload = json.dumps(request_data)
signature = generate_skyvern_signature(
payload=payload,
api_key=org_token.token,
)
timestamp = str(int(datetime.utcnow().timestamp()))
headers = {
"x-skyvern-timestamp": timestamp,
"x-skyvern-signature": signature,
"Content-Type": "application/json",
}
json_resp = await aiohttp_post(url=url, data=request_data, headers=headers, raise_exception=False)
verification_code = json_resp.get("verification_code", None)
if verification_code:
LOG.info("Got verification code", verification_code=verification_code)
return verification_code
await asyncio.sleep(10)