feat: self healing skyvern api key (#3614)

Co-authored-by: Suchintan <suchintan@users.noreply.github.com>
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
greg niemeyer
2025-10-13 07:55:59 -07:00
committed by GitHub
parent a8179ae61c
commit 2faf4e102f
16 changed files with 638 additions and 46 deletions

View File

@@ -0,0 +1,83 @@
import os
from pathlib import Path
import structlog
from dotenv import set_key
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.core import security
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
LOG = structlog.get_logger()
SKYVERN_LOCAL_ORG = "Skyvern-local"
SKYVERN_LOCAL_DOMAIN = "skyvern.local"
def _write_env(path: Path, key: str, value: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
path.touch()
set_key(path, key, value)
LOG.info(".env written", path=str(path), key=key)
def fingerprint_token(value: str) -> str:
return f"{value[:6]}{value[-4:]}" if len(value) > 12 else "[redacted -- token too short]"
async def ensure_local_org() -> Organization:
"""Ensure the local development organization exists and return it."""
organization = await app.DATABASE.get_organization_by_domain(SKYVERN_LOCAL_DOMAIN)
if organization:
return organization
return await app.DATABASE.create_organization(
organization_name=SKYVERN_LOCAL_ORG,
domain=SKYVERN_LOCAL_DOMAIN,
max_steps_per_run=10,
max_retries_per_step=3,
)
async def regenerate_local_api_key() -> tuple[str, str, str, str | None]:
"""Create a fresh API key for the local organization and persist it to env files.
Returns:
tuple: (api_key, org_id, backend_env_path, frontend_env_path_or_none)
"""
organization = await ensure_local_org()
org_id = organization.organization_id
await app.DATABASE.invalidate_org_auth_tokens(
organization_id=org_id,
token_type=OrganizationAuthTokenType.api,
)
api_key = security.create_access_token(org_id, expires_delta=API_KEY_LIFETIME)
await app.DATABASE.create_org_auth_token(
organization_id=org_id,
token=api_key,
token_type=OrganizationAuthTokenType.api,
)
backend_env_path = resolve_backend_env_path()
_write_env(backend_env_path, "SKYVERN_API_KEY", api_key)
frontend_env_path = resolve_frontend_env_path()
if frontend_env_path:
_write_env(frontend_env_path, "VITE_SKYVERN_API_KEY", api_key)
else:
LOG.warning("Frontend directory not found; skipping VITE_SKYVERN_API_KEY update")
settings.SKYVERN_API_KEY = api_key
os.environ["SKYVERN_API_KEY"] = api_key
LOG.info(
"Local API key regenerated",
organization_id=org_id,
fingerprint=fingerprint_token(api_key),
)
return api_key, org_id, str(backend_env_path), str(frontend_env_path) if frontend_env_path else None

View File

@@ -1,4 +1,5 @@
import time
from dataclasses import dataclass
from typing import Annotated
import structlog
@@ -15,7 +16,11 @@ from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.models import TokenPayload
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.organizations import (
Organization,
OrganizationAuthToken,
OrganizationAuthTokenType,
)
LOG = structlog.get_logger()
@@ -24,6 +29,13 @@ CACHE_SIZE = 128
ALGORITHM = "HS256"
@dataclass
class ApiKeyValidationResult:
organization: Organization
payload: TokenPayload
token: OrganizationAuthToken
async def get_current_org(
x_api_key: Annotated[
str | None,
@@ -159,11 +171,11 @@ async def _authenticate_user_helper(authorization: str) -> str:
return user_id
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
"""
Authentication is cached for one hour
"""
async def resolve_org_from_api_key(
x_api_key: str,
db: AgentDB,
) -> ApiKeyValidationResult:
"""Decode and validate the API key against the database."""
try:
payload = jwt.decode(
x_api_key,
@@ -188,7 +200,6 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
raise HTTPException(status_code=404, detail="Organization not found")
# check if the token exists in the database
api_key_db_obj = await db.validate_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
@@ -207,9 +218,21 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
)
return ApiKeyValidationResult(
organization=organization,
payload=api_key_data,
token=api_key_db_obj,
)
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
"""Authentication is cached for one hour."""
validation = await resolve_org_from_api_key(x_api_key, db)
# set organization_id in skyvern context and log context
context = skyvern_context.current()
if context:
context.organization_id = organization.organization_id
context.organization_name = organization.organization_name
return organization
context.organization_id = validation.organization.organization_id
context.organization_name = validation.organization.organization_name
return validation.organization