feat: self healing skyvern api key (#3614)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -1,4 +1,10 @@
|
|||||||
import { apiBaseUrl, artifactApiBaseUrl, envCredential } from "@/util/env";
|
import {
|
||||||
|
apiBaseUrl,
|
||||||
|
artifactApiBaseUrl,
|
||||||
|
getRuntimeApiKey,
|
||||||
|
persistRuntimeApiKey,
|
||||||
|
clearRuntimeApiKey,
|
||||||
|
} from "@/util/env";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
|
||||||
type ApiVersion = "sans-api-v1" | "v1" | "v2";
|
type ApiVersion = "sans-api-v1" | "v1" | "v2";
|
||||||
@@ -9,12 +15,15 @@ const url = new URL(apiBaseUrl);
|
|||||||
const pathname = url.pathname.replace("/api", "");
|
const pathname = url.pathname.replace("/api", "");
|
||||||
const apiSansApiV1BaseUrl = `${url.origin}${pathname}`;
|
const apiSansApiV1BaseUrl = `${url.origin}${pathname}`;
|
||||||
|
|
||||||
|
const initialApiKey = getRuntimeApiKey();
|
||||||
|
const apiKeyHeader = initialApiKey ? { "X-API-Key": initialApiKey } : {};
|
||||||
|
|
||||||
const client = axios.create({
|
const client = axios.create({
|
||||||
baseURL: apiV1BaseUrl,
|
baseURL: apiV1BaseUrl,
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"x-user-agent": "skyvern-ui",
|
"x-user-agent": "skyvern-ui",
|
||||||
"x-api-key": envCredential,
|
...apiKeyHeader,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -23,7 +32,7 @@ const v2Client = axios.create({
|
|||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"x-user-agent": "skyvern-ui",
|
"x-user-agent": "skyvern-ui",
|
||||||
"x-api-key": envCredential,
|
...apiKeyHeader,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -32,7 +41,7 @@ const clientSansApiV1 = axios.create({
|
|||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"x-user-agent": "skyvern-ui",
|
"x-user-agent": "skyvern-ui",
|
||||||
"x-api-key": envCredential,
|
...apiKeyHeader,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -55,12 +64,14 @@ export function removeAuthorizationHeader() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function setApiKeyHeader(apiKey: string) {
|
export function setApiKeyHeader(apiKey: string) {
|
||||||
|
persistRuntimeApiKey(apiKey);
|
||||||
client.defaults.headers.common["X-API-Key"] = apiKey;
|
client.defaults.headers.common["X-API-Key"] = apiKey;
|
||||||
v2Client.defaults.headers.common["X-API-Key"] = apiKey;
|
v2Client.defaults.headers.common["X-API-Key"] = apiKey;
|
||||||
clientSansApiV1.defaults.headers.common["X-API-Key"] = apiKey;
|
clientSansApiV1.defaults.headers.common["X-API-Key"] = apiKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function removeApiKeyHeader() {
|
export function removeApiKeyHeader() {
|
||||||
|
clearRuntimeApiKey();
|
||||||
if (client.defaults.headers.common["X-API-Key"]) {
|
if (client.defaults.headers.common["X-API-Key"]) {
|
||||||
delete client.defaults.headers.common["X-API-Key"];
|
delete client.defaults.headers.common["X-API-Key"];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ import { useCredentialGetter } from "@/hooks/useCredentialGetter";
|
|||||||
import { statusIsNotFinalized } from "@/routes/tasks/types";
|
import { statusIsNotFinalized } from "@/routes/tasks/types";
|
||||||
import { useClientIdStore } from "@/store/useClientIdStore";
|
import { useClientIdStore } from "@/store/useClientIdStore";
|
||||||
import {
|
import {
|
||||||
envCredential,
|
|
||||||
environment,
|
environment,
|
||||||
wssBaseUrl,
|
wssBaseUrl,
|
||||||
newWssBaseUrl,
|
newWssBaseUrl,
|
||||||
|
getRuntimeApiKey,
|
||||||
} from "@/util/env";
|
} from "@/util/env";
|
||||||
import { cn } from "@/util/utils";
|
import { cn } from "@/util/utils";
|
||||||
|
|
||||||
@@ -140,22 +140,18 @@ function BrowserStream({
|
|||||||
|
|
||||||
const getWebSocketParams = useCallback(async () => {
|
const getWebSocketParams = useCallback(async () => {
|
||||||
const clientIdQueryParam = `client_id=${clientId}`;
|
const clientIdQueryParam = `client_id=${clientId}`;
|
||||||
let credentialQueryParam = "";
|
const runtimeApiKey = getRuntimeApiKey();
|
||||||
|
|
||||||
if (environment === "local") {
|
let credentialQueryParam = runtimeApiKey ? `apikey=${runtimeApiKey}` : "";
|
||||||
credentialQueryParam = `apikey=${envCredential}`;
|
|
||||||
} else {
|
if (environment !== "local" && credentialGetter) {
|
||||||
if (credentialGetter) {
|
const token = await credentialGetter();
|
||||||
const token = await credentialGetter();
|
credentialQueryParam = token ? `token=Bearer ${token}` : "";
|
||||||
credentialQueryParam = `token=Bearer ${token}`;
|
|
||||||
} else {
|
|
||||||
credentialQueryParam = `apikey=${envCredential}`;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const params = [credentialQueryParam, clientIdQueryParam].join("&");
|
return credentialQueryParam
|
||||||
|
? `${credentialQueryParam}&${clientIdQueryParam}`
|
||||||
return `${params}`;
|
: clientIdQueryParam;
|
||||||
}, [clientId, credentialGetter]);
|
}, [clientId, credentialGetter]);
|
||||||
|
|
||||||
// browser is ready
|
// browser is ready
|
||||||
|
|||||||
202
skyvern-frontend/src/components/SelfHealApiKeyBanner.tsx
Normal file
202
skyvern-frontend/src/components/SelfHealApiKeyBanner.tsx
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
|
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { useToast } from "@/components/ui/use-toast";
|
||||||
|
import { getClient, setApiKeyHeader } from "@/api/AxiosClient";
|
||||||
|
import {
|
||||||
|
AuthStatusValue,
|
||||||
|
useAuthDiagnostics,
|
||||||
|
} from "@/hooks/useAuthDiagnostics";
|
||||||
|
|
||||||
|
type BannerStatus = Exclude<AuthStatusValue, "ok"> | "error";
|
||||||
|
|
||||||
|
function getCopy(status: BannerStatus): { title: string; description: string } {
|
||||||
|
switch (status) {
|
||||||
|
case "missing_env":
|
||||||
|
return {
|
||||||
|
title: "Skyvern API key missing",
|
||||||
|
description:
|
||||||
|
"All requests from the UI to the local backend will fail until a valid key is configured.",
|
||||||
|
};
|
||||||
|
case "invalid_format":
|
||||||
|
return {
|
||||||
|
title: "Skyvern API key is invalid",
|
||||||
|
description:
|
||||||
|
"The configured key cannot be decoded. Regenerate a new key to continue using the UI.",
|
||||||
|
};
|
||||||
|
case "invalid":
|
||||||
|
return {
|
||||||
|
title: "Skyvern API key not recognized",
|
||||||
|
description:
|
||||||
|
"The backend rejected the configured key. Regenerate it to refresh local auth.",
|
||||||
|
};
|
||||||
|
case "expired":
|
||||||
|
return {
|
||||||
|
title: "Skyvern API key expired",
|
||||||
|
description:
|
||||||
|
"The current key is no longer valid. Generate a fresh key to restore connectivity.",
|
||||||
|
};
|
||||||
|
case "not_found":
|
||||||
|
return {
|
||||||
|
title: "Local organization missing",
|
||||||
|
description:
|
||||||
|
"The backend could not find the Skyvern-local organization. Regenerate the key to recreate it.",
|
||||||
|
};
|
||||||
|
case "error":
|
||||||
|
default:
|
||||||
|
return {
|
||||||
|
title: "Unable to verify Skyvern API key",
|
||||||
|
description:
|
||||||
|
"The UI could not reach the diagnostics endpoint. Ensure the backend is running locally.",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function SelfHealApiKeyBanner() {
|
||||||
|
const diagnosticsQuery = useAuthDiagnostics();
|
||||||
|
const { toast } = useToast();
|
||||||
|
const [isRepairing, setIsRepairing] = useState(false);
|
||||||
|
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||||
|
const isProductionBuild = !import.meta.env.DEV;
|
||||||
|
|
||||||
|
const { data, error, isLoading, refetch } = diagnosticsQuery;
|
||||||
|
|
||||||
|
const rawStatus = data?.status;
|
||||||
|
const bannerStatus: BannerStatus | null = error
|
||||||
|
? "error"
|
||||||
|
: rawStatus && rawStatus !== "ok"
|
||||||
|
? rawStatus
|
||||||
|
: null;
|
||||||
|
|
||||||
|
if (!bannerStatus && !errorMessage) {
|
||||||
|
if (isLoading) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const copy = getCopy(bannerStatus ?? "missing_env");
|
||||||
|
const queryErrorMessage = error?.message ?? null;
|
||||||
|
|
||||||
|
const handleRepair = async () => {
|
||||||
|
setIsRepairing(true);
|
||||||
|
setErrorMessage(null);
|
||||||
|
try {
|
||||||
|
const client = await getClient(null);
|
||||||
|
const response = await client.post<{
|
||||||
|
fingerprint?: string;
|
||||||
|
api_key?: string;
|
||||||
|
backend_env_path?: string;
|
||||||
|
frontend_env_path?: string;
|
||||||
|
}>("/internal/auth/repair");
|
||||||
|
|
||||||
|
const {
|
||||||
|
fingerprint,
|
||||||
|
api_key: apiKey,
|
||||||
|
backend_env_path: backendEnvPath,
|
||||||
|
frontend_env_path: frontendEnvPath,
|
||||||
|
} = response.data;
|
||||||
|
|
||||||
|
if (!apiKey) {
|
||||||
|
throw new Error("Repair succeeded but no API key was returned.");
|
||||||
|
}
|
||||||
|
|
||||||
|
setApiKeyHeader(apiKey);
|
||||||
|
|
||||||
|
const fingerprintSuffix = fingerprint
|
||||||
|
? ` (fingerprint ${fingerprint})`
|
||||||
|
: "";
|
||||||
|
|
||||||
|
const pathsElements = [];
|
||||||
|
if (backendEnvPath) {
|
||||||
|
pathsElements.push(<div key="backend">Backend: {backendEnvPath}</div>);
|
||||||
|
}
|
||||||
|
if (frontendEnvPath) {
|
||||||
|
pathsElements.push(
|
||||||
|
<div key="frontend">Frontend: {frontendEnvPath}</div>,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
toast({
|
||||||
|
title: "API key regenerated",
|
||||||
|
description: (
|
||||||
|
<div>
|
||||||
|
<div>
|
||||||
|
Requests now use the updated key automatically{fingerprintSuffix}{" "}
|
||||||
|
persisted to sessionStorage and written to the following .env
|
||||||
|
paths:
|
||||||
|
</div>
|
||||||
|
{pathsElements.length > 0 && (
|
||||||
|
<div className="mt-2 space-y-2">{pathsElements}</div>
|
||||||
|
)}
|
||||||
|
{isProductionBuild && (
|
||||||
|
<div className="mt-3">
|
||||||
|
Restart the UI server for more robust API key persistence.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
});
|
||||||
|
|
||||||
|
await refetch({ throwOnError: false });
|
||||||
|
} catch (fetchError) {
|
||||||
|
const message =
|
||||||
|
fetchError instanceof Error
|
||||||
|
? fetchError.message
|
||||||
|
: "Unable to repair API key";
|
||||||
|
setErrorMessage(message);
|
||||||
|
} finally {
|
||||||
|
setIsRepairing(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="px-4 pt-4">
|
||||||
|
<Alert className="flex flex-col items-center gap-2 border-slate-700 bg-slate-900 text-slate-50">
|
||||||
|
<AlertTitle className="text-center text-base font-semibold tracking-wide">
|
||||||
|
{copy.title}
|
||||||
|
</AlertTitle>
|
||||||
|
<AlertDescription className="space-y-3 text-center text-sm leading-6">
|
||||||
|
{bannerStatus !== "error" ? (
|
||||||
|
<>
|
||||||
|
<p>
|
||||||
|
{copy.description} Update <code>VITE_SKYVERN_API_KEY</code> in{" "}
|
||||||
|
<code className="mx-1">skyvern-frontend/.env</code>
|
||||||
|
by running <code>skyvern init</code> or click the button below
|
||||||
|
to regenerate it automatically.
|
||||||
|
</p>
|
||||||
|
{isProductionBuild && (
|
||||||
|
<p className="text-yellow-300">
|
||||||
|
When running a production build, the regenerated API key is
|
||||||
|
stored in sessionStorage. Closing this tab or browser window
|
||||||
|
will lose the key. Restart the UI server for more robust
|
||||||
|
persistence.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
<div className="flex justify-center">
|
||||||
|
<Button
|
||||||
|
onClick={handleRepair}
|
||||||
|
disabled={isRepairing}
|
||||||
|
variant="secondary"
|
||||||
|
>
|
||||||
|
{isRepairing ? "Regenerating…" : "Regenerate API key"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<p>{copy.description}</p>
|
||||||
|
)}
|
||||||
|
{errorMessage ? (
|
||||||
|
<p className="text-xs text-rose-200">{errorMessage}</p>
|
||||||
|
) : null}
|
||||||
|
{queryErrorMessage && !errorMessage ? (
|
||||||
|
<p className="text-xs text-rose-200">{queryErrorMessage}</p>
|
||||||
|
) : null}
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export { SelfHealApiKeyBanner };
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
import { useQuery } from "@tanstack/react-query";
|
import { useQuery } from "@tanstack/react-query";
|
||||||
import { useCredentialGetter } from "./useCredentialGetter";
|
import { useCredentialGetter } from "./useCredentialGetter";
|
||||||
import { getClient } from "@/api/AxiosClient";
|
import { getClient } from "@/api/AxiosClient";
|
||||||
import { envCredential } from "@/util/env";
|
import { getRuntimeApiKey } from "@/util/env";
|
||||||
import { ApiKeyApiResponse, OrganizationApiResponse } from "@/api/types";
|
import { ApiKeyApiResponse, OrganizationApiResponse } from "@/api/types";
|
||||||
|
|
||||||
function useApiCredential() {
|
function useApiCredential() {
|
||||||
const credentialGetter = useCredentialGetter();
|
const credentialGetter = useCredentialGetter();
|
||||||
const credentialsFromEnv = envCredential;
|
const credentialsFromEnv = getRuntimeApiKey();
|
||||||
|
|
||||||
const { data: organizations } = useQuery<Array<OrganizationApiResponse>>({
|
const { data: organizations } = useQuery<Array<OrganizationApiResponse>>({
|
||||||
queryKey: ["organizations"],
|
queryKey: ["organizations"],
|
||||||
@@ -16,7 +16,7 @@ function useApiCredential() {
|
|||||||
.get("/organizations/")
|
.get("/organizations/")
|
||||||
.then((response) => response.data.organizations);
|
.then((response) => response.data.organizations);
|
||||||
},
|
},
|
||||||
enabled: envCredential === null,
|
enabled: credentialsFromEnv === null,
|
||||||
});
|
});
|
||||||
|
|
||||||
const organization = organizations?.[0];
|
const organization = organizations?.[0];
|
||||||
|
|||||||
46
skyvern-frontend/src/hooks/useAuthDiagnostics.ts
Normal file
46
skyvern-frontend/src/hooks/useAuthDiagnostics.ts
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import { useQuery } from "@tanstack/react-query";
|
||||||
|
import axios from "axios";
|
||||||
|
|
||||||
|
import { getClient } from "@/api/AxiosClient";
|
||||||
|
|
||||||
|
export type AuthStatusValue =
|
||||||
|
| "missing_env"
|
||||||
|
| "invalid_format"
|
||||||
|
| "invalid"
|
||||||
|
| "expired"
|
||||||
|
| "not_found"
|
||||||
|
| "ok";
|
||||||
|
|
||||||
|
export type AuthDiagnosticsResponse = {
|
||||||
|
status: AuthStatusValue;
|
||||||
|
fingerprint?: string;
|
||||||
|
organization_id?: string;
|
||||||
|
expires_at?: number;
|
||||||
|
api_key?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
async function fetchDiagnostics(): Promise<AuthDiagnosticsResponse> {
|
||||||
|
const client = await getClient(null);
|
||||||
|
try {
|
||||||
|
const response = await client.get<AuthDiagnosticsResponse>(
|
||||||
|
"/internal/auth/status",
|
||||||
|
);
|
||||||
|
return response.data;
|
||||||
|
} catch (error) {
|
||||||
|
if (axios.isAxiosError(error) && error.response?.status === 404) {
|
||||||
|
return { status: "ok" };
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function useAuthDiagnostics() {
|
||||||
|
return useQuery<AuthDiagnosticsResponse, Error>({
|
||||||
|
queryKey: ["internal", "auth", "status"],
|
||||||
|
queryFn: fetchDiagnostics,
|
||||||
|
retry: false,
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export { useAuthDiagnostics };
|
||||||
@@ -5,6 +5,7 @@ import { Outlet } from "react-router-dom";
|
|||||||
import { Header } from "./Header";
|
import { Header } from "./Header";
|
||||||
import { Sidebar } from "./Sidebar";
|
import { Sidebar } from "./Sidebar";
|
||||||
import { useDebugStore } from "@/store/useDebugStore";
|
import { useDebugStore } from "@/store/useDebugStore";
|
||||||
|
import { SelfHealApiKeyBanner } from "@/components/SelfHealApiKeyBanner";
|
||||||
|
|
||||||
function RootLayout() {
|
function RootLayout() {
|
||||||
const collapsed = useSidebarStore((state) => state.collapsed);
|
const collapsed = useSidebarStore((state) => state.collapsed);
|
||||||
@@ -12,15 +13,21 @@ function RootLayout() {
|
|||||||
const isEmbedded = embed === "true";
|
const isEmbedded = embed === "true";
|
||||||
const debugStore = useDebugStore();
|
const debugStore = useDebugStore();
|
||||||
|
|
||||||
|
const horizontalPadding = cn("lg:pl-64", {
|
||||||
|
"lg:pl-28": collapsed,
|
||||||
|
"lg:pl-4": isEmbedded,
|
||||||
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{!isEmbedded && <Sidebar />}
|
{!isEmbedded && <Sidebar />}
|
||||||
<div className="h-full w-full">
|
<div className="h-full w-full">
|
||||||
|
<div className={horizontalPadding}>
|
||||||
|
<SelfHealApiKeyBanner />
|
||||||
|
</div>
|
||||||
<Header />
|
<Header />
|
||||||
<main
|
<main
|
||||||
className={cn("lg:pb-4 lg:pl-64", {
|
className={cn("lg:pb-4", horizontalPadding, {
|
||||||
"lg:pl-28": collapsed,
|
|
||||||
"lg:pl-4": isEmbedded,
|
|
||||||
"lg:pb-0": debugStore.isDebugMode,
|
"lg:pb-0": debugStore.isDebugMode,
|
||||||
})}
|
})}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import {
|
|||||||
CardHeader,
|
CardHeader,
|
||||||
CardTitle,
|
CardTitle,
|
||||||
} from "@/components/ui/card";
|
} from "@/components/ui/card";
|
||||||
import { envCredential } from "@/util/env";
|
import { getRuntimeApiKey } from "@/util/env";
|
||||||
import { HiddenCopyableInput } from "@/components/ui/hidden-copyable-input";
|
import { HiddenCopyableInput } from "@/components/ui/hidden-copyable-input";
|
||||||
import { OnePasswordTokenForm } from "@/components/OnePasswordTokenForm";
|
import { OnePasswordTokenForm } from "@/components/OnePasswordTokenForm";
|
||||||
import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSecretCredentialTokenForm";
|
import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSecretCredentialTokenForm";
|
||||||
@@ -22,7 +22,7 @@ import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSe
|
|||||||
function Settings() {
|
function Settings() {
|
||||||
const { environment, organization, setEnvironment, setOrganization } =
|
const { environment, organization, setEnvironment, setOrganization } =
|
||||||
useSettingsStore();
|
useSettingsStore();
|
||||||
const apiKey = envCredential;
|
const apiKey = getRuntimeApiKey();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-8">
|
<div className="flex flex-col gap-8">
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { toast } from "@/components/ui/use-toast";
|
|||||||
import { ZoomableImage } from "@/components/ZoomableImage";
|
import { ZoomableImage } from "@/components/ZoomableImage";
|
||||||
import { useCostCalculator } from "@/hooks/useCostCalculator";
|
import { useCostCalculator } from "@/hooks/useCostCalculator";
|
||||||
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
|
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
|
||||||
import { envCredential } from "@/util/env";
|
import { getRuntimeApiKey } from "@/util/env";
|
||||||
import {
|
import {
|
||||||
keepPreviousData,
|
keepPreviousData,
|
||||||
useQuery,
|
useQuery,
|
||||||
@@ -79,7 +79,8 @@ function TaskActions() {
|
|||||||
const token = await credentialGetter();
|
const token = await credentialGetter();
|
||||||
credential = `?token=Bearer ${token}`;
|
credential = `?token=Bearer ${token}`;
|
||||||
} else {
|
} else {
|
||||||
credential = `?apikey=${envCredential}`;
|
const apiKey = getRuntimeApiKey();
|
||||||
|
credential = apiKey ? `?apikey=${apiKey}` : "";
|
||||||
}
|
}
|
||||||
if (socket) {
|
if (socket) {
|
||||||
socket.close();
|
socket.close();
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { useEffect, useState } from "react";
|
|||||||
import { statusIsNotFinalized } from "@/routes/tasks/types";
|
import { statusIsNotFinalized } from "@/routes/tasks/types";
|
||||||
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
|
import { useCredentialGetter } from "@/hooks/useCredentialGetter";
|
||||||
import { useParams } from "react-router-dom";
|
import { useParams } from "react-router-dom";
|
||||||
import { envCredential } from "@/util/env";
|
import { getRuntimeApiKey } from "@/util/env";
|
||||||
import { toast } from "@/components/ui/use-toast";
|
import { toast } from "@/components/ui/use-toast";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
|
||||||
@@ -45,7 +45,8 @@ function WorkflowRunStream(props?: Props) {
|
|||||||
const token = await credentialGetter();
|
const token = await credentialGetter();
|
||||||
credential = `?token=Bearer ${token}`;
|
credential = `?token=Bearer ${token}`;
|
||||||
} else {
|
} else {
|
||||||
credential = `?apikey=${envCredential}`;
|
const apiKey = getRuntimeApiKey();
|
||||||
|
credential = apiKey ? `?apikey=${apiKey}` : "";
|
||||||
}
|
}
|
||||||
if (socket) {
|
if (socket) {
|
||||||
socket.close();
|
socket.close();
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ if (!environment) {
|
|||||||
console.warn("environment environment variable was not set");
|
console.warn("environment environment variable was not set");
|
||||||
}
|
}
|
||||||
|
|
||||||
const envCredential: string | null =
|
const buildTimeApiKey: string | null =
|
||||||
import.meta.env.VITE_SKYVERN_API_KEY ?? null;
|
typeof import.meta.env.VITE_SKYVERN_API_KEY === "string"
|
||||||
|
? import.meta.env.VITE_SKYVERN_API_KEY
|
||||||
|
: null;
|
||||||
|
|
||||||
const artifactApiBaseUrl = import.meta.env.VITE_ARTIFACT_API_BASE_URL;
|
const artifactApiBaseUrl = import.meta.env.VITE_ARTIFACT_API_BASE_URL;
|
||||||
|
|
||||||
@@ -21,8 +23,11 @@ if (!artifactApiBaseUrl) {
|
|||||||
|
|
||||||
const apiPathPrefix = import.meta.env.VITE_API_PATH_PREFIX ?? "";
|
const apiPathPrefix = import.meta.env.VITE_API_PATH_PREFIX ?? "";
|
||||||
|
|
||||||
|
const API_KEY_STORAGE_KEY = "skyvern.apiKey";
|
||||||
|
|
||||||
const lsKeys = {
|
const lsKeys = {
|
||||||
browserSessionId: "skyvern.browserSessionId",
|
browserSessionId: "skyvern.browserSessionId",
|
||||||
|
apiKey: API_KEY_STORAGE_KEY,
|
||||||
};
|
};
|
||||||
|
|
||||||
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
|
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
|
||||||
@@ -38,13 +43,53 @@ try {
|
|||||||
newWssBaseUrl = wssBaseUrl.replace("/api", "");
|
newWssBaseUrl = wssBaseUrl.replace("/api", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let runtimeApiKey: string | null | undefined;
|
||||||
|
|
||||||
|
function readPersistedApiKey(): string | null {
|
||||||
|
if (typeof window === "undefined") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return window.sessionStorage.getItem(API_KEY_STORAGE_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
|
function getRuntimeApiKey(): string | null {
|
||||||
|
if (runtimeApiKey !== undefined) {
|
||||||
|
return runtimeApiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const persisted = readPersistedApiKey();
|
||||||
|
const candidate = persisted ?? buildTimeApiKey;
|
||||||
|
|
||||||
|
// Treat YOUR_API_KEY as missing. We may inherit this from .env.example
|
||||||
|
// in some cases of misconfiguration.
|
||||||
|
runtimeApiKey = candidate === "YOUR_API_KEY" ? null : candidate;
|
||||||
|
return runtimeApiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
function persistRuntimeApiKey(value: string): void {
|
||||||
|
runtimeApiKey = value;
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.sessionStorage.setItem(API_KEY_STORAGE_KEY, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function clearRuntimeApiKey(): void {
|
||||||
|
runtimeApiKey = null;
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.sessionStorage.removeItem(API_KEY_STORAGE_KEY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export {
|
export {
|
||||||
apiBaseUrl,
|
apiBaseUrl,
|
||||||
environment,
|
environment,
|
||||||
envCredential,
|
|
||||||
artifactApiBaseUrl,
|
artifactApiBaseUrl,
|
||||||
apiPathPrefix,
|
apiPathPrefix,
|
||||||
lsKeys,
|
lsKeys,
|
||||||
wssBaseUrl,
|
wssBaseUrl,
|
||||||
newWssBaseUrl,
|
newWssBaseUrl,
|
||||||
|
getRuntimeApiKey,
|
||||||
|
persistRuntimeApiKey,
|
||||||
|
clearRuntimeApiKey,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from skyvern.forge.request_logging import log_raw_request_middleware
|
|||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||||
|
from skyvern.forge.sdk.routes import internal_auth
|
||||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router, legacy_v2_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router, legacy_v2_router
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
@@ -68,6 +69,13 @@ def get_agent_app() -> FastAPI:
|
|||||||
app.include_router(base_router, prefix="/v1")
|
app.include_router(base_router, prefix="/v1")
|
||||||
app.include_router(legacy_base_router, prefix="/api/v1")
|
app.include_router(legacy_base_router, prefix="/api/v1")
|
||||||
app.include_router(legacy_v2_router, prefix="/api/v2")
|
app.include_router(legacy_v2_router, prefix="/api/v2")
|
||||||
|
|
||||||
|
# local dev endpoints
|
||||||
|
if settings.ENV == "local":
|
||||||
|
app.include_router(internal_auth.router, prefix="/v1")
|
||||||
|
app.include_router(internal_auth.router, prefix="/api/v1")
|
||||||
|
app.include_router(internal_auth.router, prefix="/api/v2")
|
||||||
|
|
||||||
app.openapi = custom_openapi
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|||||||
133
skyvern/forge/sdk/routes/internal_auth.py
Normal file
133
skyvern/forge/sdk/routes/internal_auth.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import ipaddress
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import APIRouter, HTTPException, Request, status
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.services.local_org_auth_token_service import fingerprint_token, regenerate_local_api_key
|
||||||
|
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/internal/auth", tags=["internal"])
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AuthStatus(str, Enum):
|
||||||
|
missing_env = "missing_env"
|
||||||
|
invalid_format = "invalid_format"
|
||||||
|
invalid = "invalid"
|
||||||
|
expired = "expired"
|
||||||
|
not_found = "not_found"
|
||||||
|
ok = "ok"
|
||||||
|
|
||||||
|
|
||||||
|
class DiagnosticsResult(NamedTuple):
|
||||||
|
status: AuthStatus
|
||||||
|
detail: str | None
|
||||||
|
validation: Any | None
|
||||||
|
token: str | None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_local_request(request: Request) -> bool:
|
||||||
|
host = request.client.host if request.client else None
|
||||||
|
if not host:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return addr.is_loopback or addr.is_private
|
||||||
|
|
||||||
|
|
||||||
|
def _require_local_access(request: Request) -> None:
|
||||||
|
if settings.ENV != "local":
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN, "Endpoint only available in local env")
|
||||||
|
if not _is_local_request(request):
|
||||||
|
raise HTTPException(status.HTTP_403_FORBIDDEN, "Endpoint requires localhost access")
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_local_api_key(token: str) -> DiagnosticsResult:
|
||||||
|
token_candidate = token.strip()
|
||||||
|
if not token_candidate or token_candidate == "YOUR_API_KEY":
|
||||||
|
return DiagnosticsResult(status=AuthStatus.missing_env, detail=None, validation=None, token=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validation = await resolve_org_from_api_key(token_candidate, app.DATABASE)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code == status.HTTP_404_NOT_FOUND:
|
||||||
|
return DiagnosticsResult(status=AuthStatus.not_found, detail=None, token=None, validation=None)
|
||||||
|
|
||||||
|
detail_text = exc.detail if isinstance(exc.detail, str) else None
|
||||||
|
if exc.status_code == status.HTTP_403_FORBIDDEN:
|
||||||
|
status_value = AuthStatus.invalid
|
||||||
|
if detail_text and "expired" in detail_text.lower():
|
||||||
|
status_value = AuthStatus.expired
|
||||||
|
elif detail_text and "validate" in detail_text.lower():
|
||||||
|
status_value = AuthStatus.invalid_format
|
||||||
|
return DiagnosticsResult(status=status_value, detail=detail_text, token=None, validation=None)
|
||||||
|
|
||||||
|
LOG.error("Unexpected error while diagnosing API key", status_code=exc.status_code, detail=detail_text)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
LOG.error("Unexpected exception while diagnosing API key", exc_info=True)
|
||||||
|
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Unable to diagnose API key")
|
||||||
|
|
||||||
|
return DiagnosticsResult(status=AuthStatus.ok, detail=None, validation=validation, token=token_candidate)
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_diagnostics(result: DiagnosticsResult) -> dict[str, object]:
|
||||||
|
status_value = result.status.value
|
||||||
|
|
||||||
|
if result.status is AuthStatus.ok and result.validation and result.token:
|
||||||
|
fingerprint = fingerprint_token(result.token)
|
||||||
|
LOG.info(
|
||||||
|
"Local auth diagnostics",
|
||||||
|
status=status_value,
|
||||||
|
organization_id=result.validation.organization.organization_id,
|
||||||
|
fingerprint=fingerprint,
|
||||||
|
expires_at=result.validation.payload.exp,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": status_value,
|
||||||
|
"organization_id": result.validation.organization.organization_id,
|
||||||
|
"fingerprint": fingerprint,
|
||||||
|
"expires_at": result.validation.payload.exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
log_kwargs: dict[str, object] = {"status": status_value}
|
||||||
|
if result.detail:
|
||||||
|
log_kwargs["detail"] = result.detail
|
||||||
|
|
||||||
|
LOG.warning("Local auth diagnostics", **log_kwargs)
|
||||||
|
|
||||||
|
return {"status": status_value}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/repair", include_in_schema=False)
|
||||||
|
async def repair_api_key(request: Request) -> dict[str, object]:
|
||||||
|
_require_local_access(request)
|
||||||
|
|
||||||
|
token, organization_id, backend_env_path, frontend_env_path = await regenerate_local_api_key()
|
||||||
|
|
||||||
|
response: dict[str, object] = {
|
||||||
|
"status": AuthStatus.ok.value,
|
||||||
|
"organization_id": organization_id,
|
||||||
|
"fingerprint": fingerprint_token(token),
|
||||||
|
"api_key": token,
|
||||||
|
"backend_env_path": backend_env_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
if frontend_env_path:
|
||||||
|
response["frontend_env_path"] = frontend_env_path
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status", include_in_schema=False)
|
||||||
|
async def auth_status(request: Request) -> dict[str, object]:
|
||||||
|
_require_local_access(request)
|
||||||
|
token_candidate = request.headers.get("x-api-key") or ""
|
||||||
|
result = await _evaluate_local_api_key(token_candidate)
|
||||||
|
return _emit_diagnostics(result)
|
||||||
83
skyvern/forge/sdk/services/local_org_auth_token_service.py
Normal file
83
skyvern/forge/sdk/services/local_org_auth_token_service.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from dotenv import set_key
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.core import security
|
||||||
|
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
|
||||||
|
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||||
|
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
SKYVERN_LOCAL_ORG = "Skyvern-local"
|
||||||
|
SKYVERN_LOCAL_DOMAIN = "skyvern.local"
|
||||||
|
|
||||||
|
|
||||||
|
def _write_env(path: Path, key: str, value: str) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not path.exists():
|
||||||
|
path.touch()
|
||||||
|
set_key(path, key, value)
|
||||||
|
LOG.info(".env written", path=str(path), key=key)
|
||||||
|
|
||||||
|
|
||||||
|
def fingerprint_token(value: str) -> str:
|
||||||
|
return f"{value[:6]}…{value[-4:]}" if len(value) > 12 else "[redacted -- token too short]"
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_local_org() -> Organization:
|
||||||
|
"""Ensure the local development organization exists and return it."""
|
||||||
|
organization = await app.DATABASE.get_organization_by_domain(SKYVERN_LOCAL_DOMAIN)
|
||||||
|
if organization:
|
||||||
|
return organization
|
||||||
|
|
||||||
|
return await app.DATABASE.create_organization(
|
||||||
|
organization_name=SKYVERN_LOCAL_ORG,
|
||||||
|
domain=SKYVERN_LOCAL_DOMAIN,
|
||||||
|
max_steps_per_run=10,
|
||||||
|
max_retries_per_step=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def regenerate_local_api_key() -> tuple[str, str, str, str | None]:
|
||||||
|
"""Create a fresh API key for the local organization and persist it to env files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (api_key, org_id, backend_env_path, frontend_env_path_or_none)
|
||||||
|
"""
|
||||||
|
organization = await ensure_local_org()
|
||||||
|
org_id = organization.organization_id
|
||||||
|
|
||||||
|
await app.DATABASE.invalidate_org_auth_tokens(
|
||||||
|
organization_id=org_id,
|
||||||
|
token_type=OrganizationAuthTokenType.api,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = security.create_access_token(org_id, expires_delta=API_KEY_LIFETIME)
|
||||||
|
await app.DATABASE.create_org_auth_token(
|
||||||
|
organization_id=org_id,
|
||||||
|
token=api_key,
|
||||||
|
token_type=OrganizationAuthTokenType.api,
|
||||||
|
)
|
||||||
|
|
||||||
|
backend_env_path = resolve_backend_env_path()
|
||||||
|
_write_env(backend_env_path, "SKYVERN_API_KEY", api_key)
|
||||||
|
|
||||||
|
frontend_env_path = resolve_frontend_env_path()
|
||||||
|
if frontend_env_path:
|
||||||
|
_write_env(frontend_env_path, "VITE_SKYVERN_API_KEY", api_key)
|
||||||
|
else:
|
||||||
|
LOG.warning("Frontend directory not found; skipping VITE_SKYVERN_API_KEY update")
|
||||||
|
|
||||||
|
settings.SKYVERN_API_KEY = api_key
|
||||||
|
os.environ["SKYVERN_API_KEY"] = api_key
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Local API key regenerated",
|
||||||
|
organization_id=org_id,
|
||||||
|
fingerprint=fingerprint_token(api_key),
|
||||||
|
)
|
||||||
|
return api_key, org_id, str(backend_env_path), str(frontend_env_path) if frontend_env_path else None
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
@@ -15,7 +16,11 @@ from skyvern.forge import app
|
|||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.db.client import AgentDB
|
from skyvern.forge.sdk.db.client import AgentDB
|
||||||
from skyvern.forge.sdk.models import TokenPayload
|
from skyvern.forge.sdk.models import TokenPayload
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
|
from skyvern.forge.sdk.schemas.organizations import (
|
||||||
|
Organization,
|
||||||
|
OrganizationAuthToken,
|
||||||
|
OrganizationAuthTokenType,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
@@ -24,6 +29,13 @@ CACHE_SIZE = 128
|
|||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ApiKeyValidationResult:
|
||||||
|
organization: Organization
|
||||||
|
payload: TokenPayload
|
||||||
|
token: OrganizationAuthToken
|
||||||
|
|
||||||
|
|
||||||
async def get_current_org(
|
async def get_current_org(
|
||||||
x_api_key: Annotated[
|
x_api_key: Annotated[
|
||||||
str | None,
|
str | None,
|
||||||
@@ -159,11 +171,11 @@ async def _authenticate_user_helper(authorization: str) -> str:
|
|||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
async def resolve_org_from_api_key(
|
||||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
x_api_key: str,
|
||||||
"""
|
db: AgentDB,
|
||||||
Authentication is cached for one hour
|
) -> ApiKeyValidationResult:
|
||||||
"""
|
"""Decode and validate the API key against the database."""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
x_api_key,
|
x_api_key,
|
||||||
@@ -188,7 +200,6 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
|||||||
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
|
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
|
||||||
raise HTTPException(status_code=404, detail="Organization not found")
|
raise HTTPException(status_code=404, detail="Organization not found")
|
||||||
|
|
||||||
# check if the token exists in the database
|
|
||||||
api_key_db_obj = await db.validate_org_auth_token(
|
api_key_db_obj = await db.validate_org_auth_token(
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
token_type=OrganizationAuthTokenType.api,
|
token_type=OrganizationAuthTokenType.api,
|
||||||
@@ -207,9 +218,21 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
|||||||
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
|
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return ApiKeyValidationResult(
|
||||||
|
organization=organization,
|
||||||
|
payload=api_key_data,
|
||||||
|
token=api_key_db_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||||
|
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||||
|
"""Authentication is cached for one hour."""
|
||||||
|
validation = await resolve_org_from_api_key(x_api_key, db)
|
||||||
|
|
||||||
# set organization_id in skyvern context and log context
|
# set organization_id in skyvern context and log context
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
if context:
|
if context:
|
||||||
context.organization_id = organization.organization_id
|
context.organization_id = validation.organization.organization_id
|
||||||
context.organization_name = organization.organization_name
|
context.organization_name = validation.organization.organization_name
|
||||||
return organization
|
return validation.organization
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
|||||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Status
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Status
|
||||||
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
||||||
|
from skyvern.forge.sdk.services.local_org_auth_token_service import SKYVERN_LOCAL_DOMAIN, SKYVERN_LOCAL_ORG
|
||||||
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||||
from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
from skyvern.library.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
||||||
@@ -96,11 +97,11 @@ class Skyvern(AsyncSkyvern):
|
|||||||
raise ValueError("Initializing Skyvern failed: api_key must be provided")
|
raise ValueError("Initializing Skyvern failed: api_key must be provided")
|
||||||
|
|
||||||
async def get_organization(self) -> Organization:
|
async def get_organization(self) -> Organization:
|
||||||
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
|
organization = await app.DATABASE.get_organization_by_domain(SKYVERN_LOCAL_DOMAIN)
|
||||||
if not organization:
|
if not organization:
|
||||||
organization = await app.DATABASE.create_organization(
|
organization = await app.DATABASE.create_organization(
|
||||||
organization_name="Skyvern-local",
|
organization_name=SKYVERN_LOCAL_ORG,
|
||||||
domain="skyvern.local",
|
domain=SKYVERN_LOCAL_DOMAIN,
|
||||||
max_steps_per_run=10,
|
max_steps_per_run=10,
|
||||||
max_retries_per_step=3,
|
max_retries_per_step=3,
|
||||||
)
|
)
|
||||||
|
|||||||
35
tests/unit_tests/test_internal_auth.py
Normal file
35
tests/unit_tests/test_internal_auth.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.routes.internal_auth import _is_local_request
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(host: str | None) -> Request:
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"client": (host, 12345) if host else None,
|
||||||
|
"headers": [],
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"scheme": "http",
|
||||||
|
}
|
||||||
|
return Request(scope)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_local_request_returns_false_for_public_ip() -> None:
|
||||||
|
request = _make_request("8.8.8.8") # public IPv4 address
|
||||||
|
assert _is_local_request(request) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_local_request_accepts_loopback() -> None:
|
||||||
|
request = _make_request("127.0.0.1")
|
||||||
|
assert _is_local_request(request) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_local_request_accepts_private_networks() -> None:
|
||||||
|
request = _make_request("192.168.1.20")
|
||||||
|
assert _is_local_request(request) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_local_request_handles_missing_client() -> None:
|
||||||
|
request = _make_request(None)
|
||||||
|
assert _is_local_request(request) is False
|
||||||
Reference in New Issue
Block a user