diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java index 7078d63950..985058bae9 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolver.java @@ -97,6 +97,8 @@ class ControllerMethodResolver { private final Map exceptionHandlerAdviceCache = new LinkedHashMap<>(64); + private final Map, SessionAttributesHandler> sessionAttributesHandlerCache = new ConcurrentHashMap<>(64); + ControllerMethodResolver(ArgumentResolverConfigurer argumentResolvers, List> messageReaders, ReactiveAdapterRegistry reactiveRegistry, @@ -154,6 +156,7 @@ class ControllerMethodResolver { registrar.addIfModelAttribute(() -> new ErrorsMethodArgumentResolver(reactiveRegistry)); registrar.add(new ServerWebExchangeArgumentResolver(reactiveRegistry)); registrar.add(new PrincipalArgumentResolver(reactiveRegistry)); + registrar.addIfRequestBody(readers -> new SessionStatusMethodArgumentResolver()); registrar.add(new WebSessionArgumentResolver(reactiveRegistry)); // Custom... @@ -315,6 +318,25 @@ class ControllerMethodResolver { return invocable; } + /** + * Return the handler for the type-level {@code @SessionAttributes} annotation + * based on the given controller method. + */ + public SessionAttributesHandler getSessionAttributesHandler(HandlerMethod handlerMethod) { + Class handlerType = handlerMethod.getBeanType(); + SessionAttributesHandler result = this.sessionAttributesHandlerCache.get(handlerType); + if (result == null) { + synchronized (this.sessionAttributesHandlerCache) { + result = this.sessionAttributesHandlerCache.get(handlerType); + if (result == null) { + result = new SessionAttributesHandler(handlerType); + this.sessionAttributesHandlerCache.put(handlerType, result); + } + } + } + return result; + } + /** Filter for {@link InitBinder @InitBinder} methods. */ private static final ReflectionUtils.MethodFilter BINDER_METHODS = method -> @@ -336,6 +358,7 @@ class ControllerMethodResolver { private final List result = new ArrayList<>(); + private ArgumentResolverRegistrar(ArgumentResolverConfigurer resolvers, List> messageReaders, boolean modelAttribute) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java index 54c5c716d5..839fe9db06 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java @@ -23,12 +23,15 @@ import java.util.List; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.bind.annotation.InitBinder; +import org.springframework.web.bind.support.SessionStatus; +import org.springframework.web.bind.support.SimpleSessionStatus; import org.springframework.web.bind.support.WebBindingInitializer; import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; /** * Extends {@link BindingContext} with {@code @InitBinder} method initialization. @@ -43,6 +46,11 @@ class InitBinderBindingContext extends BindingContext { /* Simple BindingContext to help with the invoking @InitBinder methods */ private final BindingContext binderMethodContext; + private final SessionStatus sessionStatus = new SimpleSessionStatus(); + + @Nullable + private Runnable saveModelOperation; + InitBinderBindingContext(@Nullable WebBindingInitializer initializer, List binderMethods) { @@ -53,6 +61,15 @@ class InitBinderBindingContext extends BindingContext { } + /** + * Return the {@link SessionStatus} instance to use that can be used to + * signal that session processing is complete. + */ + public SessionStatus getSessionStatus() { + return this.sessionStatus; + } + + @Override protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder dataBinder, ServerWebExchange exchange) { @@ -87,4 +104,29 @@ class InitBinderBindingContext extends BindingContext { } } + /** + * Provide the context required to apply {@link #saveModel()} after the + * controller method has been invoked. + */ + public void setSessionContext(SessionAttributesHandler attributesHandler, WebSession session) { + this.saveModelOperation = () -> { + if (getSessionStatus().isComplete()) { + attributesHandler.cleanupAttributes(session); + } + else { + attributesHandler.storeAttributes(session, getModel().asMap()); + } + }; + } + + /** + * Save model attributes in the session based on a type-level declarations + * in an {@code @SessionAttributes} annotation. + */ + public void saveModel() { + if (this.saveModelOperation != null) { + this.saveModelOperation.run(); + } + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java index e3e8c15e88..2848cae2e0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java @@ -21,13 +21,11 @@ import java.lang.annotation.Annotation; import java.lang.reflect.Constructor; import java.util.List; import java.util.Map; -import java.util.Optional; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import org.springframework.beans.BeanUtils; -import org.springframework.core.Conventions; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.MethodParameter; import org.springframework.core.ParameterNameDiscoverer; @@ -39,7 +37,6 @@ import org.springframework.lang.Nullable; import org.springframework.ui.Model; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; -import org.springframework.util.StringUtils; import org.springframework.validation.BindingResult; import org.springframework.validation.Errors; import org.springframework.validation.annotation.Validated; @@ -115,7 +112,7 @@ public class ModelAttributeMethodArgumentResolver extends HandlerMethodArgumentR () -> getClass().getSimpleName() + " does not support multi-value reactive type wrapper: " + parameter.getGenericParameterType()); - String name = getAttributeName(parameter); + String name = ModelInitializer.getNameForParameter(parameter); Mono valueMono = prepareAttributeMono(name, valueType, context, exchange); Map model = context.getModel().asMap(); @@ -150,13 +147,6 @@ public class ModelAttributeMethodArgumentResolver extends HandlerMethodArgumentR }); } - private String getAttributeName(MethodParameter parameter) { - return Optional.ofNullable(parameter.getParameterAnnotation(ModelAttribute.class)) - .filter(ann -> StringUtils.hasText(ann.value())) - .map(ModelAttribute::value) - .orElse(Conventions.getVariableNameForParameter(parameter)); - } - private Mono prepareAttributeMono(String attributeName, ResolvableType attributeType, BindingContext context, ServerWebExchange exchange) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelInitializer.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelInitializer.java index 9d7e3ef653..fd3de94883 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelInitializer.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelInitializer.java @@ -19,9 +19,11 @@ package org.springframework.web.reactive.result.method.annotation; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; import reactor.core.publisher.Mono; import org.springframework.core.Conventions; @@ -31,8 +33,10 @@ import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.result.method.InvocableHandlerMethod; @@ -47,40 +51,72 @@ import org.springframework.web.server.ServerWebExchange; */ class ModelInitializer { + private final ControllerMethodResolver methodResolver; + private final ReactiveAdapterRegistry adapterRegistry; - public ModelInitializer(ReactiveAdapterRegistry adapterRegistry) { + public ModelInitializer(ControllerMethodResolver methodResolver, ReactiveAdapterRegistry adapterRegistry) { + Assert.notNull(methodResolver, "ControllerMethodResolver is required"); + Assert.notNull(adapterRegistry, "ReactiveAdapterRegistry is required"); + this.methodResolver = methodResolver; this.adapterRegistry = adapterRegistry; } /** - * Initialize the default model in the given {@code BindingContext} through - * the {@code @ModelAttribute} methods and indicate when complete. - *

This will wait for {@code @ModelAttribute} methods that return - * {@code Mono} since those may be adding attributes asynchronously. - * However if methods return async attributes, those will be added to the - * model as-is and without waiting for them to be resolved. - * @param bindingContext the BindingContext with the default model - * @param attributeMethods the {@code @ModelAttribute} methods + * Initialize the {@link org.springframework.ui.Model Model} based on a + * (type-level) {@code @SessionAttributes} annotation and + * {@code @ModelAttribute} methods. + * @param handlerMethod the target controller method + * @param bindingContext the context containing the model * @param exchange the current exchange * @return a {@code Mono} for when the model is populated. */ @SuppressWarnings("Convert2MethodRef") - public Mono initModel(BindingContext bindingContext, - List attributeMethods, ServerWebExchange exchange) { + public Mono initModel(HandlerMethod handlerMethod, InitBinderBindingContext bindingContext, + ServerWebExchange exchange) { + + List modelMethods = + this.methodResolver.getModelAttributeMethods(handlerMethod); + + SessionAttributesHandler sessionAttributesHandler = + this.methodResolver.getSessionAttributesHandler(handlerMethod); + + if (!sessionAttributesHandler.hasSessionAttributes()) { + return invokeModelAttributeMethods(bindingContext, modelMethods, exchange); + } + + return exchange.getSession() + .flatMap(session -> { + Map attributes = sessionAttributesHandler.retrieveAttributes(session); + bindingContext.getModel().mergeAttributes(attributes); + bindingContext.setSessionContext(sessionAttributesHandler, session); + return invokeModelAttributeMethods(bindingContext, modelMethods, exchange) + .doOnSuccess(aVoid -> { + findModelAttributes(handlerMethod, sessionAttributesHandler).forEach(name -> { + if (!bindingContext.getModel().containsAttribute(name)) { + Object value = session.getRequiredAttribute(name); + bindingContext.getModel().addAttribute(name, value); + } + }); + }); + }); + } + + @NotNull + private Mono invokeModelAttributeMethods(BindingContext bindingContext, + List modelMethods, ServerWebExchange exchange) { List> resultList = new ArrayList<>(); - attributeMethods.forEach(invocable -> resultList.add(invocable.invoke(exchange, bindingContext))); + modelMethods.forEach(invocable -> resultList.add(invocable.invoke(exchange, bindingContext))); return Mono - .zip(resultList, objectArray -> { - return Arrays.stream(objectArray) - .map(object -> handleResult(((HandlerResult) object), bindingContext)) - .collect(Collectors.toList()); - }) - .flatMap(completionList -> Mono.when(completionList)); + .zip(resultList, objectArray -> + Arrays.stream(objectArray) + .map(object -> handleResult(((HandlerResult) object), bindingContext)) + .collect(Collectors.toList())) + .flatMap(Mono::when); } private Mono handleResult(HandlerResult handlerResult, BindingContext bindingContext) { @@ -109,4 +145,35 @@ class ModelInitializer { .orElse(Conventions.getVariableNameForParameter(param)); } + /** Find {@code @ModelAttribute} arguments also listed as {@code @SessionAttributes}. */ + private List findModelAttributes(HandlerMethod handlerMethod, + SessionAttributesHandler sessionAttributesHandler) { + + List result = new ArrayList<>(); + for (MethodParameter parameter : handlerMethod.getMethodParameters()) { + if (parameter.hasParameterAnnotation(ModelAttribute.class)) { + String name = getNameForParameter(parameter); + Class paramType = parameter.getParameterType(); + if (sessionAttributesHandler.isHandlerSessionAttribute(name, paramType)) { + result.add(name); + } + } + } + return result; + } + + /** + * Derive the model attribute name for the given method parameter based on + * a {@code @ModelAttribute} parameter annotation (if present) or falling + * back on parameter type based conventions. + * @param parameter a descriptor for the method parameter + * @return the derived name + * @see Conventions#getVariableNameForParameter(MethodParameter) + */ + public static String getNameForParameter(MethodParameter parameter) { + ModelAttribute ann = parameter.getParameterAnnotation(ModelAttribute.class); + String name = (ann != null ? ann.value() : null); + return (StringUtils.hasText(name) ? name : Conventions.getVariableNameForParameter(parameter)); + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java index 0279ebad89..c6136e0717 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java @@ -169,7 +169,7 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, Application this.methodResolver = new ControllerMethodResolver(this.argumentResolverConfigurer, this.messageReaders, this.reactiveAdapterRegistry, this.applicationContext); - this.modelInitializer = new ModelInitializer(this.reactiveAdapterRegistry); + this.modelInitializer = new ModelInitializer(this.methodResolver, this.reactiveAdapterRegistry); } @@ -183,21 +183,20 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, Application HandlerMethod handlerMethod = (HandlerMethod) handler; Assert.state(this.methodResolver != null && this.modelInitializer != null, "Not initialized"); - BindingContext bindingContext = new InitBinderBindingContext( + InitBinderBindingContext bindingContext = new InitBinderBindingContext( getWebBindingInitializer(), this.methodResolver.getInitBinderMethods(handlerMethod)); - List modelAttributeMethods = - this.methodResolver.getModelAttributeMethods(handlerMethod); + InvocableHandlerMethod invocableMethod = this.methodResolver.getRequestMappingMethod(handlerMethod); Function> exceptionHandler = ex -> handleException(ex, handlerMethod, bindingContext, exchange); return this.modelInitializer - .initModel(bindingContext, modelAttributeMethods, exchange) - .then(Mono.defer(() -> this.methodResolver.getRequestMappingMethod(handlerMethod) - .invoke(exchange, bindingContext) - .doOnNext(result -> result.setExceptionHandler(exceptionHandler)) - .onErrorResume(exceptionHandler))); + .initModel(handlerMethod, bindingContext, exchange) + .then(Mono.defer(() -> invocableMethod.invoke(exchange, bindingContext))) + .doOnNext(result -> result.setExceptionHandler(exceptionHandler)) + .doOnNext(result -> bindingContext.saveModel()) + .onErrorResume(exceptionHandler); } private Mono handleException(Throwable exception, HandlerMethod handlerMethod, diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/SessionAttributesHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/SessionAttributesHandler.java new file mode 100644 index 0000000000..a03c35e700 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/SessionAttributesHandler.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.result.method.annotation; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.util.Assert; +import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.server.WebSession; + +/** + * Package-private class to assist {@link ModelInitializer} with managing model + * attributes in the {@link WebSession} based on model attribute names and types + * declared va {@link SessionAttributes @SessionAttributes}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +class SessionAttributesHandler { + + private final Set attributeNames = new HashSet<>(); + + private final Set> attributeTypes = new HashSet<>(); + + private final Set knownAttributeNames = Collections.newSetFromMap(new ConcurrentHashMap<>(4)); + + + /** + * Create a new instance for a controller type. Session attribute names and + * types are extracted from the {@code @SessionAttributes} annotation, if + * present, on the given type. + * @param handlerType the controller type + */ + public SessionAttributesHandler(Class handlerType) { + SessionAttributes annotation = + AnnotatedElementUtils.findMergedAnnotation(handlerType, SessionAttributes.class); + if (annotation != null) { + this.attributeNames.addAll(Arrays.asList(annotation.names())); + this.attributeTypes.addAll(Arrays.asList(annotation.types())); + } + this.knownAttributeNames.addAll(this.attributeNames); + } + + + /** + * Whether the controller represented by this instance has declared any + * session attributes through an {@link SessionAttributes} annotation. + */ + public boolean hasSessionAttributes() { + return (!this.attributeNames.isEmpty() || !this.attributeTypes.isEmpty()); + } + + /** + * Whether the attribute name or type match the names and types specified + * via {@code @SessionAttributes} on the underlying controller. + *

Attributes successfully resolved through this method are "remembered" + * and subsequently used in {@link #retrieveAttributes(WebSession)} + * and also {@link #cleanupAttributes(WebSession)}. + * @param attributeName the attribute name to check + * @param attributeType the type for the attribute + */ + public boolean isHandlerSessionAttribute(String attributeName, Class attributeType) { + Assert.notNull(attributeName, "Attribute name must not be null"); + if (this.attributeNames.contains(attributeName) || this.attributeTypes.contains(attributeType)) { + this.knownAttributeNames.add(attributeName); + return true; + } + else { + return false; + } + } + + /** + * Retrieve "known" attributes from the session, i.e. attributes listed + * by name in {@code @SessionAttributes} or attributes previously stored + * in the model that matched by type. + * @param session the current session + * @return a map with handler session attributes, possibly empty + */ + public Map retrieveAttributes(WebSession session) { + Map attributes = new HashMap<>(); + this.knownAttributeNames.forEach(name -> { + Object value = session.getAttribute(name); + if (value != null) { + attributes.put(name, value); + } + }); + return attributes; + } + + /** + * Store a subset of the given attributes in the session. Attributes not + * declared as session attributes via {@code @SessionAttributes} are ignored. + * @param session the current session + * @param attributes candidate attributes for session storage + */ + public void storeAttributes(WebSession session, Map attributes) { + attributes.keySet().forEach(name -> { + Object value = attributes.get(name); + if (value != null && isHandlerSessionAttribute(name, value.getClass())) { + session.getAttributes().put(name, value); + } + }); + } + + /** + * Remove "known" attributes from the session, i.e. attributes listed + * by name in {@code @SessionAttributes} or attributes previously stored + * in the model that matched by type. + * @param session the current session + */ + public void cleanupAttributes(WebSession session) { + this.knownAttributeNames.forEach(name -> session.getAttributes().remove(name)); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/SessionStatusMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/SessionStatusMethodArgumentResolver.java new file mode 100644 index 0000000000..c6c847164d --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/SessionStatusMethodArgumentResolver.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.result.method.annotation; + +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.bind.support.SessionStatus; +import org.springframework.web.reactive.BindingContext; +import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentResolver; +import org.springframework.web.server.ServerWebExchange; + +/** + * Resolver for a {@link SessionStatus} argument obtaining it from the + * {@link BindingContext}. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class SessionStatusMethodArgumentResolver implements SyncHandlerMethodArgumentResolver { + + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return SessionStatus.class == parameter.getParameterType(); + } + + @Nullable + @Override + public Object resolveArgumentValue(MethodParameter parameter, BindingContext bindingContext, + ServerWebExchange exchange) { + + Assert.isInstanceOf(InitBinderBindingContext.class, bindingContext); + return ((InitBinderBindingContext) bindingContext).getSessionStatus(); + } + +} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java index c07bcc43f3..87c448b7d0 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ControllerMethodResolverTests.java @@ -46,7 +46,8 @@ import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; /** * Unit tests for {@link ControllerMethodResolver}. @@ -108,6 +109,7 @@ public class ControllerMethodResolverTests { assertEquals(ErrorsMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(ServerWebExchangeArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PrincipalArgumentResolver.class, next(resolvers, index).getClass()); + assertEquals(SessionStatusMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(WebSessionArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(CustomArgumentResolver.class, next(resolvers, index).getClass()); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java index 21b0027fbb..528ee9300e 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ModelInitializerTests.java @@ -23,30 +23,40 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; +import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Mono; import rx.Single; +import org.springframework.context.support.StaticApplicationContext; import org.springframework.core.MethodIntrospector; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.lang.Nullable; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.ui.Model; import org.springframework.util.ReflectionUtils; import org.springframework.validation.Validator; import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.ModelAttribute; +import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.SessionAttributes; import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; import org.springframework.web.bind.support.WebBindingInitializer; import org.springframework.web.bind.support.WebExchangeDataBinder; -import org.springframework.web.reactive.BindingContext; -import org.springframework.web.reactive.result.method.InvocableHandlerMethod; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.method.ResolvableMethod; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; /** @@ -55,31 +65,55 @@ import static org.mockito.Mockito.mock; */ public class ModelInitializerTests { - private final ModelInitializer modelInitializer = new ModelInitializer(new ReactiveAdapterRegistry()); + private ModelInitializer modelInitializer; private final ServerWebExchange exchange = MockServerHttpRequest.get("/path").toExchange(); + @Before + public void setUp() throws Exception { + + ReactiveAdapterRegistry adapterRegistry = new ReactiveAdapterRegistry(); + + ArgumentResolverConfigurer resolverConfigurer = new ArgumentResolverConfigurer(); + resolverConfigurer.addCustomResolver(new ModelArgumentResolver(adapterRegistry)); + + ControllerMethodResolver methodResolver = new ControllerMethodResolver( + resolverConfigurer, Collections.emptyList(), adapterRegistry, new StaticApplicationContext()); + + this.modelInitializer = new ModelInitializer(methodResolver, adapterRegistry); + } + + @SuppressWarnings("unchecked") @Test - public void basic() throws Exception { - TestController controller = new TestController(); + public void initBinderMethod() throws Exception { Validator validator = mock(Validator.class); + + TestController controller = new TestController(); controller.setValidator(validator); + InitBinderBindingContext context = getBindingContext(controller); - List binderMethods = getBinderMethods(controller); - List attributeMethods = getAttributeMethods(controller); + Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod(); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000)); - WebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer(); - BindingContext bindingContext = new InitBinderBindingContext(bindingInitializer, binderMethods); + WebExchangeDataBinder binder = context.createDataBinder(this.exchange, "name"); + assertEquals(Collections.singletonList(validator), binder.getValidators()); + } - this.modelInitializer.initModel(bindingContext, attributeMethods, this.exchange).block(Duration.ofMillis(5000)); + @SuppressWarnings("unchecked") + @Test + public void modelAttributeMethods() throws Exception { + TestController controller = new TestController(); + InitBinderBindingContext context = getBindingContext(controller); - WebExchangeDataBinder binder = bindingContext.createDataBinder(this.exchange, "name"); - assertEquals(Collections.singletonList(validator), binder.getValidators()); + Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod(); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000)); - Map model = bindingContext.getModel().asMap(); + Map model = context.getModel().asMap(); assertEquals(5, model.size()); Object value = model.get("bean"); @@ -98,31 +132,101 @@ public class ModelInitializerTests { assertEquals("Void Mono Method Bean", ((TestBean) value).getName()); } - private List getBinderMethods(Object controller) { - return MethodIntrospector - .selectMethods(controller.getClass(), BINDER_METHODS).stream() - .map(method -> new SyncInvocableHandlerMethod(controller, method)) - .collect(Collectors.toList()); + @Test + public void saveModelAttributeToSession() throws Exception { + TestController controller = new TestController(); + InitBinderBindingContext context = getBindingContext(controller); + + Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod(); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000)); + + WebSession session = this.exchange.getSession().block(Duration.ZERO); + assertNotNull(session); + assertEquals(0, session.getAttributes().size()); + + context.saveModel(); + assertEquals(1, session.getAttributes().size()); + assertEquals("Bean", ((TestBean) session.getRequiredAttribute("bean")).getName()); + } + + @Test + public void retrieveModelAttributeFromSession() throws Exception { + WebSession session = this.exchange.getSession().block(Duration.ZERO); + assertNotNull(session); + + TestBean testBean = new TestBean("Session Bean"); + session.getAttributes().put("bean", testBean); + + TestController controller = new TestController(); + InitBinderBindingContext context = getBindingContext(controller); + + Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod(); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000)); + + context.saveModel(); + assertEquals(1, session.getAttributes().size()); + assertEquals("Session Bean", ((TestBean) session.getRequiredAttribute("bean")).getName()); + } + + @Test + public void requiredSessionAttributeMissing() throws Exception { + TestController controller = new TestController(); + InitBinderBindingContext context = getBindingContext(controller); + + Method method = ResolvableMethod.on(TestController.class).annotPresent(PostMapping.class).resolveMethod(); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + try { + this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000)); + fail(); + } + catch (IllegalArgumentException ex) { + assertEquals("Required attribute 'missing-bean' is missing.", ex.getMessage()); + } } - private List getAttributeMethods(Object controller) { - return MethodIntrospector - .selectMethods(controller.getClass(), ATTRIBUTE_METHODS).stream() - .map(method -> toInvocable(controller, method)) - .collect(Collectors.toList()); + @Test + public void clearModelAttributeFromSession() throws Exception { + WebSession session = this.exchange.getSession().block(Duration.ZERO); + assertNotNull(session); + + TestBean testBean = new TestBean("Session Bean"); + session.getAttributes().put("bean", testBean); + + TestController controller = new TestController(); + InitBinderBindingContext context = getBindingContext(controller); + + Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod(); + HandlerMethod handlerMethod = new HandlerMethod(controller, method); + this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000)); + + context.getSessionStatus().setComplete(); + context.saveModel(); + + assertEquals(0, session.getAttributes().size()); } - private InvocableHandlerMethod toInvocable(Object controller, Method method) { - ModelArgumentResolver resolver = new ModelArgumentResolver(new ReactiveAdapterRegistry()); - InvocableHandlerMethod handlerMethod = new InvocableHandlerMethod(controller, method); - handlerMethod.setArgumentResolvers(Collections.singletonList(resolver)); - return handlerMethod; + + @NotNull + private InitBinderBindingContext getBindingContext(Object controller) { + + List binderMethods = + MethodIntrospector.selectMethods(controller.getClass(), BINDER_METHODS) + .stream() + .map(method -> new SyncInvocableHandlerMethod(controller, method)) + .collect(Collectors.toList());; + + WebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer(); + return new InitBinderBindingContext(bindingInitializer, binderMethods); } @SuppressWarnings("unused") + @SessionAttributes({"bean", "missing-bean"}) private static class TestController { + @Nullable private Validator validator; @@ -165,8 +269,12 @@ public class ModelInitializerTests { .then(); } - @RequestMapping - public void handle() {} + @GetMapping + public void handleGet() {} + + @PostMapping + public void handlePost(@ModelAttribute("missing-bean") TestBean testBean) {} + } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/SessionAttributesHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/SessionAttributesHandlerTests.java new file mode 100644 index 0000000000..761ab62c3c --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/SessionAttributesHandlerTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + + +import java.time.Duration; +import java.util.HashSet; + +import org.junit.Test; + +import org.springframework.tests.sample.beans.TestBean; +import org.springframework.ui.ModelMap; +import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.session.InMemoryWebSessionStore; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Test fixture with {@link SessionAttributesHandler}. + * @author Rossen Stoyanchev + */ +public class SessionAttributesHandlerTests { + + private final SessionAttributesHandler sessionAttributesHandler = + new SessionAttributesHandler(TestController.class); + + + @Test + public void isSessionAttribute() throws Exception { + assertTrue(this.sessionAttributesHandler.isHandlerSessionAttribute("attr1", String.class)); + assertTrue(this.sessionAttributesHandler.isHandlerSessionAttribute("attr2", String.class)); + assertTrue(this.sessionAttributesHandler.isHandlerSessionAttribute("simple", TestBean.class)); + assertFalse(this.sessionAttributesHandler.isHandlerSessionAttribute("simple", String.class)); + } + + @Test + public void retrieveAttributes() throws Exception { + WebSession session = new InMemoryWebSessionStore().createWebSession().block(Duration.ZERO); + assertNotNull(session); + + session.getAttributes().put("attr1", "value1"); + session.getAttributes().put("attr2", "value2"); + session.getAttributes().put("attr3", new TestBean()); + session.getAttributes().put("attr4", new TestBean()); + + assertEquals("Named attributes (attr1, attr2) should be 'known' right away", + new HashSet<>(asList("attr1", "attr2")), + sessionAttributesHandler.retrieveAttributes(session).keySet()); + + // Resolve 'attr3' by type + sessionAttributesHandler.isHandlerSessionAttribute("attr3", TestBean.class); + + assertEquals("Named attributes (attr1, attr2) and resolved attribute (att3) should be 'known'", + new HashSet<>(asList("attr1", "attr2", "attr3")), + sessionAttributesHandler.retrieveAttributes(session).keySet()); + } + + @Test + public void cleanupAttributes() throws Exception { + WebSession session = new InMemoryWebSessionStore().createWebSession().block(Duration.ZERO); + assertNotNull(session); + + session.getAttributes().put("attr1", "value1"); + session.getAttributes().put("attr2", "value2"); + session.getAttributes().put("attr3", new TestBean()); + + this.sessionAttributesHandler.cleanupAttributes(session); + + assertNull(session.getAttributes().get("attr1")); + assertNull(session.getAttributes().get("attr2")); + assertNotNull(session.getAttributes().get("attr3")); + + // Resolve 'attr3' by type + this.sessionAttributesHandler.isHandlerSessionAttribute("attr3", TestBean.class); + this.sessionAttributesHandler.cleanupAttributes(session); + + assertNull(session.getAttributes().get("attr3")); + } + + @Test + public void storeAttributes() throws Exception { + WebSession session = new InMemoryWebSessionStore().createWebSession().block(Duration.ZERO); + assertNotNull(session); + + ModelMap model = new ModelMap(); + model.put("attr1", "value1"); + model.put("attr2", "value2"); + model.put("attr3", new TestBean()); + + sessionAttributesHandler.storeAttributes(session, model); + + assertEquals("value1", session.getAttributes().get("attr1")); + assertEquals("value2", session.getAttributes().get("attr2")); + assertTrue(session.getAttributes().get("attr3") instanceof TestBean); + } + + + @SessionAttributes(names = { "attr1", "attr2" }, types = { TestBean.class }) + private static class TestController { + } + +} \ No newline at end of file