From c52526ad42a6e29ef8353e5da6fa713bd38dbaba Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 3 Feb 2021 21:55:29 +0000 Subject: [PATCH] Fix in MockMultipartHttpServletRequest#getMultipartHeaders Previously this method returned headers only when a Content-Type part header was present. Now it is guaranteed to return headers (possibly empty) as long as there is a MultipartFile or Part with the given name. Closes gh-26501 --- .../web/MockMultipartHttpServletRequest.java | 27 ++++++++++++++----- .../MultipartHttpServletRequest.java | 9 ++++--- .../MockMultipartHttpServletRequest.java | 27 ++++++++++++++----- ...equestPartMethodArgumentResolverTests.java | 23 ++++++++++++++-- 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java index f9607f5bb2..15153dcea3 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.mock.web; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; @@ -33,6 +34,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; @@ -155,15 +157,28 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl @Override public HttpHeaders getMultipartHeaders(String paramOrFileName) { - String contentType = getMultipartContentType(paramOrFileName); - if (contentType != null) { + MultipartFile file = getFile(paramOrFileName); + if (file != null) { HttpHeaders headers = new HttpHeaders(); - headers.add(HttpHeaders.CONTENT_TYPE, contentType); + if (file.getContentType() != null) { + headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType()); + } return headers; } - else { - return null; + try { + Part part = getPart(paramOrFileName); + if (part != null) { + HttpHeaders headers = new HttpHeaders(); + for (String headerName : part.getHeaderNames()) { + headers.put(headerName, new ArrayList<>(part.getHeaders(headerName))); + } + return headers; + } } + catch (Throwable ex) { + throw new MultipartException("Could not access multipart servlet request", ex); + } + return null; } } diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java index c083b42fdc..7788b1e704 100644 --- a/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2011 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,9 +60,10 @@ public interface MultipartHttpServletRequest extends HttpServletRequest, Multipa HttpHeaders getRequestHeaders(); /** - * Return the headers associated with the specified part of the multipart request. - *

If the underlying implementation supports access to headers, then all headers are returned. - * Otherwise, the returned headers will include a 'Content-Type' header at the very least. + * Return the headers for the specified part of the multipart request. + *

If the underlying implementation supports access to part headers, + * then all headers are returned. Otherwise, e.g. for a file upload, the + * returned headers may expose a 'Content-Type' if available. */ @Nullable HttpHeaders getMultipartHeaders(String paramOrFileName); diff --git a/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java b/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java index ace2e125c5..c36eb3ac7b 100644 --- a/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java +++ b/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.testfixture.servlet; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; @@ -33,6 +34,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; @@ -155,15 +157,28 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl @Override public HttpHeaders getMultipartHeaders(String paramOrFileName) { - String contentType = getMultipartContentType(paramOrFileName); - if (contentType != null) { + MultipartFile file = getFile(paramOrFileName); + if (file != null) { HttpHeaders headers = new HttpHeaders(); - headers.add(HttpHeaders.CONTENT_TYPE, contentType); + if (file.getContentType() != null) { + headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType()); + } return headers; } - else { - return null; + try { + Part part = getPart(paramOrFileName); + if (part != null) { + HttpHeaders headers = new HttpHeaders(); + for (String headerName : part.getHeaderNames()) { + headers.put(headerName, new ArrayList<>(part.getHeaders(headerName))); + } + return headers; + } } + catch (Throwable ex) { + throw new MultipartException("Could not access multipart servlet request", ex); + } + return null; } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java index f35a3817f5..9cf34618a6 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ import org.springframework.core.annotation.SynthesizingMethodParameter; import org.springframework.http.HttpInputMessage; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.util.ReflectionUtils; import org.springframework.validation.BindingResult; @@ -51,6 +52,7 @@ import org.springframework.web.method.support.ModelAndViewContainer; import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.support.MissingServletRequestPartException; +import org.springframework.web.testfixture.method.ResolvableMethod; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; import org.springframework.web.testfixture.servlet.MockMultipartFile; @@ -311,6 +313,22 @@ public class RequestPartMethodArgumentResolverTests { testResolveArgument(new SimpleBean("foo"), paramValidRequestPart); } + @Test // gh-26501 + public void resolveRequestPartWithoutContentType() throws Exception { + MockMultipartHttpServletRequest servletRequest = new MockMultipartHttpServletRequest(); + servletRequest.addPart(new MockPart("requestPartString", "part value".getBytes(StandardCharsets.UTF_8))); + ServletWebRequest webRequest = new ServletWebRequest(servletRequest, new MockHttpServletResponse()); + + List> converters = Collections.singletonList(new StringHttpMessageConverter()); + RequestPartMethodArgumentResolver resolver = new RequestPartMethodArgumentResolver(converters); + MethodParameter parameter = ResolvableMethod.on(getClass()).named("handle").build().arg(String.class); + + Object actualValue = resolver.resolveArgument( + parameter, new ModelAndViewContainer(), webRequest, new ValidatingBinderFactory()); + + assertThat(actualValue).isEqualTo("part value"); + } + @Test public void isMultipartRequest() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); @@ -606,7 +624,8 @@ public class RequestPartMethodArgumentResolverTests { @RequestPart("requestPart") Optional> optionalMultipartFileList, Optional optionalPart, @RequestPart("requestPart") Optional> optionalPartList, - @RequestPart("requestPart") Optional optionalRequestPart) { + @RequestPart("requestPart") Optional optionalRequestPart, + @RequestPart("requestPartString") String requestPartString) { } }