From da7a8db29e2073e3bc50f5a8192119af97ef7115 Mon Sep 17 00:00:00 2001 From: Randall Hauch Date: Fri, 30 Jul 2021 17:48:03 -0500 Subject: [PATCH] MINOR: Use time constant algorithms when comparing passwords or keys (#10978) Author: Randall Hauch Reviewers: Manikumar Reddy , Rajini Sivaram , Mickael Maison , Ismael Juma --- .../internals/PlainServerCallbackHandler.java | 4 +- .../scram/internals/ScramSaslClient.java | 3 +- .../scram/internals/ScramSaslServer.java | 3 +- .../token/delegation/DelegationToken.java | 3 +- .../org/apache/kafka/common/utils/Utils.java | 36 +++++++++++++++ .../apache/kafka/common/utils/UtilsTest.java | 44 +++++++++++++++++++ 6 files changed, 88 insertions(+), 5 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java index 842f986abd7..10f5817a9b3 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java +++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java @@ -22,9 +22,9 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.security.plain.PlainAuthenticateCallback; import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.apache.kafka.common.utils.Utils; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.Map; @@ -65,7 +65,7 @@ public class PlainServerCallbackHandler implements AuthenticateCallbackHandler { String expectedPassword = JaasContext.configEntryOption(jaasConfigEntries, JAAS_USER_PREFIX + username, PlainLoginModule.class.getName()); - return expectedPassword != null && Arrays.equals(password, expectedPassword.toCharArray()); + return expectedPassword != null && Utils.isEqualConstantTime(password, expectedPassword.toCharArray()); } } diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java index c21a52e1116..2e6191b799a 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java @@ -18,6 +18,7 @@ package org.apache.kafka.common.security.scram.internals; import java.nio.charset.StandardCharsets; import java.security.InvalidKeyException; +import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.Collection; @@ -204,7 +205,7 @@ public class ScramSaslClient implements SaslClient { try { byte[] serverKey = formatter.serverKey(saltedPassword); byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); - if (!Arrays.equals(signature, serverSignature)) + if (!MessageDigest.isEqual(signature, serverSignature)) throw new SaslException("Invalid server signature in server final message"); } catch (InvalidKeyException e) { throw new SaslException("Sasl server signature verification failed", e); diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java index f6286a60f15..3cc8ff0b399 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java @@ -17,6 +17,7 @@ package org.apache.kafka.common.security.scram.internals; import java.security.InvalidKeyException; +import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.Collection; @@ -226,7 +227,7 @@ public class ScramSaslServer implements SaslServer { byte[] expectedStoredKey = scramCredential.storedKey(); byte[] clientSignature = formatter.clientSignature(expectedStoredKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); byte[] computedStoredKey = formatter.storedKey(clientSignature, clientFinalMessage.proof()); - if (!Arrays.equals(computedStoredKey, expectedStoredKey)) + if (!MessageDigest.isEqual(computedStoredKey, expectedStoredKey)) throw new SaslException("Invalid client credentials"); } catch (InvalidKeyException e) { throw new SaslException("Sasl client verification failed", e); diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java index b389a199a92..a2141b5549b 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java +++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java @@ -18,6 +18,7 @@ package org.apache.kafka.common.security.token.delegation; import org.apache.kafka.common.annotation.InterfaceStability; +import java.security.MessageDigest; import java.util.Arrays; import java.util.Base64; import java.util.Objects; @@ -59,7 +60,7 @@ public class DelegationToken { DelegationToken token = (DelegationToken) o; - return Objects.equals(tokenInformation, token.tokenInformation) && Arrays.equals(hmac, token.hmac); + return Objects.equals(tokenInformation, token.tokenInformation) && MessageDigest.isEqual(hmac, token.hmac); } @Override diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java index 5fa32b73ea1..921ce3c703a 100755 --- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java +++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java @@ -298,6 +298,42 @@ public final class Utils { return t; } + /** + * Compares two character arrays for equality using a constant-time algorithm, which is needed + * for comparing passwords. Two arrays are equal if they have the same length and all + * characters at corresponding positions are equal. + * + * All characters in the first array are examined to determine equality. + * The calculation time depends only on the length of this first character array; it does not + * depend on the length of the second character array or the contents of either array. + * + * @param first the first array to compare + * @param second the second array to compare + * @return true if the arrays are equal, or false otherwise + */ + public static boolean isEqualConstantTime(char[] first, char[] second) { + if (first == second) { + return true; + } + if (first == null || second == null) { + return false; + } + + if (second.length == 0) { + return first.length == 0; + } + + // time-constant comparison that always compares all characters in first array + boolean matches = first.length == second.length; + for (int i = 0; i < first.length; ++i) { + int j = i < second.length ? i : 0; + if (first[i] != second[j]) { + matches = false; + } + } + return matches; + } + /** * Sleep for a bit * @param ms The duration of the sleep diff --git a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java index 172a992b1bd..5762be4a0d8 100755 --- a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java +++ b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java @@ -47,6 +47,8 @@ import static org.apache.kafka.common.utils.Utils.validHostPattern; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; @@ -487,4 +489,46 @@ public class UtilsTest { } catch (IllegalArgumentException e) { } } + + @Test + public void testCharacterArrayEquality() { + assertCharacterArraysAreNotEqual(null, "abc"); + assertCharacterArraysAreNotEqual(null, ""); + assertCharacterArraysAreNotEqual("abc", null); + assertCharacterArraysAreNotEqual("", null); + assertCharacterArraysAreNotEqual("", "abc"); + assertCharacterArraysAreNotEqual("abc", "abC"); + assertCharacterArraysAreNotEqual("abc", "abcd"); + assertCharacterArraysAreNotEqual("abc", "abcdefg"); + assertCharacterArraysAreNotEqual("abcdefg", "abc"); + assertCharacterArraysAreEqual("abc", "abc"); + assertCharacterArraysAreEqual("a", "a"); + assertCharacterArraysAreEqual("", ""); + assertCharacterArraysAreEqual("", ""); + assertCharacterArraysAreEqual(null, null); + } + + private void assertCharacterArraysAreNotEqual(String a, String b) { + char[] first = a != null ? a.toCharArray() : null; + char[] second = b != null ? b.toCharArray() : null; + if (a == null) { + assertNotNull(b); + } else { + assertFalse(a.equals(b)); + } + assertFalse(Utils.isEqualConstantTime(first, second)); + assertFalse(Utils.isEqualConstantTime(second, first)); + } + + private void assertCharacterArraysAreEqual(String a, String b) { + char[] first = a != null ? a.toCharArray() : null; + char[] second = b != null ? b.toCharArray() : null; + if (a == null) { + assertNull(b); + } else { + assertTrue(a.equals(b)); + } + assertTrue(Utils.isEqualConstantTime(first, second)); + assertTrue(Utils.isEqualConstantTime(second, first)); + } }