Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ public void onApplicationEvent(RefreshRoutesResultEvent event) {
routeLocator.getRoutes().collectList().subscribe(routes -> {
// pre-populate with pre-existing global cors configurations to combine with.
Map<String, CorsConfiguration> corsConfigurations = new LinkedHashMap<>();
Map<String, CorsConfiguration> routeCorsConfigurations = new LinkedHashMap<>();

routes.forEach(route -> {
Optional<CorsConfiguration> corsConfiguration = getCorsConfiguration(route);
corsConfiguration.ifPresent(configuration -> {
routeCorsConfigurations.put(route.getId(), configuration);
String pathPredicate = getPathPredicate(route);
corsConfigurations.put(pathPredicate, configuration);
});
Expand All @@ -101,6 +103,7 @@ public void onApplicationEvent(RefreshRoutesResultEvent event) {
corsConfigurations.put(path, config);
}
});
routePredicateHandlerMapping.setRouteCorsConfigurations(routeCorsConfigurations);
routePredicateHandlerMapping.setCorsConfigurations(corsConfigurations);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.cloud.gateway.handler;

import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

import reactor.core.publisher.Mono;
Expand All @@ -27,6 +29,7 @@
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.env.Environment;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource;
import org.springframework.web.reactive.handler.AbstractHandlerMapping;
import org.springframework.web.server.ServerWebExchange;

Expand All @@ -51,6 +54,8 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {

private final ManagementPortType managementPortType;

private volatile Map<String, CorsConfiguration> routeCorsConfigurations = Collections.emptyMap();

public RoutePredicateHandlerMapping(FilteringWebHandler webHandler, RouteLocator routeLocator,
GlobalCorsProperties globalCorsProperties, Environment environment) {
this.webHandler = webHandler;
Expand Down Expand Up @@ -108,6 +113,30 @@ protected Mono<?> getHandlerInternal(ServerWebExchange exchange) {
});
}

public void setRouteCorsConfigurations(Map<String, CorsConfiguration> routeCorsConfigurations) {
this.routeCorsConfigurations = routeCorsConfigurations;
}

@Override
public void setCorsConfigurations(Map<String, CorsConfiguration> corsConfigurations) {
if (this.routeCorsConfigurations.isEmpty()) {
super.setCorsConfigurations(corsConfigurations);
return;
}
UrlBasedCorsConfigurationSource pathBasedSource = new UrlBasedCorsConfigurationSource(getPathPatternParser());
pathBasedSource.setCorsConfigurations(corsConfigurations);
setCorsConfigurationSource(exchange -> {
Route route = exchange.getAttribute(GATEWAY_ROUTE_ATTR);
if (route != null) {
CorsConfiguration routeConfig = this.routeCorsConfigurations.get(route.getId());
if (routeConfig != null) {
return routeConfig;
}
}
return pathBasedSource.getCorsConfiguration(exchange);
});
}

@Override
protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) {
// TODO: support cors configuration via properties on a route see gh-229
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,50 @@ public void testPreFlightCorsRequestJavaConfig() {
});
}

@Test
public void testPreFlightCorsRequestHostPredicateA() {
testClient.options()
.uri("/anything")
.header("Origin", "https://origin-a.com")
.header("Host", "hosta.example.com")
.header("Access-Control-Request-Method", "GET")
.exchange()
.expectBody(Map.class)
.consumeWith(result -> {
assertThat(result.getResponseBody()).isNull();
assertThat(result.getStatus()).isEqualTo(HttpStatus.OK);

HttpHeaders responseHeaders = result.getResponseHeaders();
assertThat(responseHeaders.getAccessControlAllowOrigin()).as(missingHeader(ACCESS_CONTROL_ALLOW_ORIGIN))
.isEqualTo("https://origin-a.com");
assertThat(responseHeaders.getAccessControlAllowMethods())
.as(missingHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS))
.containsExactly(HttpMethod.GET);
});
}

@Test
public void testPreFlightCorsRequestHostPredicateB() {
testClient.options()
.uri("/anything")
.header("Origin", "https://origin-b.com")
.header("Host", "hostb.example.com")
.header("Access-Control-Request-Method", "POST")
.exchange()
.expectBody(Map.class)
.consumeWith(result -> {
assertThat(result.getResponseBody()).isNull();
assertThat(result.getStatus()).isEqualTo(HttpStatus.OK);

HttpHeaders responseHeaders = result.getResponseHeaders();
assertThat(responseHeaders.getAccessControlAllowOrigin()).as(missingHeader(ACCESS_CONTROL_ALLOW_ORIGIN))
.isEqualTo("https://origin-b.com");
assertThat(responseHeaders.getAccessControlAllowMethods())
.as(missingHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS))
.containsExactly(HttpMethod.POST);
});
}

