diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java index 9b7dbfcb8e..6d12abc287 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/InstanceSupplier.java @@ -75,8 +75,19 @@ public interface InstanceSupplier extends ThrowingSupplier { default InstanceSupplier andThen( ThrowingBiFunction after) { Assert.notNull(after, "After must not be null"); - return registeredBean -> after.applyWithException(registeredBean, - get(registeredBean)); + return new InstanceSupplier() { + + @Override + public V get(RegisteredBean registeredBean) throws Exception { + return after.applyWithException(registeredBean, InstanceSupplier.this.get(registeredBean)); + } + + @Override + public Method getFactoryMethod() { + return InstanceSupplier.this.getFactoryMethod(); + } + + }; } /** @@ -94,6 +105,35 @@ public interface InstanceSupplier extends ThrowingSupplier { return registeredBean -> supplier.getWithException(); } + /** + * Factory method to create an {@link InstanceSupplier} from a + * {@link ThrowingSupplier}. + * @param the type of instance supplied by this supplier + * @param factoryMethod the factory method being used + * @param supplier the source supplier + * @return a new {@link InstanceSupplier} + */ + static InstanceSupplier using(@Nullable Method factoryMethod, ThrowingSupplier supplier) { + Assert.notNull(supplier, "Supplier must not be null"); + if (supplier instanceof InstanceSupplier instanceSupplier + && instanceSupplier.getFactoryMethod() == factoryMethod) { + return instanceSupplier; + } + return new InstanceSupplier() { + + @Override + public T get(RegisteredBean registeredBean) throws Exception { + return supplier.getWithException(); + } + + @Override + public Method getFactoryMethod() { + return factoryMethod; + } + + }; + } + /** * Lambda friendly method that can be used to create a * {@link InstanceSupplier} and add post processors in a single call. For diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java index a995f1b460..5cb43d2619 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/support/InstanceSupplierTests.java @@ -16,6 +16,8 @@ package org.springframework.beans.factory.support; +import java.lang.reflect.Method; + import org.junit.jupiter.api.Test; import org.springframework.util.function.ThrowingBiFunction; @@ -55,7 +57,7 @@ class InstanceSupplierTests { } @Test - void andThenWithBiFunctionWhenFunctionIsNullThrowsException() { + void andThenWhenFunctionIsNullThrowsException() { InstanceSupplier supplier = registeredBean -> "test"; ThrowingBiFunction after = null; assertThatIllegalArgumentException().isThrownBy(() -> supplier.andThen(after)) @@ -63,13 +65,23 @@ class InstanceSupplierTests { } @Test - void andThenWithBiFunctionAppliesFunctionToObtainResult() throws Exception { + void andThenAppliesFunctionToObtainResult() throws Exception { InstanceSupplier supplier = registeredBean -> "bean"; supplier = supplier.andThen( (registeredBean, string) -> registeredBean.getBeanName() + "-" + string); assertThat(supplier.get(this.registeredBean)).isEqualTo("test-bean"); } + @Test + void andThenWhenInstanceSupplierHasFactoryMethod() throws Exception { + Method factoryMethod = getClass().getDeclaredMethod("andThenWhenInstanceSupplierHasFactoryMethod"); + InstanceSupplier supplier = InstanceSupplier.using(factoryMethod, () -> "bean"); + supplier = supplier.andThen( + (registeredBean, string) -> registeredBean.getBeanName() + "-" + string); + assertThat(supplier.get(this.registeredBean)).isEqualTo("test-bean"); + assertThat(supplier.getFactoryMethod()).isSameAs(factoryMethod); + } + @Test void ofSupplierWhenInstanceSupplierReturnsSameInstance() { InstanceSupplier supplier = registeredBean -> "test";