Browse Source

ServletServerHttpResponse reflects Content-Type override

Closes gh-25490
pull/26211/head
Rossen Stoyanchev 4 years ago
parent
commit
8ac39a50fe
  1. 19
      spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java
  2. 30
      spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java

19
spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java

@ -20,6 +20,7 @@ import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -152,12 +153,14 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
@Override @Override
@Nullable @Nullable
public String getFirst(String headerName) { public String getFirst(String headerName) {
String value = servletResponse.getHeader(headerName); if (headerName.equalsIgnoreCase(CONTENT_TYPE)) {
if (value != null) { // Content-Type is written as an override so check super first
return value; String value = super.getFirst(headerName);
return (value != null ? value : servletResponse.getHeader(headerName));
} }
else { else {
return super.getFirst(headerName); String value = servletResponse.getHeader(headerName);
return (value != null ? value : super.getFirst(headerName));
} }
} }
@ -165,7 +168,13 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
public List<String> get(Object key) { public List<String> get(Object key) {
Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); Assert.isInstanceOf(String.class, key, "Key must be a String-based header name");
Collection<String> values1 = servletResponse.getHeaders((String) key); String headerName = (String) key;
if (headerName.equalsIgnoreCase(CONTENT_TYPE)) {
// Content-Type is written as an override so don't merge
return Collections.singletonList(getFirst(headerName));
}
Collection<String> values1 = servletResponse.getHeaders(headerName);
if (headersWritten) { if (headersWritten) {
return new ArrayList<>(values1); return new ArrayList<>(values1);
} }

30
spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,7 +17,6 @@
package org.springframework.http.server; package org.springframework.http.server;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List; import java.util.List;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -44,20 +43,20 @@ public class ServletServerHttpResponseTests {
@BeforeEach @BeforeEach
public void create() throws Exception { void create() {
mockResponse = new MockHttpServletResponse(); mockResponse = new MockHttpServletResponse();
response = new ServletServerHttpResponse(mockResponse); response = new ServletServerHttpResponse(mockResponse);
} }
@Test @Test
public void setStatusCode() throws Exception { void setStatusCode() {
response.setStatusCode(HttpStatus.NOT_FOUND); response.setStatusCode(HttpStatus.NOT_FOUND);
assertThat(mockResponse.getStatus()).as("Invalid status code").isEqualTo(404); assertThat(mockResponse.getStatus()).as("Invalid status code").isEqualTo(404);
} }
@Test @Test
public void getHeaders() throws Exception { void getHeaders() {
HttpHeaders headers = response.getHeaders(); HttpHeaders headers = response.getHeaders();
String headerName = "MyHeader"; String headerName = "MyHeader";
String headerValue1 = "value1"; String headerValue1 = "value1";
@ -77,23 +76,32 @@ public class ServletServerHttpResponseTests {
} }
@Test @Test
public void preExistingHeadersFromHttpServletResponse() { void preExistingHeadersFromHttpServletResponse() {
String headerName = "Access-Control-Allow-Origin"; String headerName = "Access-Control-Allow-Origin";
String headerValue = "localhost:8080"; String headerValue = "localhost:8080";
this.mockResponse.addHeader(headerName, headerValue); this.mockResponse.addHeader(headerName, headerValue);
this.mockResponse.setContentType("text/csv");
this.response = new ServletServerHttpResponse(this.mockResponse); this.response = new ServletServerHttpResponse(this.mockResponse);
assertThat(this.response.getHeaders().getFirst(headerName)).isEqualTo(headerValue); assertThat(this.response.getHeaders().getFirst(headerName)).isEqualTo(headerValue);
assertThat(this.response.getHeaders().get(headerName)).isEqualTo(Collections.singletonList(headerValue)); assertThat(this.response.getHeaders().get(headerName)).containsExactly(headerValue);
assertThat(this.response.getHeaders().containsKey(headerName)).isTrue(); assertThat(this.response.getHeaders()).containsKey(headerName);
assertThat(this.response.getHeaders().getFirst(headerName)).isEqualTo(headerValue);
assertThat(this.response.getHeaders().getAccessControlAllowOrigin()).isEqualTo(headerValue); assertThat(this.response.getHeaders().getAccessControlAllowOrigin()).isEqualTo(headerValue);
} }
@Test // gh-25490
void preExistingContentTypeIsOverriddenImmediately() {
this.mockResponse.setContentType("text/csv");
this.response = new ServletServerHttpResponse(this.mockResponse);
this.response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_JSON);
}
@Test @Test
public void getBody() throws Exception { void getBody() throws Exception {
byte[] content = "Hello World".getBytes("UTF-8"); byte[] content = "Hello World".getBytes(StandardCharsets.UTF_8);
FileCopyUtils.copy(content, response.getBody()); FileCopyUtils.copy(content, response.getBody());
assertThat(mockResponse.getContentAsByteArray()).as("Invalid content written").isEqualTo(content); assertThat(mockResponse.getContentAsByteArray()).as("Invalid content written").isEqualTo(content);

Loading…
Cancel
Save