venv added, updated

This commit is contained in:
Norbert
2024-09-13 09:46:28 +02:00
parent 577596d9f3
commit 82af8c809a
4812 changed files with 640223 additions and 2 deletions

View File

@@ -0,0 +1,35 @@
from .oct_key import OctKey
from .rsa_key import RSAKey
from .ec_key import ECKey
from .jws_algs import JWS_ALGORITHMS
from .jwe_algs import JWE_ALG_ALGORITHMS, AESAlgorithm, ECDHESAlgorithm, u32be_len_input
from .jwe_encs import JWE_ENC_ALGORITHMS, CBCHS2EncAlgorithm
from .jwe_zips import DeflateZipAlgorithm
def register_jws_rfc7518(cls):
for algorithm in JWS_ALGORITHMS:
cls.register_algorithm(algorithm)
def register_jwe_rfc7518(cls):
for algorithm in JWE_ALG_ALGORITHMS:
cls.register_algorithm(algorithm)
for algorithm in JWE_ENC_ALGORITHMS:
cls.register_algorithm(algorithm)
cls.register_algorithm(DeflateZipAlgorithm())
__all__ = [
'register_jws_rfc7518',
'register_jwe_rfc7518',
'OctKey',
'RSAKey',
'ECKey',
'u32be_len_input',
'AESAlgorithm',
'ECDHESAlgorithm',
'CBCHS2EncAlgorithm',
]

View File

@@ -0,0 +1,101 @@
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization,
EllipticCurvePrivateNumbers, EllipticCurvePublicNumbers,
SECP256R1, SECP384R1, SECP521R1, SECP256K1,
)
from cryptography.hazmat.backends import default_backend
from authlib.common.encoding import base64_to_int, int_to_base64
from ..rfc7517 import AsymmetricKey
class ECKey(AsymmetricKey):
"""Key class of the ``EC`` key type."""
kty = 'EC'
DSS_CURVES = {
'P-256': SECP256R1,
'P-384': SECP384R1,
'P-521': SECP521R1,
# https://tools.ietf.org/html/rfc8812#section-3.1
'secp256k1': SECP256K1,
}
CURVES_DSS = {
SECP256R1.name: 'P-256',
SECP384R1.name: 'P-384',
SECP521R1.name: 'P-521',
SECP256K1.name: 'secp256k1',
}
REQUIRED_JSON_FIELDS = ['crv', 'x', 'y']
PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS
PRIVATE_KEY_FIELDS = ['crv', 'd', 'x', 'y']
PUBLIC_KEY_CLS = EllipticCurvePublicKey
PRIVATE_KEY_CLS = EllipticCurvePrivateKeyWithSerialization
SSH_PUBLIC_PREFIX = b'ecdsa-sha2-'
def exchange_shared_key(self, pubkey):
# # used in ECDHESAlgorithm
private_key = self.get_private_key()
if private_key:
return private_key.exchange(ec.ECDH(), pubkey)
raise ValueError('Invalid key for exchanging shared key')
@property
def curve_key_size(self):
raw_key = self.get_private_key()
if not raw_key:
raw_key = self.public_key
return raw_key.curve.key_size
def load_private_key(self):
curve = self.DSS_CURVES[self._dict_data['crv']]()
public_numbers = EllipticCurvePublicNumbers(
base64_to_int(self._dict_data['x']),
base64_to_int(self._dict_data['y']),
curve,
)
private_numbers = EllipticCurvePrivateNumbers(
base64_to_int(self.tokens['d']),
public_numbers
)
return private_numbers.private_key(default_backend())
def load_public_key(self):
curve = self.DSS_CURVES[self._dict_data['crv']]()
public_numbers = EllipticCurvePublicNumbers(
base64_to_int(self._dict_data['x']),
base64_to_int(self._dict_data['y']),
curve,
)
return public_numbers.public_key(default_backend())
def dumps_private_key(self):
numbers = self.private_key.private_numbers()
return {
'crv': self.CURVES_DSS[self.private_key.curve.name],
'x': int_to_base64(numbers.public_numbers.x),
'y': int_to_base64(numbers.public_numbers.y),
'd': int_to_base64(numbers.private_value),
}
def dumps_public_key(self):
numbers = self.public_key.public_numbers()
return {
'crv': self.CURVES_DSS[numbers.curve.name],
'x': int_to_base64(numbers.x),
'y': int_to_base64(numbers.y)
}
@classmethod
def generate_key(cls, crv='P-256', options=None, is_private=False) -> 'ECKey':
if crv not in cls.DSS_CURVES:
raise ValueError(f'Invalid crv value: "{crv}"')
raw_key = ec.generate_private_key(
curve=cls.DSS_CURVES[crv](),
backend=default_backend(),
)
if not is_private:
raw_key = raw_key.public_key()
return cls.import_key(raw_key, options=options)

