diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index 45758ec0d9..0e6b46a1e2 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import javax.servlet.http.HttpServletResponse; @@ -152,12 +153,14 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @Override @Nullable public String getFirst(String headerName) { - String value = servletResponse.getHeader(headerName); - if (value != null) { - return value; + if (headerName.equalsIgnoreCase(CONTENT_TYPE)) { + // Content-Type is written as an override so check super first + String value = super.getFirst(headerName); + return (value != null ? value : servletResponse.getHeader(headerName)); } else { - return super.getFirst(headerName); + String value = servletResponse.getHeader(headerName); + return (value != null ? value : super.getFirst(headerName)); } } @@ -165,7 +168,13 @@ public class ServletServerHttpResponse implements ServerHttpResponse { public List get(Object key) { Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); - Collection values1 = servletResponse.getHeaders((String) key); + String headerName = (String) key; + if (headerName.equalsIgnoreCase(CONTENT_TYPE)) { + // Content-Type is written as an override so don't merge + return Collections.singletonList(getFirst(headerName)); + } + + Collection values1 = servletResponse.getHeaders(headerName); if (headersWritten) { return new ArrayList<>(values1); } diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java index 4f37e8d193..3dbef26f1b 100644 --- a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ package org.springframework.http.server; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import org.junit.jupiter.api.BeforeEach; @@ -44,20 +43,20 @@ public class ServletServerHttpResponseTests { @BeforeEach - public void create() throws Exception { + void create() { mockResponse = new MockHttpServletResponse(); response = new ServletServerHttpResponse(mockResponse); } @Test - public void setStatusCode() throws Exception { + void setStatusCode() { response.setStatusCode(HttpStatus.NOT_FOUND); assertThat(mockResponse.getStatus()).as("Invalid status code").isEqualTo(404); } @Test - public void getHeaders() throws Exception { + void getHeaders() { HttpHeaders headers = response.getHeaders(); String headerName = "MyHeader"; String headerValue1 = "value1"; @@ -77,23 +76,32 @@ public class ServletServerHttpResponseTests { } @Test - public void preExistingHeadersFromHttpServletResponse() { + void preExistingHeadersFromHttpServletResponse() { String headerName = "Access-Control-Allow-Origin"; String headerValue = "localhost:8080"; this.mockResponse.addHeader(headerName, headerValue); + this.mockResponse.setContentType("text/csv"); this.response = new ServletServerHttpResponse(this.mockResponse); assertThat(this.response.getHeaders().getFirst(headerName)).isEqualTo(headerValue); - assertThat(this.response.getHeaders().get(headerName)).isEqualTo(Collections.singletonList(headerValue)); - assertThat(this.response.getHeaders().containsKey(headerName)).isTrue(); - assertThat(this.response.getHeaders().getFirst(headerName)).isEqualTo(headerValue); + assertThat(this.response.getHeaders().get(headerName)).containsExactly(headerValue); + assertThat(this.response.getHeaders()).containsKey(headerName); assertThat(this.response.getHeaders().getAccessControlAllowOrigin()).isEqualTo(headerValue); } + @Test // gh-25490 + void preExistingContentTypeIsOverriddenImmediately() { + this.mockResponse.setContentType("text/csv"); + this.response = new ServletServerHttpResponse(this.mockResponse); + this.response.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_JSON); + } + @Test - public void getBody() throws Exception { - byte[] content = "Hello World".getBytes("UTF-8"); + void getBody() throws Exception { + byte[] content = "Hello World".getBytes(StandardCharsets.UTF_8); FileCopyUtils.copy(content, response.getBody()); assertThat(mockResponse.getContentAsByteArray()).as("Invalid content written").isEqualTo(content);