diff --git a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java index 68d211983c..734bd08500 100644 --- a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java +++ b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java @@ -48,6 +48,7 @@ import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import javax.validation.Valid; import javax.validation.constraints.NotNull; +import javax.xml.bind.annotation.XmlRootElement; import static org.junit.Assert.*; import org.junit.Test; @@ -60,6 +61,7 @@ import org.springframework.beans.DerivedTestBean; import org.springframework.beans.GenericBean; import org.springframework.beans.ITestBean; import org.springframework.beans.TestBean; +import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.config.PropertyPlaceholderConfigurer; @@ -81,10 +83,12 @@ import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; import org.springframework.http.converter.HttpMessageNotWritableException; import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.http.converter.xml.MarshallingHttpMessageConverter; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockServletConfig; import org.springframework.mock.web.MockServletContext; +import org.springframework.oxm.jaxb.Jaxb2Marshaller; import org.springframework.stereotype.Controller; import org.springframework.ui.ExtendedModelMap; import org.springframework.ui.Model; @@ -1256,6 +1260,44 @@ public class ServletAnnotationControllerTests { assertEquals(200, response.getStatus()); } + @Test + public void responseBodyArgMismatch() throws ServletException, IOException { + @SuppressWarnings("serial") DispatcherServlet servlet = new DispatcherServlet() { + @Override + protected WebApplicationContext createWebApplicationContext(WebApplicationContext parent) { + GenericWebApplicationContext wac = new GenericWebApplicationContext(); + wac.registerBeanDefinition("controller", new RootBeanDefinition(RequestBodyArgMismatchController.class)); + + Jaxb2Marshaller marshaller = new Jaxb2Marshaller(); + marshaller.setClassesToBeBound(A.class, B.class); + try { + marshaller.afterPropertiesSet(); + } + catch (Exception ex) { + throw new BeanCreationException(ex.getMessage(), ex); + } + + MarshallingHttpMessageConverter messageConverter = new MarshallingHttpMessageConverter(marshaller); + + RootBeanDefinition adapterDef = new RootBeanDefinition(AnnotationMethodHandlerAdapter.class); + adapterDef.getPropertyValues().add("messageConverters", messageConverter); + wac.registerBeanDefinition("handlerAdapter", adapterDef); + wac.refresh(); + return wac; + } + }; + servlet.init(new MockServletConfig()); + + + MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/something"); + String requestBody = ""; + request.setContent(requestBody.getBytes("UTF-8")); + request.addHeader("Content-Type", "application/xml; charset=utf-8"); + MockHttpServletResponse response = new MockHttpServletResponse(); + servlet.service(request, response); + assertEquals(400, response.getStatus()); + } + @Test public void contentTypeHeaders() throws ServletException, IOException { @@ -2299,6 +2341,25 @@ public class ServletAnnotationControllerTests { } } + @Controller + public static class RequestBodyArgMismatchController { + + @RequestMapping(value = "/something", method = RequestMethod.PUT) + public void handle(@RequestBody A a) throws IOException { + } + } + + @XmlRootElement + public static class A { + + } + + @XmlRootElement + public static class B { + + } + + public static class NotReadableMessageConverter implements HttpMessageConverter { public boolean canRead(Class clazz, MediaType mediaType) {