View File

@@ -0,0 +1,349 @@
import os
import struct
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.keywrap import (
aes_key_wrap,
aes_key_unwrap
)
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import GCM
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
from authlib.common.encoding import (
to_bytes, to_native,
urlsafe_b64decode,
urlsafe_b64encode
)
from authlib.jose.rfc7516 import JWEAlgorithm
from .rsa_key import RSAKey
from .ec_key import ECKey
from .oct_key import OctKey
class DirectAlgorithm(JWEAlgorithm):
name = 'dir'
description = 'Direct use of a shared symmetric key'
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
return {}
def wrap(self, enc_alg, headers, key, preset=None):
cek = key.get_op_key('encrypt')
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return {'ek': b'', 'cek': cek}
def unwrap(self, enc_alg, ek, headers, key):
cek = key.get_op_key('decrypt')
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class RSAAlgorithm(JWEAlgorithm):
#: A key of size 2048 bits or larger MUST be used with these algorithms
#: RSA1_5, RSA-OAEP, RSA-OAEP-256
key_size = 2048
def __init__(self, name, description, pad_fn):
self.name = name
self.description = description
self.padding = pad_fn
def prepare_key(self, raw_data):
return RSAKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
cek = enc_alg.generate_cek()
return {'cek': cek}
def wrap(self, enc_alg, headers, key, preset=None):
if preset and 'cek' in preset:
cek = preset['cek']
else:
cek = enc_alg.generate_cek()
op_key = key.get_op_key('wrapKey')
if op_key.key_size < self.key_size:
raise ValueError('A key of size 2048 bits or larger MUST be used')
ek = op_key.encrypt(cek, self.padding)
return {'ek': ek, 'cek': cek}
def unwrap(self, enc_alg, ek, headers, key):
# it will raise ValueError if failed
op_key = key.get_op_key('unwrapKey')
cek = op_key.decrypt(ek, self.padding)
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class AESAlgorithm(JWEAlgorithm):
def __init__(self, key_size):
self.name = f'A{key_size}KW'
self.description = f'AES Key Wrap using {key_size}-bit key'
self.key_size = key_size
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
cek = enc_alg.generate_cek()
return {'cek': cek}
def _check_key(self, key):
if len(key) * 8 != self.key_size:
raise ValueError(
f'A key of size {self.key_size} bits is required.')
def wrap_cek(self, cek, key):
op_key = key.get_op_key('wrapKey')
self._check_key(op_key)
ek = aes_key_wrap(op_key, cek, default_backend())
return {'ek': ek, 'cek': cek}
def wrap(self, enc_alg, headers, key, preset=None):
if preset and 'cek' in preset:
cek = preset['cek']
else:
cek = enc_alg.generate_cek()
return self.wrap_cek(cek, key)
def unwrap(self, enc_alg, ek, headers, key):
op_key = key.get_op_key('unwrapKey')
self._check_key(op_key)
cek = aes_key_unwrap(op_key, ek, default_backend())
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class AESGCMAlgorithm(JWEAlgorithm):
EXTRA_HEADERS = frozenset(['iv', 'tag'])
def __init__(self, key_size):
self.name = f'A{key_size}GCMKW'
self.description = f'Key wrapping with AES GCM using {key_size}-bit key'
self.key_size = key_size
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
cek = enc_alg.generate_cek()
return {'cek': cek}
def _check_key(self, key):
if len(key) * 8 != self.key_size:
raise ValueError(
f'A key of size {self.key_size} bits is required.')
def wrap(self, enc_alg, headers, key, preset=None):
if preset and 'cek' in preset:
cek = preset['cek']
else:
cek = enc_alg.generate_cek()
op_key = key.get_op_key('wrapKey')
self._check_key(op_key)
#: https://tools.ietf.org/html/rfc7518#section-4.7.1.1
#: The "iv" (initialization vector) Header Parameter value is the
#: base64url-encoded representation of the 96-bit IV value
iv_size = 96
iv = os.urandom(iv_size // 8)
cipher = Cipher(AES(op_key), GCM(iv), backend=default_backend())
enc = cipher.encryptor()
ek = enc.update(cek) + enc.finalize()
h = {
'iv': to_native(urlsafe_b64encode(iv)),
'tag': to_native(urlsafe_b64encode(enc.tag))
}
return {'ek': ek, 'cek': cek, 'header': h}
def unwrap(self, enc_alg, ek, headers, key):
op_key = key.get_op_key('unwrapKey')
self._check_key(op_key)
iv = headers.get('iv')
if not iv:
raise ValueError('Missing "iv" in headers')
tag = headers.get('tag')
if not tag:
raise ValueError('Missing "tag" in headers')
iv = urlsafe_b64decode(to_bytes(iv))
tag = urlsafe_b64decode(to_bytes(tag))
cipher = Cipher(AES(op_key), GCM(iv, tag), backend=default_backend())
d = cipher.decryptor()
cek = d.update(ek) + d.finalize()
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class ECDHESAlgorithm(JWEAlgorithm):
EXTRA_HEADERS = ['epk', 'apu', 'apv']
ALLOWED_KEY_CLS = ECKey
# https://tools.ietf.org/html/rfc7518#section-4.6
def __init__(self, key_size=None):
if key_size is None:
self.name = 'ECDH-ES'
self.description = 'ECDH-ES in the Direct Key Agreement mode'
else:
self.name = f'ECDH-ES+A{key_size}KW'
self.description = (
'ECDH-ES using Concat KDF and CEK wrapped '
'with A{}KW').format(key_size)
self.key_size = key_size
self.aeskw = AESAlgorithm(key_size)
def prepare_key(self, raw_data):
if isinstance(raw_data, self.ALLOWED_KEY_CLS):
return raw_data
return ECKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
epk = self._generate_ephemeral_key(key)
h = self._prepare_headers(epk)
preset = {'epk': epk, 'header': h}
if self.key_size is not None:
cek = enc_alg.generate_cek()
preset['cek'] = cek
return preset
def compute_fixed_info(self, headers, bit_size):
# AlgorithmID
if self.key_size is None:
alg_id = u32be_len_input(headers['enc'])
else:
alg_id = u32be_len_input(headers['alg'])
# PartyUInfo
apu_info = u32be_len_input(headers.get('apu'), True)
# PartyVInfo
apv_info = u32be_len_input(headers.get('apv'), True)
# SuppPubInfo
pub_info = struct.pack('>I', bit_size)
return alg_id + apu_info + apv_info + pub_info
def compute_derived_key(self, shared_key, fixed_info, bit_size):
ckdf = ConcatKDFHash(
algorithm=hashes.SHA256(),
length=bit_size // 8,
otherinfo=fixed_info,
backend=default_backend()
)
return ckdf.derive(shared_key)
def deliver(self, key, pubkey, headers, bit_size):
shared_key = key.exchange_shared_key(pubkey)
fixed_info = self.compute_fixed_info(headers, bit_size)
return self.compute_derived_key(shared_key, fixed_info, bit_size)
def _generate_ephemeral_key(self, key):
return key.generate_key(key['crv'], is_private=True)
def _prepare_headers(self, epk):
# REQUIRED_JSON_FIELDS contains only public fields
pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS}
pub_epk['kty'] = epk.kty
return {'epk': pub_epk}
def wrap(self, enc_alg, headers, key, preset=None):
if self.key_size is None:
bit_size = enc_alg.CEK_SIZE
else:
bit_size = self.key_size
if preset and 'epk' in preset:
epk = preset['epk']
h = {}
else:
epk = self._generate_ephemeral_key(key)
h = self._prepare_headers(epk)
public_key = key.get_op_key('wrapKey')
dk = self.deliver(epk, public_key, headers, bit_size)
if self.key_size is None:
return {'ek': b'', 'cek': dk, 'header': h}
if preset and 'cek' in preset:
preset_for_kw = {'cek': preset['cek']}
else:
preset_for_kw = None
kek = self.aeskw.prepare_key(dk)
rv = self.aeskw.wrap(enc_alg, headers, kek, preset_for_kw)
rv['header'] = h
return rv
def unwrap(self, enc_alg, ek, headers, key):
if 'epk' not in headers:
raise ValueError('Missing "epk" in headers')
if self.key_size is None:
bit_size = enc_alg.CEK_SIZE
else:
bit_size = self.key_size
epk = key.import_key(headers['epk'])
public_key = epk.get_op_key('wrapKey')
dk = self.deliver(key, public_key, headers, bit_size)
if self.key_size is None:
return dk
kek = self.aeskw.prepare_key(dk)
return self.aeskw.unwrap(enc_alg, ek, headers, kek)
def u32be_len_input(s, base64=False):
if not s:
return b'\x00\x00\x00\x00'
if base64:
s = urlsafe_b64decode(to_bytes(s))
else:
s = to_bytes(s)
return struct.pack('>I', len(s)) + s
JWE_ALG_ALGORITHMS = [
DirectAlgorithm(), # dir
RSAAlgorithm('RSA1_5', 'RSAES-PKCS1-v1_5', padding.PKCS1v15()),
RSAAlgorithm(
'RSA-OAEP', 'RSAES OAEP using default parameters',
padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)),
RSAAlgorithm(
'RSA-OAEP-256', 'RSAES OAEP using SHA-256 and MGF1 with SHA-256',
padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)),
AESAlgorithm(128), # A128KW
AESAlgorithm(192), # A192KW
AESAlgorithm(256), # A256KW
AESGCMAlgorithm(128), # A128GCMKW
AESGCMAlgorithm(192), # A192GCMKW
AESGCMAlgorithm(256), # A256GCMKW
ECDHESAlgorithm(None), # ECDH-ES
ECDHESAlgorithm(128), # ECDH-ES+A128KW
ECDHESAlgorithm(192), # ECDH-ES+A192KW
ECDHESAlgorithm(256), # ECDH-ES+A256KW
]
# 'PBES2-HS256+A128KW': '',
# 'PBES2-HS384+A192KW': '',
# 'PBES2-HS512+A256KW': '',

