All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ldaptive.provider.DnsSrvConnectionStrategy Maven / Gradle / Ivy

There is a newer version: 2.4.0
Show newest version
/* See LICENSE for licensing and NOTICE for copyright. */
package org.ldaptive.provider;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ThreadLocalRandom;
import javax.naming.Context;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import javax.naming.directory.InitialDirContext;
import org.ldaptive.LdapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * DNS SRV connection strategy. Queries a DNS server for SRV records and uses those records to construct a list of URLs.
 * A time to live can be set to control how often the DNS server is consulted. See http://www.ietf.org/rfc/rfc2782.txt.
 *
 * @author  Middleware Services
 */
public class DnsSrvConnectionStrategy implements ConnectionStrategy
{

  /** JNDI context factory for DNS. */
  private static final String DNS_CONTEXT_FACTORY = "com.sun.jndi.dns.DnsContextFactory";

  /** JNDI context factory for DNS. */
  private static final String DNS_PROVIDER_URL = "dns:";

  /** Default time to live for DNS results. Value is {@value}. */
  private static final long DEFAULT_TTL = 60L * 60L * 1000L;

  /** Logger for this class. */
  protected final Logger logger = LoggerFactory.getLogger(getClass());

  /** JNDI environment. */
  private Map jndiEnv = new HashMap<>();

  /** Time to live for SRV records in milliseconds. */
  private long srvTtl;

  /** SRV records from the last DNS lookup. */
  private List srvRecords;


  /** Creates a new DNS SRV connection strategy. */
  public DnsSrvConnectionStrategy()
  {
    this(null, DEFAULT_TTL);
  }


  /**
   * Creates a new DNS SRV connection strategy.
   *
   * @param  ttl  time to live in milliseconds for SRV records
   */
  public DnsSrvConnectionStrategy(final long ttl)
  {
    this(null, ttl);
  }


  /**
   * Creates a new DNS SRV connection strategy.
   *
   * @param  env  JNDI environment
   * @param  ttl  time to live in milliseconds for SRV records
   */
  public DnsSrvConnectionStrategy(final Map env, final long ttl)
  {
    if (env != null) {
      setJndiEnvironment(env);
    }
    setTimeToLive(ttl);
  }


  /**
   * Returns the JNDI environment used for DNS lookup.
   *
   * @return  jndi environment
   */
  public Map getJndiEnvironment()
  {
    return jndiEnv;
  }


  /**
   * Returns the time that DNS lookups will be cached.
   *
   * @return  time to live in milliseconds
   */
  public long getTimeToLive()
  {
    return srvTtl;
  }


  /**
   * Sets the JNDI environment used for DNS lookups. If no {@link Context#INITIAL_CONTEXT_FACTORY} is set, it is
   * defaulted to {@link #DNS_CONTEXT_FACTORY}. If no {@link Context#PROVIDER_URL} is set, it is defaulted to {@link
   * #DNS_PROVIDER_URL}.
   *
   * @param  env  jndi environment or null
   */
  public void setJndiEnvironment(final Map env)
  {
    jndiEnv = new HashMap<>(env);
  }


  /**
   * Sets the time that DNS lookups will be cached.
   *
   * @param  ttl  time to live in milliseconds
   */
  public void setTimeToLive(final long ttl)
  {
    srvTtl = ttl;
  }


  /**
   * Returns a list of URLs retrieved from DNS SRV records. The LDAP URL in the supplied metadata can be a space
   * delimited list of DNS servers, each will be tried in order.
   *
   * @param  metadata  which can be used to produce the URL list
   *
   * @return  list of URLs to attempt connections to
   */
  @Override
  public String[] getLdapUrls(final ConnectionFactoryMetadata metadata)
  {
    if (metadata == null || metadata.getLdapUrl() == null) {
      return null;
    }
    if (srvRecords == null ||
        srvRecords.isEmpty() ||
        System.currentTimeMillis() >= srvRecords.get(0).getExpirationTime()) {
      try {
        srvRecords = sortSrvRecords(retrieveDNSRecords(metadata.getLdapUrl(), jndiEnv, srvTtl));
      } catch (NamingException e) {
        throw new IllegalArgumentException("Could not retrieve DNS SRV record for " + metadata.getLdapUrl(), e);
      }
      if (srvRecords.isEmpty()) {
        throw new IllegalArgumentException("No DNS SRV records found for " + metadata.getLdapUrl());
      }
      logger.debug("Retrieved SRV records from DNS: {}", srvRecords);
    } else {
      logger.debug("Using SRV records from internal cache: {}", srvRecords);
    }

    final String[] urls = new String[srvRecords.size()];
    for (int i = 0; i < srvRecords.size(); i++) {
      urls[i] = srvRecords.get(i).getLdapURL();
    }
    return urls;
  }


