hkdf.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import annotations
  5. import typing
  6. from cryptography import utils
  7. from cryptography.exceptions import AlreadyFinalized, InvalidKey
  8. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  9. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  10. class HKDF(KeyDerivationFunction):
  11. def __init__(
  12. self,
  13. algorithm: hashes.HashAlgorithm,
  14. length: int,
  15. salt: bytes | None,
  16. info: bytes | None,
  17. backend: typing.Any = None,
  18. ):
  19. self._algorithm = algorithm
  20. if salt is None:
  21. salt = b"\x00" * self._algorithm.digest_size
  22. else:
  23. utils._check_bytes("salt", salt)
  24. self._salt = salt
  25. self._hkdf_expand = HKDFExpand(self._algorithm, length, info)
  26. def _extract(self, key_material: bytes) -> bytes:
  27. h = hmac.HMAC(self._salt, self._algorithm)
  28. h.update(key_material)
  29. return h.finalize()
  30. def derive(self, key_material: bytes) -> bytes:
  31. utils._check_byteslike("key_material", key_material)
  32. return self._hkdf_expand.derive(self._extract(key_material))
  33. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  34. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  35. raise InvalidKey
  36. class HKDFExpand(KeyDerivationFunction):
  37. def __init__(
  38. self,
  39. algorithm: hashes.HashAlgorithm,
  40. length: int,
  41. info: bytes | None,
  42. backend: typing.Any = None,
  43. ):
  44. self._algorithm = algorithm
  45. max_length = 255 * algorithm.digest_size
  46. if length > max_length:
  47. raise ValueError(
  48. f"Cannot derive keys larger than {max_length} octets."
  49. )
  50. self._length = length
  51. if info is None:
  52. info = b""
  53. else:
  54. utils._check_bytes("info", info)
  55. self._info = info
  56. self._used = False
  57. def _expand(self, key_material: bytes) -> bytes:
  58. output = [b""]
  59. counter = 1
  60. while self._algorithm.digest_size * (len(output) - 1) < self._length:
  61. h = hmac.HMAC(key_material, self._algorithm)
  62. h.update(output[-1])
  63. h.update(self._info)
  64. h.update(bytes([counter]))
  65. output.append(h.finalize())
  66. counter += 1
  67. return b"".join(output)[: self._length]
  68. def derive(self, key_material: bytes) -> bytes:
  69. utils._check_byteslike("key_material", key_material)
  70. if self._used:
  71. raise AlreadyFinalized
  72. self._used = True
  73. return self._expand(key_material)
  74. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  75. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  76. raise InvalidKey