diff --git a/gateway-service/src/main/java/org/zowe/apiml/gateway/config/WebSecurity.java b/gateway-service/src/main/java/org/zowe/apiml/gateway/config/WebSecurity.java index d92f7b1e95..0367d7a45d 100644 --- a/gateway-service/src/main/java/org/zowe/apiml/gateway/config/WebSecurity.java +++ b/gateway-service/src/main/java/org/zowe/apiml/gateway/config/WebSecurity.java @@ -12,8 +12,11 @@ import jakarta.annotation.PostConstruct; import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.Strings; import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -30,6 +33,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; import org.springframework.http.ResponseCookie; +import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequestDecorator; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.security.config.annotation.method.configuration.EnableReactiveMethodSecurity; @@ -40,20 +44,11 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.*; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver; -import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.*; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -68,6 +63,7 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; +import org.zowe.apiml.config.ApplicationInfo; import org.zowe.apiml.gateway.config.oidc.ClientConfiguration; import org.zowe.apiml.gateway.controllers.GatewayExceptionHandler; import org.zowe.apiml.gateway.filters.proxyheaders.AdditionalRegistrationGatewayRegistry; @@ -90,13 +86,7 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Base64; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.function.Predicate; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -516,11 +506,12 @@ public HttpHeaders getHeaders() { @Bean StrictServerWebExchangeFirewall httpFirewall() { - StrictServerWebExchangeFirewall firewall = new StrictServerWebExchangeFirewall(); + var strictFirewall = new StrictServerWebExchangeFirewall(); if (isStrictUrlValidationEnabled) { - return firewall; + return strictFirewall; } + StrictServerWebExchangeFirewall firewall = new ApimlStrictServerWebExchangeFirewall(strictFirewall); firewall.setAllowUrlEncodedSlash(true); firewall.setAllowUrlEncodedDoubleSlash(true); firewall.setAllowBackSlash(true); @@ -541,4 +532,70 @@ XForwardedHeadersFilter xForwardedHeadersFilter( return new X509AndGwAwareXForwardedHeadersFilter(httpsConfig, trustedProxies, additionalRegistrationGatewayRegistry); } + @RequiredArgsConstructor + static class ApimlStrictServerWebExchangeFirewall extends StrictServerWebExchangeFirewall { + + private static final String[] BASE_PATH_MICROSERVICES = { + "/gateway", + "/application", + "/images", + "/v3/api-docs" + }; + + private static final String[] BASE_PATHS_MODULITH = ArrayUtils.addAll(BASE_PATH_MICROSERVICES, new String[] { + "/apicatalog", + "/cachingservice" + }); + + @Value("${server.port}") + private int gatewayPort; + + @Autowired + private ApplicationInfo applicationInfo; + + private final StrictServerWebExchangeFirewall nonRoutingFirewall; + + boolean isPathToRoute(ServerHttpRequest request, String[] prefixes) { + var path = request.getPath().value(); + // homepage + if (Strings.CS.equals(path, "/")) { + return false; + } + for (String prefix : prefixes) { + if (Strings.CS.equals(path, prefix)) { + return false; + } + if (Strings.CS.startsWith(path, prefix + "/")) { + return false; + } + } + return true; + } + + boolean isPathToRoute(ServerHttpRequest request) { + if (applicationInfo.isModulith()) { + // check if the request is to DS on the internal port + if (request.getLocalAddress().getPort() != gatewayPort) { + return false; + } + + return isPathToRoute(request, BASE_PATHS_MODULITH); + } + + return isPathToRoute(request, BASE_PATH_MICROSERVICES); + } + + + @Override + public Mono getFirewalledExchange(ServerWebExchange exchange) { + // in case of Gateway and a request to routing use a configured values + if (isPathToRoute(exchange.getRequest())) { + return super.getFirewalledExchange(exchange); + } + + return nonRoutingFirewall.getFirewalledExchange(exchange); + } + + } + } diff --git a/gateway-service/src/test/java/org/zowe/apiml/gateway/config/WebSecurityTest.java b/gateway-service/src/test/java/org/zowe/apiml/gateway/config/WebSecurityTest.java index 5a62111c92..b578ff2e4f 100644 --- a/gateway-service/src/test/java/org/zowe/apiml/gateway/config/WebSecurityTest.java +++ b/gateway-service/src/test/java/org/zowe/apiml/gateway/config/WebSecurityTest.java @@ -14,6 +14,9 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; @@ -23,7 +26,9 @@ import org.springframework.http.server.RequestPath; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.config.Customizer; import org.springframework.security.config.web.server.ServerHttpSecurity; import org.springframework.security.core.GrantedAuthority; @@ -44,15 +49,18 @@ import org.springframework.security.web.server.WebFilterExchange; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository; +import org.springframework.security.web.server.firewall.StrictServerWebExchangeFirewall; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.web.server.ServerWebExchange; +import org.zowe.apiml.config.ApplicationInfo; import org.zowe.apiml.gateway.config.oidc.ClientConfiguration; import org.zowe.apiml.gateway.service.BasicAuthProvider; import org.zowe.apiml.gateway.service.TokenProvider; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import java.net.InetSocketAddress; import java.util.HashMap; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; @@ -482,4 +490,121 @@ void saveAuthorizationRequest_whenCalled_shouldSaveCookies () { assertThat(cookies.getFirst(WebSecurity.COOKIE_STATE).getValue()).isEqualTo("test-state"); } + + @Nested + @ExtendWith(MockitoExtension.class) + class ApimlStrictServerWebExchangeFirewall { + + private static final int GATEWAY_PORT = 10010; + + @Mock + private StrictServerWebExchangeFirewall nonRoutingFirewall; + private WebSecurity.ApimlStrictServerWebExchangeFirewall apimlStrictServerWebExchangeFirewall; + private ApplicationInfo applicationInfo = ApplicationInfo.builder().build(); + + @BeforeEach + void setUp() { + apimlStrictServerWebExchangeFirewall = new WebSecurity.ApimlStrictServerWebExchangeFirewall(nonRoutingFirewall); + ReflectionTestUtils.setField(apimlStrictServerWebExchangeFirewall, "gatewayPort", GATEWAY_PORT); + ReflectionTestUtils.setField(apimlStrictServerWebExchangeFirewall, "applicationInfo", applicationInfo); + } + + @Nested + class Modulith { + + @BeforeEach + void setUp() { + applicationInfo.setModulith(true); + } + + @ParameterizedTest(name = "givenLocalEndpointPath_whenFirewallCheck_thenDecideToUseStrictOne(port={0}, path={1})") + @CsvSource({ + "10010,/", + "10010,/gateway", + "10010,/gateway/api/v1/anyUrl", + "10010,/images", + "10010,/images/homepage/picture.gif", + "10010,/application", + "10010,/application/health", + "10010,/v3/api-docs", + "10010,/v3/api-docs/apicatalog", + "10010,/apicatalog", + "10010,/apicatalog/index.html", + "10010,/cachingservice", + "10010,/cachingservice/map", + "10011,/", + "10011,/application", + "10011,/eureka" + }) + void givenLocalEndpointPath_whenFirewallCheck_thenDecideToUseStrictOne(int port, String path) { + var request = MockServerHttpRequest.get(path).localAddress(new InetSocketAddress(port)).build(); + var exchange = MockServerWebExchange.from(request); + + apimlStrictServerWebExchangeFirewall.getFirewalledExchange(exchange); + + verify(nonRoutingFirewall).getFirewalledExchange(any()); + } + + @ParameterizedTest(name = "givenSouthBoundServicePath_whenFirewallCheck_thenDecideToUseCustomizedOne(port={0}, path={1})") + @CsvSource({ + "10010,/service/api/v1", + "10010,/v3/a/strange/service" + }) + void givenSouthBoundServicePath_whenFirewallCheck_thenDecideToUseCustomizedOne(int port, String path) { + var request = MockServerHttpRequest.get(path).localAddress(new InetSocketAddress(port)).build(); + var exchange = MockServerWebExchange.from(request); + + apimlStrictServerWebExchangeFirewall.getFirewalledExchange(exchange); + + verify(nonRoutingFirewall, never()).getFirewalledExchange(any()); + } + + } + + @Nested + class Microservices { + + @BeforeEach + void setUp() { + applicationInfo.setModulith(false); + } + + @ParameterizedTest(name = "givenLocalEndpointPath_whenFirewallCheck_thenDecideToUseStrictOne({0})") + @CsvSource({ + "/", + "/gateway", + "/gateway/api/v1/anyUrl", + "/images", + "/images/homepage/picture.gif", + "/application", + "/application/health", + "/v3/api-docs" + }) + void givenLocalEndpointPath_whenFirewallCheck_thenDecideToUseStrictOne(String path) { + var request = MockServerHttpRequest.get(path).localAddress(new InetSocketAddress(12345)).build(); + var exchange = MockServerWebExchange.from(request); + + apimlStrictServerWebExchangeFirewall.getFirewalledExchange(exchange); + + verify(nonRoutingFirewall).getFirewalledExchange(any()); + } + + @ParameterizedTest(name = "givenSouthBoundServicePath_whenFirewallCheck_thenDecideToUseCustomizedOne({0})") + @CsvSource({ + "10010,/service/api/v1", + "10010,/v3/a/strange/service" + }) + void givenSouthBoundServicePath_whenFirewallCheck_thenDecideToUseCustomizedOne(String path) { + var request = MockServerHttpRequest.get(path).localAddress(new InetSocketAddress(54321)).build(); + var exchange = MockServerWebExchange.from(request); + + apimlStrictServerWebExchangeFirewall.getFirewalledExchange(exchange); + + verify(nonRoutingFirewall, never()).getFirewalledExchange(any()); + } + + } + + } + }