feat: encrypt org auth tokens with AES (#3104)
This commit is contained in:
26
skyvern/forge/sdk/encrypt/__init__.py
Normal file
26
skyvern/forge/sdk/encrypt/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
|
||||
|
||||
|
||||
class Encryptor(BaseModel):
|
||||
def __init__(self) -> None:
|
||||
self._methods: dict[EncryptMethod, BaseEncryptor] = {}
|
||||
|
||||
def add_encrypt_method(self, encrypt_method: BaseEncryptor) -> None:
|
||||
self._methods[encrypt_method.method()] = encrypt_method
|
||||
|
||||
async def encrypt(self, plaintext: str, method: EncryptMethod) -> str:
|
||||
if method not in self._methods:
|
||||
raise ValueError(f"encrypt method not registered: {method}")
|
||||
|
||||
return await self._methods[method].encrypt(plaintext)
|
||||
|
||||
async def decrypt(self, ciphertext: str, method: EncryptMethod) -> str:
|
||||
if method not in self._methods:
|
||||
raise ValueError(f"encrypt method not registered: {method}")
|
||||
|
||||
return await self._methods[method].decrypt(ciphertext)
|
||||
|
||||
|
||||
encryptor = Encryptor()
|
||||
63
skyvern/forge/sdk/encrypt/aes.py
Normal file
63
skyvern/forge/sdk/encrypt/aes.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
|
||||
|
||||
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
|
||||
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
|
||||
|
||||
|
||||
class AES(BaseEncryptor):
|
||||
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
|
||||
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
|
||||
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
|
||||
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
|
||||
|
||||
def method(self) -> EncryptMethod:
|
||||
return EncryptMethod.AES
|
||||
|
||||
def _derive_key(self) -> bytes:
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=self.salt,
|
||||
iterations=100000,
|
||||
)
|
||||
return kdf.derive(self.secret_key)
|
||||
|
||||
async def encrypt(self, plaintext: str) -> str:
|
||||
try:
|
||||
key = self._derive_key()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
|
||||
encryptor = cipher.encryptor()
|
||||
padded_plaintext = self._pad(plaintext.encode("utf-8"))
|
||||
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
|
||||
return base64.b64encode(ciphertext).decode("utf-8")
|
||||
except Exception as e:
|
||||
raise Exception("Failed to encrypt token") from e
|
||||
|
||||
async def decrypt(self, ciphertext: str) -> str:
|
||||
try:
|
||||
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
|
||||
key = self._derive_key()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
|
||||
decryptor = cipher.decryptor()
|
||||
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
plaintext = self._unpad(padded_plaintext)
|
||||
return plaintext.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise Exception("Failed to decrypt token") from e
|
||||
|
||||
def _pad(self, data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_length = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_length] * padding_length)
|
||||
return data + padding
|
||||
|
||||
def _unpad(self, data: bytes) -> bytes:
|
||||
padding_length = data[-1]
|
||||
return data[:-padding_length]
|
||||
20
skyvern/forge/sdk/encrypt/base.py
Normal file
20
skyvern/forge/sdk/encrypt/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EncryptMethod(Enum):
|
||||
AES = "aes"
|
||||
|
||||
|
||||
class BaseEncryptor(ABC):
|
||||
@abstractmethod
|
||||
def method(self) -> EncryptMethod:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def encrypt(self, plaintext: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def decrypt(self, ciphertext: str) -> str:
|
||||
pass
|
||||
Reference in New Issue
Block a user