diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index ddba73594a..e1ba9f80d3 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -17,6 +17,8 @@ package org.springframework.web.client; import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.net.URI; @@ -185,6 +187,61 @@ final class DefaultRestClient implements RestClient { return new DefaultRestClientBuilder(this.builder); } + @SuppressWarnings({"rawtypes", "unchecked"}) + private T readWithMessageConverters(ClientHttpResponse clientResponse, Runnable callback, Type bodyType, Class bodyClass) { + MediaType contentType = getContentType(clientResponse); + + try (clientResponse) { + callback.run(); + + for (HttpMessageConverter messageConverter : this.messageConverters) { + if (messageConverter instanceof GenericHttpMessageConverter genericHttpMessageConverter) { + if (genericHttpMessageConverter.canRead(bodyType, null, contentType)) { + if (logger.isDebugEnabled()) { + logger.debug("Reading to [" + ResolvableType.forType(bodyType) + "]"); + } + return (T) genericHttpMessageConverter.read(bodyType, null, clientResponse); + } + } + if (messageConverter.canRead(bodyClass, contentType)) { + if (logger.isDebugEnabled()) { + logger.debug("Reading to [" + bodyClass.getName() + "] as \"" + contentType + "\""); + } + return (T) messageConverter.read((Class)bodyClass, clientResponse); + } + } + throw new UnknownContentTypeException(bodyType, contentType, + clientResponse.getStatusCode(), clientResponse.getStatusText(), + clientResponse.getHeaders(), RestClientUtils.getBody(clientResponse)); + } + catch (UncheckedIOException | IOException | HttpMessageNotReadableException ex) { + throw new RestClientException("Error while extracting response for type [" + + ResolvableType.forType(bodyType) + "] and content type [" + contentType + "]", ex); + } + } + + private static MediaType getContentType(ClientHttpResponse clientResponse) { + MediaType contentType = clientResponse.getHeaders().getContentType(); + if (contentType == null) { + contentType = MediaType.APPLICATION_OCTET_STREAM; + } + return contentType; + } + + @SuppressWarnings("unchecked") + private static Class bodyClass(Type type) { + if (type instanceof Class clazz) { + return (Class) clazz; + } + if (type instanceof ParameterizedType parameterizedType && + parameterizedType.getRawType() instanceof Class rawType) { + return (Class) rawType; + } + return (Class) Object.class; + } + + + private class DefaultRequestBodyUriSpec implements RequestBodyUriSpec { @@ -409,7 +466,8 @@ final class DefaultRestClient implements RestClient { } clientResponse = clientRequest.execute(); observationContext.setResponse(clientResponse); - return exchangeFunction.exchange(clientRequest, clientResponse); + ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse); + return exchangeFunction.exchange(clientRequest, convertibleWrapper); } catch (IOException ex) { ResourceAccessException resourceAccessException = createResourceAccessException(uri, this.httpMethod, ex); @@ -542,14 +600,14 @@ final class DefaultRestClient implements RestClient { @Override public T body(Class bodyType) { - return readWithMessageConverters(bodyType, bodyType); + return readBody(bodyType, bodyType); } @Override public T body(ParameterizedTypeReference bodyType) { Type type = bodyType.getType(); Class bodyClass = bodyClass(type); - return readWithMessageConverters(type, bodyClass); + return readBody(type, bodyClass); } @Override @@ -565,7 +623,7 @@ final class DefaultRestClient implements RestClient { } private ResponseEntity toEntityInternal(Type bodyType, Class bodyClass) { - T body = readWithMessageConverters(bodyType, bodyClass); + T body = readBody(bodyType, bodyClass); try { return ResponseEntity.status(this.clientResponse.getStatusCode()) .headers(this.clientResponse.getHeaders()) @@ -579,77 +637,96 @@ final class DefaultRestClient implements RestClient { @Override public ResponseEntity toBodilessEntity() { try (this.clientResponse) { - applyStatusHandlers(this.clientRequest, this.clientResponse); + applyStatusHandlers(); return ResponseEntity.status(this.clientResponse.getStatusCode()) .headers(this.clientResponse.getHeaders()) .build(); } + catch (UncheckedIOException ex) { + throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex.getCause()); + } catch (IOException ex) { throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex); } } - @SuppressWarnings("unchecked") - private static Class bodyClass(Type type) { - if (type instanceof Class clazz) { - return (Class) clazz; - } - if (type instanceof ParameterizedType parameterizedType && - parameterizedType.getRawType() instanceof Class rawType) { - return (Class) rawType; - } - return (Class) Object.class; - } - @SuppressWarnings({"rawtypes", "unchecked"}) - private T readWithMessageConverters(Type bodyType, Class bodyClass) { - MediaType contentType = getContentType(); + private T readBody(Type bodyType, Class bodyClass) { + return DefaultRestClient.this.readWithMessageConverters(this.clientResponse, this::applyStatusHandlers, + bodyType, bodyClass); - try (this.clientResponse) { - applyStatusHandlers(this.clientRequest, this.clientResponse); - - for (HttpMessageConverter messageConverter : DefaultRestClient.this.messageConverters) { - if (messageConverter instanceof GenericHttpMessageConverter genericHttpMessageConverter) { - if (genericHttpMessageConverter.canRead(bodyType, null, contentType)) { - if (logger.isDebugEnabled()) { - logger.debug("Reading to [" + ResolvableType.forType(bodyType) + "]"); - } - return (T) genericHttpMessageConverter.read(bodyType, null, this.clientResponse); - } - } - if (messageConverter.canRead(bodyClass, contentType)) { - if (logger.isDebugEnabled()) { - logger.debug("Reading to [" + bodyClass.getName() + "] as \"" + contentType + "\""); - } - return (T) messageConverter.read((Class)bodyClass, this.clientResponse); + } + + private void applyStatusHandlers() { + try { + ClientHttpResponse response = this.clientResponse; + if (response instanceof DefaultConvertibleClientHttpResponse convertibleResponse) { + response = convertibleResponse.delegate; + } + for (StatusHandler handler : this.statusHandlers) { + if (handler.test(response)) { + handler.handle(this.clientRequest, response); + return; } } - throw new UnknownContentTypeException(bodyType, contentType, - this.clientResponse.getStatusCode(), this.clientResponse.getStatusText(), - this.clientResponse.getHeaders(), RestClientUtils.getBody(this.clientResponse)); } - catch (IOException | HttpMessageNotReadableException ex) { - throw new RestClientException("Error while extracting response for type [" + - ResolvableType.forType(bodyType) + "] and content type [" + contentType + "]", ex); + catch (IOException ex) { + throw new UncheckedIOException(ex); } } + } - private MediaType getContentType() { - MediaType contentType = this.clientResponse.getHeaders().getContentType(); - if (contentType == null) { - contentType = MediaType.APPLICATION_OCTET_STREAM; - } - return contentType; + + private class DefaultConvertibleClientHttpResponse implements RequestHeadersSpec.ConvertibleClientHttpResponse { + + private final ClientHttpResponse delegate; + + + public DefaultConvertibleClientHttpResponse(ClientHttpResponse delegate) { + this.delegate = delegate; } - private void applyStatusHandlers(HttpRequest request, ClientHttpResponse response) throws IOException { - for (StatusHandler handler : this.statusHandlers) { - if (handler.test(response)) { - handler.handle(request, response); - return; - } - } + + @Nullable + @Override + public T bodyTo(Class bodyType) { + return readWithMessageConverters(this.delegate, () -> {} , bodyType, bodyType); + } + + @Nullable + @Override + public T bodyTo(ParameterizedTypeReference bodyType) { + Type type = bodyType.getType(); + Class bodyClass = bodyClass(type); + return readWithMessageConverters(this.delegate, () -> {} , type, bodyClass); + } + + @Override + public InputStream getBody() throws IOException { + return this.delegate.getBody(); + } + + @Override + public HttpHeaders getHeaders() { + return this.delegate.getHeaders(); + } + + @Override + public HttpStatusCode getStatusCode() throws IOException { + return this.delegate.getStatusCode(); } + + @Override + public String getStatusText() throws IOException { + return this.delegate.getStatusText(); + } + + @Override + public void close() { + this.delegate.close(); + } + } + } diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClient.java b/spring-web/src/main/java/org/springframework/web/client/RestClient.java index 30293c3c9c..0b7751e38e 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestClient.java @@ -623,7 +623,33 @@ public interface RestClient { * @return the exchanged type * @throws IOException in case of I/O errors */ - T exchange(HttpRequest clientRequest, ClientHttpResponse clientResponse) throws IOException; + T exchange(HttpRequest clientRequest, ConvertibleClientHttpResponse clientResponse) throws IOException; + } + + + /** + * Extension of {@link ClientHttpResponse} that can convert the body. + */ + interface ConvertibleClientHttpResponse extends ClientHttpResponse { + + /** + * Extract the response body as an object of the given type. + * @param bodyType the type of return value + * @param the body type + * @return the body, or {@code null} if no response body was available + */ + @Nullable + T bodyTo(Class bodyType); + + /** + * Extract the response body as an object of the given type. + * @param bodyType the type of return value + * @param the body type + * @return the body, or {@code null} if no response body was available + */ + @Nullable + T bodyTo(ParameterizedTypeReference bodyType); + } } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index e865758457..de4b46a5c2 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -661,6 +661,55 @@ class RestClientIntegrationTests { }); } + @ParameterizedRestClientTest + void exchangeForJson(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response + .setHeader("Content-Type", "application/json") + .setBody("{\"bar\":\"barbar\",\"foo\":\"foofoo\"}")); + + Pojo result = this.restClient.get() + .uri("/pojo") + .accept(MediaType.APPLICATION_JSON) + .exchange((request, response) -> response.bodyTo(Pojo.class)); + + assertThat(result.getFoo()).isEqualTo("foofoo"); + assertThat(result.getBar()).isEqualTo("barbar"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/pojo"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + + @ParameterizedRestClientTest + void exchangeForJsonArray(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response + .setHeader("Content-Type", "application/json") + .setBody("[{\"bar\":\"bar1\",\"foo\":\"foo1\"},{\"bar\":\"bar2\",\"foo\":\"foo2\"}]")); + + List result = this.restClient.get() + .uri("/pojo") + .accept(MediaType.APPLICATION_JSON) + .exchange((request, response) -> response.bodyTo(new ParameterizedTypeReference<>() {})); + + assertThat(result).hasSize(2); + assertThat(result.get(0).getFoo()).isEqualTo("foo1"); + assertThat(result.get(0).getBar()).isEqualTo("bar1"); + assertThat(result.get(1).getFoo()).isEqualTo("foo2"); + assertThat(result.get(1).getBar()).isEqualTo("bar2"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/pojo"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + }); + } + @ParameterizedRestClientTest void exchangeFor404(ClientHttpRequestFactory requestFactory) { startServer(requestFactory);