Azure ClientSecretCredential support (#3456)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -194,6 +194,30 @@ export type CreateOnePasswordTokenResponse = {
|
|||||||
token: OnePasswordTokenApiResponse;
|
token: OnePasswordTokenApiResponse;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export interface AzureClientSecretCredential {
|
||||||
|
tenant_id: string;
|
||||||
|
client_id: string;
|
||||||
|
client_secret: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AzureOrganizationAuthToken {
|
||||||
|
id: string;
|
||||||
|
organization_id: string;
|
||||||
|
credential: AzureClientSecretCredential;
|
||||||
|
created_at: string;
|
||||||
|
modified_at: string;
|
||||||
|
token_type: string;
|
||||||
|
valid: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CreateAzureClientSecretCredentialRequest {
|
||||||
|
credential: AzureClientSecretCredential;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AzureClientSecretCredentialResponse {
|
||||||
|
token: AzureOrganizationAuthToken;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO complete this
|
// TODO complete this
|
||||||
export const ActionTypes = {
|
export const ActionTypes = {
|
||||||
InputText: "input_text",
|
InputText: "input_text",
|
||||||
|
|||||||
@@ -0,0 +1,202 @@
|
|||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { useForm } from "react-hook-form";
|
||||||
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
|
import * as z from "zod";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import {
|
||||||
|
Form,
|
||||||
|
FormControl,
|
||||||
|
FormField,
|
||||||
|
FormItem,
|
||||||
|
FormLabel,
|
||||||
|
FormMessage,
|
||||||
|
} from "@/components/ui/form";
|
||||||
|
import { useAzureClientCredentialToken } from "@/hooks/useAzureClientCredentialToken";
|
||||||
|
import { EyeOpenIcon, EyeClosedIcon } from "@radix-ui/react-icons";
|
||||||
|
|
||||||
|
const AzureClientSecretCredentialSchema = z
|
||||||
|
.object({
|
||||||
|
tenant_id: z.string().min(1, "tenant_id is required"),
|
||||||
|
client_id: z.string().min(1, "client_id is required"),
|
||||||
|
client_secret: z.string().min(1, "client_secret is required"),
|
||||||
|
})
|
||||||
|
.strict();
|
||||||
|
|
||||||
|
const formSchema = z
|
||||||
|
.object({
|
||||||
|
credential: AzureClientSecretCredentialSchema,
|
||||||
|
})
|
||||||
|
.strict();
|
||||||
|
|
||||||
|
type FormData = z.infer<typeof formSchema>;
|
||||||
|
|
||||||
|
export function AzureClientSecretCredentialTokenForm() {
|
||||||
|
const [showClientSecret, setShowClientSecret] = useState(false);
|
||||||
|
const {
|
||||||
|
azureOrganizationAuthToken,
|
||||||
|
isLoading,
|
||||||
|
createOrUpdateToken,
|
||||||
|
isUpdating,
|
||||||
|
} = useAzureClientCredentialToken();
|
||||||
|
|
||||||
|
const form = useForm<FormData>({
|
||||||
|
resolver: zodResolver(formSchema),
|
||||||
|
defaultValues: {
|
||||||
|
credential: azureOrganizationAuthToken?.credential || {
|
||||||
|
tenant_id: "",
|
||||||
|
client_id: "",
|
||||||
|
client_secret: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const onSubmit = (data: FormData) => {
|
||||||
|
createOrUpdateToken(data);
|
||||||
|
};
|
||||||
|
|
||||||
|
const toggleClientSecretVisibility = () => {
|
||||||
|
setShowClientSecret((v) => !v);
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (azureOrganizationAuthToken?.credential) {
|
||||||
|
form.reset({ credential: azureOrganizationAuthToken.credential });
|
||||||
|
}
|
||||||
|
}, [azureOrganizationAuthToken, form]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<h3 className="text-lg font-medium">
|
||||||
|
Azure Client Secret Credential
|
||||||
|
</h3>
|
||||||
|
<p className="text-sm text-muted-foreground">
|
||||||
|
Configure your Azure Client Secret Credential to give access to your
|
||||||
|
Azure account.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
{azureOrganizationAuthToken && (
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-sm text-muted-foreground">Status:</span>
|
||||||
|
<span
|
||||||
|
className={`text-sm ${azureOrganizationAuthToken.valid ? "text-green-600" : "text-red-600"}`}
|
||||||
|
>
|
||||||
|
{azureOrganizationAuthToken.valid ? "Active" : "Inactive"}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Form {...form}>
|
||||||
|
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="credential.tenant_id"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Tenant ID</FormLabel>
|
||||||
|
<div className="relative">
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
{...field}
|
||||||
|
type="text"
|
||||||
|
placeholder="tenant_id"
|
||||||
|
disabled={isLoading || isUpdating}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</div>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="credential.client_id"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Client ID</FormLabel>
|
||||||
|
<div className="relative">
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
{...field}
|
||||||
|
type="text"
|
||||||
|
placeholder="client_id"
|
||||||
|
disabled={isLoading || isUpdating}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</div>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="credential.client_secret"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Client Secret</FormLabel>
|
||||||
|
<div className="relative">
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
{...field}
|
||||||
|
type={showClientSecret ? "text" : "password"}
|
||||||
|
placeholder="client_secret"
|
||||||
|
disabled={isLoading || isUpdating}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="absolute right-0 top-0 h-full px-3 py-2 hover:bg-transparent"
|
||||||
|
onClick={toggleClientSecretVisibility}
|
||||||
|
disabled={isLoading || isUpdating}
|
||||||
|
>
|
||||||
|
{showClientSecret ? (
|
||||||
|
<EyeClosedIcon className="h-4 w-4" />
|
||||||
|
) : (
|
||||||
|
<EyeOpenIcon className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-4">
|
||||||
|
<Button type="submit" disabled={isLoading || isUpdating}>
|
||||||
|
{isUpdating ? "Updating..." : "Update Credential"}
|
||||||
|
</Button>
|
||||||
|
{azureOrganizationAuthToken && (
|
||||||
|
<div className="text-sm text-muted-foreground">
|
||||||
|
Last updated:{" "}
|
||||||
|
{new Date(
|
||||||
|
azureOrganizationAuthToken.modified_at,
|
||||||
|
).toLocaleDateString()}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</Form>
|
||||||
|
|
||||||
|
{azureOrganizationAuthToken && (
|
||||||
|
<div className="rounded-md bg-muted p-4">
|
||||||
|
<h4 className="mb-2 text-sm font-medium">Credential Information</h4>
|
||||||
|
<div className="space-y-1 text-sm text-muted-foreground">
|
||||||
|
<div>ID: {azureOrganizationAuthToken.id}</div>
|
||||||
|
<div>Type: {azureOrganizationAuthToken.token_type}</div>
|
||||||
|
<div>
|
||||||
|
Created:{" "}
|
||||||
|
{new Date(
|
||||||
|
azureOrganizationAuthToken.created_at,
|
||||||
|
).toLocaleDateString()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
66
skyvern-frontend/src/hooks/useAzureClientCredentialToken.ts
Normal file
66
skyvern-frontend/src/hooks/useAzureClientCredentialToken.ts
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { getClient } from "@/api/AxiosClient";
|
||||||
|
import { useCredentialGetter } from "./useCredentialGetter";
|
||||||
|
import {
|
||||||
|
AzureClientSecretCredentialResponse,
|
||||||
|
AzureOrganizationAuthToken,
|
||||||
|
CreateAzureClientSecretCredentialRequest,
|
||||||
|
} from "@/api/types";
|
||||||
|
import { useToast } from "@/components/ui/use-toast";
|
||||||
|
|
||||||
|
export function useAzureClientCredentialToken() {
|
||||||
|
const credentialGetter = useCredentialGetter();
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const { toast } = useToast();
|
||||||
|
|
||||||
|
const { data: azureOrganizationAuthToken, isLoading } =
|
||||||
|
useQuery<AzureOrganizationAuthToken>({
|
||||||
|
queryKey: ["azureOrganizationAuthToken"],
|
||||||
|
queryFn: async () => {
|
||||||
|
const client = await getClient(credentialGetter, "sans-api-v1");
|
||||||
|
return await client
|
||||||
|
.get("/credentials/azure_credential/get")
|
||||||
|
.then((response) => response.data.token)
|
||||||
|
.catch(() => null);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const createOrUpdateTokenMutation = useMutation({
|
||||||
|
mutationFn: async (data: CreateAzureClientSecretCredentialRequest) => {
|
||||||
|
const client = await getClient(credentialGetter, "sans-api-v1");
|
||||||
|
return await client
|
||||||
|
.post("/credentials/azure_credential/create", data)
|
||||||
|
.then(
|
||||||
|
(response) => response.data as AzureClientSecretCredentialResponse,
|
||||||
|
);
|
||||||
|
},
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["azureOrganizationAuthToken"],
|
||||||
|
});
|
||||||
|
toast({
|
||||||
|
title: "Success",
|
||||||
|
description: "Azure Client Secret Credential updated successfully",
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onError: (error: unknown) => {
|
||||||
|
const message =
|
||||||
|
(error as { response?: { data?: { detail?: string } } })?.response?.data
|
||||||
|
?.detail ||
|
||||||
|
(error as Error)?.message ||
|
||||||
|
"Failed to update Azure Client Secret Credential";
|
||||||
|
toast({
|
||||||
|
title: "Error",
|
||||||
|
description: message,
|
||||||
|
variant: "destructive",
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
azureOrganizationAuthToken,
|
||||||
|
isLoading,
|
||||||
|
createOrUpdateToken: createOrUpdateTokenMutation.mutate,
|
||||||
|
isUpdating: createOrUpdateTokenMutation.isPending,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ import {
|
|||||||
import { envCredential } from "@/util/env";
|
import { envCredential } from "@/util/env";
|
||||||
import { HiddenCopyableInput } from "@/components/ui/hidden-copyable-input";
|
import { HiddenCopyableInput } from "@/components/ui/hidden-copyable-input";
|
||||||
import { OnePasswordTokenForm } from "@/components/OnePasswordTokenForm";
|
import { OnePasswordTokenForm } from "@/components/OnePasswordTokenForm";
|
||||||
|
import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSecretCredentialTokenForm";
|
||||||
|
|
||||||
function Settings() {
|
function Settings() {
|
||||||
const { environment, organization, setEnvironment, setOrganization } =
|
const { environment, organization, setEnvironment, setOrganization } =
|
||||||
@@ -87,6 +88,15 @@ function Settings() {
|
|||||||
<OnePasswordTokenForm />
|
<OnePasswordTokenForm />
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
<Card>
|
||||||
|
<CardHeader className="border-b-2">
|
||||||
|
<CardTitle className="text-lg">Azure Integration</CardTitle>
|
||||||
|
<CardDescription>Manage your Azure integration</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="p-8">
|
||||||
|
<AzureClientSecretCredentialTokenForm />
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,39 +1,24 @@
|
|||||||
import structlog
|
import structlog
|
||||||
from azure.identity.aio import DefaultAzureCredential
|
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
|
||||||
from azure.keyvault.secrets.aio import SecretClient
|
from azure.keyvault.secrets.aio import SecretClient
|
||||||
from azure.storage.blob.aio import BlobServiceClient
|
from azure.storage.blob.aio import BlobServiceClient
|
||||||
|
|
||||||
from skyvern.exceptions import AzureConfigurationError
|
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class AsyncAzureClient:
|
class AsyncAzureVaultClient:
|
||||||
def __init__(self, storage_account_name: str | None, storage_account_key: str | None):
|
def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential):
|
||||||
self.storage_account_name = storage_account_name
|
self.credential = credential
|
||||||
self.storage_account_key = storage_account_key
|
|
||||||
|
|
||||||
if storage_account_name and storage_account_key:
|
|
||||||
self.blob_service_client = BlobServiceClient(
|
|
||||||
account_url=f"https://{storage_account_name}.blob.core.windows.net",
|
|
||||||
credential=storage_account_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.blob_service_client = None
|
|
||||||
|
|
||||||
self.credential = DefaultAzureCredential()
|
|
||||||
|
|
||||||
async def get_secret(self, secret_name: str, vault_name: str | None = None) -> str | None:
|
|
||||||
vault_subdomain = vault_name or self.storage_account_name
|
|
||||||
if not vault_subdomain:
|
|
||||||
raise AzureConfigurationError("Missing vault")
|
|
||||||
|
|
||||||
|
async def get_secret(self, secret_name: str, vault_name: str) -> str | None:
|
||||||
try:
|
try:
|
||||||
# Azure Key Vault URL format: https://<your-key-vault-name>.vault.azure.net
|
# Azure Key Vault URL format: https://<your-key-vault-name>.vault.azure.net
|
||||||
# Assuming the secret_name is actually the Key Vault URL and the secret name
|
# Assuming the secret_name is actually the Key Vault URL and the secret name
|
||||||
# This needs to be clarified or passed as separate parameters
|
# This needs to be clarified or passed as separate parameters
|
||||||
# For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings.
|
# For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings.
|
||||||
key_vault_url = f"https://{vault_subdomain}.vault.azure.net" # Placeholder, adjust as needed
|
key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed
|
||||||
secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential)
|
secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential)
|
||||||
secret = await secret_client.get_secret(secret_name)
|
secret = await secret_client.get_secret(secret_name)
|
||||||
return secret.value
|
return secret.value
|
||||||
@@ -43,10 +28,34 @@ class AsyncAzureClient:
|
|||||||
finally:
|
finally:
|
||||||
await self.credential.close()
|
await self.credential.close()
|
||||||
|
|
||||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
async def close(self) -> None:
|
||||||
if not self.blob_service_client:
|
await self.credential.close()
|
||||||
raise AzureConfigurationError("Storage is not configured")
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_default(cls) -> "AsyncAzureVaultClient":
|
||||||
|
return cls(DefaultAzureCredential())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_client_secret(
|
||||||
|
cls,
|
||||||
|
credential: AzureClientSecretCredential,
|
||||||
|
) -> "AsyncAzureVaultClient":
|
||||||
|
cred = ClientSecretCredential(
|
||||||
|
tenant_id=credential.tenant_id,
|
||||||
|
client_id=credential.client_id,
|
||||||
|
client_secret=credential.client_secret,
|
||||||
|
)
|
||||||
|
return cls(cred)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAzureStorageClient:
|
||||||
|
def __init__(self, storage_account_name: str, storage_account_key: str):
|
||||||
|
self.blob_service_client = BlobServiceClient(
|
||||||
|
account_url=f"https://{storage_account_name}.blob.core.windows.net",
|
||||||
|
credential=storage_account_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||||
try:
|
try:
|
||||||
container_client = self.blob_service_client.get_container_client(container_name)
|
container_client = self.blob_service_client.get_container_client(container_name)
|
||||||
# Create the container if it doesn't exist
|
# Create the container if it doesn't exist
|
||||||
@@ -69,4 +78,3 @@ class AsyncAzureClient:
|
|||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self.blob_service_client.close()
|
await self.blob_service_client.close()
|
||||||
await self.credential.close()
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, List, Sequence
|
from typing import Any, List, Literal, Sequence, overload
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
|
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
|
||||||
@@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
|||||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
||||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
|
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
|
||||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
from skyvern.forge.sdk.schemas.organizations import (
|
||||||
|
AzureClientSecretCredential,
|
||||||
|
AzureOrganizationAuthToken,
|
||||||
|
Organization,
|
||||||
|
OrganizationAuthToken,
|
||||||
|
)
|
||||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||||
from skyvern.forge.sdk.schemas.runs import Run
|
from skyvern.forge.sdk.schemas.runs import Run
|
||||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||||
@@ -865,11 +870,25 @@ class AgentDB:
|
|||||||
await session.refresh(organization)
|
await session.refresh(organization)
|
||||||
return Organization.model_validate(organization)
|
return Organization.model_validate(organization)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_valid_org_auth_token(
|
||||||
|
self,
|
||||||
|
organization_id: str,
|
||||||
|
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
|
||||||
|
) -> OrganizationAuthToken | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_valid_org_auth_token(
|
||||||
|
self,
|
||||||
|
organization_id: str,
|
||||||
|
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
|
||||||
|
) -> AzureOrganizationAuthToken | None: ...
|
||||||
|
|
||||||
async def get_valid_org_auth_token(
|
async def get_valid_org_auth_token(
|
||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
token_type: OrganizationAuthTokenType,
|
token_type: OrganizationAuthTokenType,
|
||||||
) -> OrganizationAuthToken | None:
|
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
if token := (
|
if token := (
|
||||||
@@ -881,7 +900,7 @@ class AgentDB:
|
|||||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||||
)
|
)
|
||||||
).first():
|
).first():
|
||||||
return await convert_to_organization_auth_token(token)
|
return await convert_to_organization_auth_token(token, token_type)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
@@ -907,7 +926,7 @@ class AgentDB:
|
|||||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
return [await convert_to_organization_auth_token(token) for token in tokens]
|
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
LOG.error("SQLAlchemyError", exc_info=True)
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -941,7 +960,7 @@ class AgentDB:
|
|||||||
if valid is not None:
|
if valid is not None:
|
||||||
query = query.filter_by(valid=valid)
|
query = query.filter_by(valid=valid)
|
||||||
if token_obj := (await session.scalars(query)).first():
|
if token_obj := (await session.scalars(query)).first():
|
||||||
return await convert_to_organization_auth_token(token_obj)
|
return await convert_to_organization_auth_token(token_obj, token_type)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
@@ -955,14 +974,22 @@ class AgentDB:
|
|||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
token_type: OrganizationAuthTokenType,
|
token_type: OrganizationAuthTokenType,
|
||||||
token: str,
|
token: str | AzureClientSecretCredential,
|
||||||
encrypted_method: EncryptMethod | None = None,
|
encrypted_method: EncryptMethod | None = None,
|
||||||
) -> OrganizationAuthToken:
|
) -> OrganizationAuthToken:
|
||||||
|
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||||
|
if not isinstance(token, AzureClientSecretCredential):
|
||||||
|
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||||
|
plaintext_token = token.model_dump_json()
|
||||||
|
else:
|
||||||
|
if not isinstance(token, str):
|
||||||
|
raise TypeError("Expected str token for this token_type")
|
||||||
plaintext_token = token
|
plaintext_token = token
|
||||||
|
|
||||||
encrypted_token = ""
|
encrypted_token = ""
|
||||||
|
|
||||||
if encrypted_method is not None:
|
if encrypted_method is not None:
|
||||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||||
plaintext_token = ""
|
plaintext_token = ""
|
||||||
|
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
@@ -977,7 +1004,7 @@ class AgentDB:
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(auth_token)
|
await session.refresh(auth_token)
|
||||||
|
|
||||||
return await convert_to_organization_auth_token(auth_token)
|
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||||
|
|
||||||
async def invalidate_org_auth_tokens(
|
async def invalidate_org_auth_tokens(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from enum import StrEnum
|
|||||||
class OrganizationAuthTokenType(StrEnum):
|
class OrganizationAuthTokenType(StrEnum):
|
||||||
api = "api"
|
api = "api"
|
||||||
onepassword_service_account = "onepassword_service_account"
|
onepassword_service_account = "onepassword_service_account"
|
||||||
|
azure_client_secret_credential = "azure_client_secret_credential"
|
||||||
|
|
||||||
|
|
||||||
class TaskType(StrEnum):
|
class TaskType(StrEnum):
|
||||||
|
|||||||
@@ -30,7 +30,12 @@ from skyvern.forge.sdk.db.models import (
|
|||||||
from skyvern.forge.sdk.encrypt import encryptor
|
from skyvern.forge.sdk.encrypt import encryptor
|
||||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||||
from skyvern.forge.sdk.models import Step, StepStatus
|
from skyvern.forge.sdk.models import Step, StepStatus
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
from skyvern.forge.sdk.schemas.organizations import (
|
||||||
|
AzureClientSecretCredential,
|
||||||
|
AzureOrganizationAuthToken,
|
||||||
|
Organization,
|
||||||
|
OrganizationAuthToken,
|
||||||
|
)
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||||
@@ -195,12 +200,24 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
|||||||
|
|
||||||
|
|
||||||
async def convert_to_organization_auth_token(
|
async def convert_to_organization_auth_token(
|
||||||
org_auth_token: OrganizationAuthTokenModel,
|
org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType
|
||||||
) -> OrganizationAuthToken:
|
) -> OrganizationAuthToken | AzureOrganizationAuthToken:
|
||||||
token = org_auth_token.token
|
token = org_auth_token.token
|
||||||
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
||||||
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
|
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
|
||||||
|
|
||||||
|
if token_type == OrganizationAuthTokenType.azure_client_secret_credential:
|
||||||
|
credential = AzureClientSecretCredential.model_validate_json(token)
|
||||||
|
return AzureOrganizationAuthToken(
|
||||||
|
id=org_auth_token.id,
|
||||||
|
organization_id=org_auth_token.organization_id,
|
||||||
|
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||||
|
credential=credential,
|
||||||
|
valid=org_auth_token.valid,
|
||||||
|
created_at=org_auth_token.created_at,
|
||||||
|
modified_at=org_auth_token.modified_at,
|
||||||
|
)
|
||||||
|
else:
|
||||||
return OrganizationAuthToken(
|
return OrganizationAuthToken(
|
||||||
id=org_auth_token.id,
|
id=org_auth_token.id,
|
||||||
organization_id=org_auth_token.organization_id,
|
organization_id=org_auth_token.organization_id,
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ from skyvern.forge.sdk.schemas.credentials import (
|
|||||||
PasswordCredentialResponse,
|
PasswordCredentialResponse,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.schemas.organizations import (
|
from skyvern.forge.sdk.schemas.organizations import (
|
||||||
|
AzureClientSecretCredentialResponse,
|
||||||
|
CreateAzureClientSecretCredentialRequest,
|
||||||
CreateOnePasswordTokenRequest,
|
CreateOnePasswordTokenRequest,
|
||||||
CreateOnePasswordTokenResponse,
|
CreateOnePasswordTokenResponse,
|
||||||
Organization,
|
Organization,
|
||||||
@@ -478,3 +480,106 @@ async def update_onepassword_token(
|
|||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Failed to create or update OnePassword service account token: {str(e)}",
|
detail=f"Failed to create or update OnePassword service account token: {str(e)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.get(
|
||||||
|
"/credentials/azure_credential/get",
|
||||||
|
response_model=AzureClientSecretCredentialResponse,
|
||||||
|
summary="Get Azure Client Secret Credential",
|
||||||
|
description="Retrieves the current Azure Client Secret Credential for the organization.",
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
@base_router.get(
|
||||||
|
"/credentials/azure_credential/get/",
|
||||||
|
response_model=AzureClientSecretCredentialResponse,
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def get_azure_client_secret_credential(
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
) -> AzureClientSecretCredentialResponse:
|
||||||
|
"""
|
||||||
|
Get the current Azure Client Secret Credential for the organization.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||||
|
)
|
||||||
|
if not auth_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No Azure Client Secret Credential found for this organization",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AzureClientSecretCredentialResponse(token=auth_token)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(
|
||||||
|
"Failed to get Azure Client Secret Credential",
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to get Azure Client Secret Credential: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.post(
|
||||||
|
"/credentials/azure_credential/create",
|
||||||
|
response_model=AzureClientSecretCredentialResponse,
|
||||||
|
summary="Create or update Azure Client Secret Credential",
|
||||||
|
description="Creates or updates a Azure Client Secret Credential for the current organization. Only one valid record is allowed per organization.",
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
@base_router.post(
|
||||||
|
"/credentials/azure_credential/create/",
|
||||||
|
response_model=AzureClientSecretCredentialResponse,
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def update_azure_client_secret_credential(
|
||||||
|
request: CreateAzureClientSecretCredentialRequest,
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
) -> AzureClientSecretCredentialResponse:
|
||||||
|
"""
|
||||||
|
Create or update an Azure Client Secret Credential for the current organization.
|
||||||
|
|
||||||
|
This endpoint ensures only one valid Azure Client Secret Credential exists per organization.
|
||||||
|
If a valid token already exists, it will be invalidated before creating the new one.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Invalidate any existing valid Azure Client Secret Credential for this organization
|
||||||
|
await app.DATABASE.invalidate_org_auth_tokens(
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the new Azure token
|
||||||
|
auth_token = await app.DATABASE.create_org_auth_token(
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||||
|
token=request.credential,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Created or updated Azure Client Secret Credential",
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
token_id=auth_token.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AzureClientSecretCredentialResponse(token=auth_token)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(
|
||||||
|
"Failed to create or update Azure Client Secret Credential",
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to create or update Azure Client Secret Credential: {str(e)}",
|
||||||
|
)
|
||||||
|
|||||||
@@ -21,16 +21,31 @@ class Organization(BaseModel):
|
|||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class OrganizationAuthToken(BaseModel):
|
class OrganizationAuthTokenBase(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
organization_id: str
|
organization_id: str
|
||||||
token_type: OrganizationAuthTokenType
|
token_type: OrganizationAuthTokenType
|
||||||
token: str
|
|
||||||
valid: bool
|
valid: bool
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationAuthToken(OrganizationAuthTokenBase):
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class AzureClientSecretCredential(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOrganizationAuthToken(OrganizationAuthTokenBase):
|
||||||
|
"""Represents OrganizationAuthToken for Azure; defined by 3 fields: tenant_id, client_id, and client_secret"""
|
||||||
|
|
||||||
|
credential: AzureClientSecretCredential
|
||||||
|
|
||||||
|
|
||||||
class CreateOnePasswordTokenRequest(BaseModel):
|
class CreateOnePasswordTokenRequest(BaseModel):
|
||||||
"""Request model for creating or updating a 1Password service account token."""
|
"""Request model for creating or updating a 1Password service account token."""
|
||||||
|
|
||||||
@@ -50,6 +65,21 @@ class CreateOnePasswordTokenResponse(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureClientSecretCredentialResponse(BaseModel):
|
||||||
|
"""Response model for Azure ClientSecretCredential operations."""
|
||||||
|
|
||||||
|
token: AzureOrganizationAuthToken = Field(
|
||||||
|
...,
|
||||||
|
description="The created or updated Azure ClientSecretCredential",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAzureClientSecretCredentialRequest(BaseModel):
|
||||||
|
"""Request model for creating or updating an Azure ClientSecretCredential."""
|
||||||
|
|
||||||
|
credential: AzureClientSecretCredential
|
||||||
|
|
||||||
|
|
||||||
class GetOrganizationsResponse(BaseModel):
|
class GetOrganizationsResponse(BaseModel):
|
||||||
organizations: list[Organization]
|
organizations: list[Organization]
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from onepassword.client import Client as OnePasswordClient
|
|||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import (
|
from skyvern.exceptions import (
|
||||||
|
AzureConfigurationError,
|
||||||
BitwardenBaseError,
|
BitwardenBaseError,
|
||||||
CredentialParameterNotFoundError,
|
CredentialParameterNotFoundError,
|
||||||
SkyvernException,
|
SkyvernException,
|
||||||
@@ -14,7 +15,7 @@ from skyvern.exceptions import (
|
|||||||
)
|
)
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||||
from skyvern.forge.sdk.api.azure import AsyncAzureClient
|
from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.forge.sdk.schemas.credentials import PasswordCredential
|
from skyvern.forge.sdk.schemas.credentials import PasswordCredential
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||||
@@ -56,7 +57,6 @@ class WorkflowRunContext:
|
|||||||
async def init(
|
async def init(
|
||||||
cls,
|
cls,
|
||||||
aws_client: AsyncAWSClient,
|
aws_client: AsyncAWSClient,
|
||||||
azure_client: AsyncAzureClient,
|
|
||||||
organization: Organization,
|
organization: Organization,
|
||||||
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
|
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
|
||||||
workflow_output_parameters: list[OutputParameter],
|
workflow_output_parameters: list[OutputParameter],
|
||||||
@@ -71,7 +71,7 @@ class WorkflowRunContext:
|
|||||||
block_outputs: dict[str, Any] | None = None,
|
block_outputs: dict[str, Any] | None = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
# key is label name
|
# key is label name
|
||||||
workflow_run_context = cls(aws_client=aws_client, azure_client=azure_client)
|
workflow_run_context = cls(aws_client=aws_client)
|
||||||
for parameter, run_parameter in workflow_parameter_tuples:
|
for parameter, run_parameter in workflow_parameter_tuples:
|
||||||
if parameter.workflow_parameter_type == WorkflowParameterType.CREDENTIAL_ID:
|
if parameter.workflow_parameter_type == WorkflowParameterType.CREDENTIAL_ID:
|
||||||
await workflow_run_context.register_secret_workflow_parameter_value(
|
await workflow_run_context.register_secret_workflow_parameter_value(
|
||||||
@@ -109,7 +109,9 @@ class WorkflowRunContext:
|
|||||||
secret_parameter, organization
|
secret_parameter, organization
|
||||||
)
|
)
|
||||||
elif isinstance(secret_parameter, AzureVaultCredentialParameter):
|
elif isinstance(secret_parameter, AzureVaultCredentialParameter):
|
||||||
await workflow_run_context.register_azure_vault_credential_parameter_value(secret_parameter)
|
await workflow_run_context.register_azure_vault_credential_parameter_value(
|
||||||
|
secret_parameter, organization
|
||||||
|
)
|
||||||
elif isinstance(secret_parameter, BitwardenLoginCredentialParameter):
|
elif isinstance(secret_parameter, BitwardenLoginCredentialParameter):
|
||||||
await workflow_run_context.register_bitwarden_login_credential_parameter_value(
|
await workflow_run_context.register_bitwarden_login_credential_parameter_value(
|
||||||
secret_parameter, organization
|
secret_parameter, organization
|
||||||
@@ -131,13 +133,12 @@ class WorkflowRunContext:
|
|||||||
|
|
||||||
return workflow_run_context
|
return workflow_run_context
|
||||||
|
|
||||||
def __init__(self, aws_client: AsyncAWSClient, azure_client: AsyncAzureClient) -> None:
|
def __init__(self, aws_client: AsyncAWSClient) -> None:
|
||||||
self.blocks_metadata: dict[str, BlockMetadata] = {}
|
self.blocks_metadata: dict[str, BlockMetadata] = {}
|
||||||
self.parameters: dict[str, PARAMETER_TYPE] = {}
|
self.parameters: dict[str, PARAMETER_TYPE] = {}
|
||||||
self.values: dict[str, Any] = {}
|
self.values: dict[str, Any] = {}
|
||||||
self.secrets: dict[str, Any] = {}
|
self.secrets: dict[str, Any] = {}
|
||||||
self._aws_client = aws_client
|
self._aws_client = aws_client
|
||||||
self._azure_client = azure_client
|
|
||||||
|
|
||||||
def get_parameter(self, key: str) -> Parameter:
|
def get_parameter(self, key: str) -> Parameter:
|
||||||
return self.parameters[key]
|
return self.parameters[key]
|
||||||
@@ -343,10 +344,16 @@ class WorkflowRunContext:
|
|||||||
self,
|
self,
|
||||||
parameter: AzureSecretParameter,
|
parameter: AzureSecretParameter,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
vault_name = settings.AZURE_STORAGE_ACCOUNT_NAME
|
||||||
|
if vault_name is None:
|
||||||
|
LOG.error("AZURE_STORAGE_ACCOUNT_NAME is not configured, cannot register Azure secret parameter value")
|
||||||
|
raise AzureConfigurationError("AZURE_STORAGE_ACCOUNT_NAME is not configured")
|
||||||
|
|
||||||
# If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict
|
# If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict
|
||||||
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
|
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
|
||||||
# We'll replace the random secret id with the actual secret value when we need to use it.
|
# We'll replace the random secret id with the actual secret value when we need to use it.
|
||||||
secret_value = await self._azure_client.get_secret(parameter.azure_key)
|
azure_vault_client = AsyncAzureVaultClient.create_default()
|
||||||
|
secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name)
|
||||||
if secret_value is not None:
|
if secret_value is not None:
|
||||||
random_secret_id = self.generate_random_secret_id()
|
random_secret_id = self.generate_random_secret_id()
|
||||||
self.secrets[random_secret_id] = secret_value
|
self.secrets[random_secret_id] = secret_value
|
||||||
@@ -491,7 +498,11 @@ class WorkflowRunContext:
|
|||||||
LOG.error(f"Failed to get secret from Bitwarden. Error: {e}")
|
LOG.error(f"Failed to get secret from Bitwarden. Error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def register_azure_vault_credential_parameter_value(self, parameter: AzureVaultCredentialParameter) -> None:
|
async def register_azure_vault_credential_parameter_value(
|
||||||
|
self,
|
||||||
|
parameter: AzureVaultCredentialParameter,
|
||||||
|
organization: Organization,
|
||||||
|
) -> None:
|
||||||
vault_name = self._resolve_parameter_value(parameter.vault_name)
|
vault_name = self._resolve_parameter_value(parameter.vault_name)
|
||||||
if not vault_name:
|
if not vault_name:
|
||||||
raise ValueError("Azure Vault Name is missing")
|
raise ValueError("Azure Vault Name is missing")
|
||||||
@@ -504,18 +515,28 @@ class WorkflowRunContext:
|
|||||||
|
|
||||||
totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key)
|
totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key)
|
||||||
|
|
||||||
secret_login = await self._azure_client.get_secret(username_key, vault_name)
|
azure_vault_client = await self._get_azure_vault_client_for_organization(organization)
|
||||||
secret_password = await self._azure_client.get_secret(password_key, vault_name)
|
|
||||||
|
secret_username = await azure_vault_client.get_secret(username_key, vault_name)
|
||||||
|
if not secret_username:
|
||||||
|
raise ValueError(f"Azure Vault username not found by key: {username_key}")
|
||||||
|
|
||||||
|
secret_password = await azure_vault_client.get_secret(password_key, vault_name)
|
||||||
|
if not secret_password:
|
||||||
|
raise ValueError(f"Azure Vault password not found by key: {password_key}")
|
||||||
|
|
||||||
if totp_secret_key:
|
if totp_secret_key:
|
||||||
totp_secret = await self._azure_client.get_secret(totp_secret_key, vault_name)
|
totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name)
|
||||||
|
if not totp_secret:
|
||||||
|
raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}")
|
||||||
else:
|
else:
|
||||||
totp_secret = None
|
totp_secret = None
|
||||||
|
|
||||||
if secret_login is not None and secret_password is not None:
|
if secret_username is not None and secret_password is not None:
|
||||||
random_secret_id = self.generate_random_secret_id()
|
random_secret_id = self.generate_random_secret_id()
|
||||||
# login secret
|
# login secret
|
||||||
username_secret_id = f"{random_secret_id}_username"
|
username_secret_id = f"{random_secret_id}_username"
|
||||||
self.secrets[username_secret_id] = secret_login
|
self.secrets[username_secret_id] = secret_username
|
||||||
# password secret
|
# password secret
|
||||||
password_secret_id = f"{random_secret_id}_password"
|
password_secret_id = f"{random_secret_id}_password"
|
||||||
self.secrets[password_secret_id] = secret_password
|
self.secrets[password_secret_id] = secret_password
|
||||||
@@ -895,10 +916,21 @@ class WorkflowRunContext:
|
|||||||
else:
|
else:
|
||||||
return jinja_sandbox_env.from_string(parameter_value).render(self.values)
|
return jinja_sandbox_env.from_string(parameter_value).render(self.values)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient:
|
||||||
|
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||||
|
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential
|
||||||
|
)
|
||||||
|
if org_auth_token:
|
||||||
|
azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential)
|
||||||
|
else:
|
||||||
|
# Use the DefaultAzureCredential if not configured on organization level
|
||||||
|
azure_vault_client = AsyncAzureVaultClient.create_default()
|
||||||
|
return azure_vault_client
|
||||||
|
|
||||||
|
|
||||||
class WorkflowContextManager:
|
class WorkflowContextManager:
|
||||||
aws_client: AsyncAWSClient
|
aws_client: AsyncAWSClient
|
||||||
azure_client: AsyncAzureClient
|
|
||||||
workflow_run_contexts: dict[str, WorkflowRunContext]
|
workflow_run_contexts: dict[str, WorkflowRunContext]
|
||||||
|
|
||||||
parameters: dict[str, PARAMETER_TYPE]
|
parameters: dict[str, PARAMETER_TYPE]
|
||||||
@@ -907,10 +939,6 @@ class WorkflowContextManager:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.aws_client = AsyncAWSClient()
|
self.aws_client = AsyncAWSClient()
|
||||||
self.azure_client = AsyncAzureClient(
|
|
||||||
storage_account_name=settings.AZURE_STORAGE_ACCOUNT_NAME,
|
|
||||||
storage_account_key=settings.AZURE_STORAGE_ACCOUNT_KEY,
|
|
||||||
)
|
|
||||||
self.workflow_run_contexts = {}
|
self.workflow_run_contexts = {}
|
||||||
|
|
||||||
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
|
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
|
||||||
@@ -935,7 +963,6 @@ class WorkflowContextManager:
|
|||||||
) -> WorkflowRunContext:
|
) -> WorkflowRunContext:
|
||||||
workflow_run_context = await WorkflowRunContext.init(
|
workflow_run_context = await WorkflowRunContext.init(
|
||||||
self.aws_client,
|
self.aws_client,
|
||||||
self.azure_client,
|
|
||||||
organization,
|
organization,
|
||||||
workflow_parameter_tuples,
|
workflow_parameter_tuples,
|
||||||
workflow_output_parameters,
|
workflow_output_parameters,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from skyvern.constants import (
|
|||||||
MAX_UPLOAD_FILE_COUNT,
|
MAX_UPLOAD_FILE_COUNT,
|
||||||
)
|
)
|
||||||
from skyvern.exceptions import (
|
from skyvern.exceptions import (
|
||||||
|
AzureConfigurationError,
|
||||||
ContextParameterValueNotFound,
|
ContextParameterValueNotFound,
|
||||||
MissingBrowserState,
|
MissingBrowserState,
|
||||||
MissingBrowserStatePage,
|
MissingBrowserStatePage,
|
||||||
@@ -44,7 +45,7 @@ from skyvern.exceptions import (
|
|||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||||
from skyvern.forge.sdk.api.azure import AsyncAzureClient
|
from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient
|
||||||
from skyvern.forge.sdk.api.files import (
|
from skyvern.forge.sdk.api.files import (
|
||||||
calculate_sha256_for_file,
|
calculate_sha256_for_file,
|
||||||
create_named_temporary_file,
|
create_named_temporary_file,
|
||||||
@@ -2061,7 +2062,10 @@ class FileUploadBlock(Block):
|
|||||||
workflow_run_context.get_original_secret_value_or_none(self.azure_storage_account_key)
|
workflow_run_context.get_original_secret_value_or_none(self.azure_storage_account_key)
|
||||||
or self.azure_storage_account_key
|
or self.azure_storage_account_key
|
||||||
)
|
)
|
||||||
azure_client = AsyncAzureClient(
|
if actual_azure_storage_account_name is None or actual_azure_storage_account_key is None:
|
||||||
|
raise AzureConfigurationError("Azure Storage is not configured")
|
||||||
|
|
||||||
|
azure_client = AsyncAzureStorageClient(
|
||||||
storage_account_name=actual_azure_storage_account_name,
|
storage_account_name=actual_azure_storage_account_name,
|
||||||
storage_account_key=actual_azure_storage_account_key,
|
storage_account_key=actual_azure_storage_account_key,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user