From d6de0906f44242d0226ad26022a58d3fc5c4904e Mon Sep 17 00:00:00 2001 From: Pierre Souchay Date: Wed, 22 Feb 2017 23:54:59 +0100 Subject: [PATCH] Added DNS Fallback for REALM resolution --- .../sqlserver/jdbc/KerbAuthentication.java | 45 ++++-- .../jdbc/dns/DNSKerberosLocator.java | 33 ++++ .../sqlserver/jdbc/dns/DNSRecordSRV.java | 142 ++++++++++++++++++ .../sqlserver/jdbc/dns/DNSUtilities.java | 63 ++++++++ .../sqlserver/jdbc/dns/DNSRealmsTest.java | 21 +++ 5 files changed, 290 insertions(+), 14 deletions(-) create mode 100644 src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java create mode 100644 src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java create mode 100644 src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java create mode 100644 src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java b/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java index f82d913cf..e0a256703 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java @@ -23,6 +23,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.naming.NamingException; import javax.security.auth.Subject; import javax.security.auth.login.AppConfigurationEntry; import javax.security.auth.login.Configuration; @@ -36,6 +37,8 @@ import org.ietf.jgss.GSSName; import org.ietf.jgss.Oid; +import com.microsoft.sqlserver.jdbc.dns.DNSKerberosLocator; + /** * KerbAuthentication for int auth. */ @@ -268,7 +271,9 @@ private String makeSpn(String server, spn = makeSpn(address, port); } this.spn = enrichSpnWithRealm(spn, null == userSuppliedServerSpn); - //DEBUG System.err.println("SPN before enrichment: " + spn + " ; AFTER enrichment: " + this.spn); + if (!this.spn.equals(spn) && authLogger.isLoggable(Level.FINER)){ + authLogger.finer(toString() + "SPN enriched: " + spn + " := " + this.spn); + } } private static final Pattern SPN_PATTERN = Pattern.compile("MSSQLSvc/(.*):([^:@]+)(@.+)?", Pattern.CASE_INSENSITIVE); @@ -287,7 +292,7 @@ private String enrichSpnWithRealm(String spn, boolean allowHostnameCanonicalizat } String dnsName = m.group(1); String portOrInstance = m.group(2); - RealmValidator realmValidator = getRealmValidator(); + RealmValidator realmValidator = getRealmValidator(dnsName); String realm = findRealmFromHostname(realmValidator, dnsName); if (realm == null && allowHostnameCanonicalization) { // We failed, try with canonical host name to find a better match @@ -315,9 +320,10 @@ private String enrichSpnWithRealm(String spn, boolean allowHostnameCanonicalizat /** * Find a suitable way of validating a REALM for given JVM. * + * @param hostnameToTest an example hostname we are gonna use to test our realm validator. * @return a not null realm Validator. */ - static RealmValidator getRealmValidator() { + static RealmValidator getRealmValidator(String hostnameToTest) { if (validator != null) { return validator; } @@ -327,15 +333,11 @@ static RealmValidator getRealmValidator() { Method getInstance = clz.getMethod("getInstance", new Class[0]); final Method getKDCList = clz.getMethod("getKDCList", new Class[] { String.class }); final Object instance = getInstance.invoke(null); - //DEBUG final Method getDefaultRealm = clz.getMethod("getDefaultRealm", new Class[0]); - //DEBUG final Object realmDefault = getDefaultRealm.invoke(instance); - //DEBUG System.err.println("default_realm="+realmDefault); RealmValidator oracleRealmValidator = new RealmValidator() { @Override public boolean isRealmValid(String realm) { try { - //DEBUG System.err.println("RealmValidator.isRealmValid("+realm+")"); Object ret = getKDCList.invoke(instance, realm); return ret!=null; } catch (Exception err) { @@ -344,15 +346,28 @@ public boolean isRealmValid(String realm) { } }; validator = oracleRealmValidator; - return oracleRealmValidator; + // As explained here: https://github.com/Microsoft/mssql-jdbc/pull/40#issuecomment-281509304 + // The default Oracle Resolution mechanism is not bulletproof + // If it resolves a crappy name, drop it. + if (!validator.isRealmValid("this.might.not.exist." + hostnameToTest)){ + // Our realm validator is well working, return it + authLogger.fine("Kerberos Realm Validator: Using Built-in Oracle Realm Validation method."); + return oracleRealmValidator; + } + authLogger.fine("Kerberos Realm Validator: Detected buggy Oracle Realm Validator, using DNSKerberosLocator."); } catch (ReflectiveOperationException notTheRightJVMException) { // Ignored, we simply are not using the right JVM + authLogger.fine("Kerberos Realm Validator: No Oracle Realm Validator Available, using DNSKerberosLocator."); } // No implementation found, default one, not any realm is valid validator = new RealmValidator() { @Override public boolean isRealmValid(String realm) { - return false; + try { + return DNSKerberosLocator.isRealmValid(realm); + } catch (NamingException err){ + return false; + } } }; return validator; @@ -365,13 +380,16 @@ public boolean isRealmValid(String realm) { * @param hostname the name we are looking a REALM for * @return the realm if found, null otherwise */ - private static String findRealmFromHostname(RealmValidator realmValidator, String hostname) { + private String findRealmFromHostname(RealmValidator realmValidator, String hostname) { if (hostname == null) { return null; } int index = 0; while (index != -1 && index < hostname.length() - 2) { String realm = hostname.substring(index); + if (authLogger.isLoggable(Level.FINEST)) { + authLogger.finest(toString() + " looking up REALM candidate " + realm); + } if (realmValidator.isRealmValid(realm)) { return realm.toUpperCase(); } @@ -390,10 +408,9 @@ interface RealmValidator { boolean isRealmValid(String realm); } - byte[] GenerateClientContext(byte[] pin, boolean[] done ) throws SQLServerException - { - if(null == peerContext) - { + byte[] GenerateClientContext(byte[] pin, + boolean[] done) throws SQLServerException { + if (null == peerContext) { intAuthInit(); } return intAuthHandShake(pin, done); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java new file mode 100644 index 000000000..3e493a13c --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java @@ -0,0 +1,33 @@ +package com.microsoft.sqlserver.jdbc.dns; + +import java.util.Set; + +import javax.naming.NameNotFoundException; +import javax.naming.NamingException; + +public final class DNSKerberosLocator { + + private DNSKerberosLocator() {} + + /** + * Tells whether a realm is valid. + * + * @param realmName the realm to test + * @return true if realm is valid, false otherwise + * @throws NamingException if DNS failed, so realm existence cannot be determined + */ + public static boolean isRealmValid(String realmName) throws NamingException { + if (realmName == null || realmName.length() < 2) { + return false; + } + if (realmName.startsWith(".")) { + realmName = realmName.substring(1); + } + try { + Set records = DNSUtilities.findSrvRecords("_kerberos._udp." + realmName); + return !records.isEmpty(); + } catch (NameNotFoundException wrongDomainException) { + return false; + } + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java new file mode 100644 index 000000000..08342981b --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java @@ -0,0 +1,142 @@ +package com.microsoft.sqlserver.jdbc.dns; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Describe an DNS SRV Record. + */ +public class DNSRecordSRV implements Comparable { + + private static final Pattern PATTERN = Pattern.compile("^([0-9]+) ([0-9]+) ([0-9]+) (.+)$"); + + private final int priority; + + /** + * Parse a DNS SRC Record from a DNS String record. + * + * @param record + * the record to parse + * @return a not null DNS Record + * @throws IllegalArgumentException + * if record is not correct and cannot be parsed + */ + public static DNSRecordSRV parseFromDNSRecord(String record) throws IllegalArgumentException { + Matcher m = PATTERN.matcher(record); + if (!m.matches()) { + throw new IllegalArgumentException("record '" + record + "' cannot be matched as a valid DNS SRV Record"); + } + try { + int priority = Integer.parseInt(m.group(1)); + int weight = Integer.parseInt(m.group(2)); + int port = Integer.parseInt(m.group(3)); + String serverName = m.group(4); + // Avoid issues with Kerberos SPN when fully qualified records ends with '.' + if (serverName.endsWith(".")) { + serverName = serverName.substring(0, serverName.length() - 1); + } + return new DNSRecordSRV(priority, weight, port, serverName); + } catch (IllegalArgumentException err) { + throw err; + } catch (Exception err) { + throw new IllegalArgumentException("Failed to parse DNS SRV record '" + record + "'", err); + } + } + + @Override + public String toString() { + return String.format("DNS.SRV[pri=%d w=%d port=%d h='%s']", priority, weight, port, serverName); + } + + /** + * Constructor. + * + * @param priority + * is lowest + * @param weight + * 1 at minimum + * @param port + * the port of service + * @param serverName + * the host + * @throws IllegalArgumentException + * if priority < 0 or weight <= 1 + */ + public DNSRecordSRV(int priority, int weight, int port, String serverName) throws IllegalArgumentException { + if (priority < 0) { + throw new IllegalArgumentException("priority must be >= 0, but was: " + priority); + } + this.priority = priority; + if (weight < 0) { + // Weight == 0 is OK to disable load balancing, but not below + throw new IllegalArgumentException("weight must be >= 0, but was: " + weight); + } + this.weight = weight; + if (port < 0 || port > 65535) { + throw new IllegalArgumentException("port must be between 0 and 65535, but was: " + port); + } + this.port = port; + if (serverName == null || serverName.trim().isEmpty()) { + throw new IllegalArgumentException("hostname is not supposed to be null or empty in a SRV Record"); + } + this.serverName = serverName; + } + + private final int weight; + private final int port; + private final String serverName; + + @Override + public int hashCode() { + return serverName.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof DNSRecordSRV)) { + return false; + } + + DNSRecordSRV r = (DNSRecordSRV) other; + return port == r.port && weight == r.weight && priority == r.priority && serverName.equals(r.serverName); + } + + @Override + public int compareTo(DNSRecordSRV o) { + if (o == null) { + return 1; + } + int p = Integer.compare(priority, o.priority); + if (p != 0) { + return p; + } + p = Integer.compare(weight, o.weight); + if (p != 0) { + return p; + } + p = Integer.compare(port, o.port); + if (p != 0) { + return p; + } + return serverName.compareTo(o.serverName); + } + + public int getPriority() { + return priority; + } + + public int getWeight() { + return weight; + } + + public int getPort() { + return port; + } + + public String getServerName() { + return serverName; + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java new file mode 100644 index 000000000..a923e6048 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java @@ -0,0 +1,63 @@ +package com.microsoft.sqlserver.jdbc.dns; + +import java.util.Hashtable; +import java.util.Set; +import java.util.TreeSet; +import java.util.logging.Level; +import java.util.logging.Logger; + +import javax.naming.NamingEnumeration; +import javax.naming.NamingException; +import javax.naming.directory.Attribute; +import javax.naming.directory.Attributes; +import javax.naming.directory.DirContext; +import javax.naming.directory.InitialDirContext; + +public class DNSUtilities { + + private final static Logger LOG = Logger.getLogger(DNSUtilities.class.getName()); + + private static final Level DNS_ERR_LOG_LEVEL = Level.FINE; + + /** + * Find all SRV Record using DNS. + * You can then use {@link DNSRecordsSRVCollection#getBestRecord()} to find + * the best candidate (for instance for Round-Robin calls) + * + * @param dnsSrvRecordToFind + * the DNS record, for instance: _ldap._tcp.dc._msdcs.DOMAIN.COM + * to find all LDAP servers in DOMAIN.COM + * @return the collection of records with facilities to find the best + * candidate + * @throws NamingException + * if DNS is not available + */ + public static Set findSrvRecords(final String dnsSrvRecordToFind) throws NamingException { + Hashtable env = new Hashtable(); + env.put("java.naming.factory.initial", "com.sun.jndi.dns.DnsContextFactory"); + env.put("java.naming.provider.url", "dns:"); + DirContext ctx = new InitialDirContext(env); + Attributes attrs = ctx.getAttributes(dnsSrvRecordToFind, new String[] { "SRV" }); + NamingEnumeration allServers = attrs.getAll(); + TreeSet records = new TreeSet(); + while (allServers.hasMoreElements()) { + Attribute a = allServers.nextElement(); + NamingEnumeration srvRecord = a.getAll(); + while (srvRecord.hasMore()) { + final String record = String.valueOf(srvRecord.nextElement()); + try { + DNSRecordSRV rec = DNSRecordSRV.parseFromDNSRecord(record); + if (rec != null) { + records.add(rec); + } + } catch (IllegalArgumentException errorParsingRecord) { + if (LOG.isLoggable(DNS_ERR_LOG_LEVEL)) { + LOG.log(DNS_ERR_LOG_LEVEL, String.format("Failed to parse SRV DNS Record: '%s'", record), + errorParsingRecord); + } + } + } + } + return records; + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java new file mode 100644 index 000000000..9753dadff --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java @@ -0,0 +1,21 @@ +package com.microsoft.sqlserver.jdbc.dns; + +import javax.naming.NamingException; + +public class DNSRealmsTest { + + public static void main(String... args) { + if (args.length < 1) { + System.err.println("USAGE: list of domains to test for kerberos realms"); + } + for (String realmName : args) { + try { + System.out.print(DNSKerberosLocator.isRealmValid(realmName) ? "[ VALID ] " : "[INVALID] "); + } catch (NamingException err) { + System.err.print("[ FAILED] : " + err.getClass().getName() + ":" + err.getMessage()); + } + System.out.println(realmName); + } + } + +}