From 546e96c65c86edfb935691e123750f5dc13f694e Mon Sep 17 00:00:00 2001 From: Tristan Sweeney Date: Mon, 7 Nov 2022 17:46:26 -0500 Subject: [PATCH 1/3] Add "algorithm mismatch" error to improve jws Upstream libraries that depend on `jws.verify()` break when the upstream keys contain a mixed set of algorithms. This is a nominal occurance for OIDC servers and should be properly handled. --- jose/backends/cryptography_backend.py | 18 +++++++++--------- jose/backends/ecdsa_backend.py | 6 +++--- jose/backends/native.py | 6 +++--- jose/backends/rsa_backend.py | 6 +++--- jose/exceptions.py | 4 ++++ jose/jws.py | 7 +++++-- 6 files changed, 27 insertions(+), 20 deletions(-) diff --git a/jose/backends/cryptography_backend.py b/jose/backends/cryptography_backend.py index abd24260..1117e106 100644 --- a/jose/backends/cryptography_backend.py +++ b/jose/backends/cryptography_backend.py @@ -15,7 +15,7 @@ from cryptography.x509 import load_pem_x509_certificate from ..constants import ALGORITHMS -from ..exceptions import JWEError, JWKError +from ..exceptions import JWEError, JWKError, JWKAlgMismatchError from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64 from .base import Key @@ -52,7 +52,7 @@ class CryptographyECKey(Key): def __init__(self, key, algorithm, cryptography_backend=default_backend): if algorithm not in ALGORITHMS.EC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm) self.hash_alg = { ALGORITHMS.ES256: self.SHA256, @@ -97,7 +97,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "EC": - raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) if not all(k in jwk_dict for k in ["x", "y", "crv"]): raise JWKError("Mandatory parameters are missing") @@ -226,7 +226,7 @@ class CryptographyRSAKey(Key): def __init__(self, key, algorithm, cryptography_backend=default_backend): if algorithm not in ALGORITHMS.RSA: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm) self.hash_alg = { ALGORITHMS.RS256: self.SHA256, @@ -273,7 +273,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "RSA": - raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) e = base64_to_long(jwk_dict.get("e", 256)) n = base64_to_long(jwk_dict.get("n")) @@ -441,9 +441,9 @@ class CryptographyAESKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.AES: - raise JWKError("%s is not a valid AES algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid AES algorithm" % algorithm) if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO): - raise JWKError("%s is not a supported algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a supported algorithm" % algorithm) self._algorithm = algorithm self._mode = self.MODES.get(self._algorithm) @@ -538,7 +538,7 @@ class CryptographyHMACKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.HMAC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm) self._algorithm = algorithm self._hash_alg = self.ALG_MAP.get(algorithm) @@ -569,7 +569,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "oct": - raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) k = jwk_dict.get("k") k = k.encode("utf-8") diff --git a/jose/backends/ecdsa_backend.py b/jose/backends/ecdsa_backend.py index 756c7ea8..ecb5aac6 100644 --- a/jose/backends/ecdsa_backend.py +++ b/jose/backends/ecdsa_backend.py @@ -4,7 +4,7 @@ from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWKError +from jose.exceptions import JWKError, JWKAlgMismatchError from jose.utils import base64_to_long, long_to_base64 @@ -35,7 +35,7 @@ class ECDSAECKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.EC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm) self.hash_alg = { ALGORITHMS.ES256: self.SHA256, @@ -75,7 +75,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "EC": - raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) if not all(k in jwk_dict for k in ["x", "y", "crv"]): raise JWKError("Mandatory parameters are missing") diff --git a/jose/backends/native.py b/jose/backends/native.py index eb3a6ae3..7661e2c5 100644 --- a/jose/backends/native.py +++ b/jose/backends/native.py @@ -4,7 +4,7 @@ from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWKError +from jose.exceptions import JWKError, JWKAlgMismatchError from jose.utils import base64url_decode, base64url_encode @@ -22,7 +22,7 @@ class HMACKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.HMAC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm) self._algorithm = algorithm self._hash_alg = self.HASHES.get(algorithm) @@ -53,7 +53,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "oct": - raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) k = jwk_dict.get("k") k = k.encode("utf-8") diff --git a/jose/backends/rsa_backend.py b/jose/backends/rsa_backend.py index 4e8ccf1c..c908e4b3 100644 --- a/jose/backends/rsa_backend.py +++ b/jose/backends/rsa_backend.py @@ -13,7 +13,7 @@ ) from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWEError, JWKError +from jose.exceptions import JWEError, JWKError, JWKAlgMismatchError from jose.utils import base64_to_long, long_to_base64 ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported @@ -124,7 +124,7 @@ class RSAKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.RSA: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm) if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5: raise JWKError("alg: %s is not supported by the RSA backend" % algorithm) @@ -174,7 +174,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "RSA": - raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) e = base64_to_long(jwk_dict.get("e")) n = base64_to_long(jwk_dict.get("n")) diff --git a/jose/exceptions.py b/jose/exceptions.py index e8edc3b6..630c954e 100644 --- a/jose/exceptions.py +++ b/jose/exceptions.py @@ -29,6 +29,10 @@ class ExpiredSignatureError(JWTError): class JWKError(JOSEError): pass +class JWKAlgMismatchError(JWKError): + '''JWK Key type doesn't support the given algorithm.''' + pass + class JWEError(JOSEError): """Base error for all JWE errors""" diff --git a/jose/jws.py b/jose/jws.py index bfaf6bd0..944130e1 100644 --- a/jose/jws.py +++ b/jose/jws.py @@ -5,7 +5,7 @@ from jose import jwk from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWSError, JWSSignatureError +from jose.exceptions import JWSError, JWSSignatureError, JWKAlgMismatchError from jose.utils import base64url_decode, base64url_encode @@ -205,7 +205,10 @@ def _load(jwt): def _sig_matches_keys(keys, signing_input, signature, alg): for key in keys: if not isinstance(key, Key): - key = jwk.construct(key, alg) + try: + key = jwk.construct(key, alg) + except JWKAlgMismatchError: + continue try: if key.verify(signing_input, signature): return True From 534857a91fc4543b79f981e94bb5cc1660ae267a Mon Sep 17 00:00:00 2001 From: Tristan Sweeney Date: Tue, 8 Nov 2022 12:21:04 -0500 Subject: [PATCH 2/3] Update test for code coverage --- tests/test_jws.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_jws.py b/tests/test_jws.py index 01b5fd05..e7660892 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -3,7 +3,7 @@ import pytest -from jose import jwk, jws +from jose import jwk, jws, jwt from jose.backends import RSAKey from jose.constants import ALGORITHMS from jose.exceptions import JWSError @@ -25,6 +25,16 @@ def test_unicode_token(self): token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8" jws.verify(token, "secret", ["HS256"]) + def test_hetero_keys(self): + private_key = b"-----BEGIN PRIVATE KEY-----\nMIGEAgEAMBAGByqGSM49AgEGBS..." + public_key = b"-----BEGIN PUBLIC KEY-----\nMHYwEAYHKoZIzj0CAQYFK4EEAC..." + token = jwt.encode({"some": "claims"}, private_key, algorithm="RS256") + + rsa_key = jwk.RSAKey(public_key, "RS256").to_dict() + hmac_key = jwk.HMACKey("secret", "HS256").to_dict() + # RSA key must come second to exercise "JWKAlgMismatchError" + jws.verify(token, {"keys": [hmac_key, rsa_key]}, ["HS256", "RS256"]) + def test_multiple_keys(self): old_jwk_verify = jwk.HMACKey.verify try: From 1ce256e2b1db831ac4fa165b69a9f21a149c86fd Mon Sep 17 00:00:00 2001 From: Tristan Sweeney Date: Tue, 8 Nov 2022 13:34:52 -0500 Subject: [PATCH 3/3] Fix test coverage --- jose/exceptions.py | 1 + tests/test_jws.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/jose/exceptions.py b/jose/exceptions.py index 630c954e..2099208d 100644 --- a/jose/exceptions.py +++ b/jose/exceptions.py @@ -29,6 +29,7 @@ class ExpiredSignatureError(JWTError): class JWKError(JOSEError): pass + class JWKAlgMismatchError(JWKError): '''JWK Key type doesn't support the given algorithm.''' pass diff --git a/tests/test_jws.py b/tests/test_jws.py index e7660892..75a7398d 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -3,10 +3,10 @@ import pytest -from jose import jwk, jws, jwt +from jose import jwk, jws from jose.backends import RSAKey from jose.constants import ALGORITHMS -from jose.exceptions import JWSError +from jose.exceptions import JWSError, JWKAlgMismatchError try: from jose.backends.cryptography_backend import CryptographyRSAKey @@ -26,14 +26,14 @@ def test_unicode_token(self): jws.verify(token, "secret", ["HS256"]) def test_hetero_keys(self): - private_key = b"-----BEGIN PRIVATE KEY-----\nMIGEAgEAMBAGByqGSM49AgEGBS..." - public_key = b"-----BEGIN PUBLIC KEY-----\nMHYwEAYHKoZIzj0CAQYFK4EEAC..." - token = jwt.encode({"some": "claims"}, private_key, algorithm="RS256") - - rsa_key = jwk.RSAKey(public_key, "RS256").to_dict() - hmac_key = jwk.HMACKey("secret", "HS256").to_dict() - # RSA key must come second to exercise "JWKAlgMismatchError" - jws.verify(token, {"keys": [hmac_key, rsa_key]}, ["HS256", "RS256"]) + class BadKey(jwk.Key): + def __init__(self, key, algorithm): + if key != "xyzw": + raise JWKAlgMismatchError("%s is not a valid XYZW algorithm" % algorithm) + + token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8" + jwk.register_key("XYZW", BadKey) + jws.verify(token, {"keys": [{"alg": "XYZW"}, "secret"]}, ["XYZW", "HS256"]) def test_multiple_keys(self): old_jwk_verify = jwk.HMACKey.verify