From 54e6103defb99c067f3f4d5a6054aa24a29a379e Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Mon, 20 Apr 2020 16:54:56 +0200 Subject: [PATCH] Add ServerRequest::multipartData in WebMvc.fn This commit adds the multipartData method to ServerRequest in WebMvc.fn, returning a MultiValueMap. Closes gh-24909 --- .../function/DefaultServerRequest.java | 17 +++++++++++++++ .../function/DefaultServerRequestBuilder.java | 12 ++++++++++- .../servlet/function/RequestPredicates.java | 11 ++++++++++ .../web/servlet/function/ServerRequest.java | 12 +++++++++++ .../function/DefaultServerRequestTests.java | 21 +++++++++++++++++++ 5 files changed, 72 insertions(+), 1 deletion(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java index fa930797f2..7f6e815ccc 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java @@ -44,6 +44,7 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import javax.servlet.http.Part; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; @@ -85,6 +86,9 @@ class DefaultServerRequest implements ServerRequest { private final Map attributes; + @Nullable + private MultiValueMap parts; + public DefaultServerRequest(HttpServletRequest servletRequest, List> messageConverters) { this.serverHttpRequest = new ServletServerHttpRequest(servletRequest); @@ -228,6 +232,19 @@ class DefaultServerRequest implements ServerRequest { return this.params; } + @Override + public MultiValueMap multipartData() throws IOException, ServletException { + MultiValueMap result = this.parts; + if (result == null) { + result = servletRequest().getParts().stream() + .collect(Collectors.groupingBy(Part::getName, + LinkedMultiValueMap::new, + Collectors.toList())); + this.parts = result; + } + return result; + } + @Override @SuppressWarnings("unchecked") public Map pathVariables() { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java index cc8c44dca7..1c86ab9eca 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Consumer; +import java.util.stream.Collectors; import javax.servlet.ReadListener; import javax.servlet.ServletException; @@ -37,6 +38,7 @@ import javax.servlet.ServletInputStream; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import javax.servlet.http.Part; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; @@ -199,6 +201,14 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { return this.methodName; } + @Override + public MultiValueMap multipartData() throws IOException, ServletException { + return servletRequest().getParts().stream() + .collect(Collectors.groupingBy(Part::getName, + LinkedMultiValueMap::new, + Collectors.toList())); + } + @Override public URI uri() { return this.uri; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java index efc3fd9789..307a603abb 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java @@ -40,6 +40,7 @@ import javax.servlet.ServletException; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import javax.servlet.http.Part; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -994,6 +995,16 @@ public abstract class RequestPredicates { return this.request.params(); } + @Override + public MultiValueMap multipartData() throws IOException, ServletException { + return this.request.multipartData(); + } + + @Override + public String pathVariable(String name) { + return this.request.pathVariable(name); + } + @Override @SuppressWarnings("unchecked") public Map pathVariables() { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java index 1fee8d25cb..a76a1cab42 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java @@ -33,6 +33,7 @@ import javax.servlet.ServletException; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import javax.servlet.http.Part; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.buffer.DataBuffer; @@ -186,6 +187,17 @@ public interface ServerRequest { */ MultiValueMap params(); + /** + * Get the parts of a multipart request, provided the Content-Type is + * {@code "multipart/form-data"}, or an exception otherwise. + * @return the multipart data, mapping from name to part(s) + * @throws IOException if an I/O error occurred during the retrieval + * @throws ServletException if this request is not of type {@code "multipart/form-data"} + * @since 5.3 + * @see HttpServletRequest#getParts() + */ + MultiValueMap multipartData() throws IOException, ServletException; + /** * Get the path variable with the given name, if present. * @param name the variable name diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java index 2741c5a2e3..04ace77cff 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java @@ -33,6 +33,7 @@ import java.util.Optional; import java.util.OptionalLong; import javax.servlet.http.Cookie; +import javax.servlet.http.Part; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -53,6 +54,7 @@ import org.springframework.web.HttpMediaTypeNotSupportedException; import org.springframework.web.testfixture.server.MockServerWebExchange; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpSession; +import org.springframework.web.testfixture.servlet.MockPart; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; @@ -129,6 +131,25 @@ public class DefaultServerRequestTests { assertThat(request.param("foo")).isEqualTo(Optional.of("bar")); } + @Test + public void multipartData() throws Exception { + MockPart formPart = new MockPart("form", "foo".getBytes(UTF_8)); + MockPart filePart = new MockPart("file", "foo.txt", "foo".getBytes(UTF_8)); + + MockHttpServletRequest servletRequest = new MockHttpServletRequest("POST", "/"); + servletRequest.addPart(formPart); + servletRequest.addPart(filePart); + + DefaultServerRequest request = + new DefaultServerRequest(servletRequest, this.messageConverters); + + MultiValueMap result = request.multipartData(); + + assertThat(result).hasSize(2); + assertThat(result.get("form")).containsExactly(formPart); + assertThat(result.get("file")).containsExactly(filePart); + } + @Test public void emptyQueryParam() { MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/");