View File

@@ -0,0 +1,144 @@
"""
authlib.jose.rfc7518
~~~~~~~~~~~~~~~~~~~~
Cryptographic Algorithms for Cryptographic Algorithms for Content
Encryption per `Section 5`_.
.. _`Section 5`: https://tools.ietf.org/html/rfc7518#section-5
"""
import hmac
import hashlib
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import GCM, CBC
from cryptography.hazmat.primitives.padding import PKCS7
from cryptography.exceptions import InvalidTag
from ..rfc7516 import JWEEncAlgorithm
from .util import encode_int
class CBCHS2EncAlgorithm(JWEEncAlgorithm):
# The IV used is a 128-bit value generated randomly or
# pseudo-randomly for use in the cipher.
IV_SIZE = 128
def __init__(self, key_size, hash_type):
self.name = f'A{key_size}CBC-HS{hash_type}'
tpl = 'AES_{}_CBC_HMAC_SHA_{} authenticated encryption algorithm'
self.description = tpl.format(key_size, hash_type)
# bit length
self.key_size = key_size
# byte length
self.key_len = key_size // 8
self.CEK_SIZE = key_size * 2
self.hash_alg = getattr(hashlib, f'sha{hash_type}')
def _hmac(self, ciphertext, aad, iv, key):
al = encode_int(len(aad) * 8, 64)
msg = aad + iv + ciphertext + al
d = hmac.new(key, msg, self.hash_alg).digest()
return d[:self.key_len]
def encrypt(self, msg, aad, iv, key):
"""Key Encryption with AES_CBC_HMAC_SHA2.
:param msg: text to be encrypt in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param key: encrypted key in bytes
:return: (ciphertext, iv, tag)
"""
self.check_iv(iv)
hkey = key[:self.key_len]
ekey = key[self.key_len:]
pad = PKCS7(AES.block_size).padder()
padded_data = pad.update(msg) + pad.finalize()
cipher = Cipher(AES(ekey), CBC(iv), backend=default_backend())
enc = cipher.encryptor()
ciphertext = enc.update(padded_data) + enc.finalize()
tag = self._hmac(ciphertext, aad, iv, hkey)
return ciphertext, tag
def decrypt(self, ciphertext, aad, iv, tag, key):
"""Key Decryption with AES AES_CBC_HMAC_SHA2.
:param ciphertext: ciphertext in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param tag: authentication tag in bytes
:param key: encrypted key in bytes
:return: message
"""
self.check_iv(iv)
hkey = key[:self.key_len]
dkey = key[self.key_len:]
_tag = self._hmac(ciphertext, aad, iv, hkey)
if not hmac.compare_digest(_tag, tag):
raise InvalidTag()
cipher = Cipher(AES(dkey), CBC(iv), backend=default_backend())
d = cipher.decryptor()
data = d.update(ciphertext) + d.finalize()
unpad = PKCS7(AES.block_size).unpadder()
return unpad.update(data) + unpad.finalize()
class GCMEncAlgorithm(JWEEncAlgorithm):
# Use of an IV of size 96 bits is REQUIRED with this algorithm.
# https://tools.ietf.org/html/rfc7518#section-5.3
IV_SIZE = 96
def __init__(self, key_size):
self.name = f'A{key_size}GCM'
self.description = f'AES GCM using {key_size}-bit key'
self.key_size = key_size
self.CEK_SIZE = key_size
def encrypt(self, msg, aad, iv, key):
"""Key Encryption with AES GCM
:param msg: text to be encrypt in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param key: encrypted key in bytes
:return: (ciphertext, iv, tag)
"""
self.check_iv(iv)
cipher = Cipher(AES(key), GCM(iv), backend=default_backend())
enc = cipher.encryptor()
enc.authenticate_additional_data(aad)
ciphertext = enc.update(msg) + enc.finalize()
return ciphertext, enc.tag
def decrypt(self, ciphertext, aad, iv, tag, key):
"""Key Decryption with AES GCM
:param ciphertext: ciphertext in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param tag: authentication tag in bytes
:param key: encrypted key in bytes
:return: message
"""
self.check_iv(iv)
cipher = Cipher(AES(key), GCM(iv, tag), backend=default_backend())
d = cipher.decryptor()
d.authenticate_additional_data(aad)
return d.update(ciphertext) + d.finalize()
JWE_ENC_ALGORITHMS = [
CBCHS2EncAlgorithm(128, 256), # A128CBC-HS256
CBCHS2EncAlgorithm(192, 384), # A192CBC-HS384
CBCHS2EncAlgorithm(256, 512), # A256CBC-HS512
GCMEncAlgorithm(128), # A128GCM
GCMEncAlgorithm(192), # A192GCM
GCMEncAlgorithm(256), # A256GCM
]

