Files
2025-04-02 21:44:17 -07:00

496 lines
17 KiB
Python

import os
import json
import unittest
from binascii import unhexlify
from Cryptodome.Protocol import HPKE
from Cryptodome.Protocol.HPKE import DeserializeError
from Cryptodome.PublicKey import ECC
from Cryptodome.SelfTest.st_common import list_test_cases
from Cryptodome.Protocol import DH
from Cryptodome.Hash import SHA256, SHA384, SHA512
class HPKE_Tests(unittest.TestCase):
key1 = ECC.generate(curve='p256')
key2 = ECC.generate(curve='p256')
# name, size of enc
curves = {
'p256': 65,
'p384': 97,
'p521': 133,
'curve25519': 32,
'curve448': 56,
}
def round_trip(self, curve, aead_id):
key1 = ECC.generate(curve=curve)
aead_id = aead_id
encryptor = HPKE.new(receiver_key=key1.public_key(),
aead_id=aead_id)
self.assertEqual(len(encryptor.enc), self.curves[curve])
# First message
ct = encryptor.seal(b'ABC', auth_data=b'DEF')
decryptor = HPKE.new(receiver_key=key1,
aead_id=aead_id,
enc=encryptor.enc)
pt = decryptor.unseal(ct, auth_data=b'DEF')
self.assertEqual(b'ABC', pt)
# Second message
ct2 = encryptor.seal(b'GHI')
pt2 = decryptor.unseal(ct2)
self.assertEqual(b'GHI', pt2)
def test_round_trip(self):
for curve in self.curves.keys():
for aead_id in HPKE.AEAD:
self.round_trip(curve, aead_id)
def test_psk(self):
aead_id = HPKE.AEAD.AES128_GCM
HPKE.new(receiver_key=self.key1.public_key(),
aead_id=aead_id,
psk=(b'a', b'c' * 32))
def test_info(self):
aead_id = HPKE.AEAD.AES128_GCM
HPKE.new(receiver_key=self.key1.public_key(),
aead_id=aead_id,
info=b'baba')
def test_neg_unsupported_curve(self):
key3 = ECC.generate(curve='p224')
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=key3.public_key(),
aead_id=HPKE.AEAD.AES128_GCM)
self.assertIn("Unsupported curve", str(cm.exception))
def test_neg_too_many_private_keys(self):
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1,
sender_key=self.key2,
aead_id=HPKE.AEAD.AES128_GCM)
self.assertIn("Exactly 1 private key", str(cm.exception))
def test_neg_curve_mismatch(self):
key3 = ECC.generate(curve='p384')
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1.public_key(),
sender_key=key3,
aead_id=HPKE.AEAD.AES128_GCM)
self.assertIn("but recipient key", str(cm.exception))
def test_neg_psk(self):
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1.public_key(),
psk=(b'', b'G' * 32),
aead_id=HPKE.AEAD.AES128_GCM)
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1.public_key(),
psk=(b'JJJ', b''),
aead_id=HPKE.AEAD.AES128_GCM)
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1.public_key(),
psk=(b'JJJ', b'Y' * 31),
aead_id=HPKE.AEAD.AES128_GCM)
self.assertIn("at least 32", str(cm.exception))
def test_neg_wrong_enc(self):
wrong_enc = b'\xFF' + b'8' * 64
with self.assertRaises(DeserializeError):
HPKE.new(receiver_key=self.key1,
aead_id=HPKE.AEAD.AES128_GCM,
enc=wrong_enc)
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1.public_key(),
enc=self.key1.public_key().export_key(format='raw'),
aead_id=HPKE.AEAD.AES128_GCM)
self.assertIn("'enc' cannot be an input", str(cm.exception))
with self.assertRaises(ValueError) as cm:
HPKE.new(receiver_key=self.key1,
aead_id=HPKE.AEAD.AES128_GCM)
self.assertIn("'enc' required", str(cm.exception))
def test_neg_unseal_wrong_ct(self):
decryptor = HPKE.new(receiver_key=self.key1,
aead_id=HPKE.AEAD.CHACHA20_POLY1305,
enc=self.key2.public_key().export_key(format='raw'))
with self.assertRaises(ValueError):
decryptor.unseal(b'XYZ' * 20)
def test_neg_unseal_no_auth_data(self):
aead_id = HPKE.AEAD.CHACHA20_POLY1305
encryptor = HPKE.new(receiver_key=self.key1.public_key(),
aead_id=aead_id)
ct = encryptor.seal(b'ABC', auth_data=b'DEF')
decryptor = HPKE.new(receiver_key=self.key1,
aead_id=aead_id,
enc=encryptor.enc)
with self.assertRaises(ValueError):
decryptor.unseal(ct)
def test_x25519_mode_0(self):
# RFC x9180, A.1.1.1, seq 0 and 1
keyR_hex = "4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8"
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
pt = bytes.fromhex(pt_hex)
ct0_hex = "f938558b5d72f1a23810b4be2ab4f84331acc02fc97babc53a52ae8218a355a96d8770ac83d07bea87e13c512a"
ct0 = bytes.fromhex(ct0_hex)
enc_hex = "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431"
enc = bytes.fromhex(enc_hex)
aad0_hex = "436f756e742d30"
aad0 = bytes.fromhex(aad0_hex)
aad1_hex = "436f756e742d31"
aad1 = bytes.fromhex(aad1_hex)
info_hex = "4f6465206f6e2061204772656369616e2055726e"
info = bytes.fromhex(info_hex)
ct1_hex = "af2d7e9ac9ae7e270f46ba1f975be53c09f8d875bdc8535458c2494e8a6eab251c03d0c22a56b8ca42c2063b84"
ct1 = bytes.fromhex(ct1_hex)
aead_id = HPKE.AEAD.AES128_GCM
decryptor = HPKE.new(receiver_key=keyR,
aead_id=aead_id,
info=info,
enc=enc)
pt_X0 = decryptor.unseal(ct0, aad0)
self.assertEqual(pt_X0, pt)
pt_X1 = decryptor.unseal(ct1, aad1)
self.assertEqual(pt_X1, pt)
def test_x25519_mode_1(self):
# RFC x9180, A.1.2.1, seq 0 and 1
keyR_hex = "c5eb01eb457fe6c6f57577c5413b931550a162c71a03ac8d196babbd4e5ce0fd"
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961"
psk_id = bytes.fromhex(psk_id_hex)
psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82"
psk = bytes.fromhex(psk_hex)
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
pt = bytes.fromhex(pt_hex)
ct0_hex = "e52c6fed7f758d0cf7145689f21bc1be6ec9ea097fef4e959440012f4feb73fb611b946199e681f4cfc34db8ea"
ct0 = bytes.fromhex(ct0_hex)
enc_hex = "0ad0950d9fb9588e59690b74f1237ecdf1d775cd60be2eca57af5a4b0471c91b"
enc = bytes.fromhex(enc_hex)
aad0_hex = "436f756e742d30"
aad0 = bytes.fromhex(aad0_hex)
aad1_hex = "436f756e742d31"
aad1 = bytes.fromhex(aad1_hex)
info_hex = "4f6465206f6e2061204772656369616e2055726e"
info = bytes.fromhex(info_hex)
ct1_hex = "49f3b19b28a9ea9f43e8c71204c00d4a490ee7f61387b6719db765e948123b45b61633ef059ba22cd62437c8ba"
ct1 = bytes.fromhex(ct1_hex)
aead_id = HPKE.AEAD.AES128_GCM
decryptor = HPKE.new(receiver_key=keyR,
aead_id=aead_id,
info=info,
psk=(psk_id, psk),
enc=enc)
pt_X0 = decryptor.unseal(ct0, aad0)
self.assertEqual(pt_X0, pt)
pt_X1 = decryptor.unseal(ct1, aad1)
self.assertEqual(pt_X1, pt)
def test_x25519_mode_2(self):
# RFC x9180, A.1.3.1, seq 0 and 1
keyR_hex = "fdea67cf831f1ca98d8e27b1f6abeb5b7745e9d35348b80fa407ff6958f9137e"
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
keyS_hex = "dc4a146313cce60a278a5323d321f051c5707e9c45ba21a3479fecdf76fc69dd"
keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex))
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
pt = bytes.fromhex(pt_hex)
ct0_hex = "5fd92cc9d46dbf8943e72a07e42f363ed5f721212cd90bcfd072bfd9f44e06b80fd17824947496e21b680c141b"
ct0 = bytes.fromhex(ct0_hex)
enc_hex = "23fb952571a14a25e3d678140cd0e5eb47a0961bb18afcf85896e5453c312e76"
enc = bytes.fromhex(enc_hex)
aad0_hex = "436f756e742d30"
aad0 = bytes.fromhex(aad0_hex)
aad1_hex = "436f756e742d31"
aad1 = bytes.fromhex(aad1_hex)
info_hex = "4f6465206f6e2061204772656369616e2055726e"
info = bytes.fromhex(info_hex)
ct1_hex = "d3736bb256c19bfa93d79e8f80b7971262cb7c887e35c26370cfed62254369a1b52e3d505b79dd699f002bc8ed"
ct1 = bytes.fromhex(ct1_hex)
aead_id = HPKE.AEAD.AES128_GCM
decryptor = HPKE.new(receiver_key=keyR,
sender_key=keyS.public_key(),
aead_id=aead_id,
info=info,
enc=enc)
pt_X0 = decryptor.unseal(ct0, aad0)
self.assertEqual(pt_X0, pt)
pt_X1 = decryptor.unseal(ct1, aad1)
self.assertEqual(pt_X1, pt)
def test_x25519_mode_3(self):
# RFC x9180, A.1.4.1, seq 0 and 1
keyR_hex = "cb29a95649dc5656c2d054c1aa0d3df0493155e9d5da6d7e344ed8b6a64a9423"
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
keyS_hex = "fc1c87d2f3832adb178b431fce2ac77c7ca2fd680f3406c77b5ecdf818b119f4"
keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex))
psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961"
psk_id = bytes.fromhex(psk_id_hex)
psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82"
psk = bytes.fromhex(psk_hex)
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
pt = bytes.fromhex(pt_hex)
ct0_hex = "a84c64df1e11d8fd11450039d4fe64ff0c8a99fca0bd72c2d4c3e0400bc14a40f27e45e141a24001697737533e"
ct0 = bytes.fromhex(ct0_hex)
enc_hex = "820818d3c23993492cc5623ab437a48a0a7ca3e9639c140fe1e33811eb844b7c"
enc = bytes.fromhex(enc_hex)
aad0_hex = "436f756e742d30"
aad0 = bytes.fromhex(aad0_hex)
aad1_hex = "436f756e742d31"
aad1 = bytes.fromhex(aad1_hex)
info_hex = "4f6465206f6e2061204772656369616e2055726e"
info = bytes.fromhex(info_hex)
ct1_hex = "4d19303b848f424fc3c3beca249b2c6de0a34083b8e909b6aa4c3688505c05ffe0c8f57a0a4c5ab9da127435d9"
ct1 = bytes.fromhex(ct1_hex)
aead_id = HPKE.AEAD.AES128_GCM
decryptor = HPKE.new(receiver_key=keyR,
sender_key=keyS.public_key(),
aead_id=aead_id,
psk=(psk_id, psk),
info=info,
enc=enc)
pt_X0 = decryptor.unseal(ct0, aad0)
self.assertEqual(pt_X0, pt)
pt_X1 = decryptor.unseal(ct1, aad1)
self.assertEqual(pt_X1, pt)
class HPKE_TestVectors(unittest.TestCase):
def setUp(self):
self.vectors = []
try:
import pycryptodome_test_vectors # type: ignore
init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
full_file_name = os.path.join(init_dir, "Protocol", "HPKE-test-vectors.json")
with open(full_file_name, "r") as f:
self.vectors = json.load(f)
except (FileNotFoundError, ImportError):
print("\nWarning: skipping extended tests for HPKE (install pycryptodome-test-vectors)")
def import_private_key(self, key_hex, kem_id):
key_bin = unhexlify(key_hex)
if kem_id == 0x0010:
return ECC.construct(curve='p256', d=int.from_bytes(key_bin,
byteorder="big"))
elif kem_id == 0x0011:
return ECC.construct(curve='p384', d=int.from_bytes(key_bin,
byteorder="big"))
elif kem_id == 0x0012:
return ECC.construct(curve='p521', d=int.from_bytes(key_bin,
byteorder="big"))
elif kem_id == 0x0020:
return DH.import_x25519_private_key(key_bin)
elif kem_id == 0x0021:
return DH.import_x448_private_key(key_bin)
def test_hpke_encap(self):
"""Test HPKE encapsulation using test vectors."""
if not self.vectors:
self.skipTest("No test vectors available")
for idx, vector in enumerate(self.vectors):
kem_id = vector["kem_id"]
kdf_id = vector["kdf_id"]
aead_id = vector["aead_id"]
# No export-only pseudo-cipher
if aead_id == 0xffff:
continue
# We support only one KDF per curve
supported_combi = {
(0x10, 0x1): SHA256,
(0x11, 0x2): SHA384,
(0x12, 0x3): SHA512,
(0x20, 0x1): SHA256,
(0x21, 0x3): SHA512,
}
hashmod = supported_combi.get((kem_id, kdf_id))
if hashmod is None:
continue
with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id):
receiver_pub = self.import_private_key(vector["skRm"],
kem_id).public_key()
sender_priv = None
if "skSm" in vector:
sender_priv = self.import_private_key(vector["skSm"],
kem_id)
encap_key = self.import_private_key(vector["skEm"], kem_id)
shared_secret, enc = HPKE.HPKE_Cipher._encap(receiver_pub,
kem_id,
hashmod,
sender_priv,
encap_key)
self.assertEqual(enc.hex(), vector["enc"])
self.assertEqual(shared_secret,
unhexlify(vector["shared_secret"]))
print(".", end="", flush=True)
def test_hpke_unseal(self):
"""Test HPKE encryption and decryption using test vectors."""
if not self.vectors:
self.skipTest("No test vectors available")
for idx, vector in enumerate(self.vectors):
kem_id = vector["kem_id"]
kdf_id = vector["kdf_id"]
aead_id = vector["aead_id"]
# No export-only pseudo-cipher
if aead_id == 0xffff:
continue
# We support only one KDF per curve
supported_combi = (
(0x10, 0x1),
(0x11, 0x2),
(0x12, 0x3),
(0x20, 0x1),
(0x21, 0x3),
)
if (kem_id, kdf_id) not in supported_combi:
continue
with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id):
receiver_priv = self.import_private_key(vector["skRm"],
kem_id)
sender_pub = None
if "skSm" in vector:
sender_priv = self.import_private_key(vector["skSm"],
kem_id)
sender_pub = sender_priv.public_key()
encap_key = unhexlify(vector["enc"])
psk = None
if "psk_id" in vector:
psk = unhexlify(vector["psk_id"]), unhexlify(vector["psk"])
receiver_hpke = HPKE.new(receiver_key=receiver_priv,
aead_id=HPKE.AEAD(aead_id),
enc=encap_key,
sender_key=sender_pub,
psk=psk,
info=unhexlify(vector["info"]))
for encryption in vector['encryptions']:
plaintext = unhexlify(encryption["pt"])
ciphertext = unhexlify(encryption["ct"])
aad = unhexlify(encryption["aad"])
# Decrypt (unseal)
decrypted = receiver_hpke.unseal(ciphertext, aad)
self.assertEqual(decrypted, plaintext, "Decryption failed")
print(".", end="", flush=True)
if __name__ == "__main__":
unittest.main()
def get_tests(config={}):
tests = []
tests += list_test_cases(HPKE_Tests)
if config.get('slow_tests'):
tests += list_test_cases(HPKE_TestVectors)
return tests
if __name__ == '__main__':
def suite():
return unittest.TestSuite(get_tests())
unittest.main(defaultTest='suite')