feat: self healing skyvern api key (#3614)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from skyvern.forge.request_logging import log_raw_request_middleware
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.routes import internal_auth
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router, legacy_v2_router
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -68,6 +69,13 @@ def get_agent_app() -> FastAPI:
|
||||
app.include_router(base_router, prefix="/v1")
|
||||
app.include_router(legacy_base_router, prefix="/api/v1")
|
||||
app.include_router(legacy_v2_router, prefix="/api/v2")
|
||||
|
||||
# local dev endpoints
|
||||
if settings.ENV == "local":
|
||||
app.include_router(internal_auth.router, prefix="/v1")
|
||||
app.include_router(internal_auth.router, prefix="/api/v1")
|
||||
app.include_router(internal_auth.router, prefix="/api/v2")
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
app.add_middleware(
|
||||
|
||||
133
skyvern/forge/sdk/routes/internal_auth.py
Normal file
133
skyvern/forge/sdk/routes/internal_auth.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import ipaddress
|
||||
from enum import Enum
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Request, status
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.services.local_org_auth_token_service import fingerprint_token, regenerate_local_api_key
|
||||
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||
|
||||
router = APIRouter(prefix="/internal/auth", tags=["internal"])
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AuthStatus(str, Enum):
|
||||
missing_env = "missing_env"
|
||||
invalid_format = "invalid_format"
|
||||
invalid = "invalid"
|
||||
expired = "expired"
|
||||
not_found = "not_found"
|
||||
ok = "ok"
|
||||
|
||||
|
||||
class DiagnosticsResult(NamedTuple):
|
||||
status: AuthStatus
|
||||
detail: str | None
|
||||
validation: Any | None
|
||||
token: str | None
|
||||
|
||||
|
||||
def _is_local_request(request: Request) -> bool:
|
||||
host = request.client.host if request.client else None
|
||||
if not host:
|
||||
return False
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
return False
|
||||
return addr.is_loopback or addr.is_private
|
||||
|
||||
|
||||
def _require_local_access(request: Request) -> None:
|
||||
if settings.ENV != "local":
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, "Endpoint only available in local env")
|
||||
if not _is_local_request(request):
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, "Endpoint requires localhost access")
|
||||
|
||||
|
||||
async def _evaluate_local_api_key(token: str) -> DiagnosticsResult:
|
||||
token_candidate = token.strip()
|
||||
if not token_candidate or token_candidate == "YOUR_API_KEY":
|
||||
return DiagnosticsResult(status=AuthStatus.missing_env, detail=None, validation=None, token=None)
|
||||
|
||||
try:
|
||||
validation = await resolve_org_from_api_key(token_candidate, app.DATABASE)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == status.HTTP_404_NOT_FOUND:
|
||||
return DiagnosticsResult(status=AuthStatus.not_found, detail=None, token=None, validation=None)
|
||||
|
||||
detail_text = exc.detail if isinstance(exc.detail, str) else None
|
||||
if exc.status_code == status.HTTP_403_FORBIDDEN:
|
||||
status_value = AuthStatus.invalid
|
||||
if detail_text and "expired" in detail_text.lower():
|
||||
status_value = AuthStatus.expired
|
||||
elif detail_text and "validate" in detail_text.lower():
|
||||
status_value = AuthStatus.invalid_format
|
||||
return DiagnosticsResult(status=status_value, detail=detail_text, token=None, validation=None)
|
||||
|
||||
LOG.error("Unexpected error while diagnosing API key", status_code=exc.status_code, detail=detail_text)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("Unexpected exception while diagnosing API key", exc_info=True)
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Unable to diagnose API key")
|
||||
|
||||
return DiagnosticsResult(status=AuthStatus.ok, detail=None, validation=validation, token=token_candidate)
|
||||
|
||||
|
||||
def _emit_diagnostics(result: DiagnosticsResult) -> dict[str, object]:
|
||||
status_value = result.status.value
|
||||
|
||||
if result.status is AuthStatus.ok and result.validation and result.token:
|
||||
fingerprint = fingerprint_token(result.token)
|
||||
LOG.info(
|
||||
"Local auth diagnostics",
|
||||
status=status_value,
|
||||
organization_id=result.validation.organization.organization_id,
|
||||
fingerprint=fingerprint,
|
||||
expires_at=result.validation.payload.exp,
|
||||
)
|
||||
return {
|
||||
"status": status_value,
|
||||
"organization_id": result.validation.organization.organization_id,
|
||||
"fingerprint": fingerprint,
|
||||
"expires_at": result.validation.payload.exp,
|
||||
}
|
||||
|
||||
log_kwargs: dict[str, object] = {"status": status_value}
|
||||
if result.detail:
|
||||
log_kwargs["detail"] = result.detail
|
||||
|
||||
LOG.warning("Local auth diagnostics", **log_kwargs)
|
||||
|
||||
return {"status": status_value}
|
||||
|
||||
|
||||
@router.post("/repair", include_in_schema=False)
|
||||
async def repair_api_key(request: Request) -> dict[str, object]:
|
||||
_require_local_access(request)
|
||||
|
||||
token, organization_id, backend_env_path, frontend_env_path = await regenerate_local_api_key()
|
||||
|
||||
response: dict[str, object] = {
|
||||
"status": AuthStatus.ok.value,
|
||||
"organization_id": organization_id,
|
||||
"fingerprint": fingerprint_token(token),
|
||||
"api_key": token,
|
||||
"backend_env_path": backend_env_path,
|
||||
}
|
||||
|
||||
if frontend_env_path:
|
||||
response["frontend_env_path"] = frontend_env_path
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/status", include_in_schema=False)
|
||||
async def auth_status(request: Request) -> dict[str, object]:
|
||||
_require_local_access(request)
|
||||
token_candidate = request.headers.get("x-api-key") or ""
|
||||
result = await _evaluate_local_api_key(token_candidate)
|
||||
return _emit_diagnostics(result)
|
||||
83
skyvern/forge/sdk/services/local_org_auth_token_service.py
Normal file
83
skyvern/forge/sdk/services/local_org_auth_token_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from dotenv import set_key
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import security
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
SKYVERN_LOCAL_ORG = "Skyvern-local"
|
||||
SKYVERN_LOCAL_DOMAIN = "skyvern.local"
|
||||
|
||||
|
||||
def _write_env(path: Path, key: str, value: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not path.exists():
|
||||
path.touch()
|
||||
set_key(path, key, value)
|
||||
LOG.info(".env written", path=str(path), key=key)
|
||||
|
||||
|
||||
def fingerprint_token(value: str) -> str:
|
||||
return f"{value[:6]}…{value[-4:]}" if len(value) > 12 else "[redacted -- token too short]"
|
||||
|
||||
|
||||
async def ensure_local_org() -> Organization:
|
||||
"""Ensure the local development organization exists and return it."""
|
||||
organization = await app.DATABASE.get_organization_by_domain(SKYVERN_LOCAL_DOMAIN)
|
||||
if organization:
|
||||
return organization
|
||||
|
||||
return await app.DATABASE.create_organization(
|
||||
organization_name=SKYVERN_LOCAL_ORG,
|
||||
domain=SKYVERN_LOCAL_DOMAIN,
|
||||
max_steps_per_run=10,
|
||||
max_retries_per_step=3,
|
||||
)
|
||||
|
||||
|
||||
async def regenerate_local_api_key() -> tuple[str, str, str, str | None]:
|
||||
"""Create a fresh API key for the local organization and persist it to env files.
|
||||
|
||||
Returns:
|
||||
tuple: (api_key, org_id, backend_env_path, frontend_env_path_or_none)
|
||||
"""
|
||||
organization = await ensure_local_org()
|
||||
org_id = organization.organization_id
|
||||
|
||||
await app.DATABASE.invalidate_org_auth_tokens(
|
||||
organization_id=org_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
)
|
||||
|
||||
api_key = security.create_access_token(org_id, expires_delta=API_KEY_LIFETIME)
|
||||
await app.DATABASE.create_org_auth_token(
|
||||
organization_id=org_id,
|
||||
token=api_key,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
)
|
||||
|
||||
backend_env_path = resolve_backend_env_path()
|
||||
_write_env(backend_env_path, "SKYVERN_API_KEY", api_key)
|
||||
|
||||
frontend_env_path = resolve_frontend_env_path()
|
||||
if frontend_env_path:
|
||||
_write_env(frontend_env_path, "VITE_SKYVERN_API_KEY", api_key)
|
||||
else:
|
||||
LOG.warning("Frontend directory not found; skipping VITE_SKYVERN_API_KEY update")
|
||||
|
||||
settings.SKYVERN_API_KEY = api_key
|
||||
os.environ["SKYVERN_API_KEY"] = api_key
|
||||
|
||||
LOG.info(
|
||||
"Local API key regenerated",
|
||||
organization_id=org_id,
|
||||
fingerprint=fingerprint_token(api_key),
|
||||
)
|
||||
return api_key, org_id, str(backend_env_path), str(frontend_env_path) if frontend_env_path else None
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
import structlog
|
||||
@@ -15,7 +16,11 @@ from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.models import TokenPayload
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
OrganizationAuthTokenType,
|
||||
)
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -24,6 +29,13 @@ CACHE_SIZE = 128
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyValidationResult:
|
||||
organization: Organization
|
||||
payload: TokenPayload
|
||||
token: OrganizationAuthToken
|
||||
|
||||
|
||||
async def get_current_org(
|
||||
x_api_key: Annotated[
|
||||
str | None,
|
||||
@@ -159,11 +171,11 @@ async def _authenticate_user_helper(authorization: str) -> str:
|
||||
return user_id
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""
|
||||
Authentication is cached for one hour
|
||||
"""
|
||||
async def resolve_org_from_api_key(
|
||||
x_api_key: str,
|
||||
db: AgentDB,
|
||||
) -> ApiKeyValidationResult:
|
||||
"""Decode and validate the API key against the database."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
x_api_key,
|
||||
@@ -188,7 +200,6 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# check if the token exists in the database
|
||||
api_key_db_obj = await db.validate_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
@@ -207,9 +218,21 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
|
||||
)
|
||||
|
||||
return ApiKeyValidationResult(
|
||||
organization=organization,
|
||||
payload=api_key_data,
|
||||
token=api_key_db_obj,
|
||||
)
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""Authentication is cached for one hour."""
|
||||
validation = await resolve_org_from_api_key(x_api_key, db)
|
||||
|
||||
# set organization_id in skyvern context and log context
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
context.organization_id = organization.organization_id
|
||||
context.organization_name = organization.organization_name
|
||||
return organization
|
||||
context.organization_id = validation.organization.organization_id
|
||||
context.organization_name = validation.organization.organization_name
|
||||
return validation.organization
|
||||
|
||||
Reference in New Issue
Block a user