View File

@@ -0,0 +1,21 @@
import zlib
from ..rfc7516 import JWEZipAlgorithm, JsonWebEncryption
class DeflateZipAlgorithm(JWEZipAlgorithm):
name = 'DEF'
description = 'DEFLATE'
def compress(self, s):
"""Compress bytes data with DEFLATE algorithm."""
data = zlib.compress(s)
# drop gzip headers and tail
return data[2:-4]
def decompress(self, s):
"""Decompress DEFLATE bytes data."""
return zlib.decompress(s, -zlib.MAX_WBITS)
def register_jwe_rfc7518():
JsonWebEncryption.register_algorithm(DeflateZipAlgorithm())

View File

@@ -0,0 +1,215 @@
"""
authlib.jose.rfc7518
~~~~~~~~~~~~~~~~~~~~
"alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_.
.. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3
"""
import hmac
import hashlib
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.utils import (
decode_dss_signature, encode_dss_signature
)
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.exceptions import InvalidSignature
from ..rfc7515 import JWSAlgorithm
from .oct_key import OctKey
from .rsa_key import RSAKey
from .ec_key import ECKey
from .util import encode_int, decode_int
class NoneAlgorithm(JWSAlgorithm):
name = 'none'
description = 'No digital signature or MAC performed'
def prepare_key(self, raw_data):
return None
def sign(self, msg, key):
return b''
def verify(self, msg, sig, key):
return False
class HMACAlgorithm(JWSAlgorithm):
"""HMAC using SHA algorithms for JWS. Available algorithms:
- HS256: HMAC using SHA-256
- HS384: HMAC using SHA-384
- HS512: HMAC using SHA-512
"""
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
def __init__(self, sha_type):
self.name = f'HS{sha_type}'
self.description = f'HMAC using SHA-{sha_type}'
self.hash_alg = getattr(self, f'SHA{sha_type}')
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def sign(self, msg, key):
# it is faster than the one in cryptography
op_key = key.get_op_key('sign')
return hmac.new(op_key, msg, self.hash_alg).digest()
def verify(self, msg, sig, key):
op_key = key.get_op_key('verify')
v_sig = hmac.new(op_key, msg, self.hash_alg).digest()
return hmac.compare_digest(sig, v_sig)
class RSAAlgorithm(JWSAlgorithm):
"""RSA using SHA algorithms for JWS. Available algorithms:
- RS256: RSASSA-PKCS1-v1_5 using SHA-256
- RS384: RSASSA-PKCS1-v1_5 using SHA-384
- RS512: RSASSA-PKCS1-v1_5 using SHA-512
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, sha_type):
self.name = f'RS{sha_type}'
self.description = f'RSASSA-PKCS1-v1_5 using SHA-{sha_type}'
self.hash_alg = getattr(self, f'SHA{sha_type}')
self.padding = padding.PKCS1v15()
def prepare_key(self, raw_data):
return RSAKey.import_key(raw_data)
def sign(self, msg, key):
op_key = key.get_op_key('sign')
return op_key.sign(msg, self.padding, self.hash_alg())
def verify(self, msg, sig, key):
op_key = key.get_op_key('verify')
try:
op_key.verify(sig, msg, self.padding, self.hash_alg())
return True
except InvalidSignature:
return False
class ECAlgorithm(JWSAlgorithm):
"""ECDSA using SHA algorithms for JWS. Available algorithms:
- ES256: ECDSA using P-256 and SHA-256
- ES384: ECDSA using P-384 and SHA-384
- ES512: ECDSA using P-521 and SHA-512
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, name, curve, sha_type):
self.name = name
self.curve = curve
self.description = f'ECDSA using {self.curve} and SHA-{sha_type}'
self.hash_alg = getattr(self, f'SHA{sha_type}')
def prepare_key(self, raw_data):
key = ECKey.import_key(raw_data)
if key['crv'] != self.curve:
raise ValueError(f'Key for "{self.name}" not supported, only "{self.curve}" allowed')
return key
def sign(self, msg, key):
op_key = key.get_op_key('sign')
der_sig = op_key.sign(msg, ECDSA(self.hash_alg()))
r, s = decode_dss_signature(der_sig)
size = key.curve_key_size
return encode_int(r, size) + encode_int(s, size)
def verify(self, msg, sig, key):
key_size = key.curve_key_size
length = (key_size + 7) // 8
if len(sig) != 2 * length:
return False
r = decode_int(sig[:length])
s = decode_int(sig[length:])
der_sig = encode_dss_signature(r, s)
try:
op_key = key.get_op_key('verify')
op_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
class RSAPSSAlgorithm(JWSAlgorithm):
"""RSASSA-PSS using SHA algorithms for JWS. Available algorithms:
- PS256: RSASSA-PSS using SHA-256 and MGF1 with SHA-256
- PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384
- PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, sha_type):
self.name = f'PS{sha_type}'
tpl = 'RSASSA-PSS using SHA-{} and MGF1 with SHA-{}'
self.description = tpl.format(sha_type, sha_type)
self.hash_alg = getattr(self, f'SHA{sha_type}')
def prepare_key(self, raw_data):
return RSAKey.import_key(raw_data)
def sign(self, msg, key):
op_key = key.get_op_key('sign')
return op_key.sign(
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size
),
self.hash_alg()
)
def verify(self, msg, sig, key):
op_key = key.get_op_key('verify')
try:
op_key.verify(
sig,
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size
),
self.hash_alg()
)
return True
except InvalidSignature:
return False
JWS_ALGORITHMS = [
NoneAlgorithm(), # none
HMACAlgorithm(256), # HS256
HMACAlgorithm(384), # HS384
HMACAlgorithm(512), # HS512
RSAAlgorithm(256), # RS256
RSAAlgorithm(384), # RS384
RSAAlgorithm(512), # RS512
ECAlgorithm('ES256', 'P-256', 256),
ECAlgorithm('ES384', 'P-384', 384),
ECAlgorithm('ES512', 'P-521', 512),
ECAlgorithm('ES256K', 'secp256k1', 256), # defined in RFC8812
RSAPSSAlgorithm(256), # PS256
RSAPSSAlgorithm(384), # PS384
RSAPSSAlgorithm(512), # PS512
]

View File

@@ -0,0 +1,95 @@
from authlib.common.encoding import (
to_bytes, to_unicode,
urlsafe_b64encode, urlsafe_b64decode,
)
from authlib.common.security import generate_token
from ..rfc7517 import Key
POSSIBLE_UNSAFE_KEYS = (
b"-----BEGIN ",
b"---- BEGIN ",
b"ssh-rsa ",
b"ssh-dss ",
b"ssh-ed25519 ",
b"ecdsa-sha2-",
)
class OctKey(Key):
"""Key class of the ``oct`` key type."""
kty = 'oct'
REQUIRED_JSON_FIELDS = ['k']
def __init__(self, raw_key=None, options=None):
super().__init__(options)
self.raw_key = raw_key
@property
def public_only(self):
return False
def get_op_key(self, operation):
"""Get the raw key for the given key_op. This method will also
check if the given key_op is supported by this key.
:param operation: key operation value, such as "sign", "encrypt".
:return: raw key
"""
self.check_key_op(operation)
if not self.raw_key:
self.load_raw_key()
return self.raw_key
def load_raw_key(self):
self.raw_key = urlsafe_b64decode(to_bytes(self.tokens['k']))
def load_dict_key(self):
k = to_unicode(urlsafe_b64encode(self.raw_key))
self._dict_data = {'kty': self.kty, 'k': k}
def as_dict(self, is_private=False, **params):
tokens = self.tokens
if 'kid' not in tokens:
tokens['kid'] = self.thumbprint()
tokens.update(params)
return tokens
@classmethod
def validate_raw_key(cls, key):
return isinstance(key, bytes)
@classmethod
def import_key(cls, raw, options=None):
"""Import a key from bytes, string, or dict data."""
if isinstance(raw, cls):
if options is not None:
raw.options.update(options)
return raw
if isinstance(raw, dict):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
else:
raw_key = to_bytes(raw)
# security check
if raw_key.startswith(POSSIBLE_UNSAFE_KEYS):
raise ValueError("This key may not be safe to import")
key = cls(raw_key=raw_key, options=options)
return key
@classmethod
def generate_key(cls, key_size=256, options=None, is_private=True):
"""Generate a ``OctKey`` with the given bit size."""
if not is_private:
raise ValueError('oct key can not be generated as public')
if key_size % 8 != 0:
raise ValueError('Invalid bit size for oct key')
return cls.import_key(generate_token(key_size // 8), options)

View File

@@ -0,0 +1,123 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPublicKey, RSAPrivateKeyWithSerialization,
RSAPrivateNumbers, RSAPublicNumbers,
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.backends import default_backend
from authlib.common.encoding import base64_to_int, int_to_base64
from ..rfc7517 import AsymmetricKey
class RSAKey(AsymmetricKey):
"""Key class of the ``RSA`` key type."""
kty = 'RSA'
PUBLIC_KEY_CLS = RSAPublicKey
PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization
PUBLIC_KEY_FIELDS = ['e', 'n']
PRIVATE_KEY_FIELDS = ['d', 'dp', 'dq', 'e', 'n', 'p', 'q', 'qi']
REQUIRED_JSON_FIELDS = ['e', 'n']
SSH_PUBLIC_PREFIX = b'ssh-rsa'
def dumps_private_key(self):
numbers = self.private_key.private_numbers()
return {
'n': int_to_base64(numbers.public_numbers.n),
'e': int_to_base64(numbers.public_numbers.e),
'd': int_to_base64(numbers.d),
'p': int_to_base64(numbers.p),
'q': int_to_base64(numbers.q),
'dp': int_to_base64(numbers.dmp1),
'dq': int_to_base64(numbers.dmq1),
'qi': int_to_base64(numbers.iqmp)
}
def dumps_public_key(self):
numbers = self.public_key.public_numbers()
return {
'n': int_to_base64(numbers.n),
'e': int_to_base64(numbers.e)
}
def load_private_key(self):
obj = self._dict_data
if 'oth' in obj: # pragma: no cover
# https://tools.ietf.org/html/rfc7518#section-6.3.2.7
raise ValueError('"oth" is not supported yet')
public_numbers = RSAPublicNumbers(
base64_to_int(obj['e']), base64_to_int(obj['n']))
if has_all_prime_factors(obj):
numbers = RSAPrivateNumbers(
d=base64_to_int(obj['d']),
p=base64_to_int(obj['p']),
q=base64_to_int(obj['q']),
dmp1=base64_to_int(obj['dp']),
dmq1=base64_to_int(obj['dq']),
iqmp=base64_to_int(obj['qi']),
public_numbers=public_numbers)
else:
d = base64_to_int(obj['d'])
p, q = rsa_recover_prime_factors(
public_numbers.n, d, public_numbers.e)
numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers)
return numbers.private_key(default_backend())
def load_public_key(self):
numbers = RSAPublicNumbers(
base64_to_int(self._dict_data['e']),
base64_to_int(self._dict_data['n'])
)
return numbers.public_key(default_backend())
@classmethod
def generate_key(cls, key_size=2048, options=None, is_private=False) -> 'RSAKey':
if key_size < 512:
raise ValueError('key_size must not be less than 512')
if key_size % 8 != 0:
raise ValueError('Invalid key_size for RSAKey')
raw_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend(),
)
if not is_private:
raw_key = raw_key.public_key()
return cls.import_key(raw_key, options=options)
@classmethod
def import_dict_key(cls, raw, options=None):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
if 'd' in raw and not has_all_prime_factors(raw):
# reload dict key
key.load_raw_key()
key.load_dict_key()
return key
def has_all_prime_factors(obj):
props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in props]
if all(props_found):
return True
if any(props_found):
raise ValueError(
'RSA key must include all parameters '
'if any are present besides d')
return False

View File

@@ -0,0 +1,12 @@
import binascii
def encode_int(num, bits):
length = ((bits + 7) // 8) * 2
padded_hex = '%0*x' % (length, num)
big_endian = binascii.a2b_hex(padded_hex.encode('ascii'))
return big_endian
def decode_int(b):
return int(binascii.b2a_hex(b), 16)