496 lines
17 KiB
Python
496 lines
17 KiB
Python
import os
|
|
import json
|
|
import unittest
|
|
from binascii import unhexlify
|
|
|
|
from Cryptodome.Protocol import HPKE
|
|
from Cryptodome.Protocol.HPKE import DeserializeError
|
|
|
|
from Cryptodome.PublicKey import ECC
|
|
from Cryptodome.SelfTest.st_common import list_test_cases
|
|
|
|
from Cryptodome.Protocol import DH
|
|
from Cryptodome.Hash import SHA256, SHA384, SHA512
|
|
|
|
|
|
class HPKE_Tests(unittest.TestCase):
|
|
|
|
key1 = ECC.generate(curve='p256')
|
|
key2 = ECC.generate(curve='p256')
|
|
|
|
# name, size of enc
|
|
curves = {
|
|
'p256': 65,
|
|
'p384': 97,
|
|
'p521': 133,
|
|
'curve25519': 32,
|
|
'curve448': 56,
|
|
}
|
|
|
|
def round_trip(self, curve, aead_id):
|
|
key1 = ECC.generate(curve=curve)
|
|
aead_id = aead_id
|
|
|
|
encryptor = HPKE.new(receiver_key=key1.public_key(),
|
|
aead_id=aead_id)
|
|
self.assertEqual(len(encryptor.enc), self.curves[curve])
|
|
|
|
# First message
|
|
ct = encryptor.seal(b'ABC', auth_data=b'DEF')
|
|
|
|
decryptor = HPKE.new(receiver_key=key1,
|
|
aead_id=aead_id,
|
|
enc=encryptor.enc)
|
|
|
|
pt = decryptor.unseal(ct, auth_data=b'DEF')
|
|
self.assertEqual(b'ABC', pt)
|
|
|
|
# Second message
|
|
ct2 = encryptor.seal(b'GHI')
|
|
pt2 = decryptor.unseal(ct2)
|
|
self.assertEqual(b'GHI', pt2)
|
|
|
|
def test_round_trip(self):
|
|
for curve in self.curves.keys():
|
|
for aead_id in HPKE.AEAD:
|
|
self.round_trip(curve, aead_id)
|
|
|
|
def test_psk(self):
|
|
aead_id = HPKE.AEAD.AES128_GCM
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
aead_id=aead_id,
|
|
psk=(b'a', b'c' * 32))
|
|
|
|
def test_info(self):
|
|
aead_id = HPKE.AEAD.AES128_GCM
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
aead_id=aead_id,
|
|
info=b'baba')
|
|
|
|
def test_neg_unsupported_curve(self):
|
|
key3 = ECC.generate(curve='p224')
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=key3.public_key(),
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
self.assertIn("Unsupported curve", str(cm.exception))
|
|
|
|
def test_neg_too_many_private_keys(self):
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1,
|
|
sender_key=self.key2,
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
self.assertIn("Exactly 1 private key", str(cm.exception))
|
|
|
|
def test_neg_curve_mismatch(self):
|
|
key3 = ECC.generate(curve='p384')
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
sender_key=key3,
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
self.assertIn("but recipient key", str(cm.exception))
|
|
|
|
def test_neg_psk(self):
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
psk=(b'', b'G' * 32),
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
psk=(b'JJJ', b''),
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
psk=(b'JJJ', b'Y' * 31),
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
self.assertIn("at least 32", str(cm.exception))
|
|
|
|
def test_neg_wrong_enc(self):
|
|
wrong_enc = b'\xFF' + b'8' * 64
|
|
with self.assertRaises(DeserializeError):
|
|
HPKE.new(receiver_key=self.key1,
|
|
aead_id=HPKE.AEAD.AES128_GCM,
|
|
enc=wrong_enc)
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1.public_key(),
|
|
enc=self.key1.public_key().export_key(format='raw'),
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
self.assertIn("'enc' cannot be an input", str(cm.exception))
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
HPKE.new(receiver_key=self.key1,
|
|
aead_id=HPKE.AEAD.AES128_GCM)
|
|
self.assertIn("'enc' required", str(cm.exception))
|
|
|
|
def test_neg_unseal_wrong_ct(self):
|
|
decryptor = HPKE.new(receiver_key=self.key1,
|
|
aead_id=HPKE.AEAD.CHACHA20_POLY1305,
|
|
enc=self.key2.public_key().export_key(format='raw'))
|
|
|
|
with self.assertRaises(ValueError):
|
|
decryptor.unseal(b'XYZ' * 20)
|
|
|
|
def test_neg_unseal_no_auth_data(self):
|
|
aead_id = HPKE.AEAD.CHACHA20_POLY1305
|
|
|
|
encryptor = HPKE.new(receiver_key=self.key1.public_key(),
|
|
aead_id=aead_id)
|
|
|
|
ct = encryptor.seal(b'ABC', auth_data=b'DEF')
|
|
|
|
decryptor = HPKE.new(receiver_key=self.key1,
|
|
aead_id=aead_id,
|
|
enc=encryptor.enc)
|
|
|
|
with self.assertRaises(ValueError):
|
|
decryptor.unseal(ct)
|
|
|
|
def test_x25519_mode_0(self):
|
|
# RFC x9180, A.1.1.1, seq 0 and 1
|
|
|
|
keyR_hex = "4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8"
|
|
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
|
|
|
|
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
|
|
pt = bytes.fromhex(pt_hex)
|
|
|
|
ct0_hex = "f938558b5d72f1a23810b4be2ab4f84331acc02fc97babc53a52ae8218a355a96d8770ac83d07bea87e13c512a"
|
|
ct0 = bytes.fromhex(ct0_hex)
|
|
|
|
enc_hex = "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431"
|
|
enc = bytes.fromhex(enc_hex)
|
|
|
|
aad0_hex = "436f756e742d30"
|
|
aad0 = bytes.fromhex(aad0_hex)
|
|
|
|
aad1_hex = "436f756e742d31"
|
|
aad1 = bytes.fromhex(aad1_hex)
|
|
|
|
info_hex = "4f6465206f6e2061204772656369616e2055726e"
|
|
info = bytes.fromhex(info_hex)
|
|
|
|
ct1_hex = "af2d7e9ac9ae7e270f46ba1f975be53c09f8d875bdc8535458c2494e8a6eab251c03d0c22a56b8ca42c2063b84"
|
|
ct1 = bytes.fromhex(ct1_hex)
|
|
|
|
aead_id = HPKE.AEAD.AES128_GCM
|
|
|
|
decryptor = HPKE.new(receiver_key=keyR,
|
|
aead_id=aead_id,
|
|
info=info,
|
|
enc=enc)
|
|
|
|
pt_X0 = decryptor.unseal(ct0, aad0)
|
|
self.assertEqual(pt_X0, pt)
|
|
|
|
pt_X1 = decryptor.unseal(ct1, aad1)
|
|
self.assertEqual(pt_X1, pt)
|
|
|
|
def test_x25519_mode_1(self):
|
|
# RFC x9180, A.1.2.1, seq 0 and 1
|
|
|
|
keyR_hex = "c5eb01eb457fe6c6f57577c5413b931550a162c71a03ac8d196babbd4e5ce0fd"
|
|
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
|
|
|
|
psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961"
|
|
psk_id = bytes.fromhex(psk_id_hex)
|
|
|
|
psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82"
|
|
psk = bytes.fromhex(psk_hex)
|
|
|
|
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
|
|
pt = bytes.fromhex(pt_hex)
|
|
|
|
ct0_hex = "e52c6fed7f758d0cf7145689f21bc1be6ec9ea097fef4e959440012f4feb73fb611b946199e681f4cfc34db8ea"
|
|
ct0 = bytes.fromhex(ct0_hex)
|
|
|
|
enc_hex = "0ad0950d9fb9588e59690b74f1237ecdf1d775cd60be2eca57af5a4b0471c91b"
|
|
enc = bytes.fromhex(enc_hex)
|
|
|
|
aad0_hex = "436f756e742d30"
|
|
aad0 = bytes.fromhex(aad0_hex)
|
|
|
|
aad1_hex = "436f756e742d31"
|
|
aad1 = bytes.fromhex(aad1_hex)
|
|
|
|
info_hex = "4f6465206f6e2061204772656369616e2055726e"
|
|
info = bytes.fromhex(info_hex)
|
|
|
|
ct1_hex = "49f3b19b28a9ea9f43e8c71204c00d4a490ee7f61387b6719db765e948123b45b61633ef059ba22cd62437c8ba"
|
|
ct1 = bytes.fromhex(ct1_hex)
|
|
|
|
aead_id = HPKE.AEAD.AES128_GCM
|
|
|
|
decryptor = HPKE.new(receiver_key=keyR,
|
|
aead_id=aead_id,
|
|
info=info,
|
|
psk=(psk_id, psk),
|
|
enc=enc)
|
|
|
|
pt_X0 = decryptor.unseal(ct0, aad0)
|
|
self.assertEqual(pt_X0, pt)
|
|
|
|
pt_X1 = decryptor.unseal(ct1, aad1)
|
|
self.assertEqual(pt_X1, pt)
|
|
|
|
def test_x25519_mode_2(self):
|
|
# RFC x9180, A.1.3.1, seq 0 and 1
|
|
|
|
keyR_hex = "fdea67cf831f1ca98d8e27b1f6abeb5b7745e9d35348b80fa407ff6958f9137e"
|
|
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
|
|
|
|
keyS_hex = "dc4a146313cce60a278a5323d321f051c5707e9c45ba21a3479fecdf76fc69dd"
|
|
keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex))
|
|
|
|
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
|
|
pt = bytes.fromhex(pt_hex)
|
|
|
|
ct0_hex = "5fd92cc9d46dbf8943e72a07e42f363ed5f721212cd90bcfd072bfd9f44e06b80fd17824947496e21b680c141b"
|
|
ct0 = bytes.fromhex(ct0_hex)
|
|
|
|
enc_hex = "23fb952571a14a25e3d678140cd0e5eb47a0961bb18afcf85896e5453c312e76"
|
|
enc = bytes.fromhex(enc_hex)
|
|
|
|
aad0_hex = "436f756e742d30"
|
|
aad0 = bytes.fromhex(aad0_hex)
|
|
|
|
aad1_hex = "436f756e742d31"
|
|
aad1 = bytes.fromhex(aad1_hex)
|
|
|
|
info_hex = "4f6465206f6e2061204772656369616e2055726e"
|
|
info = bytes.fromhex(info_hex)
|
|
|
|
ct1_hex = "d3736bb256c19bfa93d79e8f80b7971262cb7c887e35c26370cfed62254369a1b52e3d505b79dd699f002bc8ed"
|
|
ct1 = bytes.fromhex(ct1_hex)
|
|
|
|
aead_id = HPKE.AEAD.AES128_GCM
|
|
|
|
decryptor = HPKE.new(receiver_key=keyR,
|
|
sender_key=keyS.public_key(),
|
|
aead_id=aead_id,
|
|
info=info,
|
|
enc=enc)
|
|
|
|
pt_X0 = decryptor.unseal(ct0, aad0)
|
|
self.assertEqual(pt_X0, pt)
|
|
|
|
pt_X1 = decryptor.unseal(ct1, aad1)
|
|
self.assertEqual(pt_X1, pt)
|
|
|
|
def test_x25519_mode_3(self):
|
|
# RFC x9180, A.1.4.1, seq 0 and 1
|
|
|
|
keyR_hex = "cb29a95649dc5656c2d054c1aa0d3df0493155e9d5da6d7e344ed8b6a64a9423"
|
|
keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
|
|
|
|
keyS_hex = "fc1c87d2f3832adb178b431fce2ac77c7ca2fd680f3406c77b5ecdf818b119f4"
|
|
keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex))
|
|
|
|
psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961"
|
|
psk_id = bytes.fromhex(psk_id_hex)
|
|
|
|
psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82"
|
|
psk = bytes.fromhex(psk_hex)
|
|
|
|
pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
|
|
pt = bytes.fromhex(pt_hex)
|
|
|
|
ct0_hex = "a84c64df1e11d8fd11450039d4fe64ff0c8a99fca0bd72c2d4c3e0400bc14a40f27e45e141a24001697737533e"
|
|
ct0 = bytes.fromhex(ct0_hex)
|
|
|
|
enc_hex = "820818d3c23993492cc5623ab437a48a0a7ca3e9639c140fe1e33811eb844b7c"
|
|
enc = bytes.fromhex(enc_hex)
|
|
|
|
aad0_hex = "436f756e742d30"
|
|
aad0 = bytes.fromhex(aad0_hex)
|
|
|
|
aad1_hex = "436f756e742d31"
|
|
aad1 = bytes.fromhex(aad1_hex)
|
|
|
|
info_hex = "4f6465206f6e2061204772656369616e2055726e"
|
|
info = bytes.fromhex(info_hex)
|
|
|
|
ct1_hex = "4d19303b848f424fc3c3beca249b2c6de0a34083b8e909b6aa4c3688505c05ffe0c8f57a0a4c5ab9da127435d9"
|
|
ct1 = bytes.fromhex(ct1_hex)
|
|
|
|
aead_id = HPKE.AEAD.AES128_GCM
|
|
|
|
decryptor = HPKE.new(receiver_key=keyR,
|
|
sender_key=keyS.public_key(),
|
|
aead_id=aead_id,
|
|
psk=(psk_id, psk),
|
|
info=info,
|
|
enc=enc)
|
|
|
|
pt_X0 = decryptor.unseal(ct0, aad0)
|
|
self.assertEqual(pt_X0, pt)
|
|
|
|
pt_X1 = decryptor.unseal(ct1, aad1)
|
|
self.assertEqual(pt_X1, pt)
|
|
|
|
|
|
class HPKE_TestVectors(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.vectors = []
|
|
try:
|
|
import pycryptodome_test_vectors # type: ignore
|
|
init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
|
|
full_file_name = os.path.join(init_dir, "Protocol", "HPKE-test-vectors.json")
|
|
with open(full_file_name, "r") as f:
|
|
self.vectors = json.load(f)
|
|
except (FileNotFoundError, ImportError):
|
|
print("\nWarning: skipping extended tests for HPKE (install pycryptodome-test-vectors)")
|
|
|
|
def import_private_key(self, key_hex, kem_id):
|
|
key_bin = unhexlify(key_hex)
|
|
if kem_id == 0x0010:
|
|
return ECC.construct(curve='p256', d=int.from_bytes(key_bin,
|
|
byteorder="big"))
|
|
elif kem_id == 0x0011:
|
|
return ECC.construct(curve='p384', d=int.from_bytes(key_bin,
|
|
byteorder="big"))
|
|
elif kem_id == 0x0012:
|
|
return ECC.construct(curve='p521', d=int.from_bytes(key_bin,
|
|
byteorder="big"))
|
|
elif kem_id == 0x0020:
|
|
return DH.import_x25519_private_key(key_bin)
|
|
elif kem_id == 0x0021:
|
|
return DH.import_x448_private_key(key_bin)
|
|
|
|
def test_hpke_encap(self):
|
|
"""Test HPKE encapsulation using test vectors."""
|
|
|
|
if not self.vectors:
|
|
self.skipTest("No test vectors available")
|
|
|
|
for idx, vector in enumerate(self.vectors):
|
|
|
|
kem_id = vector["kem_id"]
|
|
kdf_id = vector["kdf_id"]
|
|
aead_id = vector["aead_id"]
|
|
|
|
# No export-only pseudo-cipher
|
|
if aead_id == 0xffff:
|
|
continue
|
|
|
|
# We support only one KDF per curve
|
|
supported_combi = {
|
|
(0x10, 0x1): SHA256,
|
|
(0x11, 0x2): SHA384,
|
|
(0x12, 0x3): SHA512,
|
|
(0x20, 0x1): SHA256,
|
|
(0x21, 0x3): SHA512,
|
|
}
|
|
hashmod = supported_combi.get((kem_id, kdf_id))
|
|
if hashmod is None:
|
|
continue
|
|
|
|
with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id):
|
|
|
|
receiver_pub = self.import_private_key(vector["skRm"],
|
|
kem_id).public_key()
|
|
|
|
sender_priv = None
|
|
if "skSm" in vector:
|
|
sender_priv = self.import_private_key(vector["skSm"],
|
|
kem_id)
|
|
|
|
encap_key = self.import_private_key(vector["skEm"], kem_id)
|
|
|
|
shared_secret, enc = HPKE.HPKE_Cipher._encap(receiver_pub,
|
|
kem_id,
|
|
hashmod,
|
|
sender_priv,
|
|
encap_key)
|
|
self.assertEqual(enc.hex(), vector["enc"])
|
|
self.assertEqual(shared_secret,
|
|
unhexlify(vector["shared_secret"]))
|
|
|
|
print(".", end="", flush=True)
|
|
|
|
def test_hpke_unseal(self):
|
|
"""Test HPKE encryption and decryption using test vectors."""
|
|
|
|
if not self.vectors:
|
|
self.skipTest("No test vectors available")
|
|
|
|
for idx, vector in enumerate(self.vectors):
|
|
|
|
kem_id = vector["kem_id"]
|
|
kdf_id = vector["kdf_id"]
|
|
aead_id = vector["aead_id"]
|
|
|
|
# No export-only pseudo-cipher
|
|
if aead_id == 0xffff:
|
|
continue
|
|
|
|
# We support only one KDF per curve
|
|
supported_combi = (
|
|
(0x10, 0x1),
|
|
(0x11, 0x2),
|
|
(0x12, 0x3),
|
|
(0x20, 0x1),
|
|
(0x21, 0x3),
|
|
)
|
|
if (kem_id, kdf_id) not in supported_combi:
|
|
continue
|
|
|
|
with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id):
|
|
|
|
receiver_priv = self.import_private_key(vector["skRm"],
|
|
kem_id)
|
|
|
|
sender_pub = None
|
|
if "skSm" in vector:
|
|
sender_priv = self.import_private_key(vector["skSm"],
|
|
kem_id)
|
|
sender_pub = sender_priv.public_key()
|
|
|
|
encap_key = unhexlify(vector["enc"])
|
|
|
|
psk = None
|
|
if "psk_id" in vector:
|
|
psk = unhexlify(vector["psk_id"]), unhexlify(vector["psk"])
|
|
|
|
receiver_hpke = HPKE.new(receiver_key=receiver_priv,
|
|
aead_id=HPKE.AEAD(aead_id),
|
|
enc=encap_key,
|
|
sender_key=sender_pub,
|
|
psk=psk,
|
|
info=unhexlify(vector["info"]))
|
|
|
|
for encryption in vector['encryptions']:
|
|
|
|
plaintext = unhexlify(encryption["pt"])
|
|
ciphertext = unhexlify(encryption["ct"])
|
|
aad = unhexlify(encryption["aad"])
|
|
|
|
# Decrypt (unseal)
|
|
decrypted = receiver_hpke.unseal(ciphertext, aad)
|
|
self.assertEqual(decrypted, plaintext, "Decryption failed")
|
|
|
|
print(".", end="", flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
|
|
|
|
def get_tests(config={}):
|
|
|
|
tests = []
|
|
tests += list_test_cases(HPKE_Tests)
|
|
|
|
if config.get('slow_tests'):
|
|
tests += list_test_cases(HPKE_TestVectors)
|
|
|
|
return tests
|
|
|
|
|
|
if __name__ == '__main__':
|
|
def suite():
|
|
return unittest.TestSuite(get_tests())
|
|
unittest.main(defaultTest='suite')
|