diff --git a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java index a5eb2e738e..8bdde212d4 100644 --- a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java +++ b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java @@ -28,7 +28,6 @@ import org.springframework.http.HttpInputMessage; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; import org.springframework.validation.BindingResult; -import org.springframework.web.bind.ServletRequestBindingException; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestParam; @@ -37,10 +36,11 @@ import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.annotation.support.MethodArgumentNotValidException; import org.springframework.web.method.support.ModelAndViewContainer; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; -import org.springframework.web.multipart.MultipartRequest; import org.springframework.web.multipart.MultipartResolver; +import org.springframework.web.multipart.support.MissingServletRequestPartException; import org.springframework.web.multipart.support.RequestPartServletServerHttpRequest; import org.springframework.web.servlet.mvc.support.DefaultHandlerExceptionResolver; import org.springframework.web.util.WebUtils; @@ -66,8 +66,8 @@ import org.springframework.web.util.WebUtils; * *

Automatic validation can be applied to a @{@link RequestPart} method argument * through the use of {@code @Valid}. In case of validation failure, a - * {@link RequestPartNotValidException} is thrown and handled automatically through - * the {@link DefaultHandlerExceptionResolver}. + * {@link MethodArgumentNotValidException} is thrown and handled automatically by + * the {@link DefaultHandlerExceptionResolver}. * * @author Rossen Stoyanchev * @since 3.1 @@ -111,44 +111,65 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM NativeWebRequest request, WebDataBinderFactory binderFactory) throws Exception { - String partName = getPartName(parameter); - Object arg; - HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class); + if (!isMultipartRequest(servletRequest)) { + throw new MultipartException("The current request is not a multipart request."); + } + MultipartHttpServletRequest multipartRequest = WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); + String partName = getPartName(parameter); + Object arg; + if (MultipartFile.class.equals(parameter.getParameterType())) { - assertMultipartRequest(multipartRequest, request); + Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); arg = multipartRequest.getFile(partName); } else if (isMultipartFileCollection(parameter)) { - assertMultipartRequest(multipartRequest, request); + Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); arg = multipartRequest.getFiles(partName); } else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { arg = servletRequest.getPart(partName); } else { - HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName); - arg = readWithMessageConverters(inputMessage, parameter, parameter.getParameterType()); - if (isValidationApplicable(arg, parameter)) { - WebDataBinder binder = binderFactory.createBinder(request, arg, partName); - binder.validate(); - BindingResult bindingResult = binder.getBindingResult(); - if (bindingResult.hasErrors()) { - throw new MethodArgumentNotValidException(parameter, bindingResult); + try { + HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName); + arg = readWithMessageConverters(inputMessage, parameter, parameter.getParameterType()); + if (isValidationApplicable(arg, parameter)) { + WebDataBinder binder = binderFactory.createBinder(request, arg, partName); + binder.validate(); + BindingResult bindingResult = binder.getBindingResult(); + if (bindingResult.hasErrors()) { + throw new MethodArgumentNotValidException(parameter, bindingResult); + } } + } + catch (MissingServletRequestPartException e) { + // handled below + arg = null; } } - if (arg == null) { - handleMissingValue(partName, parameter); + RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); + boolean isRequired = (annot != null) ? annot.required() : true; + + if (arg == null && isRequired) { + throw new MissingServletRequestPartException(partName); } return arg; } + private boolean isMultipartRequest(HttpServletRequest request) { + if (!"post".equals(request.getMethod().toLowerCase())) { + return false; + } + String contentType = request.getContentType(); + return (contentType != null && contentType.toLowerCase().startsWith("multipart/")); + } + private String getPartName(MethodParameter parameter) { RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); String partName = (annot != null) ? annot.value() : ""; @@ -159,13 +180,6 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM } return partName; } - - private void assertMultipartRequest(MultipartHttpServletRequest multipartRequest, NativeWebRequest request) { - if (multipartRequest == null) { - throw new IllegalStateException("Current request is not of type [" + MultipartRequest.class.getName() - + "]: " + request + ". Do you have a MultipartResolver configured?"); - } - } private boolean isMultipartFileCollection(MethodParameter parameter) { Class paramType = parameter.getParameterType(); @@ -178,22 +192,6 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM return false; } - /** - * Invoked if the resolved argument value is {@code null}. The default implementation raises - * a {@link ServletRequestBindingException} if the method parameter is required. - * @param partName the name used to look up the request part - * @param param the method argument - */ - protected void handleMissingValue(String partName, MethodParameter param) throws ServletRequestBindingException { - RequestPart annot = param.getParameterAnnotation(RequestPart.class); - boolean isRequired = (annot != null) ? annot.required() : true; - if (isRequired) { - String paramType = param.getParameterType().getName(); - throw new ServletRequestBindingException( - "Missing request part '" + partName + "' for method parameter type [" + paramType + "]"); - } - } - /** * Whether to validate the given @{@link RequestPart} method argument. * The default implementation return {@code true} if the argument value is not {@code null} diff --git a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolver.java b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolver.java index f197485359..6de0aeced4 100644 --- a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolver.java +++ b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolver.java @@ -40,6 +40,8 @@ import org.springframework.web.bind.ServletRequestBindingException; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.method.annotation.support.MethodArgumentNotValidException; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.support.MissingServletRequestPartException; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.handler.AbstractHandlerExceptionResolver; import org.springframework.web.servlet.mvc.multiaction.NoSuchRequestHandlingMethodException; @@ -131,6 +133,9 @@ public class DefaultHandlerExceptionResolver extends AbstractHandlerExceptionRes else if (ex instanceof MethodArgumentNotValidException) { return handleMethodArgumentNotValidException((MethodArgumentNotValidException) ex, request, response, handler); } + else if (ex instanceof MissingServletRequestPartException) { + return handleMissingServletRequestPartException((MissingServletRequestPartException) ex, request, response, handler); + } } catch (Exception handlerException) { logger.warn("Handling of [" + ex.getClass().getName() + "] resulted in Exception", handlerException); @@ -356,4 +361,20 @@ public class DefaultHandlerExceptionResolver extends AbstractHandlerExceptionRes return new ModelAndView(); } + /** + * Handle the case where an @{@link RequestPart}, a {@link MultipartFile}, + * or a {@code javax.servlet.http.Part} argument is required but missing. + * An HTTP 400 error is sent back to the client. + * @param request current HTTP request + * @param response current HTTP response + * @param handler the executed handler + * @return an empty ModelAndView indicating the exception was handled + * @throws IOException potentially thrown from response.sendError() + */ + protected ModelAndView handleMissingServletRequestPartException(MissingServletRequestPartException ex, + HttpServletRequest request, HttpServletResponse response, Object handler) throws IOException { + response.sendError(HttpServletResponse.SC_BAD_REQUEST); + return new ModelAndView(); + } + } diff --git a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartIntegrationTests.java b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartIntegrationTests.java index 3115a97edd..2695e06a30 100644 --- a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartIntegrationTests.java +++ b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartIntegrationTests.java @@ -24,6 +24,7 @@ import java.util.Arrays; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; +import org.junit.Ignore; import org.junit.Test; import org.mortbay.jetty.Server; import org.mortbay.jetty.servlet.Context; @@ -52,6 +53,7 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartResolver; import org.springframework.web.multipart.commons.CommonsMultipartResolver; +import org.springframework.web.multipart.support.StandardServletMultipartResolver; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; @@ -69,6 +71,7 @@ public class RequestPartIntegrationTests { private static String baseUrl; + @BeforeClass public static void startServer() throws Exception { @@ -78,10 +81,19 @@ public class RequestPartIntegrationTests { server = new Server(port); Context context = new Context(server, "/"); + Class config = CommonsMultipartResolverTestConfig.class; ServletHolder commonsResolverServlet = new ServletHolder(DispatcherServlet.class); - commonsResolverServlet.setInitParameter("contextConfigLocation", CommonsMultipartResolverTestConfig.class.getName()); + commonsResolverServlet.setInitParameter("contextConfigLocation", config.getName()); commonsResolverServlet.setInitParameter("contextClass", AnnotationConfigWebApplicationContext.class.getName()); - context.addServlet(commonsResolverServlet, "/commons/*"); + context.addServlet(commonsResolverServlet, "/commons-resolver/*"); + + config = StandardMultipartResolverTestConfig.class; + ServletHolder standardResolverServlet = new ServletHolder(DispatcherServlet.class); + standardResolverServlet.setInitParameter("contextConfigLocation", config.getName()); + standardResolverServlet.setInitParameter("contextClass", AnnotationConfigWebApplicationContext.class.getName()); + context.addServlet(standardResolverServlet, "/standard-resolver/*"); + + // TODO: add Servlet 3.0 test case without MultipartResolver server.start(); } @@ -103,17 +115,28 @@ public class RequestPartIntegrationTests { } } + @Test public void commonsMultipartResolver() throws Exception { + testCreate(baseUrl + "/commons-resolver/test"); + } + @Test + @Ignore("jetty 6.1.9 doesn't support Servlet 3.0") + public void standardMultipartResolver() throws Exception { + testCreate(baseUrl + "/standard-resolver/test"); + } + + private void testCreate(String url) { MultiValueMap parts = new LinkedMultiValueMap(); HttpEntity jsonEntity = new HttpEntity(new TestData("Jason")); parts.add("json-data", jsonEntity); parts.add("file-data", new ClassPathResource("logo.jpg", this.getClass())); - URI location = restTemplate.postForLocation(baseUrl + "/commons/test", parts); + URI location = restTemplate.postForLocation(url, parts); assertEquals("http://localhost:8080/test/Jason/logo.jpg", location.toString()); - } + } + @Configuration @EnableWebMvc @@ -125,13 +148,22 @@ public class RequestPartIntegrationTests { } } + @Configuration static class CommonsMultipartResolverTestConfig extends RequestPartTestConfig { @Bean public MultipartResolver multipartResolver() { return new CommonsMultipartResolver(); } - + } + + @Configuration + static class StandardMultipartResolverTestConfig extends RequestPartTestConfig { + + @Bean + public MultipartResolver multipartResolver() { + return new StandardServletMultipartResolver(); + } } @SuppressWarnings("unused") diff --git a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolverTests.java b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolverTests.java index 705c499a3b..966b33e209 100644 --- a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolverTests.java +++ b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolverTests.java @@ -61,7 +61,9 @@ import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.method.annotation.support.MethodArgumentNotValidException; import org.springframework.web.method.support.ModelAndViewContainer; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.support.MissingServletRequestPartException; import org.springframework.web.multipart.support.RequestPartServletServerHttpRequest; /** @@ -171,6 +173,8 @@ public class RequestPartMethodArgumentResolverTests { public void resolveServlet30PartArgument() throws Exception { MockPart expected = new MockPart("servlet30Part", "Hello World".getBytes()); MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); request.addPart(expected); webRequest = new ServletWebRequest(request); @@ -212,8 +216,8 @@ public class RequestPartMethodArgumentResolverTests { try { testResolveArgument(null, paramValidRequestPart); fail("Expected exception"); - } catch (ServletRequestBindingException e) { - assertTrue(e.getMessage().contains("Missing request part")); + } catch (MissingServletRequestPartException e) { + assertEquals("requestPart", e.getRequestPartName()); } } @@ -221,19 +225,26 @@ public class RequestPartMethodArgumentResolverTests { public void resolveRequestPartNotRequired() throws Exception { testResolveArgument(new SimpleBean("foo"), paramValidRequestPart); } + + @Test(expected=MultipartException.class) + public void notMultipartRequest() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + resolver.resolveArgument(paramMultipartFile, new ModelAndViewContainer(), new ServletWebRequest(request), null); + fail("Expected exception"); + } - private void testResolveArgument(SimpleBean expectedValue, MethodParameter parameter) throws IOException, Exception { + private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws IOException, Exception { MediaType contentType = MediaType.TEXT_PLAIN; multipartRequest.addHeader("Content-Type", contentType.toString()); expect(messageConverter.canRead(SimpleBean.class, contentType)).andReturn(true); - expect(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).andReturn(expectedValue); + expect(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).andReturn(argValue); replay(messageConverter); ModelAndViewContainer mavContainer = new ModelAndViewContainer(); Object actualValue = resolver.resolveArgument(parameter, mavContainer, webRequest, new ValidatingBinderFactory()); - assertEquals("Invalid argument value", expectedValue, actualValue); + assertEquals("Invalid argument value", argValue, actualValue); assertTrue("The ResolveView flag shouldn't change", mavContainer.isResolveView()); verify(messageConverter); diff --git a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolverTests.java b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolverTests.java index be15c5630e..1d9d90d98b 100644 --- a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolverTests.java +++ b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/support/DefaultHandlerExceptionResolverTests.java @@ -38,6 +38,7 @@ import org.springframework.web.HttpRequestMethodNotSupportedException; import org.springframework.web.bind.MissingServletRequestParameterException; import org.springframework.web.bind.ServletRequestBindingException; import org.springframework.web.method.annotation.support.MethodArgumentNotValidException; +import org.springframework.web.multipart.support.MissingServletRequestPartException; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.mvc.multiaction.NoSuchRequestHandlingMethodException; @@ -147,6 +148,15 @@ public class DefaultHandlerExceptionResolverTests { assertEquals("Invalid status code", 400, response.getStatus()); } + @Test + public void handleMissingServletRequestPartException() throws Exception { + MissingServletRequestPartException ex = new MissingServletRequestPartException("name"); + ModelAndView mav = exceptionResolver.resolveException(request, response, null, ex); + assertNotNull("No ModelAndView returned", mav); + assertTrue("No Empty ModelAndView returned", mav.isEmpty()); + assertEquals("Invalid status code", 400, response.getStatus()); + } + public void handle(String arg) { } diff --git a/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java b/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java index 91b8640a4c..10cc58b0c3 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java +++ b/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java @@ -29,6 +29,7 @@ import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.core.GenericCollectionTypeResolver; import org.springframework.core.MethodParameter; import org.springframework.core.convert.converter.Converter; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.bind.MissingServletRequestParameterException; import org.springframework.web.bind.WebDataBinder; @@ -36,9 +37,9 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.bind.annotation.ValueConstants; import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; -import org.springframework.web.multipart.MultipartRequest; import org.springframework.web.multipart.MultipartResolver; import org.springframework.web.util.WebUtils; @@ -58,8 +59,8 @@ import org.springframework.web.util.WebUtils; * parameter type, the {@link RequestParamMapMethodArgumentResolver} is used instead * providing access to all request parameters in the form of a map. * - *

A {@link WebDataBinder} is invoked to apply type conversion to resolved request header values that - * don't yet match the method parameter type. + *

A {@link WebDataBinder} is invoked to apply type conversion to resolved request + * header values that don't yet match the method parameter type. * * @author Arjen Poutsma * @author Rossen Stoyanchev @@ -71,11 +72,13 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod private final boolean useDefaultResolution; /** - * @param beanFactory a bean factory to use for resolving ${...} placeholder and #{...} SpEL expressions - * in default values, or {@code null} if default values are not expected to contain expressions - * @param useDefaultResolution in default resolution mode a method argument that is a simple type, as - * defined in {@link BeanUtils#isSimpleProperty(Class)}, is treated as a request parameter even if it doesn't have - * an @{@link RequestParam} annotation, the request parameter name is derived from the method parameter name. + * @param beanFactory a bean factory used for resolving ${...} placeholder + * and #{...} SpEL expressions in default values, or {@code null} if default + * values are not expected to contain expressions + * @param useDefaultResolution in default resolution mode a method argument + * that is a simple type, as defined in {@link BeanUtils#isSimpleProperty}, + * is treated as a request parameter even if it itsn't annotated, the + * request parameter name is derived from the method parameter name. */ public RequestParamMethodArgumentResolver(ConfigurableBeanFactory beanFactory, boolean useDefaultResolution) { @@ -87,11 +90,15 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod * Supports the following: *

*/ public boolean supportsParameter(MethodParameter parameter) { @@ -139,14 +146,17 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); if (MultipartFile.class.equals(parameter.getParameterType())) { - assertMultipartRequest(multipartRequest, webRequest); + assertIsMultipartRequest(servletRequest); + Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); arg = multipartRequest.getFile(name); } else if (isMultipartFileCollection(parameter)) { - assertMultipartRequest(multipartRequest, webRequest); + assertIsMultipartRequest(servletRequest); + Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); arg = multipartRequest.getFiles(name); } else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { + assertIsMultipartRequest(servletRequest); arg = servletRequest.getPart(name); } else { @@ -168,13 +178,20 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod return arg; } - private void assertMultipartRequest(MultipartHttpServletRequest multipartRequest, NativeWebRequest request) { - if (multipartRequest == null) { - throw new IllegalStateException("Current request is not of type [" + MultipartRequest.class.getName() - + "]: " + request + ". Do you have a MultipartResolver configured?"); + private void assertIsMultipartRequest(HttpServletRequest request) { + if (!isMultipartRequest(request)) { + throw new MultipartException("The current request is not a multipart request."); } } + private boolean isMultipartRequest(HttpServletRequest request) { + if (!"post".equals(request.getMethod().toLowerCase())) { + return false; + } + String contentType = request.getContentType(); + return (contentType != null && contentType.toLowerCase().startsWith("multipart/")); + } + private boolean isMultipartFileCollection(MethodParameter parameter) { Class paramType = parameter.getParameterType(); if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){ diff --git a/org.springframework.web/src/main/java/org/springframework/web/multipart/support/MissingServletRequestPartException.java b/org.springframework.web/src/main/java/org/springframework/web/multipart/support/MissingServletRequestPartException.java new file mode 100644 index 0000000000..0ee85e75cc --- /dev/null +++ b/org.springframework.web/src/main/java/org/springframework/web/multipart/support/MissingServletRequestPartException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2011 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.multipart.support; + +import javax.servlet.ServletException; + +import org.springframework.web.multipart.MultipartResolver; + +/** + * Raised when the part of a "multipart/form-data" request identified by its + * name cannot be found. + * + *

This may be because the request is not a multipart/form-data + * + * either because the part is not present in the request, or + * because the web application is not configured correctly for processing + * multipart requests -- e.g. no {@link MultipartResolver}. + * + * @author Rossen Stoyanchev + * @since 3.1 + */ +public class MissingServletRequestPartException extends ServletException { + + private static final long serialVersionUID = -1255077391966870705L; + + private final String partName; + + public MissingServletRequestPartException(String partName) { + super("Request part '" + partName + "' not found."); + this.partName = partName; + } + + public String getRequestPartName() { + return this.partName; + } +} diff --git a/org.springframework.web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java b/org.springframework.web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java index 234f22b43c..4cffa524bb 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java +++ b/org.springframework.web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java @@ -19,18 +19,23 @@ package org.springframework.web.multipart.support; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; + import javax.servlet.http.HttpServletRequest; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.util.ClassUtils; import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartResolver; /** - * {@link ServerHttpRequest} implementation that is based on a part of a {@link MultipartHttpServletRequest}. - * The part is accessed as {@link MultipartFile} and adapted to the ServerHttpRequest contract. + * {@link ServerHttpRequest} implementation that accesses one part of a multipart + * request. If using {@link MultipartResolver} configuration the part is accessed + * through a {@link MultipartFile}. Or if using Servlet 3.0 multipart processing + * the part is accessed through {@code ServletRequest.getPart}. * * @author Rossen Stoyanchev * @author Juergen Hoeller @@ -46,23 +51,44 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques /** - * Create a new {@link RequestPartServletServerHttpRequest} instance. - * @param request the multipart request + * Create a new instance. + * @param request the current request * @param partName the name of the part to adapt to the {@link ServerHttpRequest} contract + * @throws MissingServletRequestPartException if the request part cannot be found + * @throws IllegalArgumentException if MultipartHttpServletRequest cannot be initialized */ - public RequestPartServletServerHttpRequest(HttpServletRequest request, String partName) { + public RequestPartServletServerHttpRequest(HttpServletRequest request, String partName) + throws MissingServletRequestPartException { + super(request); - this.multipartRequest = (request instanceof MultipartHttpServletRequest ? - (MultipartHttpServletRequest) request : new StandardMultipartHttpServletRequest(request)); + this.multipartRequest = asMultipartRequest(request); this.partName = partName; this.headers = this.multipartRequest.getMultipartHeaders(this.partName); if (this.headers == null) { - throw new IllegalArgumentException("No request part found for name '" + this.partName + "'"); + if (request instanceof MultipartHttpServletRequest) { + throw new MissingServletRequestPartException(partName); + } + else { + throw new IllegalArgumentException( + "Failed to obtain request part: " + partName + ". " + + "The part is missing or multipart processing is not configured. " + + "Check for a MultipartResolver bean or if Servlet 3.0 multipart processing is enabled."); + } } } - + + private static MultipartHttpServletRequest asMultipartRequest(HttpServletRequest request) { + if (request instanceof MultipartHttpServletRequest) { + return (MultipartHttpServletRequest) request; + } + else if (ClassUtils.hasMethod(HttpServletRequest.class, "getParts")) { + // Servlet 3.0 available .. + return new StandardMultipartHttpServletRequest(request); + } + throw new IllegalArgumentException("Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); + } @Override public HttpHeaders getHeaders() { diff --git a/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java b/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java index 98f4e3c0d3..32689e6ae3 100644 --- a/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java +++ b/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java @@ -16,10 +16,10 @@ package org.springframework.web.method.annotation.support; -import static org.junit.Assert.*; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -45,6 +45,7 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; /** @@ -180,16 +181,26 @@ public class RequestParamMethodArgumentResolverTests { assertEquals(Arrays.asList(expected1, expected2), result); } - @Test(expected = IllegalStateException.class) + @Test(expected = MultipartException.class) + public void notMultipartRequest() throws Exception { + resolver.resolveArgument(paramMultiPartFile, null, webRequest, null); + fail("Expected exception: request is not a multipart request"); + } + + @Test(expected = IllegalArgumentException.class) public void missingMultipartFile() throws Exception { + request.setMethod("POST"); + request.setContentType("multipart/form-data"); resolver.resolveArgument(paramMultiPartFile, null, webRequest, null); - fail("Expected exception"); + fail("Expected exception: request is not MultiPartHttpServletRequest but param is MultipartFile"); } @Test public void resolveServlet30Part() throws Exception { MockPart expected = new MockPart("servlet30Part", "Hello World".getBytes()); MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); request.addPart(expected); webRequest = new ServletWebRequest(request);