diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java index 1dbc559e2c..535ed21608 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -169,7 +169,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements HttpMethod httpMethod = (inputMessage instanceof HttpRequest ? ((HttpRequest) inputMessage).getMethod() : null); Object body = NO_VALUE; - EmptyBodyCheckingHttpInputMessage message; + EmptyBodyCheckingHttpInputMessage message = null; try { message = new EmptyBodyCheckingHttpInputMessage(inputMessage); @@ -196,6 +196,11 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements catch (IOException ex) { throw new HttpMessageNotReadableException("I/O error while reading input message", ex, inputMessage); } + finally { + if (message != null && message.hasBody()) { + closeStreamIfNecessary(message.getBody()); + } + } if (body == NO_VALUE) { if (httpMethod == null || !SUPPORTED_METHODS.contains(httpMethod) || @@ -298,6 +303,15 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements return arg; } + /** + * Allow for closing the body stream if necessary, + * e.g. for part streams in a multipart request. + */ + void closeStreamIfNecessary(InputStream body) { + // No-op by default: A standard HttpInputMessage exposes the HTTP request stream + // (ServletRequest#getInputStream), with its lifecycle managed by the container. + } + private static class EmptyBodyCheckingHttpInputMessage implements HttpInputMessage { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java index e828239d71..eff6744eba 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.web.servlet.mvc.method.annotation; +import java.io.IOException; +import java.io.InputStream; import java.util.List; import javax.servlet.http.HttpServletRequest; @@ -180,4 +182,17 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM return partName; } + @Override + void closeStreamIfNecessary(InputStream body) { + // RequestPartServletServerHttpRequest exposes individual part streams, + // potentially from temporary files -> explicit close call after resolution + // in order to prevent file descriptor leaks. + try { + body.close(); + } + catch (IOException ex) { + // ignore + } + } + } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java index 9cf34618a6..cbde3c20b2 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ package org.springframework.web.servlet.mvc.method.annotation; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -83,6 +86,8 @@ public class RequestPartMethodArgumentResolverTests { private MultipartFile multipartFile2; + private CloseTrackingInputStream trackedStream; + private MockMultipartHttpServletRequest multipartRequest; private NativeWebRequest webRequest; @@ -116,7 +121,14 @@ public class RequestPartMethodArgumentResolverTests { reset(messageConverter); byte[] content = "doesn't matter as long as not empty".getBytes(StandardCharsets.UTF_8); - multipartFile1 = new MockMultipartFile("requestPart", "", "text/plain", content); + multipartFile1 = new MockMultipartFile("requestPart", "", "text/plain", content) { + @Override + public InputStream getInputStream() throws IOException { + CloseTrackingInputStream in = new CloseTrackingInputStream(super.getInputStream()); + trackedStream = in; + return in; + } + }; multipartFile2 = new MockMultipartFile("requestPart", "", "text/plain", content); multipartRequest = new MockMultipartHttpServletRequest(); multipartRequest.addFile(multipartFile1); @@ -182,8 +194,7 @@ public class RequestPartMethodArgumentResolverTests { @Test public void resolveMultipartFileList() throws Exception { Object actual = resolver.resolveArgument(paramMultipartFileList, null, webRequest, null); - boolean condition = actual instanceof List; - assertThat(condition).isTrue(); + assertThat(actual instanceof List).isTrue(); assertThat(actual).isEqualTo(Arrays.asList(multipartFile1, multipartFile2)); } @@ -191,8 +202,7 @@ public class RequestPartMethodArgumentResolverTests { public void resolveMultipartFileArray() throws Exception { Object actual = resolver.resolveArgument(paramMultipartFileArray, null, webRequest, null); assertThat(actual).isNotNull(); - boolean condition = actual instanceof MultipartFile[]; - assertThat(condition).isTrue(); + assertThat(actual instanceof MultipartFile[]).isTrue(); MultipartFile[] parts = (MultipartFile[]) actual; assertThat(parts.length).isEqualTo(2); assertThat(multipartFile1).isEqualTo(parts[0]); @@ -209,8 +219,7 @@ public class RequestPartMethodArgumentResolverTests { Object result = resolver.resolveArgument(paramMultipartFileNotAnnot, null, webRequest, null); - boolean condition = result instanceof MultipartFile; - assertThat(condition).isTrue(); + assertThat(result instanceof MultipartFile).isTrue(); assertThat(result).as("Invalid result").isEqualTo(expected); } @@ -225,8 +234,7 @@ public class RequestPartMethodArgumentResolverTests { webRequest = new ServletWebRequest(request); Object result = resolver.resolveArgument(paramPart, null, webRequest, null); - boolean condition = result instanceof Part; - assertThat(condition).isTrue(); + assertThat(result instanceof Part).isTrue(); assertThat(result).as("Invalid result").isEqualTo(expected); } @@ -243,8 +251,7 @@ public class RequestPartMethodArgumentResolverTests { webRequest = new ServletWebRequest(request); Object result = resolver.resolveArgument(paramPartList, null, webRequest, null); - boolean condition = result instanceof List; - assertThat(condition).isTrue(); + assertThat(result instanceof List).isTrue(); assertThat(result).isEqualTo(Arrays.asList(part1, part2)); } @@ -261,8 +268,7 @@ public class RequestPartMethodArgumentResolverTests { webRequest = new ServletWebRequest(request); Object result = resolver.resolveArgument(paramPartArray, null, webRequest, null); - boolean condition = result instanceof Part[]; - assertThat(condition).isTrue(); + assertThat(result instanceof Part[]).isTrue(); Part[] parts = (Part[]) result; assertThat(parts.length).isEqualTo(2); assertThat(part1).isEqualTo(parts[0]); @@ -357,8 +363,7 @@ public class RequestPartMethodArgumentResolverTests { assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(expected); actualValue = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null); - boolean condition = actualValue instanceof Optional; - assertThat(condition).isTrue(); + assertThat(actualValue instanceof Optional).isTrue(); assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(expected); } @@ -399,8 +404,7 @@ public class RequestPartMethodArgumentResolverTests { assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(Collections.singletonList(expected)); actualValue = resolver.resolveArgument(optionalMultipartFileList, null, webRequest, null); - boolean condition = actualValue instanceof Optional; - assertThat(condition).isTrue(); + assertThat(actualValue instanceof Optional).isTrue(); assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(Collections.singletonList(expected)); } @@ -443,8 +447,7 @@ public class RequestPartMethodArgumentResolverTests { assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(expected); actualValue = resolver.resolveArgument(optionalPart, null, webRequest, null); - boolean condition = actualValue instanceof Optional; - assertThat(condition).isTrue(); + assertThat(actualValue instanceof Optional).isTrue(); assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(expected); } @@ -489,8 +492,7 @@ public class RequestPartMethodArgumentResolverTests { assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(Collections.singletonList(expected)); actualValue = resolver.resolveArgument(optionalPartList, null, webRequest, null); - boolean condition = actualValue instanceof Optional; - assertThat(condition).isTrue(); + assertThat(actualValue instanceof Optional).isTrue(); assertThat(((Optional) actualValue).get()).as("Invalid result").isEqualTo(Collections.singletonList(expected)); } @@ -572,6 +574,7 @@ public class RequestPartMethodArgumentResolverTests { Object actualValue = resolver.resolveArgument(parameter, mavContainer, webRequest, new ValidatingBinderFactory()); assertThat(actualValue).as("Invalid argument value").isEqualTo(argValue); assertThat(mavContainer.isRequestHandled()).as("The requestHandled flag shouldn't change").isFalse(); + assertThat(trackedStream != null && trackedStream.closed).isTrue(); } @@ -591,7 +594,7 @@ public class RequestPartMethodArgumentResolverTests { } - private final class ValidatingBinderFactory implements WebDataBinderFactory { + private static class ValidatingBinderFactory implements WebDataBinderFactory { @Override public WebDataBinder createBinder(NativeWebRequest webRequest, @Nullable Object target, @@ -606,6 +609,21 @@ public class RequestPartMethodArgumentResolverTests { } + private static class CloseTrackingInputStream extends FilterInputStream { + + public boolean closed = false; + + public CloseTrackingInputStream(InputStream in) { + super(in); + } + + @Override + public void close() { + this.closed = true; + } + } + + @SuppressWarnings("unused") public void handle( @RequestPart SimpleBean requestPart,