123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569 |
- # This file is dual licensed under the terms of the Apache License, Version
- # 2.0, and the BSD License. See the LICENSE file in the root of this repository
- # for complete details.
- from __future__ import annotations
- import binascii
- import enum
- import os
- import re
- import typing
- import warnings
- from base64 import encodebytes as _base64_encode
- from dataclasses import dataclass
- from cryptography import utils
- from cryptography.exceptions import UnsupportedAlgorithm
- from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives.asymmetric import (
- dsa,
- ec,
- ed25519,
- padding,
- rsa,
- )
- from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
- from cryptography.hazmat.primitives.ciphers import (
- AEADDecryptionContext,
- Cipher,
- algorithms,
- modes,
- )
- from cryptography.hazmat.primitives.serialization import (
- Encoding,
- KeySerializationEncryption,
- NoEncryption,
- PrivateFormat,
- PublicFormat,
- _KeySerializationEncryption,
- )
- try:
- from bcrypt import kdf as _bcrypt_kdf
- _bcrypt_supported = True
- except ImportError:
- _bcrypt_supported = False
- def _bcrypt_kdf(
- password: bytes,
- salt: bytes,
- desired_key_bytes: int,
- rounds: int,
- ignore_few_rounds: bool = False,
- ) -> bytes:
- raise UnsupportedAlgorithm("Need bcrypt module")
- _SSH_ED25519 = b"ssh-ed25519"
- _SSH_RSA = b"ssh-rsa"
- _SSH_DSA = b"ssh-dss"
- _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
- _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
- _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
- _CERT_SUFFIX = b"-cert-v01@openssh.com"
- # U2F application string suffixed pubkey
- _SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com"
- _SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com"
- # These are not key types, only algorithms, so they cannot appear
- # as a public key type
- _SSH_RSA_SHA256 = b"rsa-sha2-256"
- _SSH_RSA_SHA512 = b"rsa-sha2-512"
- _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
- _SK_MAGIC = b"openssh-key-v1\0"
- _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
- _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
- _BCRYPT = b"bcrypt"
- _NONE = b"none"
- _DEFAULT_CIPHER = b"aes256-ctr"
- _DEFAULT_ROUNDS = 16
- # re is only way to work on bytes-like data
- _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
- # padding for max blocksize
- _PADDING = memoryview(bytearray(range(1, 1 + 16)))
- @dataclass
- class _SSHCipher:
- alg: type[algorithms.AES]
- key_len: int
- mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
- block_len: int
- iv_len: int
- tag_len: int | None
- is_aead: bool
- # ciphers that are actually used in key wrapping
- _SSH_CIPHERS: dict[bytes, _SSHCipher] = {
- b"aes256-ctr": _SSHCipher(
- alg=algorithms.AES,
- key_len=32,
- mode=modes.CTR,
- block_len=16,
- iv_len=16,
- tag_len=None,
- is_aead=False,
- ),
- b"aes256-cbc": _SSHCipher(
- alg=algorithms.AES,
- key_len=32,
- mode=modes.CBC,
- block_len=16,
- iv_len=16,
- tag_len=None,
- is_aead=False,
- ),
- b"aes256-gcm@openssh.com": _SSHCipher(
- alg=algorithms.AES,
- key_len=32,
- mode=modes.GCM,
- block_len=16,
- iv_len=12,
- tag_len=16,
- is_aead=True,
- ),
- }
- # map local curve name to key type
- _ECDSA_KEY_TYPE = {
- "secp256r1": _ECDSA_NISTP256,
- "secp384r1": _ECDSA_NISTP384,
- "secp521r1": _ECDSA_NISTP521,
- }
- def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
- if isinstance(key, ec.EllipticCurvePrivateKey):
- key_type = _ecdsa_key_type(key.public_key())
- elif isinstance(key, ec.EllipticCurvePublicKey):
- key_type = _ecdsa_key_type(key)
- elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
- key_type = _SSH_RSA
- elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
- key_type = _SSH_DSA
- elif isinstance(
- key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
- ):
- key_type = _SSH_ED25519
- else:
- raise ValueError("Unsupported key type")
- return key_type
- def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
- """Return SSH key_type and curve_name for private key."""
- curve = public_key.curve
- if curve.name not in _ECDSA_KEY_TYPE:
- raise ValueError(
- f"Unsupported curve for ssh private key: {curve.name!r}"
- )
- return _ECDSA_KEY_TYPE[curve.name]
- def _ssh_pem_encode(
- data: bytes,
- prefix: bytes = _SK_START + b"\n",
- suffix: bytes = _SK_END + b"\n",
- ) -> bytes:
- return b"".join([prefix, _base64_encode(data), suffix])
- def _check_block_size(data: bytes, block_len: int) -> None:
- """Require data to be full blocks"""
- if not data or len(data) % block_len != 0:
- raise ValueError("Corrupt data: missing padding")
- def _check_empty(data: bytes) -> None:
- """All data should have been parsed."""
- if data:
- raise ValueError("Corrupt data: unparsed data")
- def _init_cipher(
- ciphername: bytes,
- password: bytes | None,
- salt: bytes,
- rounds: int,
- ) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
- """Generate key + iv and return cipher."""
- if not password:
- raise ValueError("Key is password-protected.")
- ciph = _SSH_CIPHERS[ciphername]
- seed = _bcrypt_kdf(
- password, salt, ciph.key_len + ciph.iv_len, rounds, True
- )
- return Cipher(
- ciph.alg(seed[: ciph.key_len]),
- ciph.mode(seed[ciph.key_len :]),
- )
- def _get_u32(data: memoryview) -> tuple[int, memoryview]:
- """Uint32"""
- if len(data) < 4:
- raise ValueError("Invalid data")
- return int.from_bytes(data[:4], byteorder="big"), data[4:]
- def _get_u64(data: memoryview) -> tuple[int, memoryview]:
- """Uint64"""
- if len(data) < 8:
- raise ValueError("Invalid data")
- return int.from_bytes(data[:8], byteorder="big"), data[8:]
- def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
- """Bytes with u32 length prefix"""
- n, data = _get_u32(data)
- if n > len(data):
- raise ValueError("Invalid data")
- return data[:n], data[n:]
- def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
- """Big integer."""
- val, data = _get_sshstr(data)
- if val and val[0] > 0x7F:
- raise ValueError("Invalid data")
- return int.from_bytes(val, "big"), data
- def _to_mpint(val: int) -> bytes:
- """Storage format for signed bigint."""
- if val < 0:
- raise ValueError("negative mpint not allowed")
- if not val:
- return b""
- nbytes = (val.bit_length() + 8) // 8
- return utils.int_to_bytes(val, nbytes)
- class _FragList:
- """Build recursive structure without data copy."""
- flist: list[bytes]
- def __init__(self, init: list[bytes] | None = None) -> None:
- self.flist = []
- if init:
- self.flist.extend(init)
- def put_raw(self, val: bytes) -> None:
- """Add plain bytes"""
- self.flist.append(val)
- def put_u32(self, val: int) -> None:
- """Big-endian uint32"""
- self.flist.append(val.to_bytes(length=4, byteorder="big"))
- def put_u64(self, val: int) -> None:
- """Big-endian uint64"""
- self.flist.append(val.to_bytes(length=8, byteorder="big"))
- def put_sshstr(self, val: bytes | _FragList) -> None:
- """Bytes prefixed with u32 length"""
- if isinstance(val, (bytes, memoryview, bytearray)):
- self.put_u32(len(val))
- self.flist.append(val)
- else:
- self.put_u32(val.size())
- self.flist.extend(val.flist)
- def put_mpint(self, val: int) -> None:
- """Big-endian bigint prefixed with u32 length"""
- self.put_sshstr(_to_mpint(val))
- def size(self) -> int:
- """Current number of bytes"""
- return sum(map(len, self.flist))
- def render(self, dstbuf: memoryview, pos: int = 0) -> int:
- """Write into bytearray"""
- for frag in self.flist:
- flen = len(frag)
- start, pos = pos, pos + flen
- dstbuf[start:pos] = frag
- return pos
- def tobytes(self) -> bytes:
- """Return as bytes"""
- buf = memoryview(bytearray(self.size()))
- self.render(buf)
- return buf.tobytes()
- class _SSHFormatRSA:
- """Format for RSA keys.
- Public:
- mpint e, n
- Private:
- mpint n, e, d, iqmp, p, q
- """
- def get_public(
- self, data: memoryview
- ) -> tuple[tuple[int, int], memoryview]:
- """RSA public fields"""
- e, data = _get_mpint(data)
- n, data = _get_mpint(data)
- return (e, n), data
- def load_public(
- self, data: memoryview
- ) -> tuple[rsa.RSAPublicKey, memoryview]:
- """Make RSA public key from data."""
- (e, n), data = self.get_public(data)
- public_numbers = rsa.RSAPublicNumbers(e, n)
- public_key = public_numbers.public_key()
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> tuple[rsa.RSAPrivateKey, memoryview]:
- """Make RSA private key from data."""
- n, data = _get_mpint(data)
- e, data = _get_mpint(data)
- d, data = _get_mpint(data)
- iqmp, data = _get_mpint(data)
- p, data = _get_mpint(data)
- q, data = _get_mpint(data)
- if (e, n) != pubfields:
- raise ValueError("Corrupt data: rsa field mismatch")
- dmp1 = rsa.rsa_crt_dmp1(d, p)
- dmq1 = rsa.rsa_crt_dmq1(d, q)
- public_numbers = rsa.RSAPublicNumbers(e, n)
- private_numbers = rsa.RSAPrivateNumbers(
- p, q, d, dmp1, dmq1, iqmp, public_numbers
- )
- private_key = private_numbers.private_key()
- return private_key, data
- def encode_public(
- self, public_key: rsa.RSAPublicKey, f_pub: _FragList
- ) -> None:
- """Write RSA public key"""
- pubn = public_key.public_numbers()
- f_pub.put_mpint(pubn.e)
- f_pub.put_mpint(pubn.n)
- def encode_private(
- self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
- ) -> None:
- """Write RSA private key"""
- private_numbers = private_key.private_numbers()
- public_numbers = private_numbers.public_numbers
- f_priv.put_mpint(public_numbers.n)
- f_priv.put_mpint(public_numbers.e)
- f_priv.put_mpint(private_numbers.d)
- f_priv.put_mpint(private_numbers.iqmp)
- f_priv.put_mpint(private_numbers.p)
- f_priv.put_mpint(private_numbers.q)
- class _SSHFormatDSA:
- """Format for DSA keys.
- Public:
- mpint p, q, g, y
- Private:
- mpint p, q, g, y, x
- """
- def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
- """DSA public fields"""
- p, data = _get_mpint(data)
- q, data = _get_mpint(data)
- g, data = _get_mpint(data)
- y, data = _get_mpint(data)
- return (p, q, g, y), data
- def load_public(
- self, data: memoryview
- ) -> tuple[dsa.DSAPublicKey, memoryview]:
- """Make DSA public key from data."""
- (p, q, g, y), data = self.get_public(data)
- parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
- public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
- self._validate(public_numbers)
- public_key = public_numbers.public_key()
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> tuple[dsa.DSAPrivateKey, memoryview]:
- """Make DSA private key from data."""
- (p, q, g, y), data = self.get_public(data)
- x, data = _get_mpint(data)
- if (p, q, g, y) != pubfields:
- raise ValueError("Corrupt data: dsa field mismatch")
- parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
- public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
- self._validate(public_numbers)
- private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
- private_key = private_numbers.private_key()
- return private_key, data
- def encode_public(
- self, public_key: dsa.DSAPublicKey, f_pub: _FragList
- ) -> None:
- """Write DSA public key"""
- public_numbers = public_key.public_numbers()
- parameter_numbers = public_numbers.parameter_numbers
- self._validate(public_numbers)
- f_pub.put_mpint(parameter_numbers.p)
- f_pub.put_mpint(parameter_numbers.q)
- f_pub.put_mpint(parameter_numbers.g)
- f_pub.put_mpint(public_numbers.y)
- def encode_private(
- self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
- ) -> None:
- """Write DSA private key"""
- self.encode_public(private_key.public_key(), f_priv)
- f_priv.put_mpint(private_key.private_numbers().x)
- def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
- parameter_numbers = public_numbers.parameter_numbers
- if parameter_numbers.p.bit_length() != 1024:
- raise ValueError("SSH supports only 1024 bit DSA keys")
- class _SSHFormatECDSA:
- """Format for ECDSA keys.
- Public:
- str curve
- bytes point
- Private:
- str curve
- bytes point
- mpint secret
- """
- def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
- self.ssh_curve_name = ssh_curve_name
- self.curve = curve
- def get_public(
- self, data: memoryview
- ) -> tuple[tuple[memoryview, memoryview], memoryview]:
- """ECDSA public fields"""
- curve, data = _get_sshstr(data)
- point, data = _get_sshstr(data)
- if curve != self.ssh_curve_name:
- raise ValueError("Curve name mismatch")
- if point[0] != 4:
- raise NotImplementedError("Need uncompressed point")
- return (curve, point), data
- def load_public(
- self, data: memoryview
- ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
- """Make ECDSA public key from data."""
- (_, point), data = self.get_public(data)
- public_key = ec.EllipticCurvePublicKey.from_encoded_point(
- self.curve, point.tobytes()
- )
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
- """Make ECDSA private key from data."""
- (curve_name, point), data = self.get_public(data)
- secret, data = _get_mpint(data)
- if (curve_name, point) != pubfields:
- raise ValueError("Corrupt data: ecdsa field mismatch")
- private_key = ec.derive_private_key(secret, self.curve)
- return private_key, data
- def encode_public(
- self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
- ) -> None:
- """Write ECDSA public key"""
- point = public_key.public_bytes(
- Encoding.X962, PublicFormat.UncompressedPoint
- )
- f_pub.put_sshstr(self.ssh_curve_name)
- f_pub.put_sshstr(point)
- def encode_private(
- self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
- ) -> None:
- """Write ECDSA private key"""
- public_key = private_key.public_key()
- private_numbers = private_key.private_numbers()
- self.encode_public(public_key, f_priv)
- f_priv.put_mpint(private_numbers.private_value)
- class _SSHFormatEd25519:
- """Format for Ed25519 keys.
- Public:
- bytes point
- Private:
- bytes point
- bytes secret_and_point
- """
- def get_public(
- self, data: memoryview
- ) -> tuple[tuple[memoryview], memoryview]:
- """Ed25519 public fields"""
- point, data = _get_sshstr(data)
- return (point,), data
- def load_public(
- self, data: memoryview
- ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
- """Make Ed25519 public key from data."""
- (point,), data = self.get_public(data)
- public_key = ed25519.Ed25519PublicKey.from_public_bytes(
- point.tobytes()
- )
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
- """Make Ed25519 private key from data."""
- (point,), data = self.get_public(data)
- keypair, data = _get_sshstr(data)
- secret = keypair[:32]
- point2 = keypair[32:]
- if point != point2 or (point,) != pubfields:
- raise ValueError("Corrupt data: ed25519 field mismatch")
- private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
- return private_key, data
- def encode_public(
- self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
- ) -> None:
- """Write Ed25519 public key"""
- raw_public_key = public_key.public_bytes(
- Encoding.Raw, PublicFormat.Raw
- )
- f_pub.put_sshstr(raw_public_key)
- def encode_private(
- self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
- ) -> None:
- """Write Ed25519 private key"""
- public_key = private_key.public_key()
- raw_private_key = private_key.private_bytes(
- Encoding.Raw, PrivateFormat.Raw, NoEncryption()
- )
- raw_public_key = public_key.public_bytes(
- Encoding.Raw, PublicFormat.Raw
- )
- f_keypair = _FragList([raw_private_key, raw_public_key])
- self.encode_public(public_key, f_priv)
- f_priv.put_sshstr(f_keypair)
- def load_application(data) -> tuple[memoryview, memoryview]:
- """
- U2F application strings
- """
- application, data = _get_sshstr(data)
- if not application.tobytes().startswith(b"ssh:"):
- raise ValueError(
- "U2F application string does not start with b'ssh:' "
- f"({application})"
- )
- return application, data
- class _SSHFormatSKEd25519:
- """
- The format of a sk-ssh-ed25519@openssh.com public key is:
- string "sk-ssh-ed25519@openssh.com"
- string public key
- string application (user-specified, but typically "ssh:")
- """
- def load_public(
- self, data: memoryview
- ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
- """Make Ed25519 public key from data."""
- public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data)
- _, data = load_application(data)
- return public_key, data
- class _SSHFormatSKECDSA:
- """
- The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is:
- string "sk-ecdsa-sha2-nistp256@openssh.com"
- string curve name
- ec_point Q
- string application (user-specified, but typically "ssh:")
- """
- def load_public(
- self, data: memoryview
- ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
- """Make ECDSA public key from data."""
- public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data)
- _, data = load_application(data)
- return public_key, data
- _KEY_FORMATS = {
- _SSH_RSA: _SSHFormatRSA(),
- _SSH_DSA: _SSHFormatDSA(),
- _SSH_ED25519: _SSHFormatEd25519(),
- _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
- _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
- _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
- _SK_SSH_ED25519: _SSHFormatSKEd25519(),
- _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(),
- }
- def _lookup_kformat(key_type: bytes):
- """Return valid format or throw error"""
- if not isinstance(key_type, bytes):
- key_type = memoryview(key_type).tobytes()
- if key_type in _KEY_FORMATS:
- return _KEY_FORMATS[key_type]
- raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
- SSHPrivateKeyTypes = typing.Union[
- ec.EllipticCurvePrivateKey,
- rsa.RSAPrivateKey,
- dsa.DSAPrivateKey,
- ed25519.Ed25519PrivateKey,
- ]
- def load_ssh_private_key(
- data: bytes,
- password: bytes | None,
- backend: typing.Any = None,
- ) -> SSHPrivateKeyTypes:
- """Load private key from OpenSSH custom encoding."""
- utils._check_byteslike("data", data)
- if password is not None:
- utils._check_bytes("password", password)
- m = _PEM_RC.search(data)
- if not m:
- raise ValueError("Not OpenSSH private key format")
- p1 = m.start(1)
- p2 = m.end(1)
- data = binascii.a2b_base64(memoryview(data)[p1:p2])
- if not data.startswith(_SK_MAGIC):
- raise ValueError("Not OpenSSH private key format")
- data = memoryview(data)[len(_SK_MAGIC) :]
- # parse header
- ciphername, data = _get_sshstr(data)
- kdfname, data = _get_sshstr(data)
- kdfoptions, data = _get_sshstr(data)
- nkeys, data = _get_u32(data)
- if nkeys != 1:
- raise ValueError("Only one key supported")
- # load public key data
- pubdata, data = _get_sshstr(data)
- pub_key_type, pubdata = _get_sshstr(pubdata)
- kformat = _lookup_kformat(pub_key_type)
- pubfields, pubdata = kformat.get_public(pubdata)
- _check_empty(pubdata)
- if (ciphername, kdfname) != (_NONE, _NONE):
- ciphername_bytes = ciphername.tobytes()
- if ciphername_bytes not in _SSH_CIPHERS:
- raise UnsupportedAlgorithm(
- f"Unsupported cipher: {ciphername_bytes!r}"
- )
- if kdfname != _BCRYPT:
- raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
- blklen = _SSH_CIPHERS[ciphername_bytes].block_len
- tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
- # load secret data
- edata, data = _get_sshstr(data)
- # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
- # information about how OpenSSH handles AEAD tags
- if _SSH_CIPHERS[ciphername_bytes].is_aead:
- tag = bytes(data)
- if len(tag) != tag_len:
- raise ValueError("Corrupt data: invalid tag length for cipher")
- else:
- _check_empty(data)
- _check_block_size(edata, blklen)
- salt, kbuf = _get_sshstr(kdfoptions)
- rounds, kbuf = _get_u32(kbuf)
- _check_empty(kbuf)
- ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
- dec = ciph.decryptor()
- edata = memoryview(dec.update(edata))
- if _SSH_CIPHERS[ciphername_bytes].is_aead:
- assert isinstance(dec, AEADDecryptionContext)
- _check_empty(dec.finalize_with_tag(tag))
- else:
- # _check_block_size requires data to be a full block so there
- # should be no output from finalize
- _check_empty(dec.finalize())
- else:
- # load secret data
- edata, data = _get_sshstr(data)
- _check_empty(data)
- blklen = 8
- _check_block_size(edata, blklen)
- ck1, edata = _get_u32(edata)
- ck2, edata = _get_u32(edata)
- if ck1 != ck2:
- raise ValueError("Corrupt data: broken checksum")
- # load per-key struct
- key_type, edata = _get_sshstr(edata)
- if key_type != pub_key_type:
- raise ValueError("Corrupt data: key type mismatch")
- private_key, edata = kformat.load_private(edata, pubfields)
- # We don't use the comment
- _, edata = _get_sshstr(edata)
- # yes, SSH does padding check *after* all other parsing is done.
- # need to follow as it writes zero-byte padding too.
- if edata != _PADDING[: len(edata)]:
- raise ValueError("Corrupt data: invalid padding")
- if isinstance(private_key, dsa.DSAPrivateKey):
- warnings.warn(
- "SSH DSA keys are deprecated and will be removed in a future "
- "release.",
- utils.DeprecatedIn40,
- stacklevel=2,
- )
- return private_key
- def _serialize_ssh_private_key(
- private_key: SSHPrivateKeyTypes,
- password: bytes,
- encryption_algorithm: KeySerializationEncryption,
- ) -> bytes:
- """Serialize private key with OpenSSH custom encoding."""
- utils._check_bytes("password", password)
- if isinstance(private_key, dsa.DSAPrivateKey):
- warnings.warn(
- "SSH DSA key support is deprecated and will be "
- "removed in a future release",
- utils.DeprecatedIn40,
- stacklevel=4,
- )
- key_type = _get_ssh_key_type(private_key)
- kformat = _lookup_kformat(key_type)
- # setup parameters
- f_kdfoptions = _FragList()
- if password:
- ciphername = _DEFAULT_CIPHER
- blklen = _SSH_CIPHERS[ciphername].block_len
- kdfname = _BCRYPT
- rounds = _DEFAULT_ROUNDS
- if (
- isinstance(encryption_algorithm, _KeySerializationEncryption)
- and encryption_algorithm._kdf_rounds is not None
- ):
- rounds = encryption_algorithm._kdf_rounds
- salt = os.urandom(16)
- f_kdfoptions.put_sshstr(salt)
- f_kdfoptions.put_u32(rounds)
- ciph = _init_cipher(ciphername, password, salt, rounds)
- else:
- ciphername = kdfname = _NONE
- blklen = 8
- ciph = None
- nkeys = 1
- checkval = os.urandom(4)
- comment = b""
- # encode public and private parts together
- f_public_key = _FragList()
- f_public_key.put_sshstr(key_type)
- kformat.encode_public(private_key.public_key(), f_public_key)
- f_secrets = _FragList([checkval, checkval])
- f_secrets.put_sshstr(key_type)
- kformat.encode_private(private_key, f_secrets)
- f_secrets.put_sshstr(comment)
- f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
- # top-level structure
- f_main = _FragList()
- f_main.put_raw(_SK_MAGIC)
- f_main.put_sshstr(ciphername)
- f_main.put_sshstr(kdfname)
- f_main.put_sshstr(f_kdfoptions)
- f_main.put_u32(nkeys)
- f_main.put_sshstr(f_public_key)
- f_main.put_sshstr(f_secrets)
- # copy result info bytearray
- slen = f_secrets.size()
- mlen = f_main.size()
- buf = memoryview(bytearray(mlen + blklen))
- f_main.render(buf)
- ofs = mlen - slen
- # encrypt in-place
- if ciph is not None:
- ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
- return _ssh_pem_encode(buf[:mlen])
- SSHPublicKeyTypes = typing.Union[
- ec.EllipticCurvePublicKey,
- rsa.RSAPublicKey,
- dsa.DSAPublicKey,
- ed25519.Ed25519PublicKey,
- ]
- SSHCertPublicKeyTypes = typing.Union[
- ec.EllipticCurvePublicKey,
- rsa.RSAPublicKey,
- ed25519.Ed25519PublicKey,
- ]
- class SSHCertificateType(enum.Enum):
- USER = 1
- HOST = 2
- class SSHCertificate:
- def __init__(
- self,
- _nonce: memoryview,
- _public_key: SSHPublicKeyTypes,
- _serial: int,
- _cctype: int,
- _key_id: memoryview,
- _valid_principals: list[bytes],
- _valid_after: int,
- _valid_before: int,
- _critical_options: dict[bytes, bytes],
- _extensions: dict[bytes, bytes],
- _sig_type: memoryview,
- _sig_key: memoryview,
- _inner_sig_type: memoryview,
- _signature: memoryview,
- _tbs_cert_body: memoryview,
- _cert_key_type: bytes,
- _cert_body: memoryview,
- ):
- self._nonce = _nonce
- self._public_key = _public_key
- self._serial = _serial
- try:
- self._type = SSHCertificateType(_cctype)
- except ValueError:
- raise ValueError("Invalid certificate type")
- self._key_id = _key_id
- self._valid_principals = _valid_principals
- self._valid_after = _valid_after
- self._valid_before = _valid_before
- self._critical_options = _critical_options
- self._extensions = _extensions
- self._sig_type = _sig_type
- self._sig_key = _sig_key
- self._inner_sig_type = _inner_sig_type
- self._signature = _signature
- self._cert_key_type = _cert_key_type
- self._cert_body = _cert_body
- self._tbs_cert_body = _tbs_cert_body
- @property
- def nonce(self) -> bytes:
- return bytes(self._nonce)
- def public_key(self) -> SSHCertPublicKeyTypes:
- # make mypy happy until we remove DSA support entirely and
- # the underlying union won't have a disallowed type
- return typing.cast(SSHCertPublicKeyTypes, self._public_key)
- @property
- def serial(self) -> int:
- return self._serial
- @property
- def type(self) -> SSHCertificateType:
- return self._type
- @property
- def key_id(self) -> bytes:
- return bytes(self._key_id)
- @property
- def valid_principals(self) -> list[bytes]:
- return self._valid_principals
- @property
- def valid_before(self) -> int:
- return self._valid_before
- @property
- def valid_after(self) -> int:
- return self._valid_after
- @property
- def critical_options(self) -> dict[bytes, bytes]:
- return self._critical_options
- @property
- def extensions(self) -> dict[bytes, bytes]:
- return self._extensions
- def signature_key(self) -> SSHCertPublicKeyTypes:
- sigformat = _lookup_kformat(self._sig_type)
- signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
- _check_empty(sigkey_rest)
- return signature_key
- def public_bytes(self) -> bytes:
- return (
- bytes(self._cert_key_type)
- + b" "
- + binascii.b2a_base64(bytes(self._cert_body), newline=False)
- )
- def verify_cert_signature(self) -> None:
- signature_key = self.signature_key()
- if isinstance(signature_key, ed25519.Ed25519PublicKey):
- signature_key.verify(
- bytes(self._signature), bytes(self._tbs_cert_body)
- )
- elif isinstance(signature_key, ec.EllipticCurvePublicKey):
- # The signature is encoded as a pair of big-endian integers
- r, data = _get_mpint(self._signature)
- s, data = _get_mpint(data)
- _check_empty(data)
- computed_sig = asym_utils.encode_dss_signature(r, s)
- hash_alg = _get_ec_hash_alg(signature_key.curve)
- signature_key.verify(
- computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
- )
- else:
- assert isinstance(signature_key, rsa.RSAPublicKey)
- if self._inner_sig_type == _SSH_RSA:
- hash_alg = hashes.SHA1()
- elif self._inner_sig_type == _SSH_RSA_SHA256:
- hash_alg = hashes.SHA256()
- else:
- assert self._inner_sig_type == _SSH_RSA_SHA512
- hash_alg = hashes.SHA512()
- signature_key.verify(
- bytes(self._signature),
- bytes(self._tbs_cert_body),
- padding.PKCS1v15(),
- hash_alg,
- )
- def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
- if isinstance(curve, ec.SECP256R1):
- return hashes.SHA256()
- elif isinstance(curve, ec.SECP384R1):
- return hashes.SHA384()
- else:
- assert isinstance(curve, ec.SECP521R1)
- return hashes.SHA512()
- def _load_ssh_public_identity(
- data: bytes,
- _legacy_dsa_allowed=False,
- ) -> SSHCertificate | SSHPublicKeyTypes:
- utils._check_byteslike("data", data)
- m = _SSH_PUBKEY_RC.match(data)
- if not m:
- raise ValueError("Invalid line format")
- key_type = orig_key_type = m.group(1)
- key_body = m.group(2)
- with_cert = False
- if key_type.endswith(_CERT_SUFFIX):
- with_cert = True
- key_type = key_type[: -len(_CERT_SUFFIX)]
- if key_type == _SSH_DSA and not _legacy_dsa_allowed:
- raise UnsupportedAlgorithm(
- "DSA keys aren't supported in SSH certificates"
- )
- kformat = _lookup_kformat(key_type)
- try:
- rest = memoryview(binascii.a2b_base64(key_body))
- except (TypeError, binascii.Error):
- raise ValueError("Invalid format")
- if with_cert:
- cert_body = rest
- inner_key_type, rest = _get_sshstr(rest)
- if inner_key_type != orig_key_type:
- raise ValueError("Invalid key format")
- if with_cert:
- nonce, rest = _get_sshstr(rest)
- public_key, rest = kformat.load_public(rest)
- if with_cert:
- serial, rest = _get_u64(rest)
- cctype, rest = _get_u32(rest)
- key_id, rest = _get_sshstr(rest)
- principals, rest = _get_sshstr(rest)
- valid_principals = []
- while principals:
- principal, principals = _get_sshstr(principals)
- valid_principals.append(bytes(principal))
- valid_after, rest = _get_u64(rest)
- valid_before, rest = _get_u64(rest)
- crit_options, rest = _get_sshstr(rest)
- critical_options = _parse_exts_opts(crit_options)
- exts, rest = _get_sshstr(rest)
- extensions = _parse_exts_opts(exts)
- # Get the reserved field, which is unused.
- _, rest = _get_sshstr(rest)
- sig_key_raw, rest = _get_sshstr(rest)
- sig_type, sig_key = _get_sshstr(sig_key_raw)
- if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
- raise UnsupportedAlgorithm(
- "DSA signatures aren't supported in SSH certificates"
- )
- # Get the entire cert body and subtract the signature
- tbs_cert_body = cert_body[: -len(rest)]
- signature_raw, rest = _get_sshstr(rest)
- _check_empty(rest)
- inner_sig_type, sig_rest = _get_sshstr(signature_raw)
- # RSA certs can have multiple algorithm types
- if (
- sig_type == _SSH_RSA
- and inner_sig_type
- not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
- ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
- raise ValueError("Signature key type does not match")
- signature, sig_rest = _get_sshstr(sig_rest)
- _check_empty(sig_rest)
- return SSHCertificate(
- nonce,
- public_key,
- serial,
- cctype,
- key_id,
- valid_principals,
- valid_after,
- valid_before,
- critical_options,
- extensions,
- sig_type,
- sig_key,
- inner_sig_type,
- signature,
- tbs_cert_body,
- orig_key_type,
- cert_body,
- )
- else:
- _check_empty(rest)
- return public_key
- def load_ssh_public_identity(
- data: bytes,
- ) -> SSHCertificate | SSHPublicKeyTypes:
- return _load_ssh_public_identity(data)
- def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
- result: dict[bytes, bytes] = {}
- last_name = None
- while exts_opts:
- name, exts_opts = _get_sshstr(exts_opts)
- bname: bytes = bytes(name)
- if bname in result:
- raise ValueError("Duplicate name")
- if last_name is not None and bname < last_name:
- raise ValueError("Fields not lexically sorted")
- value, exts_opts = _get_sshstr(exts_opts)
- if len(value) > 0:
- value, extra = _get_sshstr(value)
- if len(extra) > 0:
- raise ValueError("Unexpected extra data after value")
- result[bname] = bytes(value)
- last_name = bname
- return result
- def load_ssh_public_key(
- data: bytes, backend: typing.Any = None
- ) -> SSHPublicKeyTypes:
- cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
- public_key: SSHPublicKeyTypes
- if isinstance(cert_or_key, SSHCertificate):
- public_key = cert_or_key.public_key()
- else:
- public_key = cert_or_key
- if isinstance(public_key, dsa.DSAPublicKey):
- warnings.warn(
- "SSH DSA keys are deprecated and will be removed in a future "
- "release.",
- utils.DeprecatedIn40,
- stacklevel=2,
- )
- return public_key
- def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
- """One-line public key format for OpenSSH"""
- if isinstance(public_key, dsa.DSAPublicKey):
- warnings.warn(
- "SSH DSA key support is deprecated and will be "
- "removed in a future release",
- utils.DeprecatedIn40,
- stacklevel=4,
- )
- key_type = _get_ssh_key_type(public_key)
- kformat = _lookup_kformat(key_type)
- f_pub = _FragList()
- f_pub.put_sshstr(key_type)
- kformat.encode_public(public_key, f_pub)
- pub = binascii.b2a_base64(f_pub.tobytes()).strip()
- return b"".join([key_type, b" ", pub])
- SSHCertPrivateKeyTypes = typing.Union[
- ec.EllipticCurvePrivateKey,
- rsa.RSAPrivateKey,
- ed25519.Ed25519PrivateKey,
- ]
- # This is an undocumented limit enforced in the openssh codebase for sshd and
- # ssh-keygen, but it is undefined in the ssh certificates spec.
- _SSHKEY_CERT_MAX_PRINCIPALS = 256
- class SSHCertificateBuilder:
- def __init__(
- self,
- _public_key: SSHCertPublicKeyTypes | None = None,
- _serial: int | None = None,
- _type: SSHCertificateType | None = None,
- _key_id: bytes | None = None,
- _valid_principals: list[bytes] = [],
- _valid_for_all_principals: bool = False,
- _valid_before: int | None = None,
- _valid_after: int | None = None,
- _critical_options: list[tuple[bytes, bytes]] = [],
- _extensions: list[tuple[bytes, bytes]] = [],
- ):
- self._public_key = _public_key
- self._serial = _serial
- self._type = _type
- self._key_id = _key_id
- self._valid_principals = _valid_principals
- self._valid_for_all_principals = _valid_for_all_principals
- self._valid_before = _valid_before
- self._valid_after = _valid_after
- self._critical_options = _critical_options
- self._extensions = _extensions
- def public_key(
- self, public_key: SSHCertPublicKeyTypes
- ) -> SSHCertificateBuilder:
- if not isinstance(
- public_key,
- (
- ec.EllipticCurvePublicKey,
- rsa.RSAPublicKey,
- ed25519.Ed25519PublicKey,
- ),
- ):
- raise TypeError("Unsupported key type")
- if self._public_key is not None:
- raise ValueError("public_key already set")
- return SSHCertificateBuilder(
- _public_key=public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def serial(self, serial: int) -> SSHCertificateBuilder:
- if not isinstance(serial, int):
- raise TypeError("serial must be an integer")
- if not 0 <= serial < 2**64:
- raise ValueError("serial must be between 0 and 2**64")
- if self._serial is not None:
- raise ValueError("serial already set")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
- if not isinstance(type, SSHCertificateType):
- raise TypeError("type must be an SSHCertificateType")
- if self._type is not None:
- raise ValueError("type already set")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
- if not isinstance(key_id, bytes):
- raise TypeError("key_id must be bytes")
- if self._key_id is not None:
- raise ValueError("key_id already set")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def valid_principals(
- self, valid_principals: list[bytes]
- ) -> SSHCertificateBuilder:
- if self._valid_for_all_principals:
- raise ValueError(
- "Principals can't be set because the cert is valid "
- "for all principals"
- )
- if (
- not all(isinstance(x, bytes) for x in valid_principals)
- or not valid_principals
- ):
- raise TypeError(
- "principals must be a list of bytes and can't be empty"
- )
- if self._valid_principals:
- raise ValueError("valid_principals already set")
- if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
- raise ValueError(
- "Reached or exceeded the maximum number of valid_principals"
- )
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def valid_for_all_principals(self):
- if self._valid_principals:
- raise ValueError(
- "valid_principals already set, can't set "
- "valid_for_all_principals"
- )
- if self._valid_for_all_principals:
- raise ValueError("valid_for_all_principals already set")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=True,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
- if not isinstance(valid_before, (int, float)):
- raise TypeError("valid_before must be an int or float")
- valid_before = int(valid_before)
- if valid_before < 0 or valid_before >= 2**64:
- raise ValueError("valid_before must [0, 2**64)")
- if self._valid_before is not None:
- raise ValueError("valid_before already set")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
- if not isinstance(valid_after, (int, float)):
- raise TypeError("valid_after must be an int or float")
- valid_after = int(valid_after)
- if valid_after < 0 or valid_after >= 2**64:
- raise ValueError("valid_after must [0, 2**64)")
- if self._valid_after is not None:
- raise ValueError("valid_after already set")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=valid_after,
- _critical_options=self._critical_options,
- _extensions=self._extensions,
- )
- def add_critical_option(
- self, name: bytes, value: bytes
- ) -> SSHCertificateBuilder:
- if not isinstance(name, bytes) or not isinstance(value, bytes):
- raise TypeError("name and value must be bytes")
- # This is O(n**2)
- if name in [name for name, _ in self._critical_options]:
- raise ValueError("Duplicate critical option name")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=[*self._critical_options, (name, value)],
- _extensions=self._extensions,
- )
- def add_extension(
- self, name: bytes, value: bytes
- ) -> SSHCertificateBuilder:
- if not isinstance(name, bytes) or not isinstance(value, bytes):
- raise TypeError("name and value must be bytes")
- # This is O(n**2)
- if name in [name for name, _ in self._extensions]:
- raise ValueError("Duplicate extension name")
- return SSHCertificateBuilder(
- _public_key=self._public_key,
- _serial=self._serial,
- _type=self._type,
- _key_id=self._key_id,
- _valid_principals=self._valid_principals,
- _valid_for_all_principals=self._valid_for_all_principals,
- _valid_before=self._valid_before,
- _valid_after=self._valid_after,
- _critical_options=self._critical_options,
- _extensions=[*self._extensions, (name, value)],
- )
- def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
- if not isinstance(
- private_key,
- (
- ec.EllipticCurvePrivateKey,
- rsa.RSAPrivateKey,
- ed25519.Ed25519PrivateKey,
- ),
- ):
- raise TypeError("Unsupported private key type")
- if self._public_key is None:
- raise ValueError("public_key must be set")
- # Not required
- serial = 0 if self._serial is None else self._serial
- if self._type is None:
- raise ValueError("type must be set")
- # Not required
- key_id = b"" if self._key_id is None else self._key_id
- # A zero length list is valid, but means the certificate
- # is valid for any principal of the specified type. We require
- # the user to explicitly set valid_for_all_principals to get
- # that behavior.
- if not self._valid_principals and not self._valid_for_all_principals:
- raise ValueError(
- "valid_principals must be set if valid_for_all_principals "
- "is False"
- )
- if self._valid_before is None:
- raise ValueError("valid_before must be set")
- if self._valid_after is None:
- raise ValueError("valid_after must be set")
- if self._valid_after > self._valid_before:
- raise ValueError("valid_after must be earlier than valid_before")
- # lexically sort our byte strings
- self._critical_options.sort(key=lambda x: x[0])
- self._extensions.sort(key=lambda x: x[0])
- key_type = _get_ssh_key_type(self._public_key)
- cert_prefix = key_type + _CERT_SUFFIX
- # Marshal the bytes to be signed
- nonce = os.urandom(32)
- kformat = _lookup_kformat(key_type)
- f = _FragList()
- f.put_sshstr(cert_prefix)
- f.put_sshstr(nonce)
- kformat.encode_public(self._public_key, f)
- f.put_u64(serial)
- f.put_u32(self._type.value)
- f.put_sshstr(key_id)
- fprincipals = _FragList()
- for p in self._valid_principals:
- fprincipals.put_sshstr(p)
- f.put_sshstr(fprincipals.tobytes())
- f.put_u64(self._valid_after)
- f.put_u64(self._valid_before)
- fcrit = _FragList()
- for name, value in self._critical_options:
- fcrit.put_sshstr(name)
- if len(value) > 0:
- foptval = _FragList()
- foptval.put_sshstr(value)
- fcrit.put_sshstr(foptval.tobytes())
- else:
- fcrit.put_sshstr(value)
- f.put_sshstr(fcrit.tobytes())
- fext = _FragList()
- for name, value in self._extensions:
- fext.put_sshstr(name)
- if len(value) > 0:
- fextval = _FragList()
- fextval.put_sshstr(value)
- fext.put_sshstr(fextval.tobytes())
- else:
- fext.put_sshstr(value)
- f.put_sshstr(fext.tobytes())
- f.put_sshstr(b"") # RESERVED FIELD
- # encode CA public key
- ca_type = _get_ssh_key_type(private_key)
- caformat = _lookup_kformat(ca_type)
- caf = _FragList()
- caf.put_sshstr(ca_type)
- caformat.encode_public(private_key.public_key(), caf)
- f.put_sshstr(caf.tobytes())
- # Sigs according to the rules defined for the CA's public key
- # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
- # and RFC8032 for Ed25519).
- if isinstance(private_key, ed25519.Ed25519PrivateKey):
- signature = private_key.sign(f.tobytes())
- fsig = _FragList()
- fsig.put_sshstr(ca_type)
- fsig.put_sshstr(signature)
- f.put_sshstr(fsig.tobytes())
- elif isinstance(private_key, ec.EllipticCurvePrivateKey):
- hash_alg = _get_ec_hash_alg(private_key.curve)
- signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
- r, s = asym_utils.decode_dss_signature(signature)
- fsig = _FragList()
- fsig.put_sshstr(ca_type)
- fsigblob = _FragList()
- fsigblob.put_mpint(r)
- fsigblob.put_mpint(s)
- fsig.put_sshstr(fsigblob.tobytes())
- f.put_sshstr(fsig.tobytes())
- else:
- assert isinstance(private_key, rsa.RSAPrivateKey)
- # Just like Golang, we're going to use SHA512 for RSA
- # https://cs.opensource.google/go/x/crypto/+/refs/tags/
- # v0.4.0:ssh/certs.go;l=445
- # RFC 8332 defines SHA256 and 512 as options
- fsig = _FragList()
- fsig.put_sshstr(_SSH_RSA_SHA512)
- signature = private_key.sign(
- f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
- )
- fsig.put_sshstr(signature)
- f.put_sshstr(fsig.tobytes())
- cert_data = binascii.b2a_base64(f.tobytes()).strip()
- # load_ssh_public_identity returns a union, but this is
- # guaranteed to be an SSHCertificate, so we cast to make
- # mypy happy.
- return typing.cast(
- SSHCertificate,
- load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
- )
|