diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt index 61803b8695..1cee19ecef 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt @@ -21,6 +21,7 @@ import org.springframework.http.HttpMethod import org.springframework.http.HttpStatus import org.springframework.http.MediaType import reactor.core.publisher.Mono +import reactor.core.publisher.cast import java.net.URI /** @@ -153,8 +154,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given request predicate applies. * @see RouterFunctions.route */ - fun GET(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.GET(pattern), HandlerFunction { f(it) }) + fun GET(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.GET(pattern), HandlerFunction { f(it).cast() }) } /** @@ -168,8 +169,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given request predicate applies. * @see RouterFunctions.route */ - fun HEAD(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.HEAD(pattern), HandlerFunction { f(it) }) + fun HEAD(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.HEAD(pattern), HandlerFunction { f(it).cast() }) } /** @@ -183,8 +184,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given POST predicate applies. * @see RouterFunctions.route */ - fun POST(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.POST(pattern), HandlerFunction { f(it) }) + fun POST(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.POST(pattern), HandlerFunction { f(it).cast() }) } /** @@ -198,8 +199,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given PUT predicate applies. * @see RouterFunctions.route */ - fun PUT(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.PUT(pattern), HandlerFunction { f(it) }) + fun PUT(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.PUT(pattern), HandlerFunction { f(it).cast() }) } /** @@ -213,8 +214,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given PATCH predicate applies. * @see RouterFunctions.route */ - fun PATCH(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.PATCH(pattern), HandlerFunction { f(it) }) + fun PATCH(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.PATCH(pattern), HandlerFunction { f(it).cast() }) } /** @@ -230,8 +231,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given DELETE predicate applies. * @see RouterFunctions.route */ - fun DELETE(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.DELETE(pattern), HandlerFunction { f(it) }) + fun DELETE(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.DELETE(pattern), HandlerFunction { f(it).cast() }) } /** @@ -247,8 +248,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given OPTIONS predicate applies. * @see RouterFunctions.route */ - fun OPTIONS(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.OPTIONS(pattern), HandlerFunction { f(it) }) + fun OPTIONS(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.OPTIONS(pattern), HandlerFunction { f(it).cast() }) } /** @@ -264,8 +265,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given accept predicate applies. * @see RouterFunctions.route */ - fun accept(mediaType: MediaType, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.accept(mediaType), HandlerFunction { f(it) }) + fun accept(mediaType: MediaType, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.accept(mediaType), HandlerFunction { f(it).cast() }) } /** @@ -281,8 +282,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given contentType predicate applies. * @see RouterFunctions.route */ - fun contentType(mediaType: MediaType, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.contentType(mediaType), HandlerFunction { f(it) }) + fun contentType(mediaType: MediaType, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.contentType(mediaType), HandlerFunction { f(it).cast() }) } /** @@ -298,8 +299,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given headers predicate applies. * @see RouterFunctions.route */ - fun headers(headersPredicate: (ServerRequest.Headers) -> Boolean, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.headers(headersPredicate), HandlerFunction { f(it) }) + fun headers(headersPredicate: (ServerRequest.Headers) -> Boolean, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.headers(headersPredicate), HandlerFunction { f(it).cast() }) } /** @@ -314,8 +315,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given method predicate applies. * @see RouterFunctions.route */ - fun method(httpMethod: HttpMethod, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.method(httpMethod), HandlerFunction { f(it) }) + fun method(httpMethod: HttpMethod, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.method(httpMethod), HandlerFunction { f(it).cast() }) } /** @@ -329,8 +330,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given path predicate applies. * @see RouterFunctions.route */ - fun path(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.path(pattern), HandlerFunction { f(it) }) + fun path(pattern: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.path(pattern), HandlerFunction { f(it).cast() }) } /** @@ -343,8 +344,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given pathExtension predicate applies. * @see RouterFunctions.route */ - fun pathExtension(extension: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.pathExtension(extension), HandlerFunction { f(it) }) + fun pathExtension(extension: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.pathExtension(extension), HandlerFunction { f(it).cast() }) } /** @@ -358,8 +359,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given pathExtension predicate applies. * @see RouterFunctions.route */ - fun pathExtension(predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.pathExtension(predicate), HandlerFunction { f(it) }) + fun pathExtension(predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.pathExtension(predicate), HandlerFunction { f(it).cast() }) } /** @@ -374,8 +375,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given queryParam predicate applies. * @see RouterFunctions.route */ - fun queryParam(name: String, predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.queryParam(name, predicate), HandlerFunction { f(it) }) + fun queryParam(name: String, predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.queryParam(name, predicate), HandlerFunction { f(it).cast() }) } /** @@ -393,8 +394,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * Route to the given handler function if the given request predicate applies. * @see RouterFunctions.route */ - operator fun RequestPredicate.invoke(f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(this, HandlerFunction { f(it) }) + operator fun RequestPredicate.invoke(f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(this, HandlerFunction { f(it).cast() }) } /** @@ -402,8 +403,8 @@ open class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) : ( * processed as a path predicate) applies. * @see RouterFunctions.route */ - operator fun String.invoke(f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.path(this), HandlerFunction { f(it) }) + operator fun String.invoke(f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.path(this), HandlerFunction { f(it).cast() }) } /** diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt index 2eceb5b59f..ce32b91b77 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ class RouterFunctionDslTests { @Test fun accept() { - val request = builder().header(ACCEPT, APPLICATION_ATOM_XML_VALUE).build() + val request = builder().uri(URI("/content")).header(ACCEPT, APPLICATION_ATOM_XML_VALUE).build() StepVerifier.create(sampleRouter().route(request)) .expectNextCount(1) .verifyComplete() @@ -63,7 +63,7 @@ class RouterFunctionDslTests { @Test fun contentType() { - val request = builder().header(CONTENT_TYPE, APPLICATION_OCTET_STREAM_VALUE).build() + val request = builder().uri(URI("/content")).header(CONTENT_TYPE, APPLICATION_OCTET_STREAM_VALUE).build() StepVerifier.create(sampleRouter().route(request)) .expectNextCount(1) .verifyComplete() @@ -112,6 +112,14 @@ class RouterFunctionDslTests { .verifyComplete() } + @Test + fun rendering() { + val request = builder().uri(URI("/rendering")).build() + StepVerifier.create(sampleRouter().route(request).flatMap { it.handle(request) }) + .expectNextMatches { it is RenderingResponse} + .verifyComplete() + } + private fun sampleRouter() = router { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } @@ -123,8 +131,10 @@ class RouterFunctionDslTests { } "/foo/" { handleFromClass(it) } } - accept(APPLICATION_ATOM_XML, ::handle) - contentType(APPLICATION_OCTET_STREAM, ::handle) + "/content".nest { + accept(APPLICATION_ATOM_XML, ::handle) + contentType(APPLICATION_OCTET_STREAM, ::handle) + } method(PATCH, ::handle) headers { it.accept().contains(APPLICATION_JSON) }.nest { GET("/api/foo/", ::handle) @@ -141,6 +151,7 @@ class RouterFunctionDslTests { } } path("/baz", ::handle) + GET("/rendering") { RenderingResponse.create("index").build() } } }