Browse Source

TestDispatcherServlet unwraps to find mock request

Issue: SPR-16695
pull/1779/head
Rossen Stoyanchev 7 years ago
parent
commit
7a896d7d80
  1. 10
      spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java
  2. 98
      spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java

10
spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java

@ -26,6 +26,8 @@ import javax.servlet.http.HttpServletResponse; @@ -26,6 +26,8 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockAsyncContext;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.CallableProcessingInterceptor;
@ -35,6 +37,7 @@ import org.springframework.web.context.request.async.WebAsyncUtils; @@ -35,6 +37,7 @@ import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.util.WebUtils;
/**
* A sub-class of {@code DispatcherServlet} that saves the result in an
@ -68,8 +71,13 @@ final class TestDispatcherServlet extends DispatcherServlet { @@ -68,8 +71,13 @@ final class TestDispatcherServlet extends DispatcherServlet {
super.service(request, response);
if (request.getAsyncContext() != null) {
MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext());
Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?");
CountDownLatch dispatchLatch = new CountDownLatch(1);
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(dispatchLatch::countDown);
mockAsyncContext.addDispatchHandler(dispatchLatch::countDown);
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
}
}

98
spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,9 +19,14 @@ package org.springframework.test.web.servlet.samples.standalone; @@ -19,9 +19,14 @@ package org.springframework.test.web.servlet.samples.standalone;
import java.io.IOException;
import java.security.Principal;
import java.util.concurrent.CompletableFuture;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncListener;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
@ -120,10 +125,10 @@ public class FilterTests { @@ -120,10 +125,10 @@ public class FilterTests {
.andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME));
}
@Test // SPR-16067
public void filterWrapsRequestResponseWithAsyncDispatch() throws Exception {
@Test // SPR-16067, SPR-16695
public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception {
MockMvc mockMvc = standaloneSetup(new PersonController())
.addFilters(new ShallowEtagHeaderFilter())
.addFilters(new WrappingRequestResponseFilter(), new ShallowEtagHeaderFilter())
.build();
MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON))
@ -189,10 +194,20 @@ public class FilterTests { @@ -189,10 +194,20 @@ public class FilterTests {
FilterChain filterChain) throws ServletException, IOException {
filterChain.doFilter(new HttpServletRequestWrapper(request) {
@Override
public Principal getUserPrincipal() {
return () -> PRINCIPAL_NAME;
}
// Like Spring Security does in HttpServlet3RequestFactory..
@Override
public AsyncContext getAsyncContext() {
return super.getAsyncContext() != null ?
new AsyncContextWrapper(super.getAsyncContext()) : null;
}
}, new HttpServletResponseWrapper(response));
}
}
@ -206,4 +221,79 @@ public class FilterTests { @@ -206,4 +221,79 @@ public class FilterTests {
response.sendRedirect("/login");
}
}
private static class AsyncContextWrapper implements AsyncContext {
private final AsyncContext delegate;
public AsyncContextWrapper(AsyncContext delegate) {
this.delegate = delegate;
}
@Override
public ServletRequest getRequest() {
return this.delegate.getRequest();
}
@Override
public ServletResponse getResponse() {
return this.delegate.getResponse();
}
@Override
public boolean hasOriginalRequestAndResponse() {
return this.delegate.hasOriginalRequestAndResponse();
}
@Override
public void dispatch() {
this.delegate.dispatch();
}
@Override
public void dispatch(String path) {
this.delegate.dispatch(path);
}
@Override
public void dispatch(ServletContext context, String path) {
this.delegate.dispatch(context, path);
}
@Override
public void complete() {
this.delegate.complete();
}
@Override
public void start(Runnable run) {
this.delegate.start(run);
}
@Override
public void addListener(AsyncListener listener) {
this.delegate.addListener(listener);
}
@Override
public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) {
this.delegate.addListener(listener, req, res);
}
@Override
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
return this.delegate.createListener(clazz);
}
@Override
public void setTimeout(long timeout) {
this.delegate.setTimeout(timeout);
}
@Override
public long getTimeout() {
return this.delegate.getTimeout();
}
}
}

Loading…
Cancel
Save