concatkdf.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import annotations
  5. import typing
  6. from cryptography import utils
  7. from cryptography.exceptions import AlreadyFinalized, InvalidKey
  8. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  9. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  10. def _int_to_u32be(n: int) -> bytes:
  11. return n.to_bytes(length=4, byteorder="big")
  12. def _common_args_checks(
  13. algorithm: hashes.HashAlgorithm,
  14. length: int,
  15. otherinfo: bytes | None,
  16. ) -> None:
  17. max_length = algorithm.digest_size * (2**32 - 1)
  18. if length > max_length:
  19. raise ValueError(f"Cannot derive keys larger than {max_length} bits.")
  20. if otherinfo is not None:
  21. utils._check_bytes("otherinfo", otherinfo)
  22. def _concatkdf_derive(
  23. key_material: bytes,
  24. length: int,
  25. auxfn: typing.Callable[[], hashes.HashContext],
  26. otherinfo: bytes,
  27. ) -> bytes:
  28. utils._check_byteslike("key_material", key_material)
  29. output = [b""]
  30. outlen = 0
  31. counter = 1
  32. while length > outlen:
  33. h = auxfn()
  34. h.update(_int_to_u32be(counter))
  35. h.update(key_material)
  36. h.update(otherinfo)
  37. output.append(h.finalize())
  38. outlen += len(output[-1])
  39. counter += 1
  40. return b"".join(output)[:length]
  41. class ConcatKDFHash(KeyDerivationFunction):
  42. def __init__(
  43. self,
  44. algorithm: hashes.HashAlgorithm,
  45. length: int,
  46. otherinfo: bytes | None,
  47. backend: typing.Any = None,
  48. ):
  49. _common_args_checks(algorithm, length, otherinfo)
  50. self._algorithm = algorithm
  51. self._length = length
  52. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  53. self._used = False
  54. def _hash(self) -> hashes.Hash:
  55. return hashes.Hash(self._algorithm)
  56. def derive(self, key_material: bytes) -> bytes:
  57. if self._used:
  58. raise AlreadyFinalized
  59. self._used = True
  60. return _concatkdf_derive(
  61. key_material, self._length, self._hash, self._otherinfo
  62. )
  63. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  64. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  65. raise InvalidKey
  66. class ConcatKDFHMAC(KeyDerivationFunction):
  67. def __init__(
  68. self,
  69. algorithm: hashes.HashAlgorithm,
  70. length: int,
  71. salt: bytes | None,
  72. otherinfo: bytes | None,
  73. backend: typing.Any = None,
  74. ):
  75. _common_args_checks(algorithm, length, otherinfo)
  76. self._algorithm = algorithm
  77. self._length = length
  78. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  79. if algorithm.block_size is None:
  80. raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF")
  81. if salt is None:
  82. salt = b"\x00" * algorithm.block_size
  83. else:
  84. utils._check_bytes("salt", salt)
  85. self._salt = salt
  86. self._used = False
  87. def _hmac(self) -> hmac.HMAC:
  88. return hmac.HMAC(self._salt, self._algorithm)
  89. def derive(self, key_material: bytes) -> bytes:
  90. if self._used:
  91. raise AlreadyFinalized
  92. self._used = True
  93. return _concatkdf_derive(
  94. key_material, self._length, self._hmac, self._otherinfo
  95. )
  96. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  97. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  98. raise InvalidKey