  /**
   * Uses JNDI to retrieve the DNS SRV record from the supplied url. The supplied properties are passed into the JNDI
   * context.
   *
   * @param  name  of the SRV records
   * @param  props  for the JNDI context
   * @param  ttl  time to live for each SRV record
   *
   * @return  list of LDAP URLs
   *
   * @throws  NamingException  if the DNS record cannot be retrieved
   */
  protected List retrieveDNSRecords(final String name, final Map props, final long ttl)
    throws NamingException
  {
    final List records = new ArrayList<>();
    DirContext context = null;
    NamingEnumeration en = null;
    try {
      // CheckStyle:IllegalType OFF
      final Hashtable env = new Hashtable<>(props);
      // CheckStyle:IllegalType ON
      if (!env.containsKey(Context.INITIAL_CONTEXT_FACTORY)) {
        env.put(Context.INITIAL_CONTEXT_FACTORY, DNS_CONTEXT_FACTORY);
      }
      if (!env.containsKey(Context.PROVIDER_URL)) {
        env.put(Context.PROVIDER_URL, DNS_PROVIDER_URL);
      }
      context = new InitialDirContext(env);

      final Attributes attrs = context.getAttributes(name, new String[] {"SRV", });
      if (attrs != null) {
        final Attribute attr = attrs.get("SRV");
        if (attr != null) {
          en = attr.getAll();

          final long expTime = System.currentTimeMillis() + ttl;
          while (en.hasMore()) {
            records.add(new SrvRecord((String) en.next(), expTime));
          }
        }
      }
    } finally {
      if (en != null) {
        en.close();
      }
      if (context != null) {
        context.close();
      }
    }
    return records;
  }


  /**
   * Sorts the supplied SRV records according to RFC 2782. Records with the lowest priority are first. Records with the
   * same priority are arranged by weight with higher weights having a greater chance to be ordered first.
   *
   * @param  records  to sort
   *
   * @return  sorted records
   */
  protected List sortSrvRecords(final List records)
  {
    // group records and order them by priority
    final Map> priorityRecords = new TreeMap<>();
    for (SrvRecord record : records) {
      final List priority;
      if (!priorityRecords.containsKey(record.getPriority())) {
        priority = new ArrayList<>();
        priorityRecords.put(record.getPriority(), priority);
      } else {
        priority = priorityRecords.get(record.getPriority());
      }
      priority.add(record);
    }

    // order records by priority then by weight
    // unweighted records are ordered last
    final List sortedRecords = new ArrayList<>();
    for (Map.Entry> entry : priorityRecords.entrySet()) {
      final Map weighted = new HashMap<>();
      final List unweighted = new ArrayList<>();
      long totalWeight = 0;
      for (SrvRecord record : entry.getValue()) {
        if (record.getWeight() == 0) {
          unweighted.add(record);
        } else {
          totalWeight += record.getWeight();
          weighted.put(totalWeight, record);
        }
      }

      while (!weighted.isEmpty()) {
        SrvRecord record = null;
        final Iterator i = weighted.keySet().iterator();
        final long random = ThreadLocalRandom.current().nextLong(totalWeight + 1);
        while (i.hasNext()) {
          final Long weight = i.next();
          if (weight >= random) {
            record = weighted.get(weight);
            totalWeight -= record.getWeight();
            i.remove();
            break;
          }
        }
        sortedRecords.add(record);
      }
      if (!unweighted.isEmpty()) {
        sortedRecords.addAll(unweighted);
      }
    }

    return sortedRecords;
  }


  @Override
  public String toString()
  {
    return
      String.format(
        "[%s@%d::jndiEnv=%s, srvTtl=%s, srvRecords=%s]",
        getClass().getName(),
        hashCode(),
        jndiEnv,
        srvTtl,
        srvRecords);
  }


  /** SRV record. */
  protected static class SrvRecord
  {

    /** hash code seed. */
    private static final int HASH_CODE_SEED = 1201;

    /** SRV priority. */
    private final long priority;

    /** SRV weight. */
    private final long weight;

    /** SRV port. */
    private final int port;

    /** SRV target. */
    private final String target;

    /** expiration time. */
    private final long expirationTime;


    /**
     * Creates a new SRV record.
     *
     * @param  record  from DNS
     * @param  time  that this record should expire
     */
    public SrvRecord(final String record, final long time)
    {
      final String[] parts = record.split(" ");
      int i = 0;
      priority = Long.parseLong(parts[i++]);
      weight = Long.parseLong(parts[i++]);
      port = Integer.parseInt(parts[i++]);
      target = parts[i].endsWith(".") ? parts[i].substring(0, parts[i].length() - 1) : parts[i];
      expirationTime = time;
    }


    /**
     * Returns the priority.
     *
     * @return  priority
     */
    public long getPriority()
    {
      return priority;
    }


    /**
     * Returns the weight.
     *
     * @return  weight
     */
    public long getWeight()
    {
      return weight;
    }


    /**
     * Returns the port.
     *
     * @return  port
     */
    public int getPort()
    {
      return port;
    }


    /**
     * Returns the target.
     *
     * @return  target
     */
    public String getTarget()
    {
      return target;
    }


    /**
     * Returns the target properly formatted as an LDAP URL.
     *
     * @return  LDAP URL
     */
    public String getLdapURL()
    {
      return String.format("ldap://%s:%s", target, port);
    }


    /**
     * Returns the time in milliseconds that this record should expire.
     *
     * @return  expiration time
     */
    public long getExpirationTime()
    {
      return expirationTime;
    }


    @Override
    public boolean equals(final Object o)
    {
      return LdapUtils.areEqual(this, o);
    }


    @Override
    public int hashCode()
    {
      return LdapUtils.computeHashCode(HASH_CODE_SEED, priority, weight, port, target, expirationTime);
    }


    @Override
    public String toString()
    {
      return
        String.format(
          "[%s@%d::priority=%s, weight=%s, port=%s, target=%s, " +
          "expirationTime=%s]",
          getClass().getName(),
          hashCode(),
          priority,
          weight,
          port,
          target,
          expirationTime);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy