@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2021 the original author or authors .
* Copyright 2002 - 2022 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -16,6 +16,9 @@
@@ -16,6 +16,9 @@
package org.springframework.web.servlet.mvc.method.annotation ;
import java.io.FilterInputStream ;
import java.io.IOException ;
import java.io.InputStream ;
import java.lang.reflect.Method ;
import java.nio.charset.StandardCharsets ;
import java.util.Arrays ;
@ -83,6 +86,8 @@ public class RequestPartMethodArgumentResolverTests {
@@ -83,6 +86,8 @@ public class RequestPartMethodArgumentResolverTests {
private MultipartFile multipartFile2 ;
private CloseTrackingInputStream trackedStream ;
private MockMultipartHttpServletRequest multipartRequest ;
private NativeWebRequest webRequest ;
@ -116,7 +121,14 @@ public class RequestPartMethodArgumentResolverTests {
@@ -116,7 +121,14 @@ public class RequestPartMethodArgumentResolverTests {
reset ( messageConverter ) ;
byte [ ] content = "doesn't matter as long as not empty" . getBytes ( StandardCharsets . UTF_8 ) ;
multipartFile1 = new MockMultipartFile ( "requestPart" , "" , "text/plain" , content ) ;
multipartFile1 = new MockMultipartFile ( "requestPart" , "" , "text/plain" , content ) {
@Override
public InputStream getInputStream ( ) throws IOException {
CloseTrackingInputStream in = new CloseTrackingInputStream ( super . getInputStream ( ) ) ;
trackedStream = in ;
return in ;
}
} ;
multipartFile2 = new MockMultipartFile ( "requestPart" , "" , "text/plain" , content ) ;
multipartRequest = new MockMultipartHttpServletRequest ( ) ;
multipartRequest . addFile ( multipartFile1 ) ;
@ -182,8 +194,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -182,8 +194,7 @@ public class RequestPartMethodArgumentResolverTests {
@Test
public void resolveMultipartFileList ( ) throws Exception {
Object actual = resolver . resolveArgument ( paramMultipartFileList , null , webRequest , null ) ;
boolean condition = actual instanceof List ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( actual instanceof List ) . isTrue ( ) ;
assertThat ( actual ) . isEqualTo ( Arrays . asList ( multipartFile1 , multipartFile2 ) ) ;
}
@ -191,8 +202,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -191,8 +202,7 @@ public class RequestPartMethodArgumentResolverTests {
public void resolveMultipartFileArray ( ) throws Exception {
Object actual = resolver . resolveArgument ( paramMultipartFileArray , null , webRequest , null ) ;
assertThat ( actual ) . isNotNull ( ) ;
boolean condition = actual instanceof MultipartFile [ ] ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( actual instanceof MultipartFile [ ] ) . isTrue ( ) ;
MultipartFile [ ] parts = ( MultipartFile [ ] ) actual ;
assertThat ( parts . length ) . isEqualTo ( 2 ) ;
assertThat ( multipartFile1 ) . isEqualTo ( parts [ 0 ] ) ;
@ -209,8 +219,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -209,8 +219,7 @@ public class RequestPartMethodArgumentResolverTests {
Object result = resolver . resolveArgument ( paramMultipartFileNotAnnot , null , webRequest , null ) ;
boolean condition = result instanceof MultipartFile ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( result instanceof MultipartFile ) . isTrue ( ) ;
assertThat ( result ) . as ( "Invalid result" ) . isEqualTo ( expected ) ;
}
@ -225,8 +234,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -225,8 +234,7 @@ public class RequestPartMethodArgumentResolverTests {
webRequest = new ServletWebRequest ( request ) ;
Object result = resolver . resolveArgument ( paramPart , null , webRequest , null ) ;
boolean condition = result instanceof Part ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( result instanceof Part ) . isTrue ( ) ;
assertThat ( result ) . as ( "Invalid result" ) . isEqualTo ( expected ) ;
}
@ -243,8 +251,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -243,8 +251,7 @@ public class RequestPartMethodArgumentResolverTests {
webRequest = new ServletWebRequest ( request ) ;
Object result = resolver . resolveArgument ( paramPartList , null , webRequest , null ) ;
boolean condition = result instanceof List ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( result instanceof List ) . isTrue ( ) ;
assertThat ( result ) . isEqualTo ( Arrays . asList ( part1 , part2 ) ) ;
}
@ -261,8 +268,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -261,8 +268,7 @@ public class RequestPartMethodArgumentResolverTests {
webRequest = new ServletWebRequest ( request ) ;
Object result = resolver . resolveArgument ( paramPartArray , null , webRequest , null ) ;
boolean condition = result instanceof Part [ ] ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( result instanceof Part [ ] ) . isTrue ( ) ;
Part [ ] parts = ( Part [ ] ) result ;
assertThat ( parts . length ) . isEqualTo ( 2 ) ;
assertThat ( part1 ) . isEqualTo ( parts [ 0 ] ) ;
@ -357,8 +363,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -357,8 +363,7 @@ public class RequestPartMethodArgumentResolverTests {
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( expected ) ;
actualValue = resolver . resolveArgument ( optionalMultipartFile , null , webRequest , null ) ;
boolean condition = actualValue instanceof Optional ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( actualValue instanceof Optional ) . isTrue ( ) ;
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( expected ) ;
}
@ -399,8 +404,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -399,8 +404,7 @@ public class RequestPartMethodArgumentResolverTests {
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( Collections . singletonList ( expected ) ) ;
actualValue = resolver . resolveArgument ( optionalMultipartFileList , null , webRequest , null ) ;
boolean condition = actualValue instanceof Optional ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( actualValue instanceof Optional ) . isTrue ( ) ;
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( Collections . singletonList ( expected ) ) ;
}
@ -443,8 +447,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -443,8 +447,7 @@ public class RequestPartMethodArgumentResolverTests {
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( expected ) ;
actualValue = resolver . resolveArgument ( optionalPart , null , webRequest , null ) ;
boolean condition = actualValue instanceof Optional ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( actualValue instanceof Optional ) . isTrue ( ) ;
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( expected ) ;
}
@ -489,8 +492,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -489,8 +492,7 @@ public class RequestPartMethodArgumentResolverTests {
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( Collections . singletonList ( expected ) ) ;
actualValue = resolver . resolveArgument ( optionalPartList , null , webRequest , null ) ;
boolean condition = actualValue instanceof Optional ;
assertThat ( condition ) . isTrue ( ) ;
assertThat ( actualValue instanceof Optional ) . isTrue ( ) ;
assertThat ( ( ( Optional < ? > ) actualValue ) . get ( ) ) . as ( "Invalid result" ) . isEqualTo ( Collections . singletonList ( expected ) ) ;
}
@ -572,6 +574,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -572,6 +574,7 @@ public class RequestPartMethodArgumentResolverTests {
Object actualValue = resolver . resolveArgument ( parameter , mavContainer , webRequest , new ValidatingBinderFactory ( ) ) ;
assertThat ( actualValue ) . as ( "Invalid argument value" ) . isEqualTo ( argValue ) ;
assertThat ( mavContainer . isRequestHandled ( ) ) . as ( "The requestHandled flag shouldn't change" ) . isFalse ( ) ;
assertThat ( trackedStream ! = null & & trackedStream . closed ) . isTrue ( ) ;
}
@ -591,7 +594,7 @@ public class RequestPartMethodArgumentResolverTests {
@@ -591,7 +594,7 @@ public class RequestPartMethodArgumentResolverTests {
}
private final class ValidatingBinderFactory implements WebDataBinderFactory {
private static class ValidatingBinderFactory implements WebDataBinderFactory {
@Override
public WebDataBinder createBinder ( NativeWebRequest webRequest , @Nullable Object target ,
@ -606,6 +609,21 @@ public class RequestPartMethodArgumentResolverTests {
@@ -606,6 +609,21 @@ public class RequestPartMethodArgumentResolverTests {
}
private static class CloseTrackingInputStream extends FilterInputStream {
public boolean closed = false ;
public CloseTrackingInputStream ( InputStream in ) {
super ( in ) ;
}
@Override
public void close ( ) {
this . closed = true ;
}
}
@SuppressWarnings ( "unused" )
public void handle (
@RequestPart SimpleBean requestPart ,