diff --git a/spring-core/src/main/java/org/springframework/util/ClassUtils.java b/spring-core/src/main/java/org/springframework/util/ClassUtils.java index 52853dd19a..61c54a84a1 100644 --- a/spring-core/src/main/java/org/springframework/util/ClassUtils.java +++ b/spring-core/src/main/java/org/springframework/util/ClassUtils.java @@ -406,24 +406,40 @@ public abstract class ClassUtils { Assert.notNull(clazz, "Class must not be null"); try { ClassLoader target = clazz.getClassLoader(); - if (target == null) { + // Common cases + if (target == classLoader || target == null) { return true; } - ClassLoader cur = classLoader; - if (cur == target) { - return true; + if (classLoader == null) { + return false; } - while (cur != null) { - cur = cur.getParent(); - if (cur == target) { + // Check for match in ancestors -> positive + ClassLoader current = classLoader; + while (current != null) { + current = current.getParent(); + if (current == target) { return true; } } - return false; + // Check for match in children -> negative + while (target != null) { + target = target.getParent(); + if (target == classLoader) { + return false; + } + } } catch (SecurityException ex) { - // Probably from the system ClassLoader - let's consider it safe. - return true; + // Fall through to Class reference comparison below + } + + try { + // Fallback for ClassLoaders without parent/child relationship: + // safe if same Class can be loaded from given ClassLoader + return (clazz == forName(clazz.getName(), classLoader)); + } + catch (ClassNotFoundException ex) { + return false; } } diff --git a/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java b/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java index 8786363c68..5d783df70d 100644 --- a/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java @@ -16,6 +16,7 @@ package org.springframework.util; +import java.io.Externalizable; import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -44,20 +45,21 @@ import static org.junit.Assert.*; * @author Rob Harrop * @author Rick Evans */ -@SuppressWarnings({ "rawtypes", "unchecked" }) public class ClassUtilsTests { private ClassLoader classLoader = getClass().getClassLoader(); + @Before - public void setUp() { + public void clearStatics() { InnerClass.noArgCalled = false; InnerClass.argCalled = false; InnerClass.overloadedCalled = false; } + @Test - public void testIsPresent() throws Exception { + public void testIsPresent() { assertTrue(ClassUtils.isPresent("java.lang.String", classLoader)); assertFalse(ClassUtils.isPresent("java.lang.MySpecialString", classLoader)); } @@ -114,6 +116,36 @@ public class ClassUtilsTests { assertEquals(double[].class, ClassUtils.forName(double[].class.getName(), classLoader)); } + @Test + public void testIsCacheSafe() { + ClassLoader childLoader1 = new ClassLoader(classLoader) {}; + ClassLoader childLoader2 = new ClassLoader(classLoader) {}; + ClassLoader childLoader3 = new ClassLoader(classLoader) { + @Override + public Class loadClass(String name) throws ClassNotFoundException { + return childLoader1.loadClass(name); + } + }; + Class composite = ClassUtils.createCompositeInterface( + new Class[] {Serializable.class, Externalizable.class}, childLoader1); + + assertTrue(ClassUtils.isCacheSafe(String.class, null)); + assertTrue(ClassUtils.isCacheSafe(String.class, classLoader)); + assertTrue(ClassUtils.isCacheSafe(String.class, childLoader1)); + assertTrue(ClassUtils.isCacheSafe(String.class, childLoader2)); + assertTrue(ClassUtils.isCacheSafe(String.class, childLoader3)); + assertFalse(ClassUtils.isCacheSafe(InnerClass.class, null)); + assertTrue(ClassUtils.isCacheSafe(InnerClass.class, classLoader)); + assertTrue(ClassUtils.isCacheSafe(InnerClass.class, childLoader1)); + assertTrue(ClassUtils.isCacheSafe(InnerClass.class, childLoader2)); + assertTrue(ClassUtils.isCacheSafe(InnerClass.class, childLoader3)); + assertFalse(ClassUtils.isCacheSafe(composite, null)); + assertFalse(ClassUtils.isCacheSafe(composite, classLoader)); + assertTrue(ClassUtils.isCacheSafe(composite, childLoader1)); + assertFalse(ClassUtils.isCacheSafe(composite, childLoader2)); + assertTrue(ClassUtils.isCacheSafe(composite, childLoader3)); + } + @Test public void testGetShortName() { String className = ClassUtils.getShortName(getClass()); @@ -199,7 +231,7 @@ public class ClassUtilsTests { } @Test - public void testHasMethod() throws Exception { + public void testHasMethod() { assertTrue(ClassUtils.hasMethod(Collection.class, "size")); assertTrue(ClassUtils.hasMethod(Collection.class, "remove", Object.class)); assertFalse(ClassUtils.hasMethod(Collection.class, "remove")); @@ -207,7 +239,7 @@ public class ClassUtilsTests { } @Test - public void testGetMethodIfAvailable() throws Exception { + public void testGetMethodIfAvailable() { Method method = ClassUtils.getMethodIfAvailable(Collection.class, "size"); assertNotNull(method); assertEquals("size", method.getName()); @@ -278,7 +310,7 @@ public class ClassUtilsTests { @Test public void testClassPackageAsResourcePath() { String result = ClassUtils.classPackageAsResourcePath(Proxy.class); - assertTrue(result.equals("java/lang/reflect")); + assertEquals("java/lang/reflect", result); } @Test @@ -294,7 +326,7 @@ public class ClassUtilsTests { @Test public void testGetAllInterfaces() { DerivedTestObject testBean = new DerivedTestObject(); - List ifcs = Arrays.asList(ClassUtils.getAllInterfaces(testBean)); + List> ifcs = Arrays.asList(ClassUtils.getAllInterfaces(testBean)); assertEquals("Correct number of interfaces", 4, ifcs.size()); assertTrue("Contains Serializable", ifcs.contains(Serializable.class)); assertTrue("Contains ITestBean", ifcs.contains(ITestObject.class)); @@ -303,13 +335,13 @@ public class ClassUtilsTests { @Test public void testClassNamesToString() { - List ifcs = new LinkedList(); + List> ifcs = new LinkedList<>(); ifcs.add(Serializable.class); ifcs.add(Runnable.class); assertEquals("[interface java.io.Serializable, interface java.lang.Runnable]", ifcs.toString()); assertEquals("[java.io.Serializable, java.lang.Runnable]", ClassUtils.classNamesToString(ifcs)); - List classes = new LinkedList(); + List> classes = new LinkedList<>(); classes.add(LinkedList.class); classes.add(Integer.class); assertEquals("[class java.util.LinkedList, class java.lang.Integer]", classes.toString()); @@ -319,7 +351,7 @@ public class ClassUtilsTests { assertEquals("[java.util.List]", ClassUtils.classNamesToString(List.class)); assertEquals("[]", Collections.EMPTY_LIST.toString()); - assertEquals("[]", ClassUtils.classNamesToString(Collections.EMPTY_LIST)); + assertEquals("[]", ClassUtils.classNamesToString(Collections.emptyList())); } @Test