From 5cee607f2872f59859fd553221ae0e79d5f58c92 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 11 May 2018 18:38:29 -0400 Subject: [PATCH] WebFlux @RequestPart support for List and Flux arguments The resolver now supports List, Flux, and List. Issue: SPR-16621 --- .../RequestPartMethodArgumentResolver.java | 57 ++-- ...equestPartMethodArgumentResolverTests.java | 257 ++++++++++++++++++ 2 files changed, 298 insertions(+), 16 deletions(-) create mode 100644 spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolverTests.java diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java index dbba8e07a5..640cfbbb79 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java @@ -74,35 +74,49 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageReaderArgu boolean isRequired = (requestPart == null || requestPart.required()); String name = getPartName(parameter, requestPart); - Flux values = exchange.getMultipartData() + Flux parts = exchange.getMultipartData() .flatMapMany(map -> { - List parts = map.get(name); - if (CollectionUtils.isEmpty(parts)) { + List list = map.get(name); + if (CollectionUtils.isEmpty(list)) { return isRequired ? Flux.error(getMissingPartException(name, parameter)) : Flux.empty(); } - return Flux.fromIterable(parts); + return Flux.fromIterable(list); }); - ReactiveAdapter adapter = getAdapterRegistry().getAdapter(parameter.getParameterType()); - MethodParameter elementType = adapter != null ? parameter.nested() : parameter; + if (Part.class.isAssignableFrom(parameter.getParameterType())) { + return parts.next().cast(Object.class); + } - if (Part.class.isAssignableFrom(elementType.getNestedParameterType())) { - if (adapter != null) { - values = adapter.isMultiValue() ? values : values.take(1); - return Mono.just(adapter.fromPublisher(values)); + if (List.class.isAssignableFrom(parameter.getParameterType())) { + MethodParameter elementType = parameter.nested(); + if (Part.class.isAssignableFrom(elementType.getNestedParameterType())) { + return parts.collectList().cast(Object.class); } else { - return values.next().cast(Object.class); + return decodePartValues(parts, elementType, bindingContext, exchange, isRequired) + .collectList().cast(Object.class); } } - return values.next().flatMap(part -> { - ServerHttpRequest partRequest = new PartServerHttpRequest(exchange.getRequest(), part); - ServerWebExchange partExchange = exchange.mutate().request(partRequest).build(); - return readBody(parameter, isRequired, bindingContext, partExchange); - }); + ReactiveAdapter adapter = getAdapterRegistry().getAdapter(parameter.getParameterType()); + if (adapter != null) { + // Mono or Flux + MethodParameter elementType = parameter.nested(); + if (Part.class.isAssignableFrom(elementType.getNestedParameterType())) { + parts = adapter.isMultiValue() ? parts : parts.take(1); + return Mono.just(adapter.fromPublisher(parts)); + } + // We have to decode the content for each part, one at a time + if (adapter.isMultiValue()) { + return Mono.just(decodePartValues(parts, elementType, bindingContext, exchange, isRequired)); + } + } + + // or Mono + return decodePartValues(parts, parameter, bindingContext, exchange, isRequired) + .next().cast(Object.class); } private String getPartName(MethodParameter methodParam, @Nullable RequestPart requestPart) { @@ -124,6 +138,17 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageReaderArgu } + private Flux decodePartValues(Flux parts, MethodParameter elementType, BindingContext bindingContext, + ServerWebExchange exchange, boolean isRequired) { + + return parts.flatMap(part -> { + ServerHttpRequest partRequest = new PartServerHttpRequest(exchange.getRequest(), part); + ServerWebExchange partExchange = exchange.mutate().request(partRequest).build(); + return readBody(elementType, isRequired, bindingContext, partExchange); + }); + } + + private static class PartServerHttpRequest extends ServerHttpRequestDecorator { private final Part part; diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolverTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolverTests.java new file mode 100644 index 0000000000..54a6c85f1a --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolverTests.java @@ -0,0 +1,257 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.http.codec.ClientCodecConfigurer; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.codec.multipart.Part; +import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; +import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.method.ResolvableMethod; +import org.springframework.web.reactive.BindingContext; +import org.springframework.web.server.ServerWebExchange; + +import static org.junit.Assert.*; +import static org.springframework.core.ResolvableType.*; +import static org.springframework.web.method.MvcAnnotationPredicates.*; + +/** + * Unit tests for {@link RequestPartMethodArgumentResolver}. + * @author Rossen Stoyanchev + */ +public class RequestPartMethodArgumentResolverTests { + + private RequestPartMethodArgumentResolver resolver; + + private ResolvableMethod testMethod = ResolvableMethod.on(getClass()).named("handle").build(); + + private MultipartHttpMessageWriter writer; + + + @Before + public void setup() throws Exception { + List> readers = ServerCodecConfigurer.create().getReaders(); + ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance(); + this.resolver = new RequestPartMethodArgumentResolver(readers, registry); + + List> writers = ClientCodecConfigurer.create().getWriters(); + this.writer = new MultipartHttpMessageWriter(writers); + } + + + @Test + public void supportsParameter() { + + MethodParameter param; + + param = this.testMethod.annot(requestPart().name("name")).arg(Person.class); + assertTrue(this.resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Person.class); + assertTrue(this.resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Person.class); + assertTrue(this.resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestPart().name("name")).arg(Part.class); + assertTrue(this.resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Part.class); + assertTrue(this.resolver.supportsParameter(param)); + + param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Part.class); + assertTrue(this.resolver.supportsParameter(param)); + } + + + @Test + public void person() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Person.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + Person actual = resolveArgument(param, bodyBuilder); + + assertEquals("Jones", actual.getName()); + } + + @Test + public void listPerson() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(List.class, Person.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + bodyBuilder.part("name", new Person("James")); + List actual = resolveArgument(param, bodyBuilder); + + assertEquals("Jones", actual.get(0).getName()); + assertEquals("James", actual.get(1).getName()); + } + + @Test + public void monoPerson() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Person.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + Mono actual = resolveArgument(param, bodyBuilder); + + assertEquals("Jones", actual.block().getName()); + } + + @Test + public void fluxPerson() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Person.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + bodyBuilder.part("name", new Person("James")); + Flux actual = resolveArgument(param, bodyBuilder); + + List persons = actual.collectList().block(); + assertEquals("Jones", persons.get(0).getName()); + assertEquals("James", persons.get(1).getName()); + } + + @Test + public void part() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Part.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + Part actual = resolveArgument(param, bodyBuilder); + + DataBuffer buffer = DataBufferUtils.join(actual.content()).block(); + assertEquals("{\"name\":\"Jones\"}", DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8)); + } + + @Test + public void listPart() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(List.class, Part.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + bodyBuilder.part("name", new Person("James")); + List actual = resolveArgument(param, bodyBuilder); + + assertEquals("{\"name\":\"Jones\"}", partToUtf8String(actual.get(0))); + assertEquals("{\"name\":\"James\"}", partToUtf8String(actual.get(1))); + } + + @Test + public void monoPart() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Part.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + Mono actual = resolveArgument(param, bodyBuilder); + + Part part = actual.block(); + assertEquals("{\"name\":\"Jones\"}", partToUtf8String(part)); + } + + @Test + public void fluxPart() { + MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Part.class); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("name", new Person("Jones")); + bodyBuilder.part("name", new Person("James")); + Flux actual = resolveArgument(param, bodyBuilder); + + List parts = actual.collectList().block(); + assertEquals("{\"name\":\"Jones\"}", partToUtf8String(parts.get(0))); + assertEquals("{\"name\":\"James\"}", partToUtf8String(parts.get(1))); + } + + @SuppressWarnings("unchecked") + private T resolveArgument(MethodParameter param, MultipartBodyBuilder builder) { + ServerWebExchange exchange = createExchange(builder); + Mono result = this.resolver.resolveArgument(param, new BindingContext(), exchange); + Object value = result.block(Duration.ofSeconds(5)); + + assertNotNull(value); + assertTrue(param.getParameterType().isAssignableFrom(value.getClass())); + return (T) value; + } + + @SuppressWarnings("ConstantConditions") + private ServerWebExchange createExchange(MultipartBodyBuilder builder) { + + MockClientHttpRequest clientRequest = new MockClientHttpRequest(HttpMethod.POST, "/"); + this.writer.write(Mono.just(builder.build()), forClass(MultiValueMap.class), + MediaType.MULTIPART_FORM_DATA, clientRequest, Collections.emptyMap()).block(); + + MockServerHttpRequest serverRequest = MockServerHttpRequest.post("/") + .contentType(clientRequest.getHeaders().getContentType()) + .body(clientRequest.getBody()); + + return MockServerWebExchange.from(serverRequest); + } + + private String partToUtf8String(Part part) { + DataBuffer buffer = DataBufferUtils.join(part.content()).block(); + return DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8); + } + + + @SuppressWarnings("unused") + void handle( + @RequestPart("name") Person person, + @RequestPart("name") Mono personMono, + @RequestPart("name") Flux personFlux, + @RequestPart("name") List personList, + @RequestPart("name") Part part, + @RequestPart("name") Mono partMono, + @RequestPart("name") Flux partFlux, + @RequestPart("name") List partList, + String notAnnotated) {} + + + private static class Person { + + private String name; + + @JsonCreator + public Person(@JsonProperty("name") String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + +}