From d6de0906f44242d0226ad26022a58d3fc5c4904e Mon Sep 17 00:00:00 2001
From: Pierre Souchay
Date: Wed, 22 Feb 2017 23:54:59 +0100
Subject: [PATCH] Added DNS Fallback for REALM resolution
---
.../sqlserver/jdbc/KerbAuthentication.java | 45 ++++--
.../jdbc/dns/DNSKerberosLocator.java | 33 ++++
.../sqlserver/jdbc/dns/DNSRecordSRV.java | 142 ++++++++++++++++++
.../sqlserver/jdbc/dns/DNSUtilities.java | 63 ++++++++
.../sqlserver/jdbc/dns/DNSRealmsTest.java | 21 +++
5 files changed, 290 insertions(+), 14 deletions(-)
create mode 100644 src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java
create mode 100644 src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java
create mode 100644 src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java
create mode 100644 src/test/java/com/microsoft/sqlserver/jdbc/dns/DNSRealmsTest.java
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java b/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java
index f82d913cf..e0a256703 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/KerbAuthentication.java
@@ -23,6 +23,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import javax.naming.NamingException;
import javax.security.auth.Subject;
import javax.security.auth.login.AppConfigurationEntry;
import javax.security.auth.login.Configuration;
@@ -36,6 +37,8 @@
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;
+import com.microsoft.sqlserver.jdbc.dns.DNSKerberosLocator;
+
/**
* KerbAuthentication for int auth.
*/
@@ -268,7 +271,9 @@ private String makeSpn(String server,
spn = makeSpn(address, port);
}
this.spn = enrichSpnWithRealm(spn, null == userSuppliedServerSpn);
- //DEBUG System.err.println("SPN before enrichment: " + spn + " ; AFTER enrichment: " + this.spn);
+ if (!this.spn.equals(spn) && authLogger.isLoggable(Level.FINER)){
+ authLogger.finer(toString() + "SPN enriched: " + spn + " := " + this.spn);
+ }
}
private static final Pattern SPN_PATTERN = Pattern.compile("MSSQLSvc/(.*):([^:@]+)(@.+)?", Pattern.CASE_INSENSITIVE);
@@ -287,7 +292,7 @@ private String enrichSpnWithRealm(String spn, boolean allowHostnameCanonicalizat
}
String dnsName = m.group(1);
String portOrInstance = m.group(2);
- RealmValidator realmValidator = getRealmValidator();
+ RealmValidator realmValidator = getRealmValidator(dnsName);
String realm = findRealmFromHostname(realmValidator, dnsName);
if (realm == null && allowHostnameCanonicalization) {
// We failed, try with canonical host name to find a better match
@@ -315,9 +320,10 @@ private String enrichSpnWithRealm(String spn, boolean allowHostnameCanonicalizat
/**
* Find a suitable way of validating a REALM for given JVM.
*
+ * @param hostnameToTest an example hostname we are gonna use to test our realm validator.
* @return a not null realm Validator.
*/
- static RealmValidator getRealmValidator() {
+ static RealmValidator getRealmValidator(String hostnameToTest) {
if (validator != null) {
return validator;
}
@@ -327,15 +333,11 @@ static RealmValidator getRealmValidator() {
Method getInstance = clz.getMethod("getInstance", new Class[0]);
final Method getKDCList = clz.getMethod("getKDCList", new Class[] { String.class });
final Object instance = getInstance.invoke(null);
- //DEBUG final Method getDefaultRealm = clz.getMethod("getDefaultRealm", new Class[0]);
- //DEBUG final Object realmDefault = getDefaultRealm.invoke(instance);
- //DEBUG System.err.println("default_realm="+realmDefault);
RealmValidator oracleRealmValidator = new RealmValidator() {
@Override
public boolean isRealmValid(String realm) {
try {
- //DEBUG System.err.println("RealmValidator.isRealmValid("+realm+")");
Object ret = getKDCList.invoke(instance, realm);
return ret!=null;
} catch (Exception err) {
@@ -344,15 +346,28 @@ public boolean isRealmValid(String realm) {
}
};
validator = oracleRealmValidator;
- return oracleRealmValidator;
+ // As explained here: https://github.com/Microsoft/mssql-jdbc/pull/40#issuecomment-281509304
+ // The default Oracle Resolution mechanism is not bulletproof
+ // If it resolves a crappy name, drop it.
+ if (!validator.isRealmValid("this.might.not.exist." + hostnameToTest)){
+ // Our realm validator is well working, return it
+ authLogger.fine("Kerberos Realm Validator: Using Built-in Oracle Realm Validation method.");
+ return oracleRealmValidator;
+ }
+ authLogger.fine("Kerberos Realm Validator: Detected buggy Oracle Realm Validator, using DNSKerberosLocator.");
} catch (ReflectiveOperationException notTheRightJVMException) {
// Ignored, we simply are not using the right JVM
+ authLogger.fine("Kerberos Realm Validator: No Oracle Realm Validator Available, using DNSKerberosLocator.");
}
// No implementation found, default one, not any realm is valid
validator = new RealmValidator() {
@Override
public boolean isRealmValid(String realm) {
- return false;
+ try {
+ return DNSKerberosLocator.isRealmValid(realm);
+ } catch (NamingException err){
+ return false;
+ }
}
};
return validator;
@@ -365,13 +380,16 @@ public boolean isRealmValid(String realm) {
* @param hostname the name we are looking a REALM for
* @return the realm if found, null otherwise
*/
- private static String findRealmFromHostname(RealmValidator realmValidator, String hostname) {
+ private String findRealmFromHostname(RealmValidator realmValidator, String hostname) {
if (hostname == null) {
return null;
}
int index = 0;
while (index != -1 && index < hostname.length() - 2) {
String realm = hostname.substring(index);
+ if (authLogger.isLoggable(Level.FINEST)) {
+ authLogger.finest(toString() + " looking up REALM candidate " + realm);
+ }
if (realmValidator.isRealmValid(realm)) {
return realm.toUpperCase();
}
@@ -390,10 +408,9 @@ interface RealmValidator {
boolean isRealmValid(String realm);
}
- byte[] GenerateClientContext(byte[] pin, boolean[] done ) throws SQLServerException
- {
- if(null == peerContext)
- {
+ byte[] GenerateClientContext(byte[] pin,
+ boolean[] done) throws SQLServerException {
+ if (null == peerContext) {
intAuthInit();
}
return intAuthHandShake(pin, done);
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java
new file mode 100644
index 000000000..3e493a13c
--- /dev/null
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSKerberosLocator.java
@@ -0,0 +1,33 @@
+package com.microsoft.sqlserver.jdbc.dns;
+
+import java.util.Set;
+
+import javax.naming.NameNotFoundException;
+import javax.naming.NamingException;
+
+public final class DNSKerberosLocator {
+
+ private DNSKerberosLocator() {}
+
+ /**
+ * Tells whether a realm is valid.
+ *
+ * @param realmName the realm to test
+ * @return true if realm is valid, false otherwise
+ * @throws NamingException if DNS failed, so realm existence cannot be determined
+ */
+ public static boolean isRealmValid(String realmName) throws NamingException {
+ if (realmName == null || realmName.length() < 2) {
+ return false;
+ }
+ if (realmName.startsWith(".")) {
+ realmName = realmName.substring(1);
+ }
+ try {
+ Set records = DNSUtilities.findSrvRecords("_kerberos._udp." + realmName);
+ return !records.isEmpty();
+ } catch (NameNotFoundException wrongDomainException) {
+ return false;
+ }
+ }
+}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java
new file mode 100644
index 000000000..08342981b
--- /dev/null
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSRecordSRV.java
@@ -0,0 +1,142 @@
+package com.microsoft.sqlserver.jdbc.dns;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * Describe an DNS SRV Record.
+ */
+public class DNSRecordSRV implements Comparable {
+
+ private static final Pattern PATTERN = Pattern.compile("^([0-9]+) ([0-9]+) ([0-9]+) (.+)$");
+
+ private final int priority;
+
+ /**
+ * Parse a DNS SRC Record from a DNS String record.
+ *
+ * @param record
+ * the record to parse
+ * @return a not null DNS Record
+ * @throws IllegalArgumentException
+ * if record is not correct and cannot be parsed
+ */
+ public static DNSRecordSRV parseFromDNSRecord(String record) throws IllegalArgumentException {
+ Matcher m = PATTERN.matcher(record);
+ if (!m.matches()) {
+ throw new IllegalArgumentException("record '" + record + "' cannot be matched as a valid DNS SRV Record");
+ }
+ try {
+ int priority = Integer.parseInt(m.group(1));
+ int weight = Integer.parseInt(m.group(2));
+ int port = Integer.parseInt(m.group(3));
+ String serverName = m.group(4);
+ // Avoid issues with Kerberos SPN when fully qualified records ends with '.'
+ if (serverName.endsWith(".")) {
+ serverName = serverName.substring(0, serverName.length() - 1);
+ }
+ return new DNSRecordSRV(priority, weight, port, serverName);
+ } catch (IllegalArgumentException err) {
+ throw err;
+ } catch (Exception err) {
+ throw new IllegalArgumentException("Failed to parse DNS SRV record '" + record + "'", err);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return String.format("DNS.SRV[pri=%d w=%d port=%d h='%s']", priority, weight, port, serverName);
+ }
+
+ /**
+ * Constructor.
+ *
+ * @param priority
+ * is lowest
+ * @param weight
+ * 1 at minimum
+ * @param port
+ * the port of service
+ * @param serverName
+ * the host
+ * @throws IllegalArgumentException
+ * if priority < 0 or weight <= 1
+ */
+ public DNSRecordSRV(int priority, int weight, int port, String serverName) throws IllegalArgumentException {
+ if (priority < 0) {
+ throw new IllegalArgumentException("priority must be >= 0, but was: " + priority);
+ }
+ this.priority = priority;
+ if (weight < 0) {
+ // Weight == 0 is OK to disable load balancing, but not below
+ throw new IllegalArgumentException("weight must be >= 0, but was: " + weight);
+ }
+ this.weight = weight;
+ if (port < 0 || port > 65535) {
+ throw new IllegalArgumentException("port must be between 0 and 65535, but was: " + port);
+ }
+ this.port = port;
+ if (serverName == null || serverName.trim().isEmpty()) {
+ throw new IllegalArgumentException("hostname is not supposed to be null or empty in a SRV Record");
+ }
+ this.serverName = serverName;
+ }
+
+ private final int weight;
+ private final int port;
+ private final String serverName;
+
+ @Override
+ public int hashCode() {
+ return serverName.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) {
+ return true;
+ }
+ if (!(other instanceof DNSRecordSRV)) {
+ return false;
+ }
+
+ DNSRecordSRV r = (DNSRecordSRV) other;
+ return port == r.port && weight == r.weight && priority == r.priority && serverName.equals(r.serverName);
+ }
+
+ @Override
+ public int compareTo(DNSRecordSRV o) {
+ if (o == null) {
+ return 1;
+ }
+ int p = Integer.compare(priority, o.priority);
+ if (p != 0) {
+ return p;
+ }
+ p = Integer.compare(weight, o.weight);
+ if (p != 0) {
+ return p;
+ }
+ p = Integer.compare(port, o.port);
+ if (p != 0) {
+ return p;
+ }
+ return serverName.compareTo(o.serverName);
+ }
+
+ public int getPriority() {
+ return priority;
+ }
+
+ public int getWeight() {
+ return weight;
+ }
+
+ public int getPort() {
+ return port;
+ }
+
+ public String getServerName() {
+ return serverName;
+ }
+}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java
new file mode 100644
index 000000000..a923e6048
--- /dev/null
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/dns/DNSUtilities.java
@@ -0,0 +1,63 @@
+package com.microsoft.sqlserver.jdbc.dns;
+
+import java.util.Hashtable;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import javax.naming.NamingEnumeration;
+import javax.naming.NamingException;
+import javax.naming.directory.Attribute;
+import javax.naming.directory.Attributes;
+import javax.naming.directory.DirContext;
+import javax.naming.directory.InitialDirContext;
+
+public class DNSUtilities {
+
+ private final static Logger LOG = Logger.getLogger(DNSUtilities.class.getName());
+
+ private static final Level DNS_ERR_LOG_LEVEL = Level.FINE;
+
+ /**
+ * Find all SRV Record using DNS.
+ * You can then use {@link DNSRecordsSRVCollection#getBestRecord()} to find
+ * the best candidate (for instance for Round-Robin calls)
+ *
+ * @param dnsSrvRecordToFind
+ * the DNS record, for instance: _ldap._tcp.dc._msdcs.DOMAIN.COM
+ * to find all LDAP servers in DOMAIN.COM
+ * @return the collection of records with facilities to find the best
+ * candidate
+ * @throws NamingException
+ * if DNS is not available
+ */
+ public static Set findSrvRecords(final String dnsSrvRecordToFind) throws NamingException {
+ Hashtable