Browse Source

KAFKA-14780: Fix flaky test 'testSecondaryRefreshAfterElapsedDelay' (#14078)

"The test RefreshingHttpsJwksTest#testSecondaryRefreshAfterElapsedDelay relies on the actual system clock, which makes it frequently fail. The fix adds a second constructor that allows for passing a ScheduledExecutorService to manually execute the scheduled tasks before refreshing. The fixed task is much more robust and stable.

Co-authored-by: Fei Xie <feixie@MacBook-Pro.attlocal.net>

Reviewers: Divij Vaidya <diviv@amazon.com>, Luke Chen <showuon@gmail.com>
pull/13953/merge
olalamichelle 1 year ago committed by GitHub
parent
commit
9972297e51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 65
      clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java
  2. 130
      clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksTest.java

65
clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java

@ -31,6 +31,7 @@ import java.util.concurrent.TimeUnit; @@ -31,6 +31,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.kafka.common.utils.Time;
import org.jose4j.jwk.HttpsJwks;
import org.jose4j.jwk.JsonWebKey;
@ -44,7 +45,7 @@ import org.slf4j.LoggerFactory; @@ -44,7 +45,7 @@ import org.slf4j.LoggerFactory;
* possible to receive a JWT that contains a <code>kid</code> that points to yet-unknown JWK,
* thus requiring a connection to the OAuth/OIDC provider to be made. Hopefully, in practice,
* keys are made available for some amount of time before they're used within JWTs.
*
* <p>
* This instance is created and provided to the
* {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using
* an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then
@ -75,7 +76,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -75,7 +76,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
* JWKS. In some cases, the call to {@link HttpsJwks#getJsonWebKeys()} will trigger a call
* to {@link HttpsJwks#refresh()} which will block the current thread in network I/O. We cache
* the JWKS ourselves (see {@link #jsonWebKeys}) to avoid the network I/O.
*
* <p>
* We want to be very careful where we use the {@link HttpsJwks} instance so that we don't
* perform any operation (directly or indirectly) that could cause blocking. This is because
* the JWKS logic is part of the larger authentication logic which operates on Kafka's network
@ -121,23 +122,17 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -121,23 +122,17 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
private boolean isInitialized;
/**
* Creates a <code>RefreshingHttpsJwks</code> that will be used by the
* {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs.
*
* @param time {@link Time} instance
* @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS
* based on the OAuth/OIDC standard
* @param refreshMs The number of milliseconds between refresh passes to connect
* to the OAuth/OIDC JWKS endpoint to retrieve the latest set
* @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS
* @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS
* Creates a <code>RefreshingHttpsJwks</code>. It should only be used for testing to pass in a mock executor
* service. Otherwise the constructor below should be used.
*/
public RefreshingHttpsJwks(Time time,
HttpsJwks httpsJwks,
long refreshMs,
long refreshRetryBackoffMs,
long refreshRetryBackoffMaxMs) {
// VisibleForTesting
RefreshingHttpsJwks(Time time,
HttpsJwks httpsJwks,
long refreshMs,
long refreshRetryBackoffMs,
long refreshRetryBackoffMaxMs,
ScheduledExecutorService executorService) {
if (refreshMs <= 0)
throw new IllegalArgumentException("JWKS validation key refresh configuration value retryWaitMs value must be positive");
@ -146,7 +141,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -146,7 +141,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
this.refreshMs = refreshMs;
this.refreshRetryBackoffMs = refreshRetryBackoffMs;
this.refreshRetryBackoffMaxMs = refreshRetryBackoffMaxMs;
this.executorService = Executors.newSingleThreadScheduledExecutor();
this.executorService = executorService;
this.missingKeyIds = new LinkedHashMap<String, Long>(MISSING_KEY_ID_CACHE_MAX_ENTRIES, .75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Long> eldest) {
@ -155,6 +150,27 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -155,6 +150,27 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
};
}
/**
* Creates a <code>RefreshingHttpsJwks</code> that will be used by the
* {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs.
*
* @param time {@link Time} instance
* @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS
* based on the OAuth/OIDC standard
* @param refreshMs The number of milliseconds between refresh passes to connect
* to the OAuth/OIDC JWKS endpoint to retrieve the latest set
* @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS
* @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS
*/
public RefreshingHttpsJwks(Time time,
HttpsJwks httpsJwks,
long refreshMs,
long refreshRetryBackoffMs,
long refreshRetryBackoffMaxMs) {
this(time, httpsJwks, refreshMs, refreshRetryBackoffMs, refreshRetryBackoffMaxMs, Executors.newSingleThreadScheduledExecutor());
}
@Override
public void init() throws IOException {
try {
@ -180,9 +196,9 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -180,9 +196,9 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
//
// Note: we refer to this as a _scheduled_ refresh.
executorService.scheduleAtFixedRate(this::refresh,
refreshMs,
refreshMs,
TimeUnit.MILLISECONDS);
refreshMs,
refreshMs,
TimeUnit.MILLISECONDS);
log.info("JWKS validation key refresh thread started with a refresh interval of {} ms", refreshMs);
} finally {
@ -203,7 +219,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -203,7 +219,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT)) {
log.warn("JWKS validation key refresh thread termination did not end after {} {}",
SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT);
SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT);
}
} catch (InterruptedException e) {
log.warn("JWKS validation key refresh thread error during close", e);
@ -217,13 +233,12 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { @@ -217,13 +233,12 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
* Our implementation avoids the blocking call within {@link HttpsJwks#refresh()} that is
* sometimes called internal to {@link HttpsJwks#getJsonWebKeys()}. We want to avoid any
* blocking I/O as this code is running in the authentication path on the Kafka network thread.
*
* <p>
* The list may be stale up to {@link #refreshMs}.
*
* @return {@link List} of {@link JsonWebKey} instances
*
* @throws JoseException Thrown if a problem is encountered parsing the JSON content into JWKs
* @throws IOException Thrown f a problem is encountered making the HTTP request
* @throws IOException Thrown f a problem is encountered making the HTTP request
*/
public List<JsonWebKey> getJsonWebKeys() throws JoseException, IOException {

130
clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksTest.java

@ -25,11 +25,21 @@ import static org.junit.jupiter.api.Assertions.assertTrue; @@ -25,11 +25,21 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Collection;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.AbstractMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.apache.kafka.common.KafkaFuture;
import org.apache.kafka.common.internals.KafkaFutureImpl;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.jose4j.http.SimpleResponse;
@ -122,14 +132,36 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest { @@ -122,14 +132,36 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest {
@Test
public void testSecondaryRefreshAfterElapsedDelay() throws Exception {
String keyId = "abc123";
Time time = MockTime.SYSTEM; // Unfortunately, we can't mock time here because the
// scheduled executor doesn't respect it.
MockTime time = new MockTime();
HttpsJwks httpsJwks = spyHttpsJwks();
try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) {
MockExecutorService mockExecutorService = new MockExecutorService(time);
ScheduledExecutorService executorService = Mockito.mock(ScheduledExecutorService.class);
Mockito.doAnswer(invocation -> {
Runnable command = invocation.getArgument(0, Runnable.class);
long delay = invocation.getArgument(1, Long.class);
TimeUnit unit = invocation.getArgument(2, TimeUnit.class);
return mockExecutorService.schedule(() -> {
command.run();
return null;
}, unit.toMillis(delay), null);
}).when(executorService).schedule(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.any(TimeUnit.class));
Mockito.doAnswer(invocation -> {
Runnable command = invocation.getArgument(0, Runnable.class);
long initialDelay = invocation.getArgument(1, Long.class);
long period = invocation.getArgument(2, Long.class);
TimeUnit unit = invocation.getArgument(3, TimeUnit.class);
return mockExecutorService.schedule(() -> {
command.run();
return null;
}, unit.toMillis(initialDelay), period);
}).when(executorService).scheduleAtFixedRate(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.anyLong(), Mockito.any(TimeUnit.class));
try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks, executorService)) {
refreshingHttpsJwks.init();
// We refresh once at the initialization time from getJsonWebKeys.
verify(httpsJwks, times(1)).refresh();
assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId));
verify(httpsJwks, times(2)).refresh();
time.sleep(REFRESH_MS + 1);
verify(httpsJwks, times(3)).refresh();
assertFalse(refreshingHttpsJwks.maybeExpediteRefresh(keyId));
@ -153,6 +185,10 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest { @@ -153,6 +185,10 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest {
return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS);
}
private RefreshingHttpsJwks getRefreshingHttpsJwks(final Time time, final HttpsJwks httpsJwks, final ScheduledExecutorService executorService) {
return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS, executorService);
}
/**
* We *spy* (not *mock*) the {@link HttpsJwks} instance because we want to have it
* _partially mocked_ to determine if it's calling its internal refresh method. We want to
@ -195,4 +231,82 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest { @@ -195,4 +231,82 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest {
return Mockito.spy(httpsJwks);
}
}
/**
* A mock ScheduledExecutorService just for the test. Note that this is not a generally reusable mock as it does not
* implement some interfaces like scheduleWithFixedDelay, etc. And it does not return ScheduledFuture correctly.
*/
private class MockExecutorService implements MockTime.Listener {
private final MockTime time;
private final TreeMap<Long, List<AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>>>> waiters = new TreeMap<>();
public MockExecutorService(MockTime time) {
this.time = time;
time.addListener(this);
}
/**
* The actual execution and rescheduling logic. Check all internal tasks to see if any one reaches its next
* execution point, call it and optionally reschedule it if it has a specified period.
*/
@Override
public synchronized void onTimeUpdated() {
long timeMs = time.milliseconds();
while (true) {
Map.Entry<Long, List<AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>>>> entry = waiters.firstEntry();
if ((entry == null) || (entry.getKey() > timeMs)) {
break;
}
for (AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>> pair : entry.getValue()) {
pair.getValue().complete(timeMs);
if (pair.getKey() != null) {
addWaiter(entry.getKey() + pair.getKey(), pair.getKey(), pair.getValue());
}
}
waiters.remove(entry.getKey());
}
}
/**
* Add a task with `delayMs` and optional period to the internal waiter.
* When `delayMs` < 0, we immediately complete the waiter. Otherwise, we add the task metadata to the waiter and
* onTimeUpdated will take care of execute and reschedule it when it reaches its scheduled timestamp.
*
* @param delayMs Delay time in ms.
* @param period Scheduling period, null means no periodic.
* @param waiter A wrapper over a callable function.
*/
private synchronized void addWaiter(long delayMs, Long period, KafkaFutureImpl<Long> waiter) {
long timeMs = time.milliseconds();
if (delayMs <= 0) {
waiter.complete(timeMs);
} else {
long triggerTimeMs = timeMs + delayMs;
List<AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>>> futures =
waiters.computeIfAbsent(triggerTimeMs, k -> new ArrayList<>());
futures.add(new AbstractMap.SimpleEntry<>(period, waiter));
}
}
/**
* Internal utility function for periodic or one time refreshes.
*
* @param period null indicates one time refresh, otherwise it is periodic.
*/
public <T> ScheduledFuture<T> schedule(final Callable<T> callable, long delayMs, Long period) {
KafkaFutureImpl<Long> waiter = new KafkaFutureImpl<>();
waiter.thenApply((KafkaFuture.BaseFunction<Long, Void>) now -> {
try {
callable.call();
} catch (Throwable e) {
e.printStackTrace();
}
return null;
});
addWaiter(delayMs, period, waiter);
return null;
}
}
}
Loading…
Cancel
Save