diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java index ef746fbb11e..5dc57dead3a 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java @@ -31,6 +31,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; + import org.apache.kafka.common.utils.Time; import org.jose4j.jwk.HttpsJwks; import org.jose4j.jwk.JsonWebKey; @@ -44,7 +45,7 @@ import org.slf4j.LoggerFactory; * possible to receive a JWT that contains a kid that points to yet-unknown JWK, * thus requiring a connection to the OAuth/OIDC provider to be made. Hopefully, in practice, * keys are made available for some amount of time before they're used within JWTs. - * + *

* This instance is created and provided to the * {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using * an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then @@ -75,7 +76,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { * JWKS. In some cases, the call to {@link HttpsJwks#getJsonWebKeys()} will trigger a call * to {@link HttpsJwks#refresh()} which will block the current thread in network I/O. We cache * the JWKS ourselves (see {@link #jsonWebKeys}) to avoid the network I/O. - * + *

* We want to be very careful where we use the {@link HttpsJwks} instance so that we don't * perform any operation (directly or indirectly) that could cause blocking. This is because * the JWKS logic is part of the larger authentication logic which operates on Kafka's network @@ -121,23 +122,17 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { private boolean isInitialized; /** - * Creates a RefreshingHttpsJwks that will be used by the - * {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs. - * - * @param time {@link Time} instance - * @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS - * based on the OAuth/OIDC standard - * @param refreshMs The number of milliseconds between refresh passes to connect - * to the OAuth/OIDC JWKS endpoint to retrieve the latest set - * @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS - * @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS + * Creates a RefreshingHttpsJwks. It should only be used for testing to pass in a mock executor + * service. Otherwise the constructor below should be used. */ - public RefreshingHttpsJwks(Time time, - HttpsJwks httpsJwks, - long refreshMs, - long refreshRetryBackoffMs, - long refreshRetryBackoffMaxMs) { + // VisibleForTesting + RefreshingHttpsJwks(Time time, + HttpsJwks httpsJwks, + long refreshMs, + long refreshRetryBackoffMs, + long refreshRetryBackoffMaxMs, + ScheduledExecutorService executorService) { if (refreshMs <= 0) throw new IllegalArgumentException("JWKS validation key refresh configuration value retryWaitMs value must be positive"); @@ -146,7 +141,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { this.refreshMs = refreshMs; this.refreshRetryBackoffMs = refreshRetryBackoffMs; this.refreshRetryBackoffMaxMs = refreshRetryBackoffMaxMs; - this.executorService = Executors.newSingleThreadScheduledExecutor(); + this.executorService = executorService; this.missingKeyIds = new LinkedHashMap(MISSING_KEY_ID_CACHE_MAX_ENTRIES, .75f, true) { @Override protected boolean removeEldestEntry(Map.Entry eldest) { @@ -155,6 +150,27 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { }; } + /** + * Creates a RefreshingHttpsJwks that will be used by the + * {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs. + * + * @param time {@link Time} instance + * @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS + * based on the OAuth/OIDC standard + * @param refreshMs The number of milliseconds between refresh passes to connect + * to the OAuth/OIDC JWKS endpoint to retrieve the latest set + * @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS + * @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS + */ + + public RefreshingHttpsJwks(Time time, + HttpsJwks httpsJwks, + long refreshMs, + long refreshRetryBackoffMs, + long refreshRetryBackoffMaxMs) { + this(time, httpsJwks, refreshMs, refreshRetryBackoffMs, refreshRetryBackoffMaxMs, Executors.newSingleThreadScheduledExecutor()); + } + @Override public void init() throws IOException { try { @@ -180,9 +196,9 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { // // Note: we refer to this as a _scheduled_ refresh. executorService.scheduleAtFixedRate(this::refresh, - refreshMs, - refreshMs, - TimeUnit.MILLISECONDS); + refreshMs, + refreshMs, + TimeUnit.MILLISECONDS); log.info("JWKS validation key refresh thread started with a refresh interval of {} ms", refreshMs); } finally { @@ -203,7 +219,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT)) { log.warn("JWKS validation key refresh thread termination did not end after {} {}", - SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT); + SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT); } } catch (InterruptedException e) { log.warn("JWKS validation key refresh thread error during close", e); @@ -217,13 +233,12 @@ public final class RefreshingHttpsJwks implements Initable, Closeable { * Our implementation avoids the blocking call within {@link HttpsJwks#refresh()} that is * sometimes called internal to {@link HttpsJwks#getJsonWebKeys()}. We want to avoid any * blocking I/O as this code is running in the authentication path on the Kafka network thread. - * + *

* The list may be stale up to {@link #refreshMs}. * * @return {@link List} of {@link JsonWebKey} instances - * * @throws JoseException Thrown if a problem is encountered parsing the JSON content into JWKs - * @throws IOException Thrown f a problem is encountered making the HTTP request + * @throws IOException Thrown f a problem is encountered making the HTTP request */ public List getJsonWebKeys() throws JoseException, IOException { diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksTest.java index 29e36115d38..705a539d906 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksTest.java @@ -25,11 +25,21 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.util.Arrays; -import java.util.Collection; import java.util.Collections; +import java.util.Collection; import java.util.List; - +import java.util.ArrayList; +import java.util.Arrays; +import java.util.AbstractMap; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.internals.KafkaFutureImpl; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; import org.jose4j.http.SimpleResponse; @@ -122,14 +132,36 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest { @Test public void testSecondaryRefreshAfterElapsedDelay() throws Exception { String keyId = "abc123"; - Time time = MockTime.SYSTEM; // Unfortunately, we can't mock time here because the - // scheduled executor doesn't respect it. + MockTime time = new MockTime(); HttpsJwks httpsJwks = spyHttpsJwks(); - - try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) { + MockExecutorService mockExecutorService = new MockExecutorService(time); + ScheduledExecutorService executorService = Mockito.mock(ScheduledExecutorService.class); + Mockito.doAnswer(invocation -> { + Runnable command = invocation.getArgument(0, Runnable.class); + long delay = invocation.getArgument(1, Long.class); + TimeUnit unit = invocation.getArgument(2, TimeUnit.class); + return mockExecutorService.schedule(() -> { + command.run(); + return null; + }, unit.toMillis(delay), null); + }).when(executorService).schedule(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.any(TimeUnit.class)); + Mockito.doAnswer(invocation -> { + Runnable command = invocation.getArgument(0, Runnable.class); + long initialDelay = invocation.getArgument(1, Long.class); + long period = invocation.getArgument(2, Long.class); + TimeUnit unit = invocation.getArgument(3, TimeUnit.class); + return mockExecutorService.schedule(() -> { + command.run(); + return null; + }, unit.toMillis(initialDelay), period); + }).when(executorService).scheduleAtFixedRate(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.anyLong(), Mockito.any(TimeUnit.class)); + + try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks, executorService)) { refreshingHttpsJwks.init(); + // We refresh once at the initialization time from getJsonWebKeys. verify(httpsJwks, times(1)).refresh(); assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + verify(httpsJwks, times(2)).refresh(); time.sleep(REFRESH_MS + 1); verify(httpsJwks, times(3)).refresh(); assertFalse(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); @@ -153,6 +185,10 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest { return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS); } + private RefreshingHttpsJwks getRefreshingHttpsJwks(final Time time, final HttpsJwks httpsJwks, final ScheduledExecutorService executorService) { + return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS, executorService); + } + /** * We *spy* (not *mock*) the {@link HttpsJwks} instance because we want to have it * _partially mocked_ to determine if it's calling its internal refresh method. We want to @@ -195,4 +231,82 @@ public class RefreshingHttpsJwksTest extends OAuthBearerTest { return Mockito.spy(httpsJwks); } -} + /** + * A mock ScheduledExecutorService just for the test. Note that this is not a generally reusable mock as it does not + * implement some interfaces like scheduleWithFixedDelay, etc. And it does not return ScheduledFuture correctly. + */ + private class MockExecutorService implements MockTime.Listener { + private final MockTime time; + + private final TreeMap>>> waiters = new TreeMap<>(); + + public MockExecutorService(MockTime time) { + this.time = time; + time.addListener(this); + } + + /** + * The actual execution and rescheduling logic. Check all internal tasks to see if any one reaches its next + * execution point, call it and optionally reschedule it if it has a specified period. + */ + @Override + public synchronized void onTimeUpdated() { + long timeMs = time.milliseconds(); + while (true) { + Map.Entry>>> entry = waiters.firstEntry(); + if ((entry == null) || (entry.getKey() > timeMs)) { + break; + } + for (AbstractMap.SimpleEntry> pair : entry.getValue()) { + pair.getValue().complete(timeMs); + if (pair.getKey() != null) { + addWaiter(entry.getKey() + pair.getKey(), pair.getKey(), pair.getValue()); + } + } + waiters.remove(entry.getKey()); + } + } + + /** + * Add a task with `delayMs` and optional period to the internal waiter. + * When `delayMs` < 0, we immediately complete the waiter. Otherwise, we add the task metadata to the waiter and + * onTimeUpdated will take care of execute and reschedule it when it reaches its scheduled timestamp. + * + * @param delayMs Delay time in ms. + * @param period Scheduling period, null means no periodic. + * @param waiter A wrapper over a callable function. + */ + private synchronized void addWaiter(long delayMs, Long period, KafkaFutureImpl waiter) { + long timeMs = time.milliseconds(); + if (delayMs <= 0) { + waiter.complete(timeMs); + } else { + long triggerTimeMs = timeMs + delayMs; + List>> futures = + waiters.computeIfAbsent(triggerTimeMs, k -> new ArrayList<>()); + futures.add(new AbstractMap.SimpleEntry<>(period, waiter)); + } + } + + /** + * Internal utility function for periodic or one time refreshes. + * + * @param period null indicates one time refresh, otherwise it is periodic. + */ + public ScheduledFuture schedule(final Callable callable, long delayMs, Long period) { + + KafkaFutureImpl waiter = new KafkaFutureImpl<>(); + waiter.thenApply((KafkaFuture.BaseFunction) now -> { + try { + callable.call(); + } catch (Throwable e) { + e.printStackTrace(); + } + return null; + }); + addWaiter(delayMs, period, waiter); + return null; + } + } + +} \ No newline at end of file