All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.milton.dns.server.JNameServer Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package io.milton.dns.server;
//import jnamed;

import io.milton.common.Service;
import io.milton.dns.resource.ZoneDomainResource;
import io.milton.dns.resource.DomainResourceFactory;
import io.milton.dns.resource.DomainResource;
import io.milton.dns.resource.NonAuthoritativeException;
import io.milton.dns.Name;
import io.milton.dns.TextParseException;
import io.milton.dns.record.CNAMERecord;
import io.milton.dns.record.DClass;
import io.milton.dns.record.DNAMERecord;
import io.milton.dns.record.ExtendedFlags;
import io.milton.dns.record.Flags;
import io.milton.dns.record.Header;
import io.milton.dns.record.Message;
import io.milton.dns.record.NameTooLongException;
import io.milton.dns.record.OPTRecord;
import io.milton.dns.record.Opcode;
import io.milton.dns.record.RRset;
import io.milton.dns.record.Rcode;
import io.milton.dns.record.Record;
import io.milton.dns.record.SOARecord;
import io.milton.dns.record.Section;
import io.milton.dns.record.SetResponse;
import io.milton.dns.record.TSIG;
import io.milton.dns.record.TSIGRecord;
import io.milton.dns.record.Type;
import io.milton.dns.resource.DomainResourceRecord;
import io.milton.dns.utils.RecordTypes;
import io.milton.dns.utils.Utils;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JNameServer implements Service {

	private static final Logger log = LoggerFactory.getLogger(JNameServer.class.getName());
	static final int FLAG_DNSSECOK = 1;
	static final int FLAG_SIGONLY = 2;
	final DomainResourceFactory drf;
	final List sockAddrs = new ArrayList<>();
	final List tcpListeners = new ArrayList<>();
	final List udpListeners = new ArrayList<>();
	private final RecordTypes recordTypes = new RecordTypes();
	volatile boolean running;

	public JNameServer(DomainResourceFactory drf, InetSocketAddress... sockAddrs) {
		this.drf = drf;
		if (sockAddrs == null) {
			this.sockAddrs.add(new InetSocketAddress(53));
		} else {
			this.sockAddrs.addAll(Arrays.asList(sockAddrs));
		}
	}

	@Override
	public void start() {
		log.info("Starting DNS server");
		running = true;
		for (InetSocketAddress sockAddr : sockAddrs) {
			log.info("Listening on interface: " + sockAddr);
			TcpListener tl = new TcpListener(sockAddr);
			tcpListeners.add(tl);
			new Thread(tl).start();

			UdpListener ul = new UdpListener(sockAddr);
			udpListeners.add(ul);
			new Thread(ul).start();
		}
		System.out.println("started server");
	}

	@Override
	public void stop() {
		running = false;
		for (TcpListener tl : tcpListeners) {
			tl.close();
		}
		for (UdpListener ul : udpListeners) {
			ul.close();
		}
	}

	private void addRRset(Name name, Message response, RRset rrset, int section, int flags) {
		for (int s = 1; s <= section; s++) {
			if (response.findRRset(name, rrset.getType(), s)) {
				return;
			}
		}
		if ((flags & FLAG_SIGONLY) == 0) {
			Iterator it = rrset.rrs();
			while (it.hasNext()) {
				Record r = (Record) it.next();
				if (r.getName().isWild() && !name.isWild()) {
					r = r.withName(name);
				}
				response.addRecord(r, section);
			}
		}
		if ((flags & (FLAG_SIGONLY | FLAG_DNSSECOK)) != 0) {
			Iterator it = rrset.sigs();
			while (it.hasNext()) {
				Record r = (Record) it.next();
				if (r.getName().isWild() && !name.isWild()) {
					r = r.withName(name);
				}
				response.addRecord(r, section);
			}
		}
	}

	private void addSOA(Message response, SOARecord soaRecord) {
		response.addRecord(soaRecord, Section.AUTHORITY);
	}

	private void addNS(Message response, RRset nsRecords, int flags) {
		// RRset nsRecords = zone.getNS();
		addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY,
				flags);
	}

	private void addGlue(Message response, Name name, int flags) {

		DomainResource dr = getDomainResource(name.toString());
		if( dr == null ) {
			return ;
		}
		Name domainName;
		try {
			domainName = Utils.stringToName(dr.getName());
		} catch (TextParseException ex) {
			throw new RuntimeException(ex);
		}
		RRset a = getRRset(domainName, dr, Type.A, DClass.IN);
		if (a == null) {
			return;
		}
		addRRset(name, response, a, Section.ADDITIONAL, flags);

	}

	private void addAdditional2(Message response, int section, int flags) {
		Record[] records = response.getSectionArray(section);
		for (Record r : records) {
			Name glueName = r.getAdditionalName();
			if (glueName != null) {
				addGlue(response, glueName, flags);
			}
		}
	}

	private void addAdditional(Message response, int flags) {
		addAdditional2(response, Section.ANSWER, flags);
		addAdditional2(response, Section.AUTHORITY, flags);
	}

	private byte addAnswer(Message response, Name name, int type, int dclass, int iterations, int flags) {
		System.out.println("addAnswer: " + name + " type=" + type + " class=" + dclass);
		SetResponse sr;
		byte rcode = Rcode.NOERROR;

		if (iterations > 6) {
			log.warn("iterations too high");
			return Rcode.NOERROR;
		}

		if (type == Type.SIG || type == Type.RRSIG) {
			type = Type.ANY;
			flags |= FLAG_SIGONLY;
		}


		sr = generateSetResponse(name, type);
		ZoneDomainResource zdr = findBestZone(drf, name);
		if (sr.isUnknown()) {
			//addCacheNS(response, getCache(dclass), name);
		}
		Name domainName;
		try {
			// TODO: got a NPE here
			domainName = Utils.stringToName(zdr.getName());
		} catch (TextParseException ex) {
			System.out.println("parse ex");
			throw new RuntimeException(ex);
		}
		if (sr.isNXDOMAIN()) {
			log.info("is NX domain");
			response.getHeader().setRcode(Rcode.NXDOMAIN);
			if (zdr != null) {
				RRset rrSet = getRRset(domainName, zdr, Type.SOA, DClass.IN);
				if (rrSet != null) {
					addSOA(response, (SOARecord) rrSet.first());
				}
				if (iterations == 0) {
					response.getHeader().setFlag(Flags.AA);
				}
			}
			rcode = Rcode.NXDOMAIN;
		} else if (sr.isNXRRSET()) {
			log.info("isNXRRSET");
			if (zdr != null) {
				RRset rrSet = getRRset(domainName, zdr, Type.SOA, DClass.IN);
				if (rrSet != null) {
					addSOA(response, (SOARecord) rrSet.first());
				}
				if (iterations == 0) {
					response.getHeader().setFlag(Flags.AA);
				}
			}
		} else if (sr.isDelegation()) {
			log.info("delegation");
			RRset nsRecords = sr.getNS();
			addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
		} else if (sr.isCNAME()) {
			log.info("isCNAME");
			CNAMERecord cname = sr.getCNAME();
			RRset rrset = new RRset(cname);
			addRRset(name, response, rrset, Section.ANSWER, flags);
			if (zdr != null && iterations == 0) {
				response.getHeader().setFlag(Flags.AA);
			}
			rcode = addAnswer(response, cname.getTarget(), type, dclass,
					iterations + 1, flags);
		} else if (sr.isDNAME()) {
			log.info("isDNAME");
			DNAMERecord dname = sr.getDNAME();
			RRset rrset = new RRset(dname);
			addRRset(name, response, rrset, Section.ANSWER, flags);
			Name newname;
			try {
				newname = name.fromDNAME(dname);
			} catch (NameTooLongException e) {
				return Rcode.YXDOMAIN;
			}
			rrset = new RRset(new CNAMERecord(name, dclass, 0, newname));
			addRRset(name, response, rrset, Section.ANSWER, flags);
			if (zdr != null && iterations == 0) {
				response.getHeader().setFlag(Flags.AA);
			}
			rcode = addAnswer(response, newname, type, dclass, iterations + 1,
					flags);
		} else if (sr.isSuccessful()) {
			RRset[] rrsets = sr.answers();
			for (RRset rrset : rrsets) {
				addRRset(name, response, rrset, Section.ANSWER, flags);
			}
			if (zdr != null) {
				RRset rrSet = getRRset(domainName, zdr, Type.NS, DClass.IN);
				addNS(response, rrSet, flags);
				if (iterations == 0) {
					response.getHeader().setFlag(Flags.AA);
				}
			} else
				;//addCacheNS(response, getCache(dclass), name);
		}
		log.info(" = " + rcode);
		return rcode;
	}

	private SetResponse generateSetResponse(Name name, int type) {

		int labels;
		int olabels;
		int tlabels;
		RRset rrset;
		Name tname;
		Object types;
		SetResponse sr;

		DomainResource dr = null;

		Name origin = new Name(name, name.labels());
		olabels = origin.labels();
		labels = name.labels();

		for (tlabels = olabels; tlabels <= labels; tlabels++) {
			boolean isOrigin = (tlabels == olabels);
			boolean isExact = (tlabels == labels);

			if (isOrigin) {
				tname = origin;
			} else if (isExact) {
				tname = name;
			} else {
				tname = new Name(name, labels - tlabels);
			}


			dr = getDomainResource(tname.toString());

			if (dr == null) {
				continue;
			}

			Name domainName;
			try {
				domainName = Utils.stringToName(dr.getName());
			} catch (TextParseException ex) {
				throw new RuntimeException(ex);
			}

			/* If this is a delegation, return that. */
			if (!(dr instanceof ZoneDomainResource)) {
				RRset ns = getRRset(domainName, dr, Type.NS, DClass.IN);
				if (ns != null) {
					return new SetResponse(SetResponse.DELEGATION, ns);
				}
			}

			/* If this is an ANY lookup, return everything. */
			if (isExact && type == Type.ANY) {
				sr = new SetResponse(SetResponse.SUCCESSFUL);
				RRset[] sets = getAllRRsets(domainName, dr, DClass.IN);
				for (RRset set : sets) {
					sr.addRRset(set);
				}
				return sr;
			}

			/*
			 * If this is the name, look for the actual type or a CNAME.
			 * Otherwise, look for a DNAME.
			 */
			if (isExact) {
				rrset = getRRset(domainName, dr, type, DClass.IN);
				if (rrset != null) {
					sr = new SetResponse(SetResponse.SUCCESSFUL);
					sr.addRRset(rrset);
					return sr;
				}
				rrset = getRRset(domainName, dr, Type.CNAME, DClass.IN);
				if (rrset != null) {
					return new SetResponse(SetResponse.CNAME, rrset);
				}
			} else {
				rrset = getRRset(domainName, dr, Type.DNAME, DClass.IN);
				if (rrset != null) {
					return new SetResponse(SetResponse.DNAME, rrset);
				}
			}

			/* We found the name, but not the type. */
			if (isExact) {
				return SetResponse.ofType(SetResponse.NXRRSET);
			}
		}


		for (int i = 0; i < labels - /*olabels*/ 1; i++) {
			tname = name.wild(i + 1);

			dr = getDomainResource(tname.toString());
			if (dr == null) {
				continue;
			}

			Name domainName;
			try {
				domainName = Utils.stringToName(dr.getName());
			} catch (TextParseException ex) {
				throw new RuntimeException(ex);
			}

			rrset = getRRset(domainName, dr, type, DClass.IN);
			if (rrset != null) {
				sr = new SetResponse(SetResponse.SUCCESSFUL);
				sr.addRRset(rrset);
				return sr;
			}
		}


		return SetResponse.ofType(SetResponse.NXDOMAIN);
	}


	/*
	 * Note: a null return value means that the caller doesn't need to do
	 * anything. Currently this only happens if this is an AXFR request over
	 * TCP.
	 */
	byte[] generateReply(Message query, byte[] in, int length, Socket s)
			throws IOException {
		Header header;
		boolean badversion;
		int maxLength;
		int flags = 0;

		header = query.getHeader();
		if (header.getFlag(Flags.QR)) {
			return null;
		}
		if (header.getRcode() != Rcode.NOERROR) {
			return errorMessage(query, Rcode.FORMERR);
		}
		if (header.getOpcode() != Opcode.QUERY) {
			return errorMessage(query, Rcode.NOTIMP);
		}

		Record queryRecord = query.getQuestion();

		TSIGRecord queryTSIG = query.getTSIG();
		TSIG tsig = null;
		if (queryTSIG != null) {
			/*
			 * tsig = (TSIG) TSIGs.get(queryTSIG.getName()); if (tsig == null ||
			 * tsig.verify(query, in, length, null) != Rcode.NOERROR) return
			 * formerrMessage(in);
			 */
			return formerrMessage(in);
		}

		OPTRecord queryOPT = query.getOPT();
		if (queryOPT != null && queryOPT.getVersion() > 0) {
			badversion = true;
		}

		if (s != null) {
			maxLength = 65535;
		} else if (queryOPT != null) {
			maxLength = Math.max(queryOPT.getPayloadSize(), 512);
		} else {
			maxLength = 512;
		}

		if (queryOPT != null && (queryOPT.getFlags() & ExtendedFlags.DO) != 0) {
			flags = FLAG_DNSSECOK;
		}

		Message response = new Message(query.getHeader().getID());
		response.getHeader().setFlag(Flags.QR);
		if (query.getHeader().getFlag(Flags.RD)) {
			response.getHeader().setFlag(Flags.RD);
		}
		response.addRecord(queryRecord, Section.QUESTION);

		Name name = queryRecord.getName();
		int type = queryRecord.getType();
		int dclass = queryRecord.getDClass();
		if (type == Type.AXFR && s != null) {
			return doAXFR(name, query, tsig, queryTSIG, s);
		}
		if (!Type.isRR(type) && type != Type.ANY) {
			return errorMessage(query, Rcode.NOTIMP);
		}

		byte rcode = addAnswer(response, name, type, dclass, 0, flags);
		if (rcode != Rcode.NOERROR && rcode != Rcode.NXDOMAIN) {
			return errorMessage(query, rcode);
		}

		addAdditional(response, flags);

		if (queryOPT != null) {
			int optflags = (flags == FLAG_DNSSECOK) ? ExtendedFlags.DO : 0;
			OPTRecord opt = new OPTRecord((short) 4096, rcode, (byte) 0,
					optflags);
			response.addRecord(opt, Section.ADDITIONAL);
		}

		response.setTSIG(tsig, Rcode.NOERROR, queryTSIG);
		return response.toWire(maxLength);
	}

	byte[] buildErrorMessage(Header header, int rcode, Record question) {
		Message response = new Message();
		response.setHeader(header);
		for (int i = 0; i < 4; i++) {
			response.removeAllRecords(i);
		}
		if (rcode == Rcode.SERVFAIL) {
			response.addRecord(question, Section.QUESTION);
		}
		header.setRcode(rcode);
		return response.toWire();
	}

	byte[] doAXFR(Name name, Message query, TSIG tsig, TSIGRecord qtsig,
			Socket s) {
		return errorMessage(query, Rcode.REFUSED);
	}

	public byte[] formerrMessage(byte[] in) {
		Header header;
		try {
			header = new Header(in);
		} catch (IOException e) {
			return null;
		}
		return buildErrorMessage(header, Rcode.FORMERR, null);
	}

	public byte[] errorMessage(Message query, int rcode) {
		return buildErrorMessage(query.getHeader(), rcode, query.getQuestion());
	}

	private DomainResource getDomainResource(String domainName) {
		if (domainName.endsWith(".")) {
			domainName = domainName.substring(0, domainName.length() - 1);
		}
		try {
			return drf.getDomainResource(domainName);
		} catch (NonAuthoritativeException e) {
			return null;
		}
	}

	class TcpListener implements Runnable {

		final InetAddress addr;
		final int port;
		ServerSocket sock;

		TcpListener(InetSocketAddress sa) {
			this.addr = sa.getAddress();
			this.port = sa.getPort();
		}

		@Override
		public void run() {
			serveTCP(addr, port);
		}

		public void close() {
			if (sock != null) {
				try {
					sock.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}

		public void TCPclient(Socket s) {
			try {
				int inLength;
				DataInputStream dataIn;
				DataOutputStream dataOut;
				byte[] in;

				InputStream is = s.getInputStream();
				dataIn = new DataInputStream(is);
				inLength = dataIn.readUnsignedShort();
				in = new byte[inLength];
				dataIn.readFully(in);

				Message query = null;
				byte[] response = null;
				try {
					query = new Message(in);
					response = generateReply(query, in, in.length, s);
					if (response == null) {
						return;
					}
				} catch (IOException e) {
					log.error("exception", e);
					response = formerrMessage(in);
				} catch (RuntimeException e) {
					log.error("exception", e);
					if (query != null) {
						response = errorMessage(query, Rcode.SERVFAIL);
					} else {
						response = formerrMessage(in);
					}

				}
				dataOut = new DataOutputStream(s.getOutputStream());
				dataOut.writeShort(response.length);
				dataOut.write(response);
			} catch (IOException e) {
				System.out.println("TCPclient("
						+ addrport(s.getLocalAddress(), s.getLocalPort()) + "): "
						+ e);
			} finally {
				try {
					s.close();
				} catch (IOException e) {
				}
			}
		}

		public void serveTCP(InetAddress addr, int port) {
			try {
				sock = new ServerSocket(port, 128, addr);
				while (running) {
					final Socket s = sock.accept();
					Thread t;
					t = new Thread(() -> TCPclient(s));
					t.start();
				}
			} catch (IOException e) {
				System.out.println("serveTCP(" + addrport(addr, port) + "): " + e);
			}
		}
	}

	class UdpListener implements Runnable {

		DatagramSocket sock;
		final InetAddress addr;
		final int port;

		UdpListener(InetSocketAddress sa) {
			this.addr = sa.getAddress();
			this.port = sa.getPort();
		}

		@Override
		public void run() {
			serveUDP(addr, port);
		}

		public void close() {
			if (sock != null) {
				try {
					sock.close();
				} catch (Exception e) {
					e.printStackTrace();
				}
			}
		}

		public void serveUDP(InetAddress addr, int port) {
			try {
				DatagramSocket sock = new DatagramSocket(port, addr);

				final short udpLength = 512;
				byte[] in = new byte[udpLength];
				DatagramPacket indp = new DatagramPacket(in, in.length);
				DatagramPacket outdp = null;
				while (running) {
					indp.setLength(in.length);
					try {
						sock.receive(indp);
					} catch (InterruptedIOException e) {
						continue;
					}
					Message query = null;
					byte[] response = null;
					try {
						query = new Message(in);
						response = generateReply(query, in, indp.getLength(), null);
						if (response == null) {
							continue;
						}
					} catch (IOException e) {
						log.error("Exeption generating DNS response", e);
						response = formerrMessage(in);
					} catch (RuntimeException e) {
						log.error("Exeption generating DNS response", e);
						if (query != null) {
							response = errorMessage(query, Rcode.SERVFAIL);
						} else {
							response = formerrMessage(in);
						}
					}
					if (outdp == null) {
						outdp = new DatagramPacket(response, response.length,
								indp.getAddress(), indp.getPort());
					} else {
						outdp.setData(response);
						outdp.setLength(response.length);
						outdp.setAddress(indp.getAddress());
						outdp.setPort(indp.getPort());
					}
					sock.send(outdp);
				}
			} catch (IOException e) {
				System.out.println("serveUDP(" + addrport(addr, port) + "): " + e);
			}
		}
	}

	private RRset getRRset(Name domainName, DomainResource dr, int type, int dclass) {

		if (dr == null) {
			return null;
		}
		List allRecords = dr.getRecords();
		if (allRecords == null || allRecords.isEmpty()) {
			return null;
		}

		try {
			RRset rrset = new RRset();
			boolean empty = true;
			for (DomainResourceRecord dnsRec : allRecords) {
				System.out.println("check: " + dnsRec.getClass());
				Record rr;
				try {
					rr = recordTypes.map(domainName, dnsRec);
				} catch (TextParseException ex) {
					throw new RuntimeException(ex);
				}
				if (rr.getType() == type && rr.getDClass() == dclass) {
					rrset.addRR(rr);
					empty = false;
				}
			}
			if (empty) {
				return null;
			}
			return rrset;
		} catch (Throwable e) {
			throw new RuntimeException(e);
		}
	}

	private RRset[] getAllRRsets(Name domainName, DomainResource dr, int dclass) {

		if (dr == null) {
			return null;
		}
		List rrSets = new ArrayList<>();
		List allRecords = dr.getRecords();

		for (DomainResourceRecord dnsRec : allRecords) {
			Record rr;
			try {
				rr = recordTypes.map(domainName, dnsRec);
			} catch (TextParseException ex) {
				throw new RuntimeException(ex);
			}
			boolean added = false;
			for (RRset rrSet : rrSets) {
				if (rrSet.getType() == rr.getType()
						&& rrSet.getDClass() == rr.getDClass()) {
					rrSet.addRR(rr);
					added = true;
					break;
				}
			}
			if (!added) {
				RRset rrSet = new RRset();
				rrSet.addRR(rr);
				rrSets.add(rrSet);
			}
		}
		RRset[] rrSetArray = new RRset[rrSets.size()];
		return rrSets.toArray(rrSetArray);
	}

	public ZoneDomainResource findBestZone(DomainResourceFactory drf, Name name) {
		ZoneDomainResource zdr = null;
		for (int tlabels = name.labels(); tlabels > 0; tlabels--) {

			Name tname = new Name(name, name.labels() - tlabels);
			DomainResource dr = getDomainResource(Utils.nameToString(tname));
			if (dr instanceof ZoneDomainResource) {
				zdr = (ZoneDomainResource) dr;
				break;
			}
		}
		return zdr;
	}

	private String addrport(InetAddress addr, int port) {
		return addr.getHostAddress() + "#" + port;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy