Files
Dorod-Sky/skyvern/forge/sdk/encrypt/aes.py

64 lines
2.5 KiB
Python
Raw Normal View History

import base64
import hashlib
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
class AES(BaseEncryptor):
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
def method(self) -> EncryptMethod:
return EncryptMethod.AES
def _derive_key(self) -> bytes:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self.salt,
iterations=100000,
)
return kdf.derive(self.secret_key)
async def encrypt(self, plaintext: str) -> str:
try:
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
encryptor = cipher.encryptor()
padded_plaintext = self._pad(plaintext.encode("utf-8"))
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
return base64.b64encode(ciphertext).decode("utf-8")
except Exception as e:
raise Exception("Failed to encrypt token") from e
async def decrypt(self, ciphertext: str) -> str:
try:
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
decryptor = cipher.decryptor()
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
plaintext = self._unpad(padded_plaintext)
return plaintext.decode("utf-8")
except Exception as e:
raise Exception("Failed to decrypt token") from e
def _pad(self, data: bytes) -> bytes:
block_size = 16
padding_length = block_size - (len(data) % block_size)
padding = bytes([padding_length] * padding_length)
return data + padding
def _unpad(self, data: bytes) -> bytes:
padding_length = data[-1]
return data[:-padding_length]