@Test
public void testPreFlightForbiddenCorsRequest() {
testClient.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.cloud.gateway.config.GlobalCorsProperties;
import org.springframework.cloud.gateway.event.RefreshRoutesResultEvent;
import org.springframework.cloud.gateway.handler.RoutePredicateHandlerMapping;
import org.springframework.cloud.gateway.handler.predicate.HostRoutePredicateFactory;
import org.springframework.cloud.gateway.handler.predicate.PathRoutePredicateFactory;
import org.springframework.cloud.gateway.route.Route;
import org.springframework.cloud.gateway.route.RouteLocator;
Expand Down Expand Up @@ -104,6 +105,9 @@ class CorsGatewayFilterApplicationListenerTests {
@Captor
private ArgumentCaptor<Map<String, CorsConfiguration>> corsConfigurations;

@Captor
private ArgumentCaptor<Map<String, CorsConfiguration>> routeCorsConfigurations;

private GlobalCorsProperties globalCorsProperties;

private CorsGatewayFilterApplicationListener listener;
Expand Down Expand Up @@ -150,6 +154,60 @@ private CorsConfiguration createCorsConfig(String origin) {
return config;
}

@Test
void testOnApplicationEvent_hostOnlyRoutes_storesRouteCorsConfigurations() {

String hostA = "hosta.example.com";
String hostB = "hostb.example.com";
String originA = "https://originA.com";
String originB = "https://originB.com";
String routeIdA = "host-route-a";
String routeIdB = "host-route-b";

Route routeA = buildHostRoute(routeIdA, hostA, originA);
Route routeB = buildHostRoute(routeIdB, hostB, originB);

when(routeLocator.getRoutes()).thenReturn(Flux.just(routeA, routeB));

listener.onApplicationEvent(new RefreshRoutesResultEvent(this));

Awaitility.await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> {

verify(handlerMapping).setRouteCorsConfigurations(routeCorsConfigurations.capture());

Map<String, CorsConfiguration> routeConfigs = routeCorsConfigurations.getValue();
assertThat(routeConfigs).containsKeys(routeIdA, routeIdB);
assertThat(routeConfigs.get(routeIdA).getAllowedOrigins()).containsExactly(originA);
assertThat(routeConfigs.get(routeIdB).getAllowedOrigins()).containsExactly(originB);
});
}

@Test
void testOnApplicationEvent_pathRoutes_alsoStoresRouteCorsConfigurations() {

Route route1 = buildRoute(ROUTE_ID_1, ROUTE_PATH_1, ORIGIN_ROUTE_1);
Route route2 = buildRoute(ROUTE_ID_2, ROUTE_PATH_2, ORIGIN_ROUTE_2);

when(routeLocator.getRoutes()).thenReturn(Flux.just(route1, route2));

listener.onApplicationEvent(new RefreshRoutesResultEvent(this));

Awaitility.await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> {

verify(handlerMapping).setRouteCorsConfigurations(routeCorsConfigurations.capture());
verify(handlerMapping).setCorsConfigurations(corsConfigurations.capture());

Map<String, CorsConfiguration> routeConfigs = routeCorsConfigurations.getValue();
assertThat(routeConfigs).containsKeys(ROUTE_ID_1, ROUTE_ID_2);
assertThat(routeConfigs.get(ROUTE_ID_1).getAllowedOrigins()).containsExactly(ORIGIN_ROUTE_1);
assertThat(routeConfigs.get(ROUTE_ID_2).getAllowedOrigins()).containsExactly(ORIGIN_ROUTE_2);

// path-based configurations should still work
Map<String, CorsConfiguration> pathConfigs = corsConfigurations.getValue();
assertThat(pathConfigs).containsKeys(ROUTE_PATH_1, ROUTE_PATH_2);
});
}

private Route buildRoute(String id, String path, String allowedOrigin) {

return Route.async()
Expand All @@ -160,4 +218,14 @@ private Route buildRoute(String id, String path, String allowedOrigin) {
.build();
}

private Route buildHostRoute(String id, String host, String allowedOrigin) {

return Route.async()
.id(id)
.uri(ROUTE_URI)
.predicate(new HostRoutePredicateFactory().apply(config -> config.setPatterns(List.of(host))))
.metadata(METADATA_KEY, Map.of(ALLOWED_ORIGINS_KEY, List.of(allowedOrigin)))
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,22 @@ spring:
allowedMethods:
- GET
- PUT
allowedHeaders: '*'
allowedHeaders: '*'
- id: cors_host_a
uri: ${test.uri}
predicates:
- Host=hosta.example.com
metadata:
cors:
allowedOrigins: 'https://origin-a.com'
allowedMethods:
- GET
- id: cors_host_b
uri: ${test.uri}
predicates:
- Host=hostb.example.com
metadata:
cors:
allowedOrigins: 'https://origin-b.com'
allowedMethods:
- POST