Browse Source

KAFKA-7902: Replace original loginContext if SASL/OAUTHBEARER refresh login fails (#6233)

Replaces original loginContext if login fails in the refresh thread to ensure that the refresh thread is left in a clean state when there are exceptions while connecting to an OAuth server. Also makes client callback handler more robust by using the token with the longest remaining time for expiry instead of throwing an exception if multiple tokens are found.

Reviewers: Rajini Sivaram <rajinisivaram@googlemail.com>
pull/6240/head
Ron Dagostino 6 years ago committed by Rajini Sivaram
parent
commit
9f7e6b2913
  1. 41
      clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
  2. 22
      clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java
  3. 102
      clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerSaslClienCallbackHandlerTest.java
  4. 79
      clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java

41
clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java

@ -19,9 +19,13 @@ package org.apache.kafka.common.security.oauthbearer.internals; @@ -19,9 +19,13 @@ package org.apache.kafka.common.security.oauthbearer.internals;
import java.io.IOException;
import java.security.AccessController;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
@ -34,6 +38,8 @@ import org.apache.kafka.common.security.auth.SaslExtensions; @@ -34,6 +38,8 @@ import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An implementation of {@code AuthenticateCallbackHandler} that recognizes
@ -49,6 +55,7 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; @@ -49,6 +55,7 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
* configuration property.
*/
public class OAuthBearerSaslClientCallbackHandler implements AuthenticateCallbackHandler {
private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslClientCallbackHandler.class);
private boolean configured = false;
/**
@ -93,11 +100,35 @@ public class OAuthBearerSaslClientCallbackHandler implements AuthenticateCallbac @@ -93,11 +100,35 @@ public class OAuthBearerSaslClientCallbackHandler implements AuthenticateCallbac
Set<OAuthBearerToken> privateCredentials = subject != null
? subject.getPrivateCredentials(OAuthBearerToken.class)
: Collections.emptySet();
if (privateCredentials.size() != 1)
throw new IOException(
String.format("Unable to find OAuth Bearer token in Subject's private credentials (size=%d)",
privateCredentials.size()));
callback.token(privateCredentials.iterator().next());
if (privateCredentials.size() == 0)
throw new IOException("No OAuth Bearer tokens in Subject's private credentials");
if (privateCredentials.size() == 1)
callback.token(privateCredentials.iterator().next());
else {
/*
* There a very small window of time upon token refresh (on the order of milliseconds)
* where both an old and a new token appear on the Subject's private credentials.
* Rather than implement a lock to eliminate this window, we will deal with it by
* checking for the existence of multiple tokens and choosing the one that has the
* longest lifetime. It is also possible that a bug could cause multiple tokens to
* exist (e.g. KAFKA-7902), so dealing with the unlikely possibility that occurs
* during normal operation also allows us to deal more robustly with potential bugs.
*/
SortedSet<OAuthBearerToken> sortedByLifetime =
new TreeSet<>(
new Comparator<OAuthBearerToken>() {
@Override
public int compare(OAuthBearerToken o1, OAuthBearerToken o2) {
return Long.compare(o1.lifetimeMs(), o2.lifetimeMs());
}
});
sortedByLifetime.addAll(privateCredentials);
log.warn("Found {} OAuth Bearer tokens in Subject's private credentials; the oldest expires at {}, will use the newest, which expires at {}",
sortedByLifetime.size(),
new Date(sortedByLifetime.first().lifetimeMs()),
new Date(sortedByLifetime.last().lifetimeMs()));
callback.token(sortedByLifetime.last());
}
}
/**

22
clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java

@ -375,13 +375,21 @@ public abstract class ExpiringCredentialRefreshingLogin implements AutoCloseable @@ -375,13 +375,21 @@ public abstract class ExpiringCredentialRefreshingLogin implements AutoCloseable
*/
ExpiringCredential optionalCredentialToLogout = expiringCredential;
LoginContext optionalLoginContextToLogout = loginContext;
loginContext = loginContextFactory.createLoginContext(ExpiringCredentialRefreshingLogin.this);
log.info("Initiating re-login for {}, logout() still needs to be called on a previous login = {}",
principalName, optionalCredentialToLogout != null);
loginContext.login();
// Perform a logout() on any original credential if necessary
if (optionalCredentialToLogout != null)
optionalLoginContextToLogout.logout();
boolean cleanLogin = false; // remember to restore the original if necessary
try {
loginContext = loginContextFactory.createLoginContext(ExpiringCredentialRefreshingLogin.this);
log.info("Initiating re-login for {}, logout() still needs to be called on a previous login = {}",
principalName, optionalCredentialToLogout != null);
loginContext.login();
cleanLogin = true; // no need to restore the original
// Perform a logout() on any original credential if necessary
if (optionalCredentialToLogout != null)
optionalLoginContextToLogout.logout();
} finally {
if (!cleanLogin)
// restore the original
loginContext = optionalLoginContextToLogout;
}
/*
* Get the new credential and make sure it is not any old one that required a
* logout() after the login()

102
clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerSaslClienCallbackHandlerTest.java

@ -0,0 +1,102 @@ @@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer;
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Collections;
import java.util.Set;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClientCallbackHandler;
import org.junit.Test;
public class OAuthBearerSaslClienCallbackHandlerTest {
private static OAuthBearerToken createTokenWithLifetimeMillis(final long lifetimeMillis) {
return new OAuthBearerToken() {
@Override
public String value() {
return null;
}
@Override
public Long startTimeMs() {
return null;
}
@Override
public Set<String> scope() {
return null;
}
@Override
public String principalName() {
return null;
}
@Override
public long lifetimeMs() {
return lifetimeMillis;
}
};
}
@Test(expected = IOException.class)
public void testWithZeroTokens() throws Throwable {
OAuthBearerSaslClientCallbackHandler handler = createCallbackHandler();
try {
Subject.doAs(new Subject(), (PrivilegedExceptionAction<Void>) () -> {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
handler.handle(new Callback[] {callback});
return null;
});
} catch (PrivilegedActionException e) {
throw e.getCause();
}
}
@Test()
public void testWithPotentiallyMultipleTokens() throws Exception {
OAuthBearerSaslClientCallbackHandler handler = createCallbackHandler();
Subject.doAs(new Subject(), (PrivilegedExceptionAction<Void>) () -> {
final int maxTokens = 4;
final Set<Object> privateCredentials = Subject.getSubject(AccessController.getContext())
.getPrivateCredentials();
privateCredentials.clear();
for (int num = 1; num <= maxTokens; ++num) {
privateCredentials.add(createTokenWithLifetimeMillis(num));
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
handler.handle(new Callback[] {callback});
assertEquals(num, callback.token().lifetimeMs());
}
return null;
});
}
private static OAuthBearerSaslClientCallbackHandler createCallbackHandler() {
OAuthBearerSaslClientCallbackHandler handler = new OAuthBearerSaslClientCallbackHandler();
handler.configure(Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Collections.emptyList());
return handler;
}
}

79
clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java

@ -47,6 +47,7 @@ import org.apache.kafka.common.utils.MockTime; @@ -47,6 +47,7 @@ import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.junit.Test;
import org.mockito.InOrder;
import org.mockito.Mockito;
public class ExpiringCredentialRefreshingLoginTest {
private static final Configuration EMPTY_WILDCARD_CONFIGURATION;
@ -232,8 +233,28 @@ public class ExpiringCredentialRefreshingLoginTest { @@ -232,8 +233,28 @@ public class ExpiringCredentialRefreshingLoginTest {
}
@Override
public LoginContext createLoginContext(ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin) {
return testLoginContext;
public LoginContext createLoginContext(ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin) throws LoginException {
return new LoginContext("", null, null, EMPTY_WILDCARD_CONFIGURATION) {
private boolean loginSuccess = false;
@Override
public void login() throws LoginException {
testLoginContext.login();
loginSuccess = true;
}
@Override
public void logout() throws LoginException {
if (!loginSuccess)
// will cause the refresher thread to exit
throw new IllegalStateException("logout called without a successful login");
testLoginContext.logout();
}
@Override
public Subject getSubject() {
return testLoginContext.getSubject();
}
};
}
@Override
@ -582,6 +603,60 @@ public class ExpiringCredentialRefreshingLoginTest { @@ -582,6 +603,60 @@ public class ExpiringCredentialRefreshingLoginTest {
}
}
@Test
public void testLoginExceptionCausesCorrectLogout() throws Exception {
int numExpectedRefreshes = 3;
boolean clientReloginAllowedBeforeLogout = true;
Subject subject = new Subject();
final LoginContext mockLoginContext = mock(LoginContext.class);
when(mockLoginContext.getSubject()).thenReturn(subject);
Mockito.doNothing().doThrow(new LoginException()).doNothing().when(mockLoginContext).login();
MockTime mockTime = new MockTime();
long startMs = mockTime.milliseconds();
/*
* Identify the lifetime of each expiring credential
*/
long lifetimeMinutes = 100L;
/*
* Identify the point at which refresh will occur in that lifetime
*/
long refreshEveryMinutes = 80L;
/*
* Set an absolute last refresh time that will cause the login thread to exit
* after a certain number of re-logins (by adding an extra half of a refresh
* interval).
*/
long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes
- 1000 * 60 * refreshEveryMinutes / 2;
/*
* Identify buffer time on either side for the refresh algorithm
*/
short minPeriodSeconds = (short) 0;
short bufferSeconds = minPeriodSeconds;
// Create the ExpiringCredentialRefreshingLogin instance under test
TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory();
TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin(
refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime(
1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds,
clientReloginAllowedBeforeLogout),
testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs,
clientReloginAllowedBeforeLogout);
testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin);
/*
* Perform the login and wait up to a certain amount of time for the refresher
* thread to exit. A timeout indicates the thread died due to logout()
* being invoked on an instance where the login() invocation had failed.
*/
assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone());
assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone());
testExpiringCredentialRefreshingLogin.login();
assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone());
testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS);
}
private static List<KafkaFutureImpl<Long>> addWaiters(MockScheduler mockScheduler, long refreshEveryMillis,
int numWaiters) {
List<KafkaFutureImpl<Long>> retvalWaiters = new ArrayList<>(numWaiters);

Loading…
Cancel
Save