diff --git a/skyvern-frontend/src/api/AxiosClient.ts b/skyvern-frontend/src/api/AxiosClient.ts index 74635e44..44ec83e5 100644 --- a/skyvern-frontend/src/api/AxiosClient.ts +++ b/skyvern-frontend/src/api/AxiosClient.ts @@ -1,4 +1,10 @@ -import { apiBaseUrl, artifactApiBaseUrl, envCredential } from "@/util/env"; +import { + apiBaseUrl, + artifactApiBaseUrl, + getRuntimeApiKey, + persistRuntimeApiKey, + clearRuntimeApiKey, +} from "@/util/env"; import axios from "axios"; type ApiVersion = "sans-api-v1" | "v1" | "v2"; @@ -9,12 +15,15 @@ const url = new URL(apiBaseUrl); const pathname = url.pathname.replace("/api", ""); const apiSansApiV1BaseUrl = `${url.origin}${pathname}`; +const initialApiKey = getRuntimeApiKey(); +const apiKeyHeader = initialApiKey ? { "X-API-Key": initialApiKey } : {}; + const client = axios.create({ baseURL: apiV1BaseUrl, headers: { "Content-Type": "application/json", "x-user-agent": "skyvern-ui", - "x-api-key": envCredential, + ...apiKeyHeader, }, }); @@ -23,7 +32,7 @@ const v2Client = axios.create({ headers: { "Content-Type": "application/json", "x-user-agent": "skyvern-ui", - "x-api-key": envCredential, + ...apiKeyHeader, }, }); @@ -32,7 +41,7 @@ const clientSansApiV1 = axios.create({ headers: { "Content-Type": "application/json", "x-user-agent": "skyvern-ui", - "x-api-key": envCredential, + ...apiKeyHeader, }, }); @@ -55,12 +64,14 @@ export function removeAuthorizationHeader() { } export function setApiKeyHeader(apiKey: string) { + persistRuntimeApiKey(apiKey); client.defaults.headers.common["X-API-Key"] = apiKey; v2Client.defaults.headers.common["X-API-Key"] = apiKey; clientSansApiV1.defaults.headers.common["X-API-Key"] = apiKey; } export function removeApiKeyHeader() { + clearRuntimeApiKey(); if (client.defaults.headers.common["X-API-Key"]) { delete client.defaults.headers.common["X-API-Key"]; } diff --git a/skyvern-frontend/src/components/BrowserStream.tsx b/skyvern-frontend/src/components/BrowserStream.tsx index 565c2ae3..34cda4a3 100644 --- a/skyvern-frontend/src/components/BrowserStream.tsx +++ b/skyvern-frontend/src/components/BrowserStream.tsx @@ -16,10 +16,10 @@ import { useCredentialGetter } from "@/hooks/useCredentialGetter"; import { statusIsNotFinalized } from "@/routes/tasks/types"; import { useClientIdStore } from "@/store/useClientIdStore"; import { - envCredential, environment, wssBaseUrl, newWssBaseUrl, + getRuntimeApiKey, } from "@/util/env"; import { cn } from "@/util/utils"; @@ -140,22 +140,18 @@ function BrowserStream({ const getWebSocketParams = useCallback(async () => { const clientIdQueryParam = `client_id=${clientId}`; - let credentialQueryParam = ""; + const runtimeApiKey = getRuntimeApiKey(); - if (environment === "local") { - credentialQueryParam = `apikey=${envCredential}`; - } else { - if (credentialGetter) { - const token = await credentialGetter(); - credentialQueryParam = `token=Bearer ${token}`; - } else { - credentialQueryParam = `apikey=${envCredential}`; - } + let credentialQueryParam = runtimeApiKey ? `apikey=${runtimeApiKey}` : ""; + + if (environment !== "local" && credentialGetter) { + const token = await credentialGetter(); + credentialQueryParam = token ? `token=Bearer ${token}` : ""; } - const params = [credentialQueryParam, clientIdQueryParam].join("&"); - - return `${params}`; + return credentialQueryParam + ? `${credentialQueryParam}&${clientIdQueryParam}` + : clientIdQueryParam; }, [clientId, credentialGetter]); // browser is ready diff --git a/skyvern-frontend/src/components/SelfHealApiKeyBanner.tsx b/skyvern-frontend/src/components/SelfHealApiKeyBanner.tsx new file mode 100644 index 00000000..79dffade --- /dev/null +++ b/skyvern-frontend/src/components/SelfHealApiKeyBanner.tsx @@ -0,0 +1,202 @@ +import { useState } from "react"; + +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Button } from "@/components/ui/button"; +import { useToast } from "@/components/ui/use-toast"; +import { getClient, setApiKeyHeader } from "@/api/AxiosClient"; +import { + AuthStatusValue, + useAuthDiagnostics, +} from "@/hooks/useAuthDiagnostics"; + +type BannerStatus = Exclude | "error"; + +function getCopy(status: BannerStatus): { title: string; description: string } { + switch (status) { + case "missing_env": + return { + title: "Skyvern API key missing", + description: + "All requests from the UI to the local backend will fail until a valid key is configured.", + }; + case "invalid_format": + return { + title: "Skyvern API key is invalid", + description: + "The configured key cannot be decoded. Regenerate a new key to continue using the UI.", + }; + case "invalid": + return { + title: "Skyvern API key not recognized", + description: + "The backend rejected the configured key. Regenerate it to refresh local auth.", + }; + case "expired": + return { + title: "Skyvern API key expired", + description: + "The current key is no longer valid. Generate a fresh key to restore connectivity.", + }; + case "not_found": + return { + title: "Local organization missing", + description: + "The backend could not find the Skyvern-local organization. Regenerate the key to recreate it.", + }; + case "error": + default: + return { + title: "Unable to verify Skyvern API key", + description: + "The UI could not reach the diagnostics endpoint. Ensure the backend is running locally.", + }; + } +} + +function SelfHealApiKeyBanner() { + const diagnosticsQuery = useAuthDiagnostics(); + const { toast } = useToast(); + const [isRepairing, setIsRepairing] = useState(false); + const [errorMessage, setErrorMessage] = useState(null); + const isProductionBuild = !import.meta.env.DEV; + + const { data, error, isLoading, refetch } = diagnosticsQuery; + + const rawStatus = data?.status; + const bannerStatus: BannerStatus | null = error + ? "error" + : rawStatus && rawStatus !== "ok" + ? rawStatus + : null; + + if (!bannerStatus && !errorMessage) { + if (isLoading) { + return null; + } + return null; + } + + const copy = getCopy(bannerStatus ?? "missing_env"); + const queryErrorMessage = error?.message ?? null; + + const handleRepair = async () => { + setIsRepairing(true); + setErrorMessage(null); + try { + const client = await getClient(null); + const response = await client.post<{ + fingerprint?: string; + api_key?: string; + backend_env_path?: string; + frontend_env_path?: string; + }>("/internal/auth/repair"); + + const { + fingerprint, + api_key: apiKey, + backend_env_path: backendEnvPath, + frontend_env_path: frontendEnvPath, + } = response.data; + + if (!apiKey) { + throw new Error("Repair succeeded but no API key was returned."); + } + + setApiKeyHeader(apiKey); + + const fingerprintSuffix = fingerprint + ? ` (fingerprint ${fingerprint})` + : ""; + + const pathsElements = []; + if (backendEnvPath) { + pathsElements.push(
Backend: {backendEnvPath}
); + } + if (frontendEnvPath) { + pathsElements.push( +
Frontend: {frontendEnvPath}
, + ); + } + + toast({ + title: "API key regenerated", + description: ( +
+
+ Requests now use the updated key automatically{fingerprintSuffix}{" "} + persisted to sessionStorage and written to the following .env + paths: +
+ {pathsElements.length > 0 && ( +
{pathsElements}
+ )} + {isProductionBuild && ( +
+ Restart the UI server for more robust API key persistence. +
+ )} +
+ ), + }); + + await refetch({ throwOnError: false }); + } catch (fetchError) { + const message = + fetchError instanceof Error + ? fetchError.message + : "Unable to repair API key"; + setErrorMessage(message); + } finally { + setIsRepairing(false); + } + }; + + return ( +
+ + + {copy.title} + + + {bannerStatus !== "error" ? ( + <> +

+ {copy.description} Update VITE_SKYVERN_API_KEY in{" "} + skyvern-frontend/.env + by running skyvern init or click the button below + to regenerate it automatically. +

+ {isProductionBuild && ( +

+ When running a production build, the regenerated API key is + stored in sessionStorage. Closing this tab or browser window + will lose the key. Restart the UI server for more robust + persistence. +

+ )} +
+ +
+ + ) : ( +

{copy.description}

+ )} + {errorMessage ? ( +

{errorMessage}

+ ) : null} + {queryErrorMessage && !errorMessage ? ( +

{queryErrorMessage}

+ ) : null} +
+
+
+ ); +} + +export { SelfHealApiKeyBanner }; diff --git a/skyvern-frontend/src/hooks/useApiCredential.ts b/skyvern-frontend/src/hooks/useApiCredential.ts index 0d63f0f7..477d2f0d 100644 --- a/skyvern-frontend/src/hooks/useApiCredential.ts +++ b/skyvern-frontend/src/hooks/useApiCredential.ts @@ -1,12 +1,12 @@ import { useQuery } from "@tanstack/react-query"; import { useCredentialGetter } from "./useCredentialGetter"; import { getClient } from "@/api/AxiosClient"; -import { envCredential } from "@/util/env"; +import { getRuntimeApiKey } from "@/util/env"; import { ApiKeyApiResponse, OrganizationApiResponse } from "@/api/types"; function useApiCredential() { const credentialGetter = useCredentialGetter(); - const credentialsFromEnv = envCredential; + const credentialsFromEnv = getRuntimeApiKey(); const { data: organizations } = useQuery>({ queryKey: ["organizations"], @@ -16,7 +16,7 @@ function useApiCredential() { .get("/organizations/") .then((response) => response.data.organizations); }, - enabled: envCredential === null, + enabled: credentialsFromEnv === null, }); const organization = organizations?.[0]; diff --git a/skyvern-frontend/src/hooks/useAuthDiagnostics.ts b/skyvern-frontend/src/hooks/useAuthDiagnostics.ts new file mode 100644 index 00000000..608b99cd --- /dev/null +++ b/skyvern-frontend/src/hooks/useAuthDiagnostics.ts @@ -0,0 +1,46 @@ +import { useQuery } from "@tanstack/react-query"; +import axios from "axios"; + +import { getClient } from "@/api/AxiosClient"; + +export type AuthStatusValue = + | "missing_env" + | "invalid_format" + | "invalid" + | "expired" + | "not_found" + | "ok"; + +export type AuthDiagnosticsResponse = { + status: AuthStatusValue; + fingerprint?: string; + organization_id?: string; + expires_at?: number; + api_key?: string; +}; + +async function fetchDiagnostics(): Promise { + const client = await getClient(null); + try { + const response = await client.get( + "/internal/auth/status", + ); + return response.data; + } catch (error) { + if (axios.isAxiosError(error) && error.response?.status === 404) { + return { status: "ok" }; + } + throw error; + } +} + +function useAuthDiagnostics() { + return useQuery({ + queryKey: ["internal", "auth", "status"], + queryFn: fetchDiagnostics, + retry: false, + refetchOnWindowFocus: false, + }); +} + +export { useAuthDiagnostics }; diff --git a/skyvern-frontend/src/routes/root/RootLayout.tsx b/skyvern-frontend/src/routes/root/RootLayout.tsx index 0fcee1a2..33bfffa5 100644 --- a/skyvern-frontend/src/routes/root/RootLayout.tsx +++ b/skyvern-frontend/src/routes/root/RootLayout.tsx @@ -5,6 +5,7 @@ import { Outlet } from "react-router-dom"; import { Header } from "./Header"; import { Sidebar } from "./Sidebar"; import { useDebugStore } from "@/store/useDebugStore"; +import { SelfHealApiKeyBanner } from "@/components/SelfHealApiKeyBanner"; function RootLayout() { const collapsed = useSidebarStore((state) => state.collapsed); @@ -12,15 +13,21 @@ function RootLayout() { const isEmbedded = embed === "true"; const debugStore = useDebugStore(); + const horizontalPadding = cn("lg:pl-64", { + "lg:pl-28": collapsed, + "lg:pl-4": isEmbedded, + }); + return ( <> {!isEmbedded && }
+
+ +
diff --git a/skyvern-frontend/src/routes/settings/Settings.tsx b/skyvern-frontend/src/routes/settings/Settings.tsx index 38f48765..bb7df454 100644 --- a/skyvern-frontend/src/routes/settings/Settings.tsx +++ b/skyvern-frontend/src/routes/settings/Settings.tsx @@ -14,7 +14,7 @@ import { CardHeader, CardTitle, } from "@/components/ui/card"; -import { envCredential } from "@/util/env"; +import { getRuntimeApiKey } from "@/util/env"; import { HiddenCopyableInput } from "@/components/ui/hidden-copyable-input"; import { OnePasswordTokenForm } from "@/components/OnePasswordTokenForm"; import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSecretCredentialTokenForm"; @@ -22,7 +22,7 @@ import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSe function Settings() { const { environment, organization, setEnvironment, setOrganization } = useSettingsStore(); - const apiKey = envCredential; + const apiKey = getRuntimeApiKey(); return (
diff --git a/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx b/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx index c6cf61e5..a29bbefe 100644 --- a/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx +++ b/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx @@ -5,7 +5,7 @@ import { toast } from "@/components/ui/use-toast"; import { ZoomableImage } from "@/components/ZoomableImage"; import { useCostCalculator } from "@/hooks/useCostCalculator"; import { useCredentialGetter } from "@/hooks/useCredentialGetter"; -import { envCredential } from "@/util/env"; +import { getRuntimeApiKey } from "@/util/env"; import { keepPreviousData, useQuery, @@ -79,7 +79,8 @@ function TaskActions() { const token = await credentialGetter(); credential = `?token=Bearer ${token}`; } else { - credential = `?apikey=${envCredential}`; + const apiKey = getRuntimeApiKey(); + credential = apiKey ? `?apikey=${apiKey}` : ""; } if (socket) { socket.close(); diff --git a/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStream.tsx b/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStream.tsx index 11ac6bc6..c67e3f8c 100644 --- a/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStream.tsx +++ b/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStream.tsx @@ -5,7 +5,7 @@ import { useEffect, useState } from "react"; import { statusIsNotFinalized } from "@/routes/tasks/types"; import { useCredentialGetter } from "@/hooks/useCredentialGetter"; import { useParams } from "react-router-dom"; -import { envCredential } from "@/util/env"; +import { getRuntimeApiKey } from "@/util/env"; import { toast } from "@/components/ui/use-toast"; import { useQueryClient } from "@tanstack/react-query"; @@ -45,7 +45,8 @@ function WorkflowRunStream(props?: Props) { const token = await credentialGetter(); credential = `?token=Bearer ${token}`; } else { - credential = `?apikey=${envCredential}`; + const apiKey = getRuntimeApiKey(); + credential = apiKey ? `?apikey=${apiKey}` : ""; } if (socket) { socket.close(); diff --git a/skyvern-frontend/src/util/env.ts b/skyvern-frontend/src/util/env.ts index 99b4f090..9c5441b9 100644 --- a/skyvern-frontend/src/util/env.ts +++ b/skyvern-frontend/src/util/env.ts @@ -10,8 +10,10 @@ if (!environment) { console.warn("environment environment variable was not set"); } -const envCredential: string | null = - import.meta.env.VITE_SKYVERN_API_KEY ?? null; +const buildTimeApiKey: string | null = + typeof import.meta.env.VITE_SKYVERN_API_KEY === "string" + ? import.meta.env.VITE_SKYVERN_API_KEY + : null; const artifactApiBaseUrl = import.meta.env.VITE_ARTIFACT_API_BASE_URL; @@ -21,8 +23,11 @@ if (!artifactApiBaseUrl) { const apiPathPrefix = import.meta.env.VITE_API_PATH_PREFIX ?? ""; +const API_KEY_STORAGE_KEY = "skyvern.apiKey"; + const lsKeys = { browserSessionId: "skyvern.browserSessionId", + apiKey: API_KEY_STORAGE_KEY, }; const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL; @@ -38,13 +43,53 @@ try { newWssBaseUrl = wssBaseUrl.replace("/api", ""); } +let runtimeApiKey: string | null | undefined; + +function readPersistedApiKey(): string | null { + if (typeof window === "undefined") { + return null; + } + + return window.sessionStorage.getItem(API_KEY_STORAGE_KEY); +} + +function getRuntimeApiKey(): string | null { + if (runtimeApiKey !== undefined) { + return runtimeApiKey; + } + + const persisted = readPersistedApiKey(); + const candidate = persisted ?? buildTimeApiKey; + + // Treat YOUR_API_KEY as missing. We may inherit this from .env.example + // in some cases of misconfiguration. + runtimeApiKey = candidate === "YOUR_API_KEY" ? null : candidate; + return runtimeApiKey; +} + +function persistRuntimeApiKey(value: string): void { + runtimeApiKey = value; + if (typeof window !== "undefined") { + window.sessionStorage.setItem(API_KEY_STORAGE_KEY, value); + } +} + +function clearRuntimeApiKey(): void { + runtimeApiKey = null; + if (typeof window !== "undefined") { + window.sessionStorage.removeItem(API_KEY_STORAGE_KEY); + } +} + export { apiBaseUrl, environment, - envCredential, artifactApiBaseUrl, apiPathPrefix, lsKeys, wssBaseUrl, newWssBaseUrl, + getRuntimeApiKey, + persistRuntimeApiKey, + clearRuntimeApiKey, }; diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py index cb97ffcb..2354300f 100644 --- a/skyvern/forge/api_app.py +++ b/skyvern/forge/api_app.py @@ -19,6 +19,7 @@ from skyvern.forge.request_logging import log_raw_request_middleware from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.routes import internal_auth from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router, legacy_v2_router LOG = structlog.get_logger() @@ -68,6 +69,13 @@ def get_agent_app() -> FastAPI: app.include_router(base_router, prefix="/v1") app.include_router(legacy_base_router, prefix="/api/v1") app.include_router(legacy_v2_router, prefix="/api/v2") + + # local dev endpoints + if settings.ENV == "local": + app.include_router(internal_auth.router, prefix="/v1") + app.include_router(internal_auth.router, prefix="/api/v1") + app.include_router(internal_auth.router, prefix="/api/v2") + app.openapi = custom_openapi app.add_middleware( diff --git a/skyvern/forge/sdk/routes/internal_auth.py b/skyvern/forge/sdk/routes/internal_auth.py new file mode 100644 index 00000000..4362d6de --- /dev/null +++ b/skyvern/forge/sdk/routes/internal_auth.py @@ -0,0 +1,133 @@ +import ipaddress +from enum import Enum +from typing import Any, NamedTuple + +import structlog +from fastapi import APIRouter, HTTPException, Request, status + +from skyvern.config import settings +from skyvern.forge import app +from skyvern.forge.sdk.services.local_org_auth_token_service import fingerprint_token, regenerate_local_api_key +from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key + +router = APIRouter(prefix="/internal/auth", tags=["internal"]) +LOG = structlog.get_logger() + + +class AuthStatus(str, Enum): + missing_env = "missing_env" + invalid_format = "invalid_format" + invalid = "invalid" + expired = "expired" + not_found = "not_found" + ok = "ok" + + +class DiagnosticsResult(NamedTuple): + status: AuthStatus + detail: str | None + validation: Any | None + token: str | None + + +def _is_local_request(request: Request) -> bool: + host = request.client.host if request.client else None + if not host: + return False + try: + addr = ipaddress.ip_address(host) + except ValueError: + return False + return addr.is_loopback or addr.is_private + + +def _require_local_access(request: Request) -> None: + if settings.ENV != "local": + raise HTTPException(status.HTTP_403_FORBIDDEN, "Endpoint only available in local env") + if not _is_local_request(request): + raise HTTPException(status.HTTP_403_FORBIDDEN, "Endpoint requires localhost access") + + +async def _evaluate_local_api_key(token: str) -> DiagnosticsResult: + token_candidate = token.strip() + if not token_candidate or token_candidate == "YOUR_API_KEY": + return DiagnosticsResult(status=AuthStatus.missing_env, detail=None, validation=None, token=None) + + try: + validation = await resolve_org_from_api_key(token_candidate, app.DATABASE) + except HTTPException as exc: + if exc.status_code == status.HTTP_404_NOT_FOUND: + return DiagnosticsResult(status=AuthStatus.not_found, detail=None, token=None, validation=None) + + detail_text = exc.detail if isinstance(exc.detail, str) else None + if exc.status_code == status.HTTP_403_FORBIDDEN: + status_value = AuthStatus.invalid + if detail_text and "expired" in detail_text.lower(): + status_value = AuthStatus.expired + elif detail_text and "validate" in detail_text.lower(): + status_value = AuthStatus.invalid_format + return DiagnosticsResult(status=status_value, detail=detail_text, token=None, validation=None) + + LOG.error("Unexpected error while diagnosing API key", status_code=exc.status_code, detail=detail_text) + raise + except Exception: + LOG.error("Unexpected exception while diagnosing API key", exc_info=True) + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Unable to diagnose API key") + + return DiagnosticsResult(status=AuthStatus.ok, detail=None, validation=validation, token=token_candidate) + + +def _emit_diagnostics(result: DiagnosticsResult) -> dict[str, object]: + status_value = result.status.value + + if result.status is AuthStatus.ok and result.validation and result.token: + fingerprint = fingerprint_token(result.token) + LOG.info( + "Local auth diagnostics", + status=status_value, + organization_id=result.validation.organization.organization_id, + fingerprint=fingerprint, + expires_at=result.validation.payload.exp, + ) + return { + "status": status_value, + "organization_id": result.validation.organization.organization_id, + "fingerprint": fingerprint, + "expires_at": result.validation.payload.exp, + } + + log_kwargs: dict[str, object] = {"status": status_value} + if result.detail: + log_kwargs["detail"] = result.detail + + LOG.warning("Local auth diagnostics", **log_kwargs) + + return {"status": status_value} + + +@router.post("/repair", include_in_schema=False) +async def repair_api_key(request: Request) -> dict[str, object]: + _require_local_access(request) + + token, organization_id, backend_env_path, frontend_env_path = await regenerate_local_api_key() + + response: dict[str, object] = { + "status": AuthStatus.ok.value, + "organization_id": organization_id, + "fingerprint": fingerprint_token(token), + "api_key": token, + "backend_env_path": backend_env_path, + } + + if frontend_env_path: + response["frontend_env_path"] = frontend_env_path + + return response + + +@router.get("/status", include_in_schema=False) +async def auth_status(request: Request) -> dict[str, object]: + _require_local_access(request) + token_candidate = request.headers.get("x-api-key") or "" + result = await _evaluate_local_api_key(token_candidate) + return _emit_diagnostics(result) diff --git a/skyvern/forge/sdk/services/local_org_auth_token_service.py b/skyvern/forge/sdk/services/local_org_auth_token_service.py new file mode 100644 index 00000000..808e4311 --- /dev/null +++ b/skyvern/forge/sdk/services/local_org_auth_token_service.py @@ -0,0 +1,83 @@ +import os +from pathlib import Path + +import structlog +from dotenv import set_key + +from skyvern.config import settings +from skyvern.forge import app +from skyvern.forge.sdk.core import security +from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType +from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME +from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path + +LOG = structlog.get_logger() +SKYVERN_LOCAL_ORG = "Skyvern-local" +SKYVERN_LOCAL_DOMAIN = "skyvern.local" + + +def _write_env(path: Path, key: str, value: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + path.touch() + set_key(path, key, value) + LOG.info(".env written", path=str(path), key=key) + + +def fingerprint_token(value: str) -> str: + return f"{value[:6]}…{value[-4:]}" if len(value) > 12 else "[redacted -- token too short]" + + +async def ensure_local_org() -> Organization: + """Ensure the local development organization exists and return it.""" + organization = await app.DATABASE.get_organization_by_domain(SKYVERN_LOCAL_DOMAIN) + if organization: + return organization + + return await app.DATABASE.create_organization( + organization_name=SKYVERN_LOCAL_ORG, + domain=SKYVERN_LOCAL_DOMAIN, + max_steps_per_run=10, + max_retries_per_step=3, + ) + + +async def regenerate_local_api_key() -> tuple[str, str, str, str | None]: + """Create a fresh API key for the local organization and persist it to env files. + + Returns: + tuple: (api_key, org_id, backend_env_path, frontend_env_path_or_none) + """ + organization = await ensure_local_org() + org_id = organization.organization_id + + await app.DATABASE.invalidate_org_auth_tokens( + organization_id=org_id, + token_type=OrganizationAuthTokenType.api, + ) + + api_key = security.create_access_token(org_id, expires_delta=API_KEY_LIFETIME) + await app.DATABASE.create_org_auth_token( + organization_id=org_id, + token=api_key, + token_type=OrganizationAuthTokenType.api, + ) + + backend_env_path = resolve_backend_env_path() + _write_env(backend_env_path, "SKYVERN_API_KEY", api_key) + + frontend_env_path = resolve_frontend_env_path() + if frontend_env_path: + _write_env(frontend_env_path, "VITE_SKYVERN_API_KEY", api_key) + else: + LOG.warning("Frontend directory not found; skipping VITE_SKYVERN_API_KEY update") + + settings.SKYVERN_API_KEY = api_key + os.environ["SKYVERN_API_KEY"] = api_key + + LOG.info( + "Local API key regenerated", + organization_id=org_id, + fingerprint=fingerprint_token(api_key), + ) + return api_key, org_id, str(backend_env_path), str(frontend_env_path) if frontend_env_path else None diff --git a/skyvern/forge/sdk/services/org_auth_service.py b/skyvern/forge/sdk/services/org_auth_service.py index 92bcd35d..8d96d1b1 100644 --- a/skyvern/forge/sdk/services/org_auth_service.py +++ b/skyvern/forge/sdk/services/org_auth_service.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import Annotated import structlog @@ -15,7 +16,11 @@ from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.db.client import AgentDB from skyvern.forge.sdk.models import TokenPayload -from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType +from skyvern.forge.sdk.schemas.organizations import ( + Organization, + OrganizationAuthToken, + OrganizationAuthTokenType, +) LOG = structlog.get_logger() @@ -24,6 +29,13 @@ CACHE_SIZE = 128 ALGORITHM = "HS256" +@dataclass +class ApiKeyValidationResult: + organization: Organization + payload: TokenPayload + token: OrganizationAuthToken + + async def get_current_org( x_api_key: Annotated[ str | None, @@ -159,11 +171,11 @@ async def _authenticate_user_helper(authorization: str) -> str: return user_id -@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL)) -async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: - """ - Authentication is cached for one hour - """ +async def resolve_org_from_api_key( + x_api_key: str, + db: AgentDB, +) -> ApiKeyValidationResult: + """Decode and validate the API key against the database.""" try: payload = jwt.decode( x_api_key, @@ -188,7 +200,6 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload) raise HTTPException(status_code=404, detail="Organization not found") - # check if the token exists in the database api_key_db_obj = await db.validate_org_auth_token( organization_id=organization.organization_id, token_type=OrganizationAuthTokenType.api, @@ -207,9 +218,21 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings", ) + return ApiKeyValidationResult( + organization=organization, + payload=api_key_data, + token=api_key_db_obj, + ) + + +@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL)) +async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: + """Authentication is cached for one hour.""" + validation = await resolve_org_from_api_key(x_api_key, db) + # set organization_id in skyvern context and log context context = skyvern_context.current() if context: - context.organization_id = organization.organization_id - context.organization_name = organization.organization_name - return organization + context.organization_id = validation.organization.organization_id + context.organization_name = validation.organization.organization_name + return validation.organization diff --git a/skyvern/library/skyvern.py b/skyvern/library/skyvern.py index d60910a4..f85e346d 100644 --- a/skyvern/library/skyvern.py +++ b/skyvern/library/skyvern.py @@ -20,6 +20,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Status from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus +from skyvern.forge.sdk.services.local_org_auth_token_service import SKYVERN_LOCAL_DOMAIN, SKYVERN_LOCAL_ORG from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT @@ -96,11 +97,11 @@ class Skyvern(AsyncSkyvern): raise ValueError("Initializing Skyvern failed: api_key must be provided") async def get_organization(self) -> Organization: - organization = await app.DATABASE.get_organization_by_domain("skyvern.local") + organization = await app.DATABASE.get_organization_by_domain(SKYVERN_LOCAL_DOMAIN) if not organization: organization = await app.DATABASE.create_organization( - organization_name="Skyvern-local", - domain="skyvern.local", + organization_name=SKYVERN_LOCAL_ORG, + domain=SKYVERN_LOCAL_DOMAIN, max_steps_per_run=10, max_retries_per_step=3, ) diff --git a/tests/unit_tests/test_internal_auth.py b/tests/unit_tests/test_internal_auth.py new file mode 100644 index 00000000..2723f4d9 --- /dev/null +++ b/tests/unit_tests/test_internal_auth.py @@ -0,0 +1,35 @@ +from starlette.requests import Request + +from skyvern.forge.sdk.routes.internal_auth import _is_local_request + + +def _make_request(host: str | None) -> Request: + scope = { + "type": "http", + "client": (host, 12345) if host else None, + "headers": [], + "method": "GET", + "path": "/", + "scheme": "http", + } + return Request(scope) + + +def test_is_local_request_returns_false_for_public_ip() -> None: + request = _make_request("8.8.8.8") # public IPv4 address + assert _is_local_request(request) is False + + +def test_is_local_request_accepts_loopback() -> None: + request = _make_request("127.0.0.1") + assert _is_local_request(request) is True + + +def test_is_local_request_accepts_private_networks() -> None: + request = _make_request("192.168.1.20") + assert _is_local_request(request) is True + + +def test_is_local_request_handles_missing_client() -> None: + request = _make_request(None) + assert _is_local_request(request) is False