Browse Source

Add dependency ordering for @ModelAttribute methods

Before this change @ModelAttribute methods were not invoked in any
particular order other than ensuring global @ControllerAdvice methods
are called first and local @Controller methods second.

This change introduces a simple algorithm that selects the next
@ModelAttribute method to invoke by making a pass over all methods and
looking for one that has no dependencies (i.e. @ModelAttribute
input arguments) or has all dependencies resolved (i.e. available in
the model). The process is repeated until no more @ModelAttribute
methods remain.

If the next @ModelAttribute method cannot be determined because all
remaining methods have unresolved dependencies, the first available
method is picked anyway just as before, i.e. with required
dependencies created through the default constructor.

Examples in ModelFactoryOrderingTests.

Issue: SPR-6299
pull/547/merge
Rossen Stoyanchev 11 years ago
parent
commit
56a82c1cbe
  1. 86
      spring-web/src/main/java/org/springframework/web/method/annotation/ModelFactory.java
  2. 336
      spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryOrderingTests.java

86
spring-web/src/main/java/org/springframework/web/method/annotation/ModelFactory.java

@ -19,9 +19,13 @@ package org.springframework.web.method.annotation; @@ -19,9 +19,13 @@ package org.springframework.web.method.annotation;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.core.Conventions;
import org.springframework.core.GenericTypeResolver;
@ -54,7 +58,9 @@ import org.springframework.web.method.support.ModelAndViewContainer; @@ -54,7 +58,9 @@ import org.springframework.web.method.support.ModelAndViewContainer;
*/
public final class ModelFactory {
private final List<InvocableHandlerMethod> handlerMethods;
private static final Log logger = LogFactory.getLog(ModelFactory.class);
private final List<ModelMethod> modelMethods = new ArrayList<ModelMethod>();
private final WebDataBinderFactory dataBinderFactory;
@ -63,14 +69,18 @@ public final class ModelFactory { @@ -63,14 +69,18 @@ public final class ModelFactory {
/**
* Create a new instance with the given {@code @ModelAttribute} methods.
* @param handlerMethods the {@code @ModelAttribute} methods to invoke
* @param invocableMethods the {@code @ModelAttribute} methods to invoke
* @param dataBinderFactory for preparation of {@link BindingResult} attributes
* @param sessionAttributesHandler for access to session attributes
*/
public ModelFactory(List<InvocableHandlerMethod> handlerMethods, WebDataBinderFactory dataBinderFactory,
public ModelFactory(List<InvocableHandlerMethod> invocableMethods, WebDataBinderFactory dataBinderFactory,
SessionAttributesHandler sessionAttributesHandler) {
this.handlerMethods = (handlerMethods != null) ? handlerMethods : new ArrayList<InvocableHandlerMethod>();
if (invocableMethods != null) {
for (InvocableHandlerMethod method : invocableMethods) {
this.modelMethods.add(new ModelMethod(method));
}
}
this.dataBinderFactory = dataBinderFactory;
this.sessionAttributesHandler = sessionAttributesHandler;
}
@ -115,7 +125,8 @@ public final class ModelFactory { @@ -115,7 +125,8 @@ public final class ModelFactory {
private void invokeModelAttributeMethods(NativeWebRequest request, ModelAndViewContainer mavContainer)
throws Exception {
for (InvocableHandlerMethod attrMethod : this.handlerMethods) {
while (!this.modelMethods.isEmpty()) {
InvocableHandlerMethod attrMethod = getNextModelMethod(mavContainer).getHandlerMethod();
String modelName = attrMethod.getMethodAnnotation(ModelAttribute.class).value();
if (mavContainer.containsAttribute(modelName)) {
continue;
@ -132,6 +143,25 @@ public final class ModelFactory { @@ -132,6 +143,25 @@ public final class ModelFactory {
}
}
private ModelMethod getNextModelMethod(ModelAndViewContainer mavContainer) {
for (ModelMethod modelMethod : this.modelMethods) {
if (modelMethod.checkDependencies(mavContainer)) {
if (logger.isTraceEnabled()) {
logger.trace("Selected @ModelAttribute method " + modelMethod);
}
this.modelMethods.remove(modelMethod);
return modelMethod;
}
}
ModelMethod modelMethod = this.modelMethods.get(0);
if (logger.isTraceEnabled()) {
logger.trace("Selected @ModelAttribute method (not present: " +
modelMethod.getUnresolvedDependencies(mavContainer)+ ") " + modelMethod);
}
this.modelMethods.remove(modelMethod);
return modelMethod;
}
/**
* Find {@code @ModelAttribute} arguments also listed as {@code @SessionAttributes}.
*/
@ -240,4 +270,50 @@ public final class ModelFactory { @@ -240,4 +270,50 @@ public final class ModelFactory {
!(value instanceof Map) && !BeanUtils.isSimpleValueType(value.getClass()));
}
private static class ModelMethod {
private final InvocableHandlerMethod handlerMethod;
private final Set<String> dependencies = new HashSet<String>();
private ModelMethod(InvocableHandlerMethod handlerMethod) {
this.handlerMethod = handlerMethod;
for (MethodParameter parameter : handlerMethod.getMethodParameters()) {
if (parameter.hasParameterAnnotation(ModelAttribute.class)) {
this.dependencies.add(getNameForParameter(parameter));
}
}
}
public InvocableHandlerMethod getHandlerMethod() {
return this.handlerMethod;
}
public boolean checkDependencies(ModelAndViewContainer mavContainer) {
for (String name : this.dependencies) {
if (!mavContainer.containsAttribute(name)) {
return false;
}
}
return true;
}
public List<String> getUnresolvedDependencies(ModelAndViewContainer mavContainer) {
List<String> result = new ArrayList<String>(this.dependencies.size());
for (String name : this.dependencies) {
if (!mavContainer.containsAttribute(name)) {
result.add(name);
}
}
return result;
}
@Override
public String toString() {
return this.handlerMethod.getMethod().toGenericString();
}
}
}

336
spring-web/src/test/java/org/springframework/web/method/annotation/ModelFactoryOrderingTests.java

@ -0,0 +1,336 @@ @@ -0,0 +1,336 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.method.annotation;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.ui.Model;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.support.DefaultDataBinderFactory;
import org.springframework.web.bind.support.DefaultSessionAttributeStore;
import org.springframework.web.bind.support.SessionAttributeStore;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.HandlerMethodSelector;
import org.springframework.web.method.support.HandlerMethodArgumentResolverComposite;
import org.springframework.web.method.support.InvocableHandlerMethod;
import org.springframework.web.method.support.ModelAndViewContainer;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import static org.junit.Assert.assertTrue;
/**
* Unit tests verifying {@code @ModelAttribute} method inter-dependencies.
*
* @author Rossen Stoyanchev
*/
public class ModelFactoryOrderingTests {
private static final Log logger = LogFactory.getLog(ModelFactoryOrderingTests.class);
private NativeWebRequest webRequest;
private ModelAndViewContainer mavContainer;
private SessionAttributeStore sessionAttributeStore;
@Before
public void setup() {
this.sessionAttributeStore = new DefaultSessionAttributeStore();
this.webRequest = new ServletWebRequest(new MockHttpServletRequest(), new MockHttpServletResponse());
this.mavContainer = new ModelAndViewContainer();
this.mavContainer.addAttribute("methods", new ArrayList<String>());
}
@Test
public void straightLineDependency() throws Exception {
runTest(new StraightLineDependencyController());
assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getB1", "getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getC2", "getC3", "getC4");
assertInvokedBefore("getC3", "getC4");
}
@Test
public void treeDependency() throws Exception {
runTest(new TreeDependencyController());
assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getB1", "getC1", "getC2");
assertInvokedBefore("getB2", "getC3", "getC4");
}
@Test
public void InvertedTreeDependency() throws Exception {
runTest(new InvertedTreeDependencyController());
assertInvokedBefore("getC1", "getA", "getB1");
assertInvokedBefore("getC2", "getA", "getB1");
assertInvokedBefore("getC3", "getA", "getB2");
assertInvokedBefore("getC4", "getA", "getB2");
assertInvokedBefore("getB1", "getA");
assertInvokedBefore("getB2", "getA");
}
@Test
public void unresolvedDependency() throws Exception {
runTest(new UnresolvedDependencyController());
assertInvokedBefore("getA", "getC1", "getC2", "getC3", "getC4");
// No other order guarantees for methods with unresolvable dependencies (and methods that depend on them),
// Required dependencies will be created via default constructor.
}
private void runTest(Object controller) throws Exception {
HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite();
resolvers.addResolver(new ModelAttributeMethodProcessor(false));
resolvers.addResolver(new ModelMethodProcessor());
WebDataBinderFactory dataBinderFactory = new DefaultDataBinderFactory(null);
Class<?> type = controller.getClass();
Set<Method> methods = HandlerMethodSelector.selectMethods(type, METHOD_FILTER);
List<InvocableHandlerMethod> modelMethods = new ArrayList<InvocableHandlerMethod>();
for (Method method : methods) {
InvocableHandlerMethod modelMethod = new InvocableHandlerMethod(controller, method);
modelMethod.setHandlerMethodArgumentResolvers(resolvers);
modelMethod.setDataBinderFactory(dataBinderFactory);
modelMethods.add(modelMethod);
}
Collections.shuffle(modelMethods);
SessionAttributesHandler sessionHandler = new SessionAttributesHandler(type, this.sessionAttributeStore);
ModelFactory factory = new ModelFactory(modelMethods, dataBinderFactory, sessionHandler);
factory.initModel(this.webRequest, this.mavContainer, new HandlerMethod(controller, "handle"));
if (logger.isDebugEnabled()) {
StringBuilder sb = new StringBuilder();
for (String name : getInvokedMethods()) {
sb.append(" >> ").append(name);
}
logger.debug(sb);
}
}
private void assertInvokedBefore(String beforeMethod, String... afterMethods) {
List<String> actual = getInvokedMethods();
for (String afterMethod : afterMethods) {
assertTrue(beforeMethod + " should be before " + afterMethod + ". Actual order: " +
actual.toString(), actual.indexOf(beforeMethod) < actual.indexOf(afterMethod));
}
}
@SuppressWarnings("unchecked")
private List<String> getInvokedMethods() {
return (List<String>) this.mavContainer.getModel().get("methods");
}
private static class AbstractController {
@RequestMapping
public void handle() {
}
@SuppressWarnings("unchecked")
<T> T updateAndReturn(Model model, String methodName, T returnValue) throws IOException {
((List<String>) model.asMap().get("methods")).add(methodName);
return returnValue;
}
}
@SuppressWarnings("unused")
private static class StraightLineDependencyController extends AbstractController {
@ModelAttribute
public A getA(Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
@ModelAttribute
public B1 getB1(@ModelAttribute A a, Model model) throws IOException {
return updateAndReturn(model, "getB1", new B1());
}
@ModelAttribute
public B2 getB2(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getB2", new B2());
}
@ModelAttribute
public C1 getC1(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(@ModelAttribute C1 c1, Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(@ModelAttribute C2 c2, Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(@ModelAttribute C3 c3, Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
}
@SuppressWarnings("unused")
private static class TreeDependencyController extends AbstractController {
@ModelAttribute
public A getA(Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
@ModelAttribute
public B1 getB1(@ModelAttribute A a, Model model) throws IOException {
return updateAndReturn(model, "getB1", new B1());
}
@ModelAttribute
public B2 getB2(@ModelAttribute A a, Model model) throws IOException {
return updateAndReturn(model, "getB2", new B2());
}
@ModelAttribute
public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
}
@SuppressWarnings("unused")
private static class InvertedTreeDependencyController extends AbstractController {
@ModelAttribute
public C1 getC1(Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
@ModelAttribute
public B1 getB1(@ModelAttribute C1 c1, @ModelAttribute C2 c2, Model model) throws IOException {
return updateAndReturn(model, "getB1", new B1());
}
@ModelAttribute
public B2 getB2(@ModelAttribute C3 c3, @ModelAttribute C4 c4, Model model) throws IOException {
return updateAndReturn(model, "getB2", new B2());
}
@ModelAttribute
public A getA(@ModelAttribute B1 b1, @ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
}
@SuppressWarnings("unused")
private static class UnresolvedDependencyController extends AbstractController {
@ModelAttribute
public A getA(Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
@ModelAttribute
public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
}
private static class A { }
private static class B1 { }
private static class B2 { }
private static class C1 { }
private static class C2 { }
private static class C3 { }
private static class C4 { }
private static final ReflectionUtils.MethodFilter METHOD_FILTER = new ReflectionUtils.MethodFilter() {
@Override
public boolean matches(Method method) {
return ((AnnotationUtils.findAnnotation(method, RequestMapping.class) == null) &&
(AnnotationUtils.findAnnotation(method, ModelAttribute.class) != null));
}
};
}
Loading…
Cancel
Save