ssh.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import annotations
  5. import binascii
  6. import enum
  7. import os
  8. import re
  9. import typing
  10. import warnings
  11. from base64 import encodebytes as _base64_encode
  12. from dataclasses import dataclass
  13. from cryptography import utils
  14. from cryptography.exceptions import UnsupportedAlgorithm
  15. from cryptography.hazmat.primitives import hashes
  16. from cryptography.hazmat.primitives.asymmetric import (
  17. dsa,
  18. ec,
  19. ed25519,
  20. padding,
  21. rsa,
  22. )
  23. from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
  24. from cryptography.hazmat.primitives.ciphers import (
  25. AEADDecryptionContext,
  26. Cipher,
  27. algorithms,
  28. modes,
  29. )
  30. from cryptography.hazmat.primitives.serialization import (
  31. Encoding,
  32. KeySerializationEncryption,
  33. NoEncryption,
  34. PrivateFormat,
  35. PublicFormat,
  36. _KeySerializationEncryption,
  37. )
  38. try:
  39. from bcrypt import kdf as _bcrypt_kdf
  40. _bcrypt_supported = True
  41. except ImportError:
  42. _bcrypt_supported = False
  43. def _bcrypt_kdf(
  44. password: bytes,
  45. salt: bytes,
  46. desired_key_bytes: int,
  47. rounds: int,
  48. ignore_few_rounds: bool = False,
  49. ) -> bytes:
  50. raise UnsupportedAlgorithm("Need bcrypt module")
  51. _SSH_ED25519 = b"ssh-ed25519"
  52. _SSH_RSA = b"ssh-rsa"
  53. _SSH_DSA = b"ssh-dss"
  54. _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
  55. _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
  56. _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
  57. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  58. # U2F application string suffixed pubkey
  59. _SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com"
  60. _SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com"
  61. # These are not key types, only algorithms, so they cannot appear
  62. # as a public key type
  63. _SSH_RSA_SHA256 = b"rsa-sha2-256"
  64. _SSH_RSA_SHA512 = b"rsa-sha2-512"
  65. _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
  66. _SK_MAGIC = b"openssh-key-v1\0"
  67. _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
  68. _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
  69. _BCRYPT = b"bcrypt"
  70. _NONE = b"none"
  71. _DEFAULT_CIPHER = b"aes256-ctr"
  72. _DEFAULT_ROUNDS = 16
  73. # re is only way to work on bytes-like data
  74. _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
  75. # padding for max blocksize
  76. _PADDING = memoryview(bytearray(range(1, 1 + 16)))
  77. @dataclass
  78. class _SSHCipher:
  79. alg: type[algorithms.AES]
  80. key_len: int
  81. mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
  82. block_len: int
  83. iv_len: int
  84. tag_len: int | None
  85. is_aead: bool
  86. # ciphers that are actually used in key wrapping
  87. _SSH_CIPHERS: dict[bytes, _SSHCipher] = {
  88. b"aes256-ctr": _SSHCipher(
  89. alg=algorithms.AES,
  90. key_len=32,
  91. mode=modes.CTR,
  92. block_len=16,
  93. iv_len=16,
  94. tag_len=None,
  95. is_aead=False,
  96. ),
  97. b"aes256-cbc": _SSHCipher(
  98. alg=algorithms.AES,
  99. key_len=32,
  100. mode=modes.CBC,
  101. block_len=16,
  102. iv_len=16,
  103. tag_len=None,
  104. is_aead=False,
  105. ),
  106. b"aes256-gcm@openssh.com": _SSHCipher(
  107. alg=algorithms.AES,
  108. key_len=32,
  109. mode=modes.GCM,
  110. block_len=16,
  111. iv_len=12,
  112. tag_len=16,
  113. is_aead=True,
  114. ),
  115. }
  116. # map local curve name to key type
  117. _ECDSA_KEY_TYPE = {
  118. "secp256r1": _ECDSA_NISTP256,
  119. "secp384r1": _ECDSA_NISTP384,
  120. "secp521r1": _ECDSA_NISTP521,
  121. }
  122. def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
  123. if isinstance(key, ec.EllipticCurvePrivateKey):
  124. key_type = _ecdsa_key_type(key.public_key())
  125. elif isinstance(key, ec.EllipticCurvePublicKey):
  126. key_type = _ecdsa_key_type(key)
  127. elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
  128. key_type = _SSH_RSA
  129. elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
  130. key_type = _SSH_DSA
  131. elif isinstance(
  132. key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
  133. ):
  134. key_type = _SSH_ED25519
  135. else:
  136. raise ValueError("Unsupported key type")
  137. return key_type
  138. def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
  139. """Return SSH key_type and curve_name for private key."""
  140. curve = public_key.curve
  141. if curve.name not in _ECDSA_KEY_TYPE:
  142. raise ValueError(
  143. f"Unsupported curve for ssh private key: {curve.name!r}"
  144. )
  145. return _ECDSA_KEY_TYPE[curve.name]
  146. def _ssh_pem_encode(
  147. data: bytes,
  148. prefix: bytes = _SK_START + b"\n",
  149. suffix: bytes = _SK_END + b"\n",
  150. ) -> bytes:
  151. return b"".join([prefix, _base64_encode(data), suffix])
  152. def _check_block_size(data: bytes, block_len: int) -> None:
  153. """Require data to be full blocks"""
  154. if not data or len(data) % block_len != 0:
  155. raise ValueError("Corrupt data: missing padding")
  156. def _check_empty(data: bytes) -> None:
  157. """All data should have been parsed."""
  158. if data:
  159. raise ValueError("Corrupt data: unparsed data")
  160. def _init_cipher(
  161. ciphername: bytes,
  162. password: bytes | None,
  163. salt: bytes,
  164. rounds: int,
  165. ) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
  166. """Generate key + iv and return cipher."""
  167. if not password:
  168. raise ValueError("Key is password-protected.")
  169. ciph = _SSH_CIPHERS[ciphername]
  170. seed = _bcrypt_kdf(
  171. password, salt, ciph.key_len + ciph.iv_len, rounds, True
  172. )
  173. return Cipher(
  174. ciph.alg(seed[: ciph.key_len]),
  175. ciph.mode(seed[ciph.key_len :]),
  176. )
  177. def _get_u32(data: memoryview) -> tuple[int, memoryview]:
  178. """Uint32"""
  179. if len(data) < 4:
  180. raise ValueError("Invalid data")
  181. return int.from_bytes(data[:4], byteorder="big"), data[4:]
  182. def _get_u64(data: memoryview) -> tuple[int, memoryview]:
  183. """Uint64"""
  184. if len(data) < 8:
  185. raise ValueError("Invalid data")
  186. return int.from_bytes(data[:8], byteorder="big"), data[8:]
  187. def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
  188. """Bytes with u32 length prefix"""
  189. n, data = _get_u32(data)
  190. if n > len(data):
  191. raise ValueError("Invalid data")
  192. return data[:n], data[n:]
  193. def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
  194. """Big integer."""
  195. val, data = _get_sshstr(data)
  196. if val and val[0] > 0x7F:
  197. raise ValueError("Invalid data")
  198. return int.from_bytes(val, "big"), data
  199. def _to_mpint(val: int) -> bytes:
  200. """Storage format for signed bigint."""
  201. if val < 0:
  202. raise ValueError("negative mpint not allowed")
  203. if not val:
  204. return b""
  205. nbytes = (val.bit_length() + 8) // 8
  206. return utils.int_to_bytes(val, nbytes)
  207. class _FragList:
  208. """Build recursive structure without data copy."""
  209. flist: list[bytes]
  210. def __init__(self, init: list[bytes] | None = None) -> None:
  211. self.flist = []
  212. if init:
  213. self.flist.extend(init)
  214. def put_raw(self, val: bytes) -> None:
  215. """Add plain bytes"""
  216. self.flist.append(val)
  217. def put_u32(self, val: int) -> None:
  218. """Big-endian uint32"""
  219. self.flist.append(val.to_bytes(length=4, byteorder="big"))
  220. def put_u64(self, val: int) -> None:
  221. """Big-endian uint64"""
  222. self.flist.append(val.to_bytes(length=8, byteorder="big"))
  223. def put_sshstr(self, val: bytes | _FragList) -> None:
  224. """Bytes prefixed with u32 length"""
  225. if isinstance(val, (bytes, memoryview, bytearray)):
  226. self.put_u32(len(val))
  227. self.flist.append(val)
  228. else:
  229. self.put_u32(val.size())
  230. self.flist.extend(val.flist)
  231. def put_mpint(self, val: int) -> None:
  232. """Big-endian bigint prefixed with u32 length"""
  233. self.put_sshstr(_to_mpint(val))
  234. def size(self) -> int:
  235. """Current number of bytes"""
  236. return sum(map(len, self.flist))
  237. def render(self, dstbuf: memoryview, pos: int = 0) -> int:
  238. """Write into bytearray"""
  239. for frag in self.flist:
  240. flen = len(frag)
  241. start, pos = pos, pos + flen
  242. dstbuf[start:pos] = frag
  243. return pos
  244. def tobytes(self) -> bytes:
  245. """Return as bytes"""
  246. buf = memoryview(bytearray(self.size()))
  247. self.render(buf)
  248. return buf.tobytes()
  249. class _SSHFormatRSA:
  250. """Format for RSA keys.
  251. Public:
  252. mpint e, n
  253. Private:
  254. mpint n, e, d, iqmp, p, q
  255. """
  256. def get_public(
  257. self, data: memoryview
  258. ) -> tuple[tuple[int, int], memoryview]:
  259. """RSA public fields"""
  260. e, data = _get_mpint(data)
  261. n, data = _get_mpint(data)
  262. return (e, n), data
  263. def load_public(
  264. self, data: memoryview
  265. ) -> tuple[rsa.RSAPublicKey, memoryview]:
  266. """Make RSA public key from data."""
  267. (e, n), data = self.get_public(data)
  268. public_numbers = rsa.RSAPublicNumbers(e, n)
  269. public_key = public_numbers.public_key()
  270. return public_key, data
  271. def load_private(
  272. self, data: memoryview, pubfields
  273. ) -> tuple[rsa.RSAPrivateKey, memoryview]:
  274. """Make RSA private key from data."""
  275. n, data = _get_mpint(data)
  276. e, data = _get_mpint(data)
  277. d, data = _get_mpint(data)
  278. iqmp, data = _get_mpint(data)
  279. p, data = _get_mpint(data)
  280. q, data = _get_mpint(data)
  281. if (e, n) != pubfields:
  282. raise ValueError("Corrupt data: rsa field mismatch")
  283. dmp1 = rsa.rsa_crt_dmp1(d, p)
  284. dmq1 = rsa.rsa_crt_dmq1(d, q)
  285. public_numbers = rsa.RSAPublicNumbers(e, n)
  286. private_numbers = rsa.RSAPrivateNumbers(
  287. p, q, d, dmp1, dmq1, iqmp, public_numbers
  288. )
  289. private_key = private_numbers.private_key()
  290. return private_key, data
  291. def encode_public(
  292. self, public_key: rsa.RSAPublicKey, f_pub: _FragList
  293. ) -> None:
  294. """Write RSA public key"""
  295. pubn = public_key.public_numbers()
  296. f_pub.put_mpint(pubn.e)
  297. f_pub.put_mpint(pubn.n)
  298. def encode_private(
  299. self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
  300. ) -> None:
  301. """Write RSA private key"""
  302. private_numbers = private_key.private_numbers()
  303. public_numbers = private_numbers.public_numbers
  304. f_priv.put_mpint(public_numbers.n)
  305. f_priv.put_mpint(public_numbers.e)
  306. f_priv.put_mpint(private_numbers.d)
  307. f_priv.put_mpint(private_numbers.iqmp)
  308. f_priv.put_mpint(private_numbers.p)
  309. f_priv.put_mpint(private_numbers.q)
  310. class _SSHFormatDSA:
  311. """Format for DSA keys.
  312. Public:
  313. mpint p, q, g, y
  314. Private:
  315. mpint p, q, g, y, x
  316. """
  317. def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
  318. """DSA public fields"""
  319. p, data = _get_mpint(data)
  320. q, data = _get_mpint(data)
  321. g, data = _get_mpint(data)
  322. y, data = _get_mpint(data)
  323. return (p, q, g, y), data
  324. def load_public(
  325. self, data: memoryview
  326. ) -> tuple[dsa.DSAPublicKey, memoryview]:
  327. """Make DSA public key from data."""
  328. (p, q, g, y), data = self.get_public(data)
  329. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  330. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  331. self._validate(public_numbers)
  332. public_key = public_numbers.public_key()
  333. return public_key, data
  334. def load_private(
  335. self, data: memoryview, pubfields
  336. ) -> tuple[dsa.DSAPrivateKey, memoryview]:
  337. """Make DSA private key from data."""
  338. (p, q, g, y), data = self.get_public(data)
  339. x, data = _get_mpint(data)
  340. if (p, q, g, y) != pubfields:
  341. raise ValueError("Corrupt data: dsa field mismatch")
  342. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  343. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  344. self._validate(public_numbers)
  345. private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
  346. private_key = private_numbers.private_key()
  347. return private_key, data
  348. def encode_public(
  349. self, public_key: dsa.DSAPublicKey, f_pub: _FragList
  350. ) -> None:
  351. """Write DSA public key"""
  352. public_numbers = public_key.public_numbers()
  353. parameter_numbers = public_numbers.parameter_numbers
  354. self._validate(public_numbers)
  355. f_pub.put_mpint(parameter_numbers.p)
  356. f_pub.put_mpint(parameter_numbers.q)
  357. f_pub.put_mpint(parameter_numbers.g)
  358. f_pub.put_mpint(public_numbers.y)
  359. def encode_private(
  360. self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
  361. ) -> None:
  362. """Write DSA private key"""
  363. self.encode_public(private_key.public_key(), f_priv)
  364. f_priv.put_mpint(private_key.private_numbers().x)
  365. def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
  366. parameter_numbers = public_numbers.parameter_numbers
  367. if parameter_numbers.p.bit_length() != 1024:
  368. raise ValueError("SSH supports only 1024 bit DSA keys")
  369. class _SSHFormatECDSA:
  370. """Format for ECDSA keys.
  371. Public:
  372. str curve
  373. bytes point
  374. Private:
  375. str curve
  376. bytes point
  377. mpint secret
  378. """
  379. def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
  380. self.ssh_curve_name = ssh_curve_name
  381. self.curve = curve
  382. def get_public(
  383. self, data: memoryview
  384. ) -> tuple[tuple[memoryview, memoryview], memoryview]:
  385. """ECDSA public fields"""
  386. curve, data = _get_sshstr(data)
  387. point, data = _get_sshstr(data)
  388. if curve != self.ssh_curve_name:
  389. raise ValueError("Curve name mismatch")
  390. if point[0] != 4:
  391. raise NotImplementedError("Need uncompressed point")
  392. return (curve, point), data
  393. def load_public(
  394. self, data: memoryview
  395. ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
  396. """Make ECDSA public key from data."""
  397. (_, point), data = self.get_public(data)
  398. public_key = ec.EllipticCurvePublicKey.from_encoded_point(
  399. self.curve, point.tobytes()
  400. )
  401. return public_key, data
  402. def load_private(
  403. self, data: memoryview, pubfields
  404. ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
  405. """Make ECDSA private key from data."""
  406. (curve_name, point), data = self.get_public(data)
  407. secret, data = _get_mpint(data)
  408. if (curve_name, point) != pubfields:
  409. raise ValueError("Corrupt data: ecdsa field mismatch")
  410. private_key = ec.derive_private_key(secret, self.curve)
  411. return private_key, data
  412. def encode_public(
  413. self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
  414. ) -> None:
  415. """Write ECDSA public key"""
  416. point = public_key.public_bytes(
  417. Encoding.X962, PublicFormat.UncompressedPoint
  418. )
  419. f_pub.put_sshstr(self.ssh_curve_name)
  420. f_pub.put_sshstr(point)
  421. def encode_private(
  422. self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
  423. ) -> None:
  424. """Write ECDSA private key"""
  425. public_key = private_key.public_key()
  426. private_numbers = private_key.private_numbers()
  427. self.encode_public(public_key, f_priv)
  428. f_priv.put_mpint(private_numbers.private_value)
  429. class _SSHFormatEd25519:
  430. """Format for Ed25519 keys.
  431. Public:
  432. bytes point
  433. Private:
  434. bytes point
  435. bytes secret_and_point
  436. """
  437. def get_public(
  438. self, data: memoryview
  439. ) -> tuple[tuple[memoryview], memoryview]:
  440. """Ed25519 public fields"""
  441. point, data = _get_sshstr(data)
  442. return (point,), data
  443. def load_public(
  444. self, data: memoryview
  445. ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
  446. """Make Ed25519 public key from data."""
  447. (point,), data = self.get_public(data)
  448. public_key = ed25519.Ed25519PublicKey.from_public_bytes(
  449. point.tobytes()
  450. )
  451. return public_key, data
  452. def load_private(
  453. self, data: memoryview, pubfields
  454. ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
  455. """Make Ed25519 private key from data."""
  456. (point,), data = self.get_public(data)
  457. keypair, data = _get_sshstr(data)
  458. secret = keypair[:32]
  459. point2 = keypair[32:]
  460. if point != point2 or (point,) != pubfields:
  461. raise ValueError("Corrupt data: ed25519 field mismatch")
  462. private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
  463. return private_key, data
  464. def encode_public(
  465. self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
  466. ) -> None:
  467. """Write Ed25519 public key"""
  468. raw_public_key = public_key.public_bytes(
  469. Encoding.Raw, PublicFormat.Raw
  470. )
  471. f_pub.put_sshstr(raw_public_key)
  472. def encode_private(
  473. self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
  474. ) -> None:
  475. """Write Ed25519 private key"""
  476. public_key = private_key.public_key()
  477. raw_private_key = private_key.private_bytes(
  478. Encoding.Raw, PrivateFormat.Raw, NoEncryption()
  479. )
  480. raw_public_key = public_key.public_bytes(
  481. Encoding.Raw, PublicFormat.Raw
  482. )
  483. f_keypair = _FragList([raw_private_key, raw_public_key])
  484. self.encode_public(public_key, f_priv)
  485. f_priv.put_sshstr(f_keypair)
  486. def load_application(data) -> tuple[memoryview, memoryview]:
  487. """
  488. U2F application strings
  489. """
  490. application, data = _get_sshstr(data)
  491. if not application.tobytes().startswith(b"ssh:"):
  492. raise ValueError(
  493. "U2F application string does not start with b'ssh:' "
  494. f"({application})"
  495. )
  496. return application, data
  497. class _SSHFormatSKEd25519:
  498. """
  499. The format of a sk-ssh-ed25519@openssh.com public key is:
  500. string "sk-ssh-ed25519@openssh.com"
  501. string public key
  502. string application (user-specified, but typically "ssh:")
  503. """
  504. def load_public(
  505. self, data: memoryview
  506. ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
  507. """Make Ed25519 public key from data."""
  508. public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data)
  509. _, data = load_application(data)
  510. return public_key, data
  511. class _SSHFormatSKECDSA:
  512. """
  513. The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is:
  514. string "sk-ecdsa-sha2-nistp256@openssh.com"
  515. string curve name
  516. ec_point Q
  517. string application (user-specified, but typically "ssh:")
  518. """
  519. def load_public(
  520. self, data: memoryview
  521. ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
  522. """Make ECDSA public key from data."""
  523. public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data)
  524. _, data = load_application(data)
  525. return public_key, data
  526. _KEY_FORMATS = {
  527. _SSH_RSA: _SSHFormatRSA(),
  528. _SSH_DSA: _SSHFormatDSA(),
  529. _SSH_ED25519: _SSHFormatEd25519(),
  530. _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
  531. _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
  532. _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
  533. _SK_SSH_ED25519: _SSHFormatSKEd25519(),
  534. _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(),
  535. }
  536. def _lookup_kformat(key_type: bytes):
  537. """Return valid format or throw error"""
  538. if not isinstance(key_type, bytes):
  539. key_type = memoryview(key_type).tobytes()
  540. if key_type in _KEY_FORMATS:
  541. return _KEY_FORMATS[key_type]
  542. raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
  543. SSHPrivateKeyTypes = typing.Union[
  544. ec.EllipticCurvePrivateKey,
  545. rsa.RSAPrivateKey,
  546. dsa.DSAPrivateKey,
  547. ed25519.Ed25519PrivateKey,
  548. ]
  549. def load_ssh_private_key(
  550. data: bytes,
  551. password: bytes | None,
  552. backend: typing.Any = None,
  553. ) -> SSHPrivateKeyTypes:
  554. """Load private key from OpenSSH custom encoding."""
  555. utils._check_byteslike("data", data)
  556. if password is not None:
  557. utils._check_bytes("password", password)
  558. m = _PEM_RC.search(data)
  559. if not m:
  560. raise ValueError("Not OpenSSH private key format")
  561. p1 = m.start(1)
  562. p2 = m.end(1)
  563. data = binascii.a2b_base64(memoryview(data)[p1:p2])
  564. if not data.startswith(_SK_MAGIC):
  565. raise ValueError("Not OpenSSH private key format")
  566. data = memoryview(data)[len(_SK_MAGIC) :]
  567. # parse header
  568. ciphername, data = _get_sshstr(data)
  569. kdfname, data = _get_sshstr(data)
  570. kdfoptions, data = _get_sshstr(data)
  571. nkeys, data = _get_u32(data)
  572. if nkeys != 1:
  573. raise ValueError("Only one key supported")
  574. # load public key data
  575. pubdata, data = _get_sshstr(data)
  576. pub_key_type, pubdata = _get_sshstr(pubdata)
  577. kformat = _lookup_kformat(pub_key_type)
  578. pubfields, pubdata = kformat.get_public(pubdata)
  579. _check_empty(pubdata)
  580. if (ciphername, kdfname) != (_NONE, _NONE):
  581. ciphername_bytes = ciphername.tobytes()
  582. if ciphername_bytes not in _SSH_CIPHERS:
  583. raise UnsupportedAlgorithm(
  584. f"Unsupported cipher: {ciphername_bytes!r}"
  585. )
  586. if kdfname != _BCRYPT:
  587. raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
  588. blklen = _SSH_CIPHERS[ciphername_bytes].block_len
  589. tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
  590. # load secret data
  591. edata, data = _get_sshstr(data)
  592. # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
  593. # information about how OpenSSH handles AEAD tags
  594. if _SSH_CIPHERS[ciphername_bytes].is_aead:
  595. tag = bytes(data)
  596. if len(tag) != tag_len:
  597. raise ValueError("Corrupt data: invalid tag length for cipher")
  598. else:
  599. _check_empty(data)
  600. _check_block_size(edata, blklen)
  601. salt, kbuf = _get_sshstr(kdfoptions)
  602. rounds, kbuf = _get_u32(kbuf)
  603. _check_empty(kbuf)
  604. ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
  605. dec = ciph.decryptor()
  606. edata = memoryview(dec.update(edata))
  607. if _SSH_CIPHERS[ciphername_bytes].is_aead:
  608. assert isinstance(dec, AEADDecryptionContext)
  609. _check_empty(dec.finalize_with_tag(tag))
  610. else:
  611. # _check_block_size requires data to be a full block so there
  612. # should be no output from finalize
  613. _check_empty(dec.finalize())
  614. else:
  615. # load secret data
  616. edata, data = _get_sshstr(data)
  617. _check_empty(data)
  618. blklen = 8
  619. _check_block_size(edata, blklen)
  620. ck1, edata = _get_u32(edata)
  621. ck2, edata = _get_u32(edata)
  622. if ck1 != ck2:
  623. raise ValueError("Corrupt data: broken checksum")
  624. # load per-key struct
  625. key_type, edata = _get_sshstr(edata)
  626. if key_type != pub_key_type:
  627. raise ValueError("Corrupt data: key type mismatch")
  628. private_key, edata = kformat.load_private(edata, pubfields)
  629. # We don't use the comment
  630. _, edata = _get_sshstr(edata)
  631. # yes, SSH does padding check *after* all other parsing is done.
  632. # need to follow as it writes zero-byte padding too.
  633. if edata != _PADDING[: len(edata)]:
  634. raise ValueError("Corrupt data: invalid padding")
  635. if isinstance(private_key, dsa.DSAPrivateKey):
  636. warnings.warn(
  637. "SSH DSA keys are deprecated and will be removed in a future "
  638. "release.",
  639. utils.DeprecatedIn40,
  640. stacklevel=2,
  641. )
  642. return private_key
  643. def _serialize_ssh_private_key(
  644. private_key: SSHPrivateKeyTypes,
  645. password: bytes,
  646. encryption_algorithm: KeySerializationEncryption,
  647. ) -> bytes:
  648. """Serialize private key with OpenSSH custom encoding."""
  649. utils._check_bytes("password", password)
  650. if isinstance(private_key, dsa.DSAPrivateKey):
  651. warnings.warn(
  652. "SSH DSA key support is deprecated and will be "
  653. "removed in a future release",
  654. utils.DeprecatedIn40,
  655. stacklevel=4,
  656. )
  657. key_type = _get_ssh_key_type(private_key)
  658. kformat = _lookup_kformat(key_type)
  659. # setup parameters
  660. f_kdfoptions = _FragList()
  661. if password:
  662. ciphername = _DEFAULT_CIPHER
  663. blklen = _SSH_CIPHERS[ciphername].block_len
  664. kdfname = _BCRYPT
  665. rounds = _DEFAULT_ROUNDS
  666. if (
  667. isinstance(encryption_algorithm, _KeySerializationEncryption)
  668. and encryption_algorithm._kdf_rounds is not None
  669. ):
  670. rounds = encryption_algorithm._kdf_rounds
  671. salt = os.urandom(16)
  672. f_kdfoptions.put_sshstr(salt)
  673. f_kdfoptions.put_u32(rounds)
  674. ciph = _init_cipher(ciphername, password, salt, rounds)
  675. else:
  676. ciphername = kdfname = _NONE
  677. blklen = 8
  678. ciph = None
  679. nkeys = 1
  680. checkval = os.urandom(4)
  681. comment = b""
  682. # encode public and private parts together
  683. f_public_key = _FragList()
  684. f_public_key.put_sshstr(key_type)
  685. kformat.encode_public(private_key.public_key(), f_public_key)
  686. f_secrets = _FragList([checkval, checkval])
  687. f_secrets.put_sshstr(key_type)
  688. kformat.encode_private(private_key, f_secrets)
  689. f_secrets.put_sshstr(comment)
  690. f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
  691. # top-level structure
  692. f_main = _FragList()
  693. f_main.put_raw(_SK_MAGIC)
  694. f_main.put_sshstr(ciphername)
  695. f_main.put_sshstr(kdfname)
  696. f_main.put_sshstr(f_kdfoptions)
  697. f_main.put_u32(nkeys)
  698. f_main.put_sshstr(f_public_key)
  699. f_main.put_sshstr(f_secrets)
  700. # copy result info bytearray
  701. slen = f_secrets.size()
  702. mlen = f_main.size()
  703. buf = memoryview(bytearray(mlen + blklen))
  704. f_main.render(buf)
  705. ofs = mlen - slen
  706. # encrypt in-place
  707. if ciph is not None:
  708. ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
  709. return _ssh_pem_encode(buf[:mlen])
  710. SSHPublicKeyTypes = typing.Union[
  711. ec.EllipticCurvePublicKey,
  712. rsa.RSAPublicKey,
  713. dsa.DSAPublicKey,
  714. ed25519.Ed25519PublicKey,
  715. ]
  716. SSHCertPublicKeyTypes = typing.Union[
  717. ec.EllipticCurvePublicKey,
  718. rsa.RSAPublicKey,
  719. ed25519.Ed25519PublicKey,
  720. ]
  721. class SSHCertificateType(enum.Enum):
  722. USER = 1
  723. HOST = 2
  724. class SSHCertificate:
  725. def __init__(
  726. self,
  727. _nonce: memoryview,
  728. _public_key: SSHPublicKeyTypes,
  729. _serial: int,
  730. _cctype: int,
  731. _key_id: memoryview,
  732. _valid_principals: list[bytes],
  733. _valid_after: int,
  734. _valid_before: int,
  735. _critical_options: dict[bytes, bytes],
  736. _extensions: dict[bytes, bytes],
  737. _sig_type: memoryview,
  738. _sig_key: memoryview,
  739. _inner_sig_type: memoryview,
  740. _signature: memoryview,
  741. _tbs_cert_body: memoryview,
  742. _cert_key_type: bytes,
  743. _cert_body: memoryview,
  744. ):
  745. self._nonce = _nonce
  746. self._public_key = _public_key
  747. self._serial = _serial
  748. try:
  749. self._type = SSHCertificateType(_cctype)
  750. except ValueError:
  751. raise ValueError("Invalid certificate type")
  752. self._key_id = _key_id
  753. self._valid_principals = _valid_principals
  754. self._valid_after = _valid_after
  755. self._valid_before = _valid_before
  756. self._critical_options = _critical_options
  757. self._extensions = _extensions
  758. self._sig_type = _sig_type
  759. self._sig_key = _sig_key
  760. self._inner_sig_type = _inner_sig_type
  761. self._signature = _signature
  762. self._cert_key_type = _cert_key_type
  763. self._cert_body = _cert_body
  764. self._tbs_cert_body = _tbs_cert_body
  765. @property
  766. def nonce(self) -> bytes:
  767. return bytes(self._nonce)
  768. def public_key(self) -> SSHCertPublicKeyTypes:
  769. # make mypy happy until we remove DSA support entirely and
  770. # the underlying union won't have a disallowed type
  771. return typing.cast(SSHCertPublicKeyTypes, self._public_key)
  772. @property
  773. def serial(self) -> int:
  774. return self._serial
  775. @property
  776. def type(self) -> SSHCertificateType:
  777. return self._type
  778. @property
  779. def key_id(self) -> bytes:
  780. return bytes(self._key_id)
  781. @property
  782. def valid_principals(self) -> list[bytes]:
  783. return self._valid_principals
  784. @property
  785. def valid_before(self) -> int:
  786. return self._valid_before
  787. @property
  788. def valid_after(self) -> int:
  789. return self._valid_after
  790. @property
  791. def critical_options(self) -> dict[bytes, bytes]:
  792. return self._critical_options
  793. @property
  794. def extensions(self) -> dict[bytes, bytes]:
  795. return self._extensions
  796. def signature_key(self) -> SSHCertPublicKeyTypes:
  797. sigformat = _lookup_kformat(self._sig_type)
  798. signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
  799. _check_empty(sigkey_rest)
  800. return signature_key
  801. def public_bytes(self) -> bytes:
  802. return (
  803. bytes(self._cert_key_type)
  804. + b" "
  805. + binascii.b2a_base64(bytes(self._cert_body), newline=False)
  806. )
  807. def verify_cert_signature(self) -> None:
  808. signature_key = self.signature_key()
  809. if isinstance(signature_key, ed25519.Ed25519PublicKey):
  810. signature_key.verify(
  811. bytes(self._signature), bytes(self._tbs_cert_body)
  812. )
  813. elif isinstance(signature_key, ec.EllipticCurvePublicKey):
  814. # The signature is encoded as a pair of big-endian integers
  815. r, data = _get_mpint(self._signature)
  816. s, data = _get_mpint(data)
  817. _check_empty(data)
  818. computed_sig = asym_utils.encode_dss_signature(r, s)
  819. hash_alg = _get_ec_hash_alg(signature_key.curve)
  820. signature_key.verify(
  821. computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
  822. )
  823. else:
  824. assert isinstance(signature_key, rsa.RSAPublicKey)
  825. if self._inner_sig_type == _SSH_RSA:
  826. hash_alg = hashes.SHA1()
  827. elif self._inner_sig_type == _SSH_RSA_SHA256:
  828. hash_alg = hashes.SHA256()
  829. else:
  830. assert self._inner_sig_type == _SSH_RSA_SHA512
  831. hash_alg = hashes.SHA512()
  832. signature_key.verify(
  833. bytes(self._signature),
  834. bytes(self._tbs_cert_body),
  835. padding.PKCS1v15(),
  836. hash_alg,
  837. )
  838. def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
  839. if isinstance(curve, ec.SECP256R1):
  840. return hashes.SHA256()
  841. elif isinstance(curve, ec.SECP384R1):
  842. return hashes.SHA384()
  843. else:
  844. assert isinstance(curve, ec.SECP521R1)
  845. return hashes.SHA512()
  846. def _load_ssh_public_identity(
  847. data: bytes,
  848. _legacy_dsa_allowed=False,
  849. ) -> SSHCertificate | SSHPublicKeyTypes:
  850. utils._check_byteslike("data", data)
  851. m = _SSH_PUBKEY_RC.match(data)
  852. if not m:
  853. raise ValueError("Invalid line format")
  854. key_type = orig_key_type = m.group(1)
  855. key_body = m.group(2)
  856. with_cert = False
  857. if key_type.endswith(_CERT_SUFFIX):
  858. with_cert = True
  859. key_type = key_type[: -len(_CERT_SUFFIX)]
  860. if key_type == _SSH_DSA and not _legacy_dsa_allowed:
  861. raise UnsupportedAlgorithm(
  862. "DSA keys aren't supported in SSH certificates"
  863. )
  864. kformat = _lookup_kformat(key_type)
  865. try:
  866. rest = memoryview(binascii.a2b_base64(key_body))
  867. except (TypeError, binascii.Error):
  868. raise ValueError("Invalid format")
  869. if with_cert:
  870. cert_body = rest
  871. inner_key_type, rest = _get_sshstr(rest)
  872. if inner_key_type != orig_key_type:
  873. raise ValueError("Invalid key format")
  874. if with_cert:
  875. nonce, rest = _get_sshstr(rest)
  876. public_key, rest = kformat.load_public(rest)
  877. if with_cert:
  878. serial, rest = _get_u64(rest)
  879. cctype, rest = _get_u32(rest)
  880. key_id, rest = _get_sshstr(rest)
  881. principals, rest = _get_sshstr(rest)
  882. valid_principals = []
  883. while principals:
  884. principal, principals = _get_sshstr(principals)
  885. valid_principals.append(bytes(principal))
  886. valid_after, rest = _get_u64(rest)
  887. valid_before, rest = _get_u64(rest)
  888. crit_options, rest = _get_sshstr(rest)
  889. critical_options = _parse_exts_opts(crit_options)
  890. exts, rest = _get_sshstr(rest)
  891. extensions = _parse_exts_opts(exts)
  892. # Get the reserved field, which is unused.
  893. _, rest = _get_sshstr(rest)
  894. sig_key_raw, rest = _get_sshstr(rest)
  895. sig_type, sig_key = _get_sshstr(sig_key_raw)
  896. if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
  897. raise UnsupportedAlgorithm(
  898. "DSA signatures aren't supported in SSH certificates"
  899. )
  900. # Get the entire cert body and subtract the signature
  901. tbs_cert_body = cert_body[: -len(rest)]
  902. signature_raw, rest = _get_sshstr(rest)
  903. _check_empty(rest)
  904. inner_sig_type, sig_rest = _get_sshstr(signature_raw)
  905. # RSA certs can have multiple algorithm types
  906. if (
  907. sig_type == _SSH_RSA
  908. and inner_sig_type
  909. not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
  910. ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
  911. raise ValueError("Signature key type does not match")
  912. signature, sig_rest = _get_sshstr(sig_rest)
  913. _check_empty(sig_rest)
  914. return SSHCertificate(
  915. nonce,
  916. public_key,
  917. serial,
  918. cctype,
  919. key_id,
  920. valid_principals,
  921. valid_after,
  922. valid_before,
  923. critical_options,
  924. extensions,
  925. sig_type,
  926. sig_key,
  927. inner_sig_type,
  928. signature,
  929. tbs_cert_body,
  930. orig_key_type,
  931. cert_body,
  932. )
  933. else:
  934. _check_empty(rest)
  935. return public_key
  936. def load_ssh_public_identity(
  937. data: bytes,
  938. ) -> SSHCertificate | SSHPublicKeyTypes:
  939. return _load_ssh_public_identity(data)
  940. def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
  941. result: dict[bytes, bytes] = {}
  942. last_name = None
  943. while exts_opts:
  944. name, exts_opts = _get_sshstr(exts_opts)
  945. bname: bytes = bytes(name)
  946. if bname in result:
  947. raise ValueError("Duplicate name")
  948. if last_name is not None and bname < last_name:
  949. raise ValueError("Fields not lexically sorted")
  950. value, exts_opts = _get_sshstr(exts_opts)
  951. if len(value) > 0:
  952. value, extra = _get_sshstr(value)
  953. if len(extra) > 0:
  954. raise ValueError("Unexpected extra data after value")
  955. result[bname] = bytes(value)
  956. last_name = bname
  957. return result
  958. def load_ssh_public_key(
  959. data: bytes, backend: typing.Any = None
  960. ) -> SSHPublicKeyTypes:
  961. cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
  962. public_key: SSHPublicKeyTypes
  963. if isinstance(cert_or_key, SSHCertificate):
  964. public_key = cert_or_key.public_key()
  965. else:
  966. public_key = cert_or_key
  967. if isinstance(public_key, dsa.DSAPublicKey):
  968. warnings.warn(
  969. "SSH DSA keys are deprecated and will be removed in a future "
  970. "release.",
  971. utils.DeprecatedIn40,
  972. stacklevel=2,
  973. )
  974. return public_key
  975. def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
  976. """One-line public key format for OpenSSH"""
  977. if isinstance(public_key, dsa.DSAPublicKey):
  978. warnings.warn(
  979. "SSH DSA key support is deprecated and will be "
  980. "removed in a future release",
  981. utils.DeprecatedIn40,
  982. stacklevel=4,
  983. )
  984. key_type = _get_ssh_key_type(public_key)
  985. kformat = _lookup_kformat(key_type)
  986. f_pub = _FragList()
  987. f_pub.put_sshstr(key_type)
  988. kformat.encode_public(public_key, f_pub)
  989. pub = binascii.b2a_base64(f_pub.tobytes()).strip()
  990. return b"".join([key_type, b" ", pub])
  991. SSHCertPrivateKeyTypes = typing.Union[
  992. ec.EllipticCurvePrivateKey,
  993. rsa.RSAPrivateKey,
  994. ed25519.Ed25519PrivateKey,
  995. ]
  996. # This is an undocumented limit enforced in the openssh codebase for sshd and
  997. # ssh-keygen, but it is undefined in the ssh certificates spec.
  998. _SSHKEY_CERT_MAX_PRINCIPALS = 256
  999. class SSHCertificateBuilder:
  1000. def __init__(
  1001. self,
  1002. _public_key: SSHCertPublicKeyTypes | None = None,
  1003. _serial: int | None = None,
  1004. _type: SSHCertificateType | None = None,
  1005. _key_id: bytes | None = None,
  1006. _valid_principals: list[bytes] = [],
  1007. _valid_for_all_principals: bool = False,
  1008. _valid_before: int | None = None,
  1009. _valid_after: int | None = None,
  1010. _critical_options: list[tuple[bytes, bytes]] = [],
  1011. _extensions: list[tuple[bytes, bytes]] = [],
  1012. ):
  1013. self._public_key = _public_key
  1014. self._serial = _serial
  1015. self._type = _type
  1016. self._key_id = _key_id
  1017. self._valid_principals = _valid_principals
  1018. self._valid_for_all_principals = _valid_for_all_principals
  1019. self._valid_before = _valid_before
  1020. self._valid_after = _valid_after
  1021. self._critical_options = _critical_options
  1022. self._extensions = _extensions
  1023. def public_key(
  1024. self, public_key: SSHCertPublicKeyTypes
  1025. ) -> SSHCertificateBuilder:
  1026. if not isinstance(
  1027. public_key,
  1028. (
  1029. ec.EllipticCurvePublicKey,
  1030. rsa.RSAPublicKey,
  1031. ed25519.Ed25519PublicKey,
  1032. ),
  1033. ):
  1034. raise TypeError("Unsupported key type")
  1035. if self._public_key is not None:
  1036. raise ValueError("public_key already set")
  1037. return SSHCertificateBuilder(
  1038. _public_key=public_key,
  1039. _serial=self._serial,
  1040. _type=self._type,
  1041. _key_id=self._key_id,
  1042. _valid_principals=self._valid_principals,
  1043. _valid_for_all_principals=self._valid_for_all_principals,
  1044. _valid_before=self._valid_before,
  1045. _valid_after=self._valid_after,
  1046. _critical_options=self._critical_options,
  1047. _extensions=self._extensions,
  1048. )
  1049. def serial(self, serial: int) -> SSHCertificateBuilder:
  1050. if not isinstance(serial, int):
  1051. raise TypeError("serial must be an integer")
  1052. if not 0 <= serial < 2**64:
  1053. raise ValueError("serial must be between 0 and 2**64")
  1054. if self._serial is not None:
  1055. raise ValueError("serial already set")
  1056. return SSHCertificateBuilder(
  1057. _public_key=self._public_key,
  1058. _serial=serial,
  1059. _type=self._type,
  1060. _key_id=self._key_id,
  1061. _valid_principals=self._valid_principals,
  1062. _valid_for_all_principals=self._valid_for_all_principals,
  1063. _valid_before=self._valid_before,
  1064. _valid_after=self._valid_after,
  1065. _critical_options=self._critical_options,
  1066. _extensions=self._extensions,
  1067. )
  1068. def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
  1069. if not isinstance(type, SSHCertificateType):
  1070. raise TypeError("type must be an SSHCertificateType")
  1071. if self._type is not None:
  1072. raise ValueError("type already set")
  1073. return SSHCertificateBuilder(
  1074. _public_key=self._public_key,
  1075. _serial=self._serial,
  1076. _type=type,
  1077. _key_id=self._key_id,
  1078. _valid_principals=self._valid_principals,
  1079. _valid_for_all_principals=self._valid_for_all_principals,
  1080. _valid_before=self._valid_before,
  1081. _valid_after=self._valid_after,
  1082. _critical_options=self._critical_options,
  1083. _extensions=self._extensions,
  1084. )
  1085. def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
  1086. if not isinstance(key_id, bytes):
  1087. raise TypeError("key_id must be bytes")
  1088. if self._key_id is not None:
  1089. raise ValueError("key_id already set")
  1090. return SSHCertificateBuilder(
  1091. _public_key=self._public_key,
  1092. _serial=self._serial,
  1093. _type=self._type,
  1094. _key_id=key_id,
  1095. _valid_principals=self._valid_principals,
  1096. _valid_for_all_principals=self._valid_for_all_principals,
  1097. _valid_before=self._valid_before,
  1098. _valid_after=self._valid_after,
  1099. _critical_options=self._critical_options,
  1100. _extensions=self._extensions,
  1101. )
  1102. def valid_principals(
  1103. self, valid_principals: list[bytes]
  1104. ) -> SSHCertificateBuilder:
  1105. if self._valid_for_all_principals:
  1106. raise ValueError(
  1107. "Principals can't be set because the cert is valid "
  1108. "for all principals"
  1109. )
  1110. if (
  1111. not all(isinstance(x, bytes) for x in valid_principals)
  1112. or not valid_principals
  1113. ):
  1114. raise TypeError(
  1115. "principals must be a list of bytes and can't be empty"
  1116. )
  1117. if self._valid_principals:
  1118. raise ValueError("valid_principals already set")
  1119. if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
  1120. raise ValueError(
  1121. "Reached or exceeded the maximum number of valid_principals"
  1122. )
  1123. return SSHCertificateBuilder(
  1124. _public_key=self._public_key,
  1125. _serial=self._serial,
  1126. _type=self._type,
  1127. _key_id=self._key_id,
  1128. _valid_principals=valid_principals,
  1129. _valid_for_all_principals=self._valid_for_all_principals,
  1130. _valid_before=self._valid_before,
  1131. _valid_after=self._valid_after,
  1132. _critical_options=self._critical_options,
  1133. _extensions=self._extensions,
  1134. )
  1135. def valid_for_all_principals(self):
  1136. if self._valid_principals:
  1137. raise ValueError(
  1138. "valid_principals already set, can't set "
  1139. "valid_for_all_principals"
  1140. )
  1141. if self._valid_for_all_principals:
  1142. raise ValueError("valid_for_all_principals already set")
  1143. return SSHCertificateBuilder(
  1144. _public_key=self._public_key,
  1145. _serial=self._serial,
  1146. _type=self._type,
  1147. _key_id=self._key_id,
  1148. _valid_principals=self._valid_principals,
  1149. _valid_for_all_principals=True,
  1150. _valid_before=self._valid_before,
  1151. _valid_after=self._valid_after,
  1152. _critical_options=self._critical_options,
  1153. _extensions=self._extensions,
  1154. )
  1155. def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
  1156. if not isinstance(valid_before, (int, float)):
  1157. raise TypeError("valid_before must be an int or float")
  1158. valid_before = int(valid_before)
  1159. if valid_before < 0 or valid_before >= 2**64:
  1160. raise ValueError("valid_before must [0, 2**64)")
  1161. if self._valid_before is not None:
  1162. raise ValueError("valid_before already set")
  1163. return SSHCertificateBuilder(
  1164. _public_key=self._public_key,
  1165. _serial=self._serial,
  1166. _type=self._type,
  1167. _key_id=self._key_id,
  1168. _valid_principals=self._valid_principals,
  1169. _valid_for_all_principals=self._valid_for_all_principals,
  1170. _valid_before=valid_before,
  1171. _valid_after=self._valid_after,
  1172. _critical_options=self._critical_options,
  1173. _extensions=self._extensions,
  1174. )
  1175. def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
  1176. if not isinstance(valid_after, (int, float)):
  1177. raise TypeError("valid_after must be an int or float")
  1178. valid_after = int(valid_after)
  1179. if valid_after < 0 or valid_after >= 2**64:
  1180. raise ValueError("valid_after must [0, 2**64)")
  1181. if self._valid_after is not None:
  1182. raise ValueError("valid_after already set")
  1183. return SSHCertificateBuilder(
  1184. _public_key=self._public_key,
  1185. _serial=self._serial,
  1186. _type=self._type,
  1187. _key_id=self._key_id,
  1188. _valid_principals=self._valid_principals,
  1189. _valid_for_all_principals=self._valid_for_all_principals,
  1190. _valid_before=self._valid_before,
  1191. _valid_after=valid_after,
  1192. _critical_options=self._critical_options,
  1193. _extensions=self._extensions,
  1194. )
  1195. def add_critical_option(
  1196. self, name: bytes, value: bytes
  1197. ) -> SSHCertificateBuilder:
  1198. if not isinstance(name, bytes) or not isinstance(value, bytes):
  1199. raise TypeError("name and value must be bytes")
  1200. # This is O(n**2)
  1201. if name in [name for name, _ in self._critical_options]:
  1202. raise ValueError("Duplicate critical option name")
  1203. return SSHCertificateBuilder(
  1204. _public_key=self._public_key,
  1205. _serial=self._serial,
  1206. _type=self._type,
  1207. _key_id=self._key_id,
  1208. _valid_principals=self._valid_principals,
  1209. _valid_for_all_principals=self._valid_for_all_principals,
  1210. _valid_before=self._valid_before,
  1211. _valid_after=self._valid_after,
  1212. _critical_options=[*self._critical_options, (name, value)],
  1213. _extensions=self._extensions,
  1214. )
  1215. def add_extension(
  1216. self, name: bytes, value: bytes
  1217. ) -> SSHCertificateBuilder:
  1218. if not isinstance(name, bytes) or not isinstance(value, bytes):
  1219. raise TypeError("name and value must be bytes")
  1220. # This is O(n**2)
  1221. if name in [name for name, _ in self._extensions]:
  1222. raise ValueError("Duplicate extension name")
  1223. return SSHCertificateBuilder(
  1224. _public_key=self._public_key,
  1225. _serial=self._serial,
  1226. _type=self._type,
  1227. _key_id=self._key_id,
  1228. _valid_principals=self._valid_principals,
  1229. _valid_for_all_principals=self._valid_for_all_principals,
  1230. _valid_before=self._valid_before,
  1231. _valid_after=self._valid_after,
  1232. _critical_options=self._critical_options,
  1233. _extensions=[*self._extensions, (name, value)],
  1234. )
  1235. def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
  1236. if not isinstance(
  1237. private_key,
  1238. (
  1239. ec.EllipticCurvePrivateKey,
  1240. rsa.RSAPrivateKey,
  1241. ed25519.Ed25519PrivateKey,
  1242. ),
  1243. ):
  1244. raise TypeError("Unsupported private key type")
  1245. if self._public_key is None:
  1246. raise ValueError("public_key must be set")
  1247. # Not required
  1248. serial = 0 if self._serial is None else self._serial
  1249. if self._type is None:
  1250. raise ValueError("type must be set")
  1251. # Not required
  1252. key_id = b"" if self._key_id is None else self._key_id
  1253. # A zero length list is valid, but means the certificate
  1254. # is valid for any principal of the specified type. We require
  1255. # the user to explicitly set valid_for_all_principals to get
  1256. # that behavior.
  1257. if not self._valid_principals and not self._valid_for_all_principals:
  1258. raise ValueError(
  1259. "valid_principals must be set if valid_for_all_principals "
  1260. "is False"
  1261. )
  1262. if self._valid_before is None:
  1263. raise ValueError("valid_before must be set")
  1264. if self._valid_after is None:
  1265. raise ValueError("valid_after must be set")
  1266. if self._valid_after > self._valid_before:
  1267. raise ValueError("valid_after must be earlier than valid_before")
  1268. # lexically sort our byte strings
  1269. self._critical_options.sort(key=lambda x: x[0])
  1270. self._extensions.sort(key=lambda x: x[0])
  1271. key_type = _get_ssh_key_type(self._public_key)
  1272. cert_prefix = key_type + _CERT_SUFFIX
  1273. # Marshal the bytes to be signed
  1274. nonce = os.urandom(32)
  1275. kformat = _lookup_kformat(key_type)
  1276. f = _FragList()
  1277. f.put_sshstr(cert_prefix)
  1278. f.put_sshstr(nonce)
  1279. kformat.encode_public(self._public_key, f)
  1280. f.put_u64(serial)
  1281. f.put_u32(self._type.value)
  1282. f.put_sshstr(key_id)
  1283. fprincipals = _FragList()
  1284. for p in self._valid_principals:
  1285. fprincipals.put_sshstr(p)
  1286. f.put_sshstr(fprincipals.tobytes())
  1287. f.put_u64(self._valid_after)
  1288. f.put_u64(self._valid_before)
  1289. fcrit = _FragList()
  1290. for name, value in self._critical_options:
  1291. fcrit.put_sshstr(name)
  1292. if len(value) > 0:
  1293. foptval = _FragList()
  1294. foptval.put_sshstr(value)
  1295. fcrit.put_sshstr(foptval.tobytes())
  1296. else:
  1297. fcrit.put_sshstr(value)
  1298. f.put_sshstr(fcrit.tobytes())
  1299. fext = _FragList()
  1300. for name, value in self._extensions:
  1301. fext.put_sshstr(name)
  1302. if len(value) > 0:
  1303. fextval = _FragList()
  1304. fextval.put_sshstr(value)
  1305. fext.put_sshstr(fextval.tobytes())
  1306. else:
  1307. fext.put_sshstr(value)
  1308. f.put_sshstr(fext.tobytes())
  1309. f.put_sshstr(b"") # RESERVED FIELD
  1310. # encode CA public key
  1311. ca_type = _get_ssh_key_type(private_key)
  1312. caformat = _lookup_kformat(ca_type)
  1313. caf = _FragList()
  1314. caf.put_sshstr(ca_type)
  1315. caformat.encode_public(private_key.public_key(), caf)
  1316. f.put_sshstr(caf.tobytes())
  1317. # Sigs according to the rules defined for the CA's public key
  1318. # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
  1319. # and RFC8032 for Ed25519).
  1320. if isinstance(private_key, ed25519.Ed25519PrivateKey):
  1321. signature = private_key.sign(f.tobytes())
  1322. fsig = _FragList()
  1323. fsig.put_sshstr(ca_type)
  1324. fsig.put_sshstr(signature)
  1325. f.put_sshstr(fsig.tobytes())
  1326. elif isinstance(private_key, ec.EllipticCurvePrivateKey):
  1327. hash_alg = _get_ec_hash_alg(private_key.curve)
  1328. signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
  1329. r, s = asym_utils.decode_dss_signature(signature)
  1330. fsig = _FragList()
  1331. fsig.put_sshstr(ca_type)
  1332. fsigblob = _FragList()
  1333. fsigblob.put_mpint(r)
  1334. fsigblob.put_mpint(s)
  1335. fsig.put_sshstr(fsigblob.tobytes())
  1336. f.put_sshstr(fsig.tobytes())
  1337. else:
  1338. assert isinstance(private_key, rsa.RSAPrivateKey)
  1339. # Just like Golang, we're going to use SHA512 for RSA
  1340. # https://cs.opensource.google/go/x/crypto/+/refs/tags/
  1341. # v0.4.0:ssh/certs.go;l=445
  1342. # RFC 8332 defines SHA256 and 512 as options
  1343. fsig = _FragList()
  1344. fsig.put_sshstr(_SSH_RSA_SHA512)
  1345. signature = private_key.sign(
  1346. f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
  1347. )
  1348. fsig.put_sshstr(signature)
  1349. f.put_sshstr(fsig.tobytes())
  1350. cert_data = binascii.b2a_base64(f.tobytes()).strip()
  1351. # load_ssh_public_identity returns a union, but this is
  1352. # guaranteed to be an SSHCertificate, so we cast to make
  1353. # mypy happy.
  1354. return typing.cast(
  1355. SSHCertificate,
  1356. load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
  1357. )