From 4000b244ff7b1db1f4eba9c4dbf2fc97162da254 Mon Sep 17 00:00:00 2001 From: Sam Brannen Date: Tue, 18 Jun 2019 15:56:50 +0300 Subject: [PATCH] Forbid null converters in RestTemplate & HttpMessageConverterExtractor Prior to this commit, RestTemplate and HttpMessageConverterExtractor did not validate that the supplied HttpMessageConverter list contained no null elements, which can lead to a NullPointerException when the converters are accessed. This commit improves the user experience by failing immediately if the supplied HttpMessageConverter list contains a null element. This applies to constructors for RestTemplate and HttpMessageConverterExtractor as well as to RestTemplate#setMessageConverters(). Note, however, that RestTemplate#getMessageConverters() returns a mutable list. Thus, if a user modifies that list so that it contains null values, that will still lead to a NullPointerException when the converters are accessed. This commit also introduces noNullElements() variants for collections in org.springframework.util.Assert. Closes gh-23151 --- .../java/org/springframework/util/Assert.java | 41 +++++++++ .../client/HttpMessageConverterExtractor.java | 4 +- .../web/client/RestTemplate.java | 11 ++- .../HttpMessageConverterExtractorTests.java | 84 ++++++------------- .../web/client/RestTemplateTests.java | 23 +++++ 5 files changed, 102 insertions(+), 61 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/util/Assert.java b/spring-core/src/main/java/org/springframework/util/Assert.java index 1f526d5ddc..8d4fc74e2f 100644 --- a/spring-core/src/main/java/org/springframework/util/Assert.java +++ b/spring-core/src/main/java/org/springframework/util/Assert.java @@ -495,6 +495,47 @@ public abstract class Assert { "[Assertion failed] - this collection must not be empty: it must contain at least 1 element"); } + /** + * Assert that a collection contains no {@code null} elements. + *

Note: Does not complain if the collection is empty! + *

Assert.noNullElements(collection, "Collection must contain non-null elements");
+ * @param collection the collection to check + * @param message the exception message to use if the assertion fails + * @throws IllegalArgumentException if the collection contains a {@code null} element + * @since 5.2 + */ + public static void noNullElements(@Nullable Collection collection, String message) { + if (collection != null) { + for (Object element : collection) { + if (element == null) { + throw new IllegalArgumentException(message); + } + } + } + } + + /** + * Assert that a collection contains no {@code null} elements. + *

Note: Does not complain if the collection is empty! + *

+	 * Assert.noNullElements(collection, () -> "Collection " + collectionName + " must contain non-null elements");
+	 * 
+ * @param collection the collection to check + * @param messageSupplier a supplier for the exception message to use if the + * assertion fails + * @throws IllegalArgumentException if the collection contains a {@code null} element + * @since 5.2 + */ + public static void noNullElements(@Nullable Collection collection, Supplier messageSupplier) { + if (collection != null) { + for (Object element : collection) { + if (element == null) { + throw new IllegalArgumentException(nullSafeGet(messageSupplier)); + } + } + } + } + /** * Assert that a Map contains entries; that is, it must not be {@code null} * and must contain at least one entry. diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java index f13766e510..3e640b2c9a 100644 --- a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java +++ b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ import org.springframework.util.Assert; * to convert the response into a type {@code T}. * * @author Arjen Poutsma + * @author Sam Brannen * @since 3.0 * @param the data type * @see RestTemplate @@ -73,6 +74,7 @@ public class HttpMessageConverterExtractor implements ResponseExtractor { HttpMessageConverterExtractor(Type responseType, List> messageConverters, Log logger) { Assert.notNull(responseType, "'responseType' must not be null"); Assert.notEmpty(messageConverters, "'messageConverters' must not be empty"); + Assert.noNullElements(messageConverters, "'messageConverters' must not contain null elements"); this.responseType = responseType; this.responseClass = (responseType instanceof Class ? (Class) responseType : null); this.messageConverters = messageConverters; diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index a1444b05d9..9ccd5e4dd7 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -84,6 +84,7 @@ import org.springframework.web.util.UriTemplateHandler; * @author Brian Clozel * @author Roy Clarkson * @author Juergen Hoeller + * @author Sam Brannen * @since 3.0 * @see HttpMessageConverter * @see RequestCallback @@ -198,11 +199,12 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat * @since 3.2.7 */ public RestTemplate(List> messageConverters) { - Assert.notEmpty(messageConverters, "At least one HttpMessageConverter required"); + validateConverters(messageConverters); this.messageConverters.addAll(messageConverters); this.uriTemplateHandler = initUriTemplateHandler(); } + private static DefaultUriBuilderFactory initUriTemplateHandler() { DefaultUriBuilderFactory uriFactory = new DefaultUriBuilderFactory(); uriFactory.setEncodingMode(EncodingMode.URI_COMPONENT); // for backwards compatibility.. @@ -215,7 +217,7 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat *

These converters are used to convert from and to HTTP requests and responses. */ public void setMessageConverters(List> messageConverters) { - Assert.notEmpty(messageConverters, "At least one HttpMessageConverter required"); + validateConverters(messageConverters); // Take getMessageConverters() List as-is when passed in here if (this.messageConverters != messageConverters) { this.messageConverters.clear(); @@ -223,6 +225,11 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat } } + private void validateConverters(List> messageConverters) { + Assert.notEmpty(messageConverters, "At least one HttpMessageConverter is required"); + Assert.noNullElements(messageConverters, "The HttpMessageConverter list must not contain null elements"); + } + /** * Return the list of message body converters. *

The returned {@link List} is active and may get appended to. diff --git a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java index 3a94179772..45329d15f5 100644 --- a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java @@ -19,7 +19,7 @@ package org.springframework.web.client; import java.io.ByteArrayInputStream; import java.io.IOException; import java.lang.reflect.Type; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.Test; @@ -34,8 +34,10 @@ import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; +import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; @@ -46,18 +48,30 @@ import static org.mockito.Mockito.mock; * * @author Arjen Poutsma * @author Brian Clozel + * @author Sam Brannen */ public class HttpMessageConverterExtractorTests { - private HttpMessageConverterExtractor extractor; - + @SuppressWarnings("unchecked") + private final HttpMessageConverter converter = mock(HttpMessageConverter.class); + private final HttpMessageConverterExtractor extractor = new HttpMessageConverterExtractor<>(String.class, asList(converter)); + private final MediaType contentType = MediaType.TEXT_PLAIN; + private final HttpHeaders responseHeaders = new HttpHeaders(); private final ClientHttpResponse response = mock(ClientHttpResponse.class); + @Test + public void constructorPreconditions() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new HttpMessageConverterExtractor<>(String.class, (List>) null)) + .withMessage("'messageConverters' must not be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new HttpMessageConverterExtractor<>(String.class, Arrays.asList(null, this.converter))) + .withMessage("'messageConverters' must not contain null elements"); + } + @Test public void noContent() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.NO_CONTENT.value()); Object result = extractor.extractData(response); @@ -66,8 +80,6 @@ public class HttpMessageConverterExtractorTests { @Test public void notModified() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_MODIFIED.value()); Object result = extractor.extractData(response); @@ -76,8 +88,6 @@ public class HttpMessageConverterExtractorTests { @Test public void informational() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.CONTINUE.value()); Object result = extractor.extractData(response); @@ -86,10 +96,7 @@ public class HttpMessageConverterExtractorTests { @Test public void zeroContentLength() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentLength(0); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); @@ -98,11 +105,7 @@ public class HttpMessageConverterExtractorTests { } @Test - @SuppressWarnings("unchecked") public void emptyMessageBody() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("".getBytes())); @@ -112,11 +115,7 @@ public class HttpMessageConverterExtractorTests { } @Test // gh-22265 - @SuppressWarnings("unchecked") public void nullMessageBody() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(null); @@ -126,14 +125,9 @@ public class HttpMessageConverterExtractorTests { } @Test - @SuppressWarnings("unchecked") public void normal() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); String expected = "Foo"; - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); @@ -145,32 +139,26 @@ public class HttpMessageConverterExtractorTests { } @Test - @SuppressWarnings("unchecked") public void cannotRead() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); given(converter.canRead(String.class, contentType)).willReturn(false); - assertThatExceptionOfType(RestClientException.class).isThrownBy(() -> - extractor.extractData(response)); + assertThatExceptionOfType(RestClientException.class).isThrownBy(() -> extractor.extractData(response)); } @Test @SuppressWarnings("unchecked") public void generics() throws IOException { - GenericHttpMessageConverter converter = mock(GenericHttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); String expected = "Foo"; ParameterizedTypeReference> reference = new ParameterizedTypeReference>() {}; Type type = reference.getType(); - extractor = new HttpMessageConverterExtractor>(type, createConverterList(converter)); + + GenericHttpMessageConverter converter = mock(GenericHttpMessageConverter.class); + HttpMessageConverterExtractor extractor = new HttpMessageConverterExtractor>(type, asList(converter)); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); @@ -182,48 +170,28 @@ public class HttpMessageConverterExtractorTests { } @Test // SPR-13592 - @SuppressWarnings("unchecked") public void converterThrowsIOException() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); given(converter.canRead(String.class, contentType)).willReturn(true); given(converter.read(eq(String.class), any(HttpInputMessage.class))).willThrow(IOException.class); - assertThatExceptionOfType(RestClientException.class).isThrownBy(() -> - extractor.extractData(response)) + assertThatExceptionOfType(RestClientException.class).isThrownBy(() -> extractor.extractData(response)) .withMessageContaining("Error while extracting response for type [class java.lang.String] and content type [text/plain]") .withCauseInstanceOf(IOException.class); } @Test // SPR-13592 - @SuppressWarnings("unchecked") public void converterThrowsHttpMessageNotReadableException() throws IOException { - HttpMessageConverter converter = mock(HttpMessageConverter.class); - HttpHeaders responseHeaders = new HttpHeaders(); - MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); - extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); given(converter.canRead(String.class, contentType)).willThrow(HttpMessageNotReadableException.class); - assertThatExceptionOfType(RestClientException.class).isThrownBy(() -> - extractor.extractData(response)) + assertThatExceptionOfType(RestClientException.class).isThrownBy(() -> extractor.extractData(response)) .withMessageContaining("Error while extracting response for type [class java.lang.String] and content type [text/plain]") .withCauseInstanceOf(HttpMessageNotReadableException.class); - - } - - private List> createConverterList(HttpMessageConverter converter) { - List> converters = new ArrayList<>(1); - converters.add(converter); - return converters; } - } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index bc8a3a0b41..62b1c2e1e4 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -49,6 +49,7 @@ import org.springframework.web.util.DefaultUriBuilderFactory; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; @@ -65,9 +66,12 @@ import static org.springframework.http.HttpMethod.PUT; import static org.springframework.http.MediaType.parseMediaType; /** + * Unit tests for {@link RestTemplate}. + * * @author Arjen Poutsma * @author Rossen Stoyanchev * @author Brian Clozel + * @author Sam Brannen */ @SuppressWarnings("unchecked") public class RestTemplateTests { @@ -98,6 +102,25 @@ public class RestTemplateTests { template.setErrorHandler(errorHandler); } + @Test + public void constructorPreconditions() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new RestTemplate((List>) null)) + .withMessage("At least one HttpMessageConverter is required"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new RestTemplate(Arrays.asList(null, this.converter))) + .withMessage("The HttpMessageConverter list must not contain null elements"); + } + + @Test + public void setMessageConvertersPreconditions() { + assertThatIllegalArgumentException() + .isThrownBy(() -> template.setMessageConverters((List>) null)) + .withMessage("At least one HttpMessageConverter is required"); + assertThatIllegalArgumentException() + .isThrownBy(() -> template.setMessageConverters(Arrays.asList(null, this.converter))) + .withMessage("The HttpMessageConverter list must not contain null elements"); + } @Test public void varArgsTemplateVariables() throws Exception {