Files
Dorod-Sky/skyvern/forge/sdk/encrypt/aes.py
2025-08-05 12:36:24 +08:00

64 lines
2.5 KiB
Python

import base64
import hashlib
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
class AES(BaseEncryptor):
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
def method(self) -> EncryptMethod:
return EncryptMethod.AES
def _derive_key(self) -> bytes:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self.salt,
iterations=100000,
)
return kdf.derive(self.secret_key)
async def encrypt(self, plaintext: str) -> str:
try:
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
encryptor = cipher.encryptor()
padded_plaintext = self._pad(plaintext.encode("utf-8"))
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
return base64.b64encode(ciphertext).decode("utf-8")
except Exception as e:
raise Exception("Failed to encrypt token") from e
async def decrypt(self, ciphertext: str) -> str:
try:
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
decryptor = cipher.decryptor()
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
plaintext = self._unpad(padded_plaintext)
return plaintext.decode("utf-8")
except Exception as e:
raise Exception("Failed to decrypt token") from e
def _pad(self, data: bytes) -> bytes:
block_size = 16
padding_length = block_size - (len(data) % block_size)
padding = bytes([padding_length] * padding_length)
return data + padding
def _unpad(self, data: bytes) -> bytes:
padding_length = data[-1]
return data[:-padding_length]