import os import json import unittest from binascii import unhexlify from Cryptodome.Protocol import HPKE from Cryptodome.Protocol.HPKE import DeserializeError from Cryptodome.PublicKey import ECC from Cryptodome.SelfTest.st_common import list_test_cases from Cryptodome.Protocol import DH from Cryptodome.Hash import SHA256, SHA384, SHA512 class HPKE_Tests(unittest.TestCase): key1 = ECC.generate(curve='p256') key2 = ECC.generate(curve='p256') # name, size of enc curves = { 'p256': 65, 'p384': 97, 'p521': 133, 'curve25519': 32, 'curve448': 56, } def round_trip(self, curve, aead_id): key1 = ECC.generate(curve=curve) aead_id = aead_id encryptor = HPKE.new(receiver_key=key1.public_key(), aead_id=aead_id) self.assertEqual(len(encryptor.enc), self.curves[curve]) # First message ct = encryptor.seal(b'ABC', auth_data=b'DEF') decryptor = HPKE.new(receiver_key=key1, aead_id=aead_id, enc=encryptor.enc) pt = decryptor.unseal(ct, auth_data=b'DEF') self.assertEqual(b'ABC', pt) # Second message ct2 = encryptor.seal(b'GHI') pt2 = decryptor.unseal(ct2) self.assertEqual(b'GHI', pt2) def test_round_trip(self): for curve in self.curves.keys(): for aead_id in HPKE.AEAD: self.round_trip(curve, aead_id) def test_psk(self): aead_id = HPKE.AEAD.AES128_GCM HPKE.new(receiver_key=self.key1.public_key(), aead_id=aead_id, psk=(b'a', b'c' * 32)) def test_info(self): aead_id = HPKE.AEAD.AES128_GCM HPKE.new(receiver_key=self.key1.public_key(), aead_id=aead_id, info=b'baba') def test_neg_unsupported_curve(self): key3 = ECC.generate(curve='p224') with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=key3.public_key(), aead_id=HPKE.AEAD.AES128_GCM) self.assertIn("Unsupported curve", str(cm.exception)) def test_neg_too_many_private_keys(self): with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1, sender_key=self.key2, aead_id=HPKE.AEAD.AES128_GCM) self.assertIn("Exactly 1 private key", str(cm.exception)) def test_neg_curve_mismatch(self): key3 = ECC.generate(curve='p384') with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1.public_key(), sender_key=key3, aead_id=HPKE.AEAD.AES128_GCM) self.assertIn("but recipient key", str(cm.exception)) def test_neg_psk(self): with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1.public_key(), psk=(b'', b'G' * 32), aead_id=HPKE.AEAD.AES128_GCM) with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1.public_key(), psk=(b'JJJ', b''), aead_id=HPKE.AEAD.AES128_GCM) with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1.public_key(), psk=(b'JJJ', b'Y' * 31), aead_id=HPKE.AEAD.AES128_GCM) self.assertIn("at least 32", str(cm.exception)) def test_neg_wrong_enc(self): wrong_enc = b'\xFF' + b'8' * 64 with self.assertRaises(DeserializeError): HPKE.new(receiver_key=self.key1, aead_id=HPKE.AEAD.AES128_GCM, enc=wrong_enc) with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1.public_key(), enc=self.key1.public_key().export_key(format='raw'), aead_id=HPKE.AEAD.AES128_GCM) self.assertIn("'enc' cannot be an input", str(cm.exception)) with self.assertRaises(ValueError) as cm: HPKE.new(receiver_key=self.key1, aead_id=HPKE.AEAD.AES128_GCM) self.assertIn("'enc' required", str(cm.exception)) def test_neg_unseal_wrong_ct(self): decryptor = HPKE.new(receiver_key=self.key1, aead_id=HPKE.AEAD.CHACHA20_POLY1305, enc=self.key2.public_key().export_key(format='raw')) with self.assertRaises(ValueError): decryptor.unseal(b'XYZ' * 20) def test_neg_unseal_no_auth_data(self): aead_id = HPKE.AEAD.CHACHA20_POLY1305 encryptor = HPKE.new(receiver_key=self.key1.public_key(), aead_id=aead_id) ct = encryptor.seal(b'ABC', auth_data=b'DEF') decryptor = HPKE.new(receiver_key=self.key1, aead_id=aead_id, enc=encryptor.enc) with self.assertRaises(ValueError): decryptor.unseal(ct) def test_x25519_mode_0(self): # RFC x9180, A.1.1.1, seq 0 and 1 keyR_hex = "4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8" keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex)) pt_hex = "4265617574792069732074727574682c20747275746820626561757479" pt = bytes.fromhex(pt_hex) ct0_hex = "f938558b5d72f1a23810b4be2ab4f84331acc02fc97babc53a52ae8218a355a96d8770ac83d07bea87e13c512a" ct0 = bytes.fromhex(ct0_hex) enc_hex = "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431" enc = bytes.fromhex(enc_hex) aad0_hex = "436f756e742d30" aad0 = bytes.fromhex(aad0_hex) aad1_hex = "436f756e742d31" aad1 = bytes.fromhex(aad1_hex) info_hex = "4f6465206f6e2061204772656369616e2055726e" info = bytes.fromhex(info_hex) ct1_hex = "af2d7e9ac9ae7e270f46ba1f975be53c09f8d875bdc8535458c2494e8a6eab251c03d0c22a56b8ca42c2063b84" ct1 = bytes.fromhex(ct1_hex) aead_id = HPKE.AEAD.AES128_GCM decryptor = HPKE.new(receiver_key=keyR, aead_id=aead_id, info=info, enc=enc) pt_X0 = decryptor.unseal(ct0, aad0) self.assertEqual(pt_X0, pt) pt_X1 = decryptor.unseal(ct1, aad1) self.assertEqual(pt_X1, pt) def test_x25519_mode_1(self): # RFC x9180, A.1.2.1, seq 0 and 1 keyR_hex = "c5eb01eb457fe6c6f57577c5413b931550a162c71a03ac8d196babbd4e5ce0fd" keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex)) psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961" psk_id = bytes.fromhex(psk_id_hex) psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82" psk = bytes.fromhex(psk_hex) pt_hex = "4265617574792069732074727574682c20747275746820626561757479" pt = bytes.fromhex(pt_hex) ct0_hex = "e52c6fed7f758d0cf7145689f21bc1be6ec9ea097fef4e959440012f4feb73fb611b946199e681f4cfc34db8ea" ct0 = bytes.fromhex(ct0_hex) enc_hex = "0ad0950d9fb9588e59690b74f1237ecdf1d775cd60be2eca57af5a4b0471c91b" enc = bytes.fromhex(enc_hex) aad0_hex = "436f756e742d30" aad0 = bytes.fromhex(aad0_hex) aad1_hex = "436f756e742d31" aad1 = bytes.fromhex(aad1_hex) info_hex = "4f6465206f6e2061204772656369616e2055726e" info = bytes.fromhex(info_hex) ct1_hex = "49f3b19b28a9ea9f43e8c71204c00d4a490ee7f61387b6719db765e948123b45b61633ef059ba22cd62437c8ba" ct1 = bytes.fromhex(ct1_hex) aead_id = HPKE.AEAD.AES128_GCM decryptor = HPKE.new(receiver_key=keyR, aead_id=aead_id, info=info, psk=(psk_id, psk), enc=enc) pt_X0 = decryptor.unseal(ct0, aad0) self.assertEqual(pt_X0, pt) pt_X1 = decryptor.unseal(ct1, aad1) self.assertEqual(pt_X1, pt) def test_x25519_mode_2(self): # RFC x9180, A.1.3.1, seq 0 and 1 keyR_hex = "fdea67cf831f1ca98d8e27b1f6abeb5b7745e9d35348b80fa407ff6958f9137e" keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex)) keyS_hex = "dc4a146313cce60a278a5323d321f051c5707e9c45ba21a3479fecdf76fc69dd" keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex)) pt_hex = "4265617574792069732074727574682c20747275746820626561757479" pt = bytes.fromhex(pt_hex) ct0_hex = "5fd92cc9d46dbf8943e72a07e42f363ed5f721212cd90bcfd072bfd9f44e06b80fd17824947496e21b680c141b" ct0 = bytes.fromhex(ct0_hex) enc_hex = "23fb952571a14a25e3d678140cd0e5eb47a0961bb18afcf85896e5453c312e76" enc = bytes.fromhex(enc_hex) aad0_hex = "436f756e742d30" aad0 = bytes.fromhex(aad0_hex) aad1_hex = "436f756e742d31" aad1 = bytes.fromhex(aad1_hex) info_hex = "4f6465206f6e2061204772656369616e2055726e" info = bytes.fromhex(info_hex) ct1_hex = "d3736bb256c19bfa93d79e8f80b7971262cb7c887e35c26370cfed62254369a1b52e3d505b79dd699f002bc8ed" ct1 = bytes.fromhex(ct1_hex) aead_id = HPKE.AEAD.AES128_GCM decryptor = HPKE.new(receiver_key=keyR, sender_key=keyS.public_key(), aead_id=aead_id, info=info, enc=enc) pt_X0 = decryptor.unseal(ct0, aad0) self.assertEqual(pt_X0, pt) pt_X1 = decryptor.unseal(ct1, aad1) self.assertEqual(pt_X1, pt) def test_x25519_mode_3(self): # RFC x9180, A.1.4.1, seq 0 and 1 keyR_hex = "cb29a95649dc5656c2d054c1aa0d3df0493155e9d5da6d7e344ed8b6a64a9423" keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex)) keyS_hex = "fc1c87d2f3832adb178b431fce2ac77c7ca2fd680f3406c77b5ecdf818b119f4" keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex)) psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961" psk_id = bytes.fromhex(psk_id_hex) psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82" psk = bytes.fromhex(psk_hex) pt_hex = "4265617574792069732074727574682c20747275746820626561757479" pt = bytes.fromhex(pt_hex) ct0_hex = "a84c64df1e11d8fd11450039d4fe64ff0c8a99fca0bd72c2d4c3e0400bc14a40f27e45e141a24001697737533e" ct0 = bytes.fromhex(ct0_hex) enc_hex = "820818d3c23993492cc5623ab437a48a0a7ca3e9639c140fe1e33811eb844b7c" enc = bytes.fromhex(enc_hex) aad0_hex = "436f756e742d30" aad0 = bytes.fromhex(aad0_hex) aad1_hex = "436f756e742d31" aad1 = bytes.fromhex(aad1_hex) info_hex = "4f6465206f6e2061204772656369616e2055726e" info = bytes.fromhex(info_hex) ct1_hex = "4d19303b848f424fc3c3beca249b2c6de0a34083b8e909b6aa4c3688505c05ffe0c8f57a0a4c5ab9da127435d9" ct1 = bytes.fromhex(ct1_hex) aead_id = HPKE.AEAD.AES128_GCM decryptor = HPKE.new(receiver_key=keyR, sender_key=keyS.public_key(), aead_id=aead_id, psk=(psk_id, psk), info=info, enc=enc) pt_X0 = decryptor.unseal(ct0, aad0) self.assertEqual(pt_X0, pt) pt_X1 = decryptor.unseal(ct1, aad1) self.assertEqual(pt_X1, pt) class HPKE_TestVectors(unittest.TestCase): def setUp(self): self.vectors = [] try: import pycryptodome_test_vectors # type: ignore init_dir = os.path.dirname(pycryptodome_test_vectors.__file__) full_file_name = os.path.join(init_dir, "Protocol", "HPKE-test-vectors.json") with open(full_file_name, "r") as f: self.vectors = json.load(f) except (FileNotFoundError, ImportError): print("\nWarning: skipping extended tests for HPKE (install pycryptodome-test-vectors)") def import_private_key(self, key_hex, kem_id): key_bin = unhexlify(key_hex) if kem_id == 0x0010: return ECC.construct(curve='p256', d=int.from_bytes(key_bin, byteorder="big")) elif kem_id == 0x0011: return ECC.construct(curve='p384', d=int.from_bytes(key_bin, byteorder="big")) elif kem_id == 0x0012: return ECC.construct(curve='p521', d=int.from_bytes(key_bin, byteorder="big")) elif kem_id == 0x0020: return DH.import_x25519_private_key(key_bin) elif kem_id == 0x0021: return DH.import_x448_private_key(key_bin) def test_hpke_encap(self): """Test HPKE encapsulation using test vectors.""" if not self.vectors: self.skipTest("No test vectors available") for idx, vector in enumerate(self.vectors): kem_id = vector["kem_id"] kdf_id = vector["kdf_id"] aead_id = vector["aead_id"] # No export-only pseudo-cipher if aead_id == 0xffff: continue # We support only one KDF per curve supported_combi = { (0x10, 0x1): SHA256, (0x11, 0x2): SHA384, (0x12, 0x3): SHA512, (0x20, 0x1): SHA256, (0x21, 0x3): SHA512, } hashmod = supported_combi.get((kem_id, kdf_id)) if hashmod is None: continue with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id): receiver_pub = self.import_private_key(vector["skRm"], kem_id).public_key() sender_priv = None if "skSm" in vector: sender_priv = self.import_private_key(vector["skSm"], kem_id) encap_key = self.import_private_key(vector["skEm"], kem_id) shared_secret, enc = HPKE.HPKE_Cipher._encap(receiver_pub, kem_id, hashmod, sender_priv, encap_key) self.assertEqual(enc.hex(), vector["enc"]) self.assertEqual(shared_secret, unhexlify(vector["shared_secret"])) print(".", end="", flush=True) def test_hpke_unseal(self): """Test HPKE encryption and decryption using test vectors.""" if not self.vectors: self.skipTest("No test vectors available") for idx, vector in enumerate(self.vectors): kem_id = vector["kem_id"] kdf_id = vector["kdf_id"] aead_id = vector["aead_id"] # No export-only pseudo-cipher if aead_id == 0xffff: continue # We support only one KDF per curve supported_combi = ( (0x10, 0x1), (0x11, 0x2), (0x12, 0x3), (0x20, 0x1), (0x21, 0x3), ) if (kem_id, kdf_id) not in supported_combi: continue with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id): receiver_priv = self.import_private_key(vector["skRm"], kem_id) sender_pub = None if "skSm" in vector: sender_priv = self.import_private_key(vector["skSm"], kem_id) sender_pub = sender_priv.public_key() encap_key = unhexlify(vector["enc"]) psk = None if "psk_id" in vector: psk = unhexlify(vector["psk_id"]), unhexlify(vector["psk"]) receiver_hpke = HPKE.new(receiver_key=receiver_priv, aead_id=HPKE.AEAD(aead_id), enc=encap_key, sender_key=sender_pub, psk=psk, info=unhexlify(vector["info"])) for encryption in vector['encryptions']: plaintext = unhexlify(encryption["pt"]) ciphertext = unhexlify(encryption["ct"]) aad = unhexlify(encryption["aad"]) # Decrypt (unseal) decrypted = receiver_hpke.unseal(ciphertext, aad) self.assertEqual(decrypted, plaintext, "Decryption failed") print(".", end="", flush=True) if __name__ == "__main__": unittest.main() def get_tests(config={}): tests = [] tests += list_test_cases(HPKE_Tests) if config.get('slow_tests'): tests += list_test_cases(HPKE_TestVectors) return tests if __name__ == '__main__': def suite(): return unittest.TestSuite(get_tests()) unittest.main(defaultTest='suite')