diff --git a/spring-web/src/main/java/org/springframework/http/HttpStatus.java b/spring-web/src/main/java/org/springframework/http/HttpStatus.java index 780376406c..4cf7b7d667 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpStatus.java +++ b/spring-web/src/main/java/org/springframework/http/HttpStatus.java @@ -16,6 +16,8 @@ package org.springframework.http; +import org.springframework.lang.Nullable; + /** * Enumeration of HTTP status codes. * @@ -499,12 +501,28 @@ public enum HttpStatus { * @throws IllegalArgumentException if this enum has no constant for the specified numeric value */ public static HttpStatus valueOf(int statusCode) { + HttpStatus status = resolve(statusCode); + if (status == null) { + throw new IllegalArgumentException("No matching constant for [" + statusCode + "]"); + } + return status; + } + + + /** + * Resolve the given status code to an {@code HttpStatus}, if possible. + * @param statusCode the HTTP status code (potentially non-standard) + * @return the corresponding {@code HttpStatus}, or {@code null} if not found + * @since 5.0 + */ + @Nullable + public static HttpStatus resolve(int statusCode) { for (HttpStatus status : values()) { if (status.value == statusCode) { return status; } } - throw new IllegalArgumentException("No matching constant for [" + statusCode + "]"); + return null; } diff --git a/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java index 05fa7ddd58..921e367406 100644 --- a/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/client/ClientHttpResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,13 +38,19 @@ public interface ClientHttpResponse extends HttpInputMessage, Closeable { * Return the HTTP status code of the response. * @return the HTTP status as an HttpStatus enum value * @throws IOException in case of I/O errors + * @throws IllegalArgumentException in case of an unknown HTTP status code + * @see HttpStatus#valueOf(int) */ HttpStatus getStatusCode() throws IOException; /** - * Return the HTTP status code of the response as integer + * Return the HTTP status code (potentially non-standard and not + * resolvable through the {@link HttpStatus} enum) as an integer. * @return the HTTP status as an integer * @throws IOException in case of I/O errors + * @since 3.1.1 + * @see #getStatusCode() + * @see HttpStatus#resolve(int) */ int getRawStatusCode() throws IOException; diff --git a/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java index 79188ce8e2..e3d1da5703 100644 --- a/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/AsyncRestTemplate.java @@ -523,7 +523,7 @@ public class AsyncRestTemplate extends org.springframework.http.client.support.I if (logger.isDebugEnabled()) { try { logger.debug("Async " + method.name() + " request for \"" + url + "\" resulted in " + - response.getStatusCode() + " (" + response.getStatusText() + ")"); + response.getRawStatusCode() + " (" + response.getStatusText() + ")"); } catch (IOException ex) { // ignore @@ -535,7 +535,7 @@ public class AsyncRestTemplate extends org.springframework.http.client.support.I if (logger.isWarnEnabled()) { try { logger.warn("Async " + method.name() + " request for \"" + url + "\" resulted in " + - response.getStatusCode() + " (" + response.getStatusText() + "); invoking error handler"); + response.getRawStatusCode() + " (" + response.getStatusText() + "); invoking error handler"); } catch (IOException ex) { // ignore diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java index b5dfec7dcb..9c2097f3ab 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultResponseErrorHandler.java @@ -46,18 +46,46 @@ public class DefaultResponseErrorHandler implements ResponseErrorHandler { */ @Override public boolean hasError(ClientHttpResponse response) throws IOException { - return hasError(getHttpStatusCode(response)); + HttpStatus statusCode = HttpStatus.resolve(response.getRawStatusCode()); + return (statusCode != null && hasError(statusCode)); } /** - * This default implementation throws a {@link HttpClientErrorException} if the response status code + * Template method called from {@link #hasError(ClientHttpResponse)}. + *

The default implementation checks if the given status code is + * {@link HttpStatus.Series#CLIENT_ERROR CLIENT_ERROR} or + * {@link HttpStatus.Series#SERVER_ERROR SERVER_ERROR}. + * Can be overridden in subclasses. + * @param statusCode the HTTP status code + * @return {@code true} if the response has an error; {@code false} otherwise + */ + protected boolean hasError(HttpStatus statusCode) { + return (statusCode.series() == HttpStatus.Series.CLIENT_ERROR || + statusCode.series() == HttpStatus.Series.SERVER_ERROR); + } + + /** + * Delegates to {@link #handleError(ClientHttpResponse, HttpStatus)} with the response status code. + */ + @Override + public void handleError(ClientHttpResponse response) throws IOException { + HttpStatus statusCode = HttpStatus.resolve(response.getRawStatusCode()); + if (statusCode == null) { + throw new UnknownHttpStatusCodeException(response.getRawStatusCode(), response.getStatusText(), + response.getHeaders(), getResponseBody(response), getCharset(response)); + } + handleError(response, statusCode); + } + + /** + * Handle the error in the given response with the given resolved status code. + *

This default implementation throws a {@link HttpClientErrorException} if the response status code * is {@link org.springframework.http.HttpStatus.Series#CLIENT_ERROR}, a {@link HttpServerErrorException} * if it is {@link org.springframework.http.HttpStatus.Series#SERVER_ERROR}, * and a {@link RestClientException} in other cases. + * @since 5.0 */ - @Override - public void handleError(ClientHttpResponse response) throws IOException { - HttpStatus statusCode = getHttpStatusCode(response); + protected void handleError(ClientHttpResponse response, HttpStatus statusCode) throws IOException { switch (statusCode.series()) { case CLIENT_ERROR: throw new HttpClientErrorException(statusCode, response.getStatusText(), @@ -66,7 +94,8 @@ public class DefaultResponseErrorHandler implements ResponseErrorHandler { throw new HttpServerErrorException(statusCode, response.getStatusText(), response.getHeaders(), getResponseBody(response), getCharset(response)); default: - throw new RestClientException("Unknown status code [" + statusCode + "]"); + throw new UnknownHttpStatusCodeException(statusCode.value(), response.getStatusText(), + response.getHeaders(), getResponseBody(response), getCharset(response)); } } @@ -79,43 +108,16 @@ public class DefaultResponseErrorHandler implements ResponseErrorHandler { * @throws UnknownHttpStatusCodeException in case of an unknown status code * that cannot be represented with the {@link HttpStatus} enum * @since 4.3.8 + * @deprecated as of 5.0, in favor of {@link #handleError(ClientHttpResponse, HttpStatus)} */ + @Deprecated protected HttpStatus getHttpStatusCode(ClientHttpResponse response) throws IOException { - try { - return response.getStatusCode(); - } - catch (IllegalArgumentException ex) { + HttpStatus statusCode = HttpStatus.resolve(response.getRawStatusCode()); + if (statusCode == null) { throw new UnknownHttpStatusCodeException(response.getRawStatusCode(), response.getStatusText(), response.getHeaders(), getResponseBody(response), getCharset(response)); } - } - - /** - * Template method called from {@link #hasError(ClientHttpResponse)}. - *

The default implementation checks if the given status code is - * {@link org.springframework.http.HttpStatus.Series#CLIENT_ERROR CLIENT_ERROR} - * or {@link org.springframework.http.HttpStatus.Series#SERVER_ERROR SERVER_ERROR}. - * Can be overridden in subclasses. - * @param statusCode the HTTP status code - * @return {@code true} if the response has an error; {@code false} otherwise - * @see #getHttpStatusCode(ClientHttpResponse) - */ - protected boolean hasError(HttpStatus statusCode) { - return (statusCode.series() == HttpStatus.Series.CLIENT_ERROR || - statusCode.series() == HttpStatus.Series.SERVER_ERROR); - } - - /** - * Determine the charset of the response (for inclusion in a status exception). - * @param response the response to inspect - * @return the associated charset, or {@code null} if none - * @since 4.3.8 - */ - @Nullable - protected Charset getCharset(ClientHttpResponse response) { - HttpHeaders headers = response.getHeaders(); - MediaType contentType = headers.getContentType(); - return (contentType != null ? contentType.getCharset() : null); + return statusCode; } /** @@ -135,4 +137,17 @@ public class DefaultResponseErrorHandler implements ResponseErrorHandler { return new byte[0]; } + /** + * Determine the charset of the response (for inclusion in a status exception). + * @param response the response to inspect + * @return the associated charset, or {@code null} if none + * @since 4.3.8 + */ + @Nullable + protected Charset getCharset(ClientHttpResponse response) { + HttpHeaders headers = response.getHeaders(); + MediaType contentType = headers.getContentType(); + return (contentType != null ? contentType.getCharset() : null); + } + } diff --git a/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java b/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java index f26e32aa91..73ab71a5b5 100644 --- a/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java +++ b/spring-web/src/main/java/org/springframework/web/client/ExtractingResponseErrorHandler.java @@ -132,8 +132,7 @@ public class ExtractingResponseErrorHandler extends DefaultResponseErrorHandler } @Override - public void handleError(ClientHttpResponse response) throws IOException { - HttpStatus statusCode = getHttpStatusCode(response); + public void handleError(ClientHttpResponse response, HttpStatus statusCode) throws IOException { if (this.statusMapping.containsKey(statusCode)) { extract(this.statusMapping.get(statusCode), response); } @@ -141,7 +140,7 @@ public class ExtractingResponseErrorHandler extends DefaultResponseErrorHandler extract(this.seriesMapping.get(statusCode.series()), response); } else { - super.handleError(response); + super.handleError(response, statusCode); } } diff --git a/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java index 9ebe5651cb..5864863e4e 100644 --- a/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/client/MessageBodyClientHttpResponseWrapper.java @@ -58,12 +58,12 @@ class MessageBodyClientHttpResponseWrapper implements ClientHttpResponse { * @throws IOException in case of I/O errors */ public boolean hasMessageBody() throws IOException { - HttpStatus responseStatus = this.getStatusCode(); - if (responseStatus.is1xxInformational() || responseStatus == HttpStatus.NO_CONTENT || - responseStatus == HttpStatus.NOT_MODIFIED) { + HttpStatus status = HttpStatus.resolve(getRawStatusCode()); + if (status != null && (status.is1xxInformational() || status == HttpStatus.NO_CONTENT || + status == HttpStatus.NOT_MODIFIED)) { return false; } - else if (getHeaders().getContentLength() == 0) { + if (getHeaders().getContentLength() == 0) { return false; } return true; diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 44ccdd0819..171b441489 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -969,10 +969,10 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat public ResponseEntity extractData(ClientHttpResponse response) throws IOException { if (this.delegate != null) { T body = this.delegate.extractData(response); - return new ResponseEntity<>(body, response.getHeaders(), response.getStatusCode()); + return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).body(body); } else { - return new ResponseEntity<>(response.getHeaders(), response.getStatusCode()); + return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).build(); } } } diff --git a/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java b/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java index 9b01977052..0eb4f37f37 100644 --- a/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java +++ b/spring-web/src/main/java/org/springframework/web/client/UnknownHttpStatusCodeException.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,12 +38,12 @@ public class UnknownHttpStatusCodeException extends RestClientResponseException * {@link HttpStatus}, status text, and response body content. * @param rawStatusCode the raw status code value * @param statusText the status text - * @param responseHeaders the response headers, may be {@code null} - * @param responseBody the response body content, may be {@code null} - * @param responseCharset the response body charset, may be {@code null} + * @param responseHeaders the response headers (may be {@code null}) + * @param responseBody the response body content (may be {@code null}) + * @param responseCharset the response body charset (may be {@code null}) */ - public UnknownHttpStatusCodeException(int rawStatusCode, String statusText, - @Nullable HttpHeaders responseHeaders, @Nullable byte[] responseBody,@Nullable Charset responseCharset) { + public UnknownHttpStatusCodeException(int rawStatusCode, String statusText, @Nullable HttpHeaders responseHeaders, + @Nullable byte[] responseBody, @Nullable Charset responseCharset) { super("Unknown status code [" + String.valueOf(rawStatusCode) + "]" + " " + statusText, rawStatusCode, statusText, responseHeaders, responseBody, responseCharset); diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java index cf1be93a38..40bb6981c6 100644 --- a/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultResponseErrorHandlerTests.java @@ -43,13 +43,13 @@ public class DefaultResponseErrorHandlerTests { @Test public void hasErrorTrue() throws Exception { - given(response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); assertTrue(handler.hasError(response)); } @Test public void hasErrorFalse() throws Exception { - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); assertFalse(handler.hasError(response)); } @@ -58,7 +58,7 @@ public class DefaultResponseErrorHandlerTests { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.TEXT_PLAIN); - given(response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); given(response.getStatusText()).willReturn("Not Found"); given(response.getHeaders()).willReturn(headers); given(response.getBody()).willReturn(new ByteArrayInputStream("Hello World".getBytes("UTF-8"))); @@ -77,7 +77,7 @@ public class DefaultResponseErrorHandlerTests { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.TEXT_PLAIN); - given(response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); given(response.getStatusText()).willReturn("Not Found"); given(response.getHeaders()).willReturn(headers); given(response.getBody()).willThrow(new IOException()); @@ -90,7 +90,7 @@ public class DefaultResponseErrorHandlerTests { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.TEXT_PLAIN); - given(response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); given(response.getStatusText()).willReturn("Not Found"); given(response.getHeaders()).willReturn(headers); @@ -102,7 +102,6 @@ public class DefaultResponseErrorHandlerTests { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.TEXT_PLAIN); - given(response.getStatusCode()).willThrow(new IllegalArgumentException("No matching constant for 999")); given(response.getRawStatusCode()).willReturn(999); given(response.getStatusText()).willReturn("Custom status code"); given(response.getHeaders()).willReturn(headers); diff --git a/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java b/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java index bc7250f093..9266026512 100644 --- a/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/ExtractingResponseErrorHandlerTests.java @@ -58,13 +58,13 @@ public class ExtractingResponseErrorHandlerTests { @Test public void hasError() throws Exception { - given(this.response.getStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT.value()); assertTrue(this.errorHandler.hasError(this.response)); - given(this.response.getStatusCode()).willReturn(HttpStatus.INTERNAL_SERVER_ERROR); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.INTERNAL_SERVER_ERROR.value()); assertTrue(this.errorHandler.hasError(this.response)); - given(this.response.getStatusCode()).willReturn(HttpStatus.OK); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); assertFalse(this.errorHandler.hasError(this.response)); } @@ -73,19 +73,19 @@ public class ExtractingResponseErrorHandlerTests { this.errorHandler.setSeriesMapping(Collections .singletonMap(HttpStatus.Series.CLIENT_ERROR, null)); - given(this.response.getStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT.value()); assertTrue(this.errorHandler.hasError(this.response)); - given(this.response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); assertFalse(this.errorHandler.hasError(this.response)); - given(this.response.getStatusCode()).willReturn(HttpStatus.OK); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); assertFalse(this.errorHandler.hasError(this.response)); } @Test public void handleErrorStatusMatch() throws Exception { - given(this.response.getStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.I_AM_A_TEAPOT.value()); HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(MediaType.APPLICATION_JSON); given(this.response.getHeaders()).willReturn(responseHeaders); @@ -105,7 +105,7 @@ public class ExtractingResponseErrorHandlerTests { @Test public void handleErrorSeriesMatch() throws Exception { - given(this.response.getStatusCode()).willReturn(HttpStatus.INTERNAL_SERVER_ERROR); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.INTERNAL_SERVER_ERROR.value()); HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(MediaType.APPLICATION_JSON); given(this.response.getHeaders()).willReturn(responseHeaders); @@ -125,7 +125,7 @@ public class ExtractingResponseErrorHandlerTests { @Test public void handleNoMatch() throws Exception { - given(this.response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(MediaType.APPLICATION_JSON); given(this.response.getHeaders()).willReturn(responseHeaders); @@ -149,7 +149,7 @@ public class ExtractingResponseErrorHandlerTests { this.errorHandler.setSeriesMapping(Collections .singletonMap(HttpStatus.Series.CLIENT_ERROR, null)); - given(this.response.getStatusCode()).willReturn(HttpStatus.NOT_FOUND); + given(this.response.getRawStatusCode()).willReturn(HttpStatus.NOT_FOUND.value()); HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(MediaType.APPLICATION_JSON); given(this.response.getHeaders()).willReturn(responseHeaders); diff --git a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java index eafa4cd2e7..01eadffef9 100644 --- a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.List; import org.hamcrest.Matchers; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -51,24 +50,19 @@ public class HttpMessageConverterExtractorTests { private HttpMessageConverterExtractor extractor; - private ClientHttpResponse response; + private final ClientHttpResponse response = mock(ClientHttpResponse.class); @Rule public final ExpectedException exception = ExpectedException.none(); - @Before - public void createMocks() { - response = mock(ClientHttpResponse.class); - } @Test public void noContent() throws IOException { HttpMessageConverter converter = mock(HttpMessageConverter.class); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.NO_CONTENT); + given(response.getRawStatusCode()).willReturn(HttpStatus.NO_CONTENT.value()); Object result = extractor.extractData(response); - assertNull(result); } @@ -76,10 +70,9 @@ public class HttpMessageConverterExtractorTests { public void notModified() throws IOException { HttpMessageConverter converter = mock(HttpMessageConverter.class); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.NOT_MODIFIED); + given(response.getRawStatusCode()).willReturn(HttpStatus.NOT_MODIFIED.value()); Object result = extractor.extractData(response); - assertNull(result); } @@ -87,10 +80,9 @@ public class HttpMessageConverterExtractorTests { public void informational() throws IOException { HttpMessageConverter converter = mock(HttpMessageConverter.class); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.CONTINUE); + given(response.getRawStatusCode()).willReturn(HttpStatus.CONTINUE.value()); Object result = extractor.extractData(response); - assertNull(result); } @@ -100,11 +92,10 @@ public class HttpMessageConverterExtractorTests { HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentLength(0); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); Object result = extractor.extractData(response); - assertNull(result); } @@ -114,7 +105,7 @@ public class HttpMessageConverterExtractorTests { HttpMessageConverter converter = mock(HttpMessageConverter.class); HttpHeaders responseHeaders = new HttpHeaders(); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("".getBytes())); @@ -131,14 +122,13 @@ public class HttpMessageConverterExtractorTests { responseHeaders.setContentType(contentType); String expected = "Foo"; extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); given(converter.canRead(String.class, contentType)).willReturn(true); given(converter.read(eq(String.class), any(HttpInputMessage.class))).willReturn(expected); Object result = extractor.extractData(response); - assertEquals(expected, result); } @@ -150,7 +140,7 @@ public class HttpMessageConverterExtractorTests { MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); given(converter.canRead(String.class, contentType)).willReturn(false); @@ -170,7 +160,7 @@ public class HttpMessageConverterExtractorTests { ParameterizedTypeReference> reference = new ParameterizedTypeReference>() {}; Type type = reference.getType(); extractor = new HttpMessageConverterExtractor>(type, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); given(converter.canRead(type, null, contentType)).willReturn(true); @@ -181,7 +171,7 @@ public class HttpMessageConverterExtractorTests { assertEquals(expected, result); } - @Test // SPR-13592 + @Test // SPR-13592 @SuppressWarnings("unchecked") public void converterThrowsIOException() throws IOException { HttpMessageConverter converter = mock(HttpMessageConverter.class); @@ -189,7 +179,7 @@ public class HttpMessageConverterExtractorTests { MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); given(converter.canRead(String.class, contentType)).willThrow(IOException.class); @@ -201,7 +191,7 @@ public class HttpMessageConverterExtractorTests { extractor.extractData(response); } - @Test // SPR-13592 + @Test // SPR-13592 @SuppressWarnings("unchecked") public void converterThrowsHttpMessageNotReadableException() throws IOException { HttpMessageConverter converter = mock(HttpMessageConverter.class); @@ -209,7 +199,7 @@ public class HttpMessageConverterExtractorTests { MediaType contentType = MediaType.TEXT_PLAIN; responseHeaders.setContentType(contentType); extractor = new HttpMessageConverterExtractor<>(String.class, createConverterList(converter)); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream("Foobar".getBytes())); given(converter.canRead(String.class, contentType)).willThrow(HttpMessageNotReadableException.class); @@ -221,8 +211,7 @@ public class HttpMessageConverterExtractorTests { extractor.extractData(response); } - private List> createConverterList( - HttpMessageConverter converter) { + private List> createConverterList(HttpMessageConverter converter) { List> converters = new ArrayList<>(1); converters.add(converter); return converters; diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index 6508cacdb7..a96e677106 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -45,18 +45,9 @@ import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.StreamUtils; import org.springframework.web.util.DefaultUriBuilderFactory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.fail; -import static org.mockito.BDDMockito.any; -import static org.mockito.BDDMockito.eq; -import static org.mockito.BDDMockito.given; -import static org.mockito.BDDMockito.mock; -import static org.mockito.BDDMockito.verify; -import static org.mockito.BDDMockito.willThrow; -import static org.springframework.http.MediaType.parseMediaType; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; +import static org.springframework.http.MediaType.*; /** * @author Arjen Poutsma @@ -272,15 +263,12 @@ public class RestTemplateTests { HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(textPlain); responseHeaders.setContentLength(10); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getStatusText()).willReturn(HttpStatus.OK.getReasonPhrase()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); given(converter.canRead(String.class, textPlain)).willReturn(true); given(converter.read(eq(String.class), any(HttpInputMessage.class))).willReturn(expected); - given(response.getStatusCode()).willReturn(HttpStatus.OK); - HttpStatus status = HttpStatus.OK; - given(response.getStatusCode()).willReturn(status); - given(response.getStatusText()).willReturn(status.getReasonPhrase()); ResponseEntity result = template.getForEntity("http://example.com", String.class); assertEquals("Invalid GET result", expected, result.getBody()); @@ -505,15 +493,12 @@ public class RestTemplateTests { HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(textPlain); responseHeaders.setContentLength(10); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getStatusText()).willReturn(HttpStatus.OK.getReasonPhrase()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.toString().getBytes())); given(converter.canRead(Integer.class, textPlain)).willReturn(true); given(converter.read(eq(Integer.class), any(HttpInputMessage.class))).willReturn(expected); - given(response.getStatusCode()).willReturn(HttpStatus.OK); - HttpStatus status = HttpStatus.OK; - given(response.getStatusCode()).willReturn(status); - given(response.getStatusText()).willReturn(status.getReasonPhrase()); ResponseEntity result = template.postForEntity("http://example.com", request, Integer.class); assertEquals("Invalid POST result", expected, result.getBody()); @@ -567,14 +552,11 @@ public class RestTemplateTests { responseHeaders.setContentType(textPlain); responseHeaders.setContentLength(10); given(response.getHeaders()).willReturn(responseHeaders); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getStatusText()).willReturn(HttpStatus.OK.getReasonPhrase()); given(response.getBody()).willReturn(StreamUtils.emptyInput()); given(converter.canRead(Integer.class, textPlain)).willReturn(true); given(converter.read(Integer.class, response)).willReturn(null); - given(response.getStatusCode()).willReturn(HttpStatus.OK); - HttpStatus status = HttpStatus.OK; - given(response.getStatusCode()).willReturn(status); - given(response.getStatusText()).willReturn(status.getReasonPhrase()); ResponseEntity result = template.postForEntity("http://example.com", null, Integer.class); assertFalse("Invalid POST result", result.hasBody()); @@ -777,16 +759,13 @@ public class RestTemplateTests { HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(MediaType.TEXT_PLAIN); responseHeaders.setContentLength(10); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getStatusText()).willReturn(HttpStatus.OK.getReasonPhrase()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(expected.toString().getBytes())); given(converter.canRead(Integer.class, MediaType.TEXT_PLAIN)).willReturn(true); given(converter.read(Integer.class, response)).willReturn(expected); given(converter.read(eq(Integer.class), any(HttpInputMessage.class))).willReturn(expected); - given(response.getStatusCode()).willReturn(HttpStatus.OK); - HttpStatus status = HttpStatus.OK; - given(response.getStatusCode()).willReturn(status); - given(response.getStatusText()).willReturn(status.getReasonPhrase()); HttpHeaders entityHeaders = new HttpHeaders(); entityHeaders.set("MyHeader", "MyValue"); @@ -822,15 +801,12 @@ public class RestTemplateTests { HttpHeaders responseHeaders = new HttpHeaders(); responseHeaders.setContentType(MediaType.TEXT_PLAIN); responseHeaders.setContentLength(10); - given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getRawStatusCode()).willReturn(HttpStatus.OK.value()); + given(response.getStatusText()).willReturn(HttpStatus.OK.getReasonPhrase()); given(response.getHeaders()).willReturn(responseHeaders); given(response.getBody()).willReturn(new ByteArrayInputStream(Integer.toString(42).getBytes())); given(converter.canRead(intList.getType(), null, MediaType.TEXT_PLAIN)).willReturn(true); given(converter.read(eq(intList.getType()), eq(null), any(HttpInputMessage.class))).willReturn(expected); - given(response.getStatusCode()).willReturn(HttpStatus.OK); - HttpStatus status = HttpStatus.OK; - given(response.getStatusCode()).willReturn(status); - given(response.getStatusText()).willReturn(status.getReasonPhrase()); HttpHeaders entityHeaders = new HttpHeaders(); entityHeaders.set("MyHeader", "MyValue"); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java index d4ab42b89f..3d21b2c86f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -38,6 +38,7 @@ import org.springframework.web.client.RequestCallback; import org.springframework.web.client.ResponseExtractor; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; +import org.springframework.web.client.UnknownHttpStatusCodeException; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; @@ -154,7 +155,7 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport { private final static ResponseExtractor> textResponseExtractor = response -> { String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET); - return new ResponseEntity<>(body, response.getHeaders(), response.getStatusCode()); + return ResponseEntity.status(response.getRawStatusCode()).headers(response.getHeaders()).body(body); }; @@ -200,14 +201,22 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport { @Override public Object extractData(ClientHttpResponse response) throws IOException { - if (!HttpStatus.OK.equals(response.getStatusCode())) { - throw new HttpServerErrorException(response.getStatusCode()); + HttpStatus httpStatus = HttpStatus.resolve(response.getRawStatusCode()); + if (httpStatus == null) { + throw new UnknownHttpStatusCodeException( + response.getRawStatusCode(), response.getStatusText(), response.getHeaders(), null, null); } + if (httpStatus != HttpStatus.OK) { + throw new HttpServerErrorException( + httpStatus, response.getStatusText(), response.getHeaders(), null, null); + } + if (logger.isTraceEnabled()) { logger.trace("XHR receive headers: " + response.getHeaders()); } InputStream is = response.getBody(); ByteArrayOutputStream os = new ByteArrayOutputStream(); + while (true) { if (this.sockJsSession.isDisconnected()) { if (logger.isDebugEnabled()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java index 53d6cfebe2..69e7fd888b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,6 @@ import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingDeque; -import org.junit.Before; import org.junit.Test; import org.springframework.core.task.SyncTaskExecutor; @@ -68,13 +67,7 @@ public class RestTemplateXhrTransportTests { private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); - private WebSocketHandler webSocketHandler; - - - @Before - public void setup() throws Exception { - this.webSocketHandler = mock(WebSocketHandler.class); - } + private final WebSocketHandler webSocketHandler = mock(WebSocketHandler.class); @Test @@ -192,7 +185,7 @@ public class RestTemplateXhrTransportTests { private ClientHttpResponse response(HttpStatus status, String body) throws IOException { ClientHttpResponse response = mock(ClientHttpResponse.class); InputStream inputStream = getInputStream(body); - given(response.getStatusCode()).willReturn(status); + given(response.getRawStatusCode()).willReturn(status.value()); given(response.getBody()).willReturn(inputStream); return response; }