Skip to content

Commit

Permalink
Merge pull request #55 from raupachz/hsts
Browse files Browse the repository at this point in the history
[SHIRO-740] SslFilter with HTTP Strict Transport Security (HSTS)
  • Loading branch information
fpapon authored Aug 21, 2020
2 parents b0091df + de6200f commit 4aa4e3e
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 2 deletions.
101 changes: 99 additions & 2 deletions web/src/main/java/org/apache/shiro/web/filter/authz/SslFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;

/**
* Filter which requires a request to be over SSL. Access is allowed if the request is received on the configured
Expand All @@ -30,21 +31,50 @@
* The {@link #getPort() port} property defaults to {@code 443} and also additionally guarantees that the
* request scheme is always 'https' (except for port 80, which retains the 'http' scheme).
* <p/>
* Example config:
* In addition the filter allows enabling HTTP Strict Transport Security (HSTS).
* This feature is opt-in and disabled by default. If enabled HSTS
* will prevent <b>any</b> communications from being sent over HTTP to the
* specified domain and will instead send all communications over HTTPS.
* </p>
* The {@link #getMaxAge() maxAge} property defaults {@code 31536000}, and
* {@link #isIncludeSubDomains includeSubDomains} is {@code false}.
* </p>
* <b>Warning:</b> Use this setting with care and only if you plan to enable
* SSL on every path.
* </p>
* Example configs:
* <pre>
* [urls]
* /secure/path/** = ssl
* </pre>
*
* with HSTS enabled
* <pre>
* [main]
* ssl.hsts.enabled = true
* [urls]
* /** = ssl
* </pre>
* @since 1.0
* @see <a href="https://tools.ietf.org/html/rfc6797">HTTP Strict Transport Security (HSTS)</a>
*/
public class SslFilter extends PortFilter {

public static final int DEFAULT_HTTPS_PORT = 443;
public static final String HTTPS_SCHEME = "https";

private HSTS hsts;

public SslFilter() {
setPort(DEFAULT_HTTPS_PORT);
this.hsts = new HSTS();
}

public HSTS getHsts() {
return hsts;
}

public void setHsts(HSTS hsts) {
this.hsts = hsts;
}

@Override
Expand Down Expand Up @@ -73,4 +103,71 @@ protected String getScheme(String requestScheme, int port) {
protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
return super.isAccessAllowed(request, response, mappedValue) && request.isSecure();
}

/**
* If HTTP Strict Transport Security (HSTS) is enabled the HTTP header
* will be written, otherwise this method does nothing.
* @param request the incoming {@code ServletRequest}
* @param response the outgoing {@code ServletResponse}
*/
@Override
protected void postHandle(ServletRequest request, ServletResponse response) {
if (hsts.isEnabled()) {
StringBuilder directives = new StringBuilder(64)
.append("max-age=").append(hsts.getMaxAge());

if (hsts.isIncludeSubDomains()) {
directives.append("; includeSubDomains");
}

HttpServletResponse resp = (HttpServletResponse) response;
resp.addHeader(HSTS.HTTP_HEADER, directives.toString());
}
}

/**
* Helper class for HTTP Strict Transport Security (HSTS)
*/
public class HSTS {

public static final String HTTP_HEADER = "Strict-Transport-Security";

public static final boolean DEFAULT_ENABLED = false;
public static final int DEFAULT_MAX_AGE = 31536000; // approx. one year in seconds
public static final boolean DEFAULT_INCLUDE_SUB_DOMAINS = false;

private boolean enabled;
private int maxAge;
private boolean includeSubDomains;

public HSTS() {
this.enabled = DEFAULT_ENABLED;
this.maxAge = DEFAULT_MAX_AGE;
this.includeSubDomains = DEFAULT_INCLUDE_SUB_DOMAINS;
}

public boolean isEnabled() {
return enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public int getMaxAge() {
return maxAge;
}

public void setMaxAge(int maxAge) {
this.maxAge = maxAge;
}

public boolean isIncludeSubDomains() {
return includeSubDomains;
}

public void setIncludeSubDomains(boolean includeSubDomains) {
this.includeSubDomains = includeSubDomains;
}
}
}
102 changes: 102 additions & 0 deletions web/src/test/java/org/apache/shiro/web/filter/authz/SslFilterTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.shiro.web.filter.authz;

import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Test;

import static org.apache.shiro.web.filter.authz.SslFilter.HSTS.*;
import org.easymock.Capture;
import org.easymock.CaptureType;
import static org.easymock.EasyMock.*;
import org.easymock.IAnswer;
import static org.junit.Assert.*;
import org.junit.Before;

public class SslFilterTest {

private HttpServletRequest request;
private HttpServletResponse response;
private SslFilter sslFilter;

@Before
public void before() {
request = createNiceMock(HttpServletRequest.class);
response = createNiceMock(HttpServletResponse.class);
sslFilter = new SslFilter();

final Map<String,String> headers = new HashMap<String,String>();

final Capture<String> capturedName = newCapture();
final Capture<String> capturedValue = newCapture();

// mock HttpServletResponse.getHeader
expect(response.getHeader(capture(capturedName))).andAnswer(new IAnswer<String>() {
@Override
public String answer() throws Throwable {
String name = capturedName.getValue();
return headers.get(name);
}

});

// mock HttpServletResponse.addHeader
response.addHeader(capture(capturedName), capture(capturedValue));
expectLastCall().andAnswer(new IAnswer<Void>() {
@Override
public Void answer() throws Throwable {
String name = capturedName.getValue();
String value = capturedValue.getValue();
headers.put(name, value);
return (null);
}
});

replay(response);
}

@Test
public void testDisabledByDefault() {
sslFilter.postHandle(request, response);
assertNull(response.getHeader(HTTP_HEADER));
}

@Test
public void testDefaultValues() {
sslFilter.getHsts().setEnabled(true);
sslFilter.postHandle(request, response);
assertEquals("max-age=" + DEFAULT_MAX_AGE, response.getHeader(HTTP_HEADER));
}

@Test
public void testSetProperties() {
sslFilter.getHsts().setEnabled(true);
sslFilter.getHsts().setMaxAge(7776000);
sslFilter.getHsts().setIncludeSubDomains(true);
sslFilter.postHandle(request, response);

String expected = "max-age=" + 7776000 + "; includeSubDomains";

assertEquals(expected, response.getHeader(HTTP_HEADER));
}

}

0 comments on commit 4aa4e3e

Please sign in to comment.