用java实现一个dns服务器
最近有个需求,一个域名,需要根据不同的IP,返回不同的结果。而且还要提供管理借口。看了一下rfc,实现一个基本版,还是非常简单的。把域名的NS设置成自己的dns服务器,解析起来就太方便了。
DNS协议本身是基于UDP的应用层协议,一开始用了node.js处理,但是发现真处理byte buffer还是不大顺手,最后还是转回到java来解决的。
有一个需要注意的地方是,dns数据段,用了[length][data]这种格式。而且每个「.」都是一段。需要注意。
另外,配合Wireshark,分析起来会方便很多。
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;
public class DNSd {
static Logger log = Logger.getLogger("DNSd");
int httpPort = 9998, dnsPort = 9999;
String hostsFile = "hosts.txt";
String host = "betadata.me";
byte[] defaultIp = ByteBuffer.allocate(4)
.putInt(ipToInt("162.243.157.169")).array();
String controlIp = null;// "12.24.2.2",if not null,only this ip is
// allowed to control the server
boolean seted = false;
static ConcurrentHashMap<Integer, byte[]> resolved = new ConcurrentHashMap<>();
/**
* save resolved to disk(file)
*/
private void save() {
if (seted) {
try {
Files.write(Paths.get(hostsFile), dump().getBytes());
seted = false;
} catch (IOException e) {
log.warning(e.getLocalizedMessage());
}
}
}
/**
* memory resolved map to txt
*
* @return
*/
private String dump() {
StringBuilder sb = new StringBuilder();
for (int key : resolved.keySet()) {
sb.append(intToIp(key)).append("\t")
.append(bytesToIp(resolved.get(key))).append("\n");
}
return sb.toString();
}
/**
* load hosts files into memory(resolved)
*/
private void load() {
try (Scanner scan = new Scanner(new File(hostsFile));) {
while (scan.hasNextLine()) {
String line = scan.nextLine();
int idx = line.indexOf('\t');
resolved.put(ipToInt(line.substring(0, idx)),
ipToBytes(line.substring(idx + 1)));
}
} catch (Exception e) {
log.warning(e.getMessage());
}
}
/**
* start a httpserver to control your dns server
*/
@SuppressWarnings("restriction")
private void startHttpServer() {
try {
HttpServer server = HttpServer.create(new InetSocketAddress(
httpPort), 0);
log.info("http control server started: " + httpPort);
server.createContext("/", new HttpHandler() {
@Override
public void handle(HttpExchange t) {
try {
log.info(t.getRemoteAddress() + "\t"
+ t.getRequestURI());
String[] req = t.getRequestURI().toString()
.substring(1).split("/");
String ret = "error, unkonwn!";
if (req.length > 0
&& ((controlIp != null && bytesToIp(
t.getRemoteAddress().getAddress()
.getAddress())
.equals(controlIp)) || controlIp == null)) {
switch (req[0]) {
case "save":
save();
ret = "saved!";
break;
case "set":// set/192.168.0.1[/202.96.204.3]
// TODO check ip format
if (req.length == 3) {
seted = true;
resolved.put(ipToInt(req[2]),
ipToBytes(req[1]));
ret = "seted: " + host + "->" + req[1]
+ " (on ip " + req[2] + ")";
} else if (req.length == 2) {
seted = true;
resolved.put(bytesToInt(t
.getRemoteAddress().getAddress()
.getAddress()), ipToBytes(req[1]));
ret = "seted: " + host + "->" + req[1]
+ " (on ip "
+ t.getRemoteAddress().getAddress()
+ ")";
} else
ret = "seted: failed! check your url: "
+ t.getRequestURI();
break;
default:
ret = dump();
break;
}
}
byte[] retBytes = ret.getBytes("UTF-8");
t.sendResponseHeaders(200, retBytes.length);
t.getResponseHeaders().add("Content-Type",
"text/plain; charset=utf-8");
OutputStream os = t.getResponseBody();
os.write(retBytes);
os.close();
} catch (Exception e) {
e.printStackTrace();
} finally {
t.close();
}
}
});
server.start();
} catch (IOException e) {
e.printStackTrace();
}
}
// ip string byte[] converts
private int ipToInt(String ipAddress) {
long result = 0;
String[] ipAddressInArray = ipAddress.split("\\.");
for (int i = 3; i >= 0; i--) {
long ip = Long.parseLong(ipAddressInArray[3 - i]);
result |= ip << (i * 8);
}
return (int) result;
}
private byte[] ipToBytes(String ip) {
return ByteBuffer.allocate(4).putInt(ipToInt(ip)).array();
}
private String intToIp(int i) {
return ((i >> 24) & 0xFF) + "." + ((i >> 16) & 0xFF) + "."
+ ((i >> 8) & 0xFF) + "." + (i & 0xFF);
}
private int bytesToInt(byte[] bytes) {
int val = 0;
for (int i = 0; i < bytes.length; i++) {
val <<= 8;
val |= bytes[i] & 0xff;
}
return val;
}
private String bytesToIp(byte[] bytes) {
return ((bytes[0]) & 0xFF) + "." + ((bytes[1]) & 0xFF) + "."
+ ((bytes[2]) & 0xFF) + "." + (bytes[3] & 0xFF);
}
private void startDnsd() {
try (DatagramSocket serverSocket = new DatagramSocket(dnsPort)) {
byte[] receiveData = new byte[512];
log.info("DNSd started at :" + dnsPort);
while (true) {
try {
DatagramPacket receivePacket = new DatagramPacket(
receiveData, receiveData.length);
serverSocket.receive(receivePacket);
StringBuilder qname = new StringBuilder();
int idx = 12;// skip
// transaction/id/flags/questions/answer/authority/additional
int len = receiveData[idx];
while (len > 0) {
qname.append(".").append(
new String(receiveData, idx + 1, len));
idx += len + 1;
len = receiveData[idx];
}
if (qname.length() > 0) {
String name = qname.substring(1).toLowerCase();
int type = receiveData[idx + 1] * 256
+ receiveData[idx + 2];
log.info(receivePacket.getAddress() + ":"
+ receivePacket.getPort() + "\t" + name + "\t"
+ type);
if ((!name.equals(host))
&& (!name.endsWith("." + host))) {
continue;// keep silence
}
if (type != 1 && !name.equals(host)) {
continue;// we only response for A records, except
// for MX
// for host
}
ByteArrayOutputStream bo = new ByteArrayOutputStream();
bo.write(new byte[] { receiveData[0], receiveData[1],
(byte) 0x81, (byte) 0x80, 0x00, 0x01, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00 });
// write query
byte[] req = Arrays.copyOfRange(receiveData, 12,
idx + 5);
bo.write(req);
bo.write(req);
bo.write(ByteBuffer.allocate(4)
.putInt(name.equals(host) ? 3600 : 10).array());// ttl,
if (type == 1) {
bo.write(new byte[] { 0x00, 0x04 });
int val = bytesToInt(receivePacket.getAddress()
.getAddress());
bo.write((!name.equals(host))
&& resolved.containsKey(val) ? resolved
.get(val) : defaultIp);
} else {// for MX
String mx = "mxdomain.qq.com";
bo.write(ByteBuffer.allocate(2)
.putShort((short) (mx.length() + 4))
.array());
bo.write(0x00);
bo.write(0x05);// preference
for (String s : mx.split("\\.")) {
bo.write((byte) s.length());
bo.write(s.getBytes());
}
bo.write(0x00);
}
byte[] sendData = bo.toByteArray();
DatagramPacket sendPacket = new DatagramPacket(
sendData, sendData.length,
receivePacket.getAddress(),
receivePacket.getPort());
serverSocket.send(sendPacket);
}
} catch (Exception e) {
log.warning(e.getMessage());
}
}
} catch (Exception e) {
log.warning(e.getMessage());
}
}
public static void main(String[] args) {
DNSd dnsd = new DNSd();
for (String arg : args) {
if (arg.startsWith("-http")) {
dnsd.httpPort = Integer.parseInt(arg.substring(5));
} else if (arg.startsWith("-dns")) {
dnsd.dnsPort = Integer.parseInt(arg.substring(4));
} else if (arg.startsWith("-host")) {
dnsd.host = arg.substring(5);
} else if (arg.startsWith("-ip")) {
dnsd.defaultIp = dnsd.ipToBytes(arg.substring(3));
} else if (arg.startsWith("-cip")) {
dnsd.controlIp = arg.substring(4);
}
}
System.out.println("i'm go to serve you, master!\n" + dnsd);
dnsd.load();
dnsd.startHttpServer();
// auto save after 10mins
new Thread() {
{
this.setDaemon(true);
}
public void run() {
while (true) {
try {
sleep(600000);
dnsd.save();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
};
}.start();
// auto save when exit
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
dnsd.save();
}
});
// i'll block by start dns server
dnsd.startDnsd();
}
// for debug
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("DNSd:\n -http=").append(httpPort).append("\n -dns=")
.append(dnsPort).append("\n -host=").append(host)
.append("\n -ip=").append(bytesToIp(defaultIp))
.append("\n -cip=").append(controlIp);
return builder.toString();
}
final protected static char[] hexArray = "0123456789ABCDEF".toCharArray();
public static String bytesToHex(byte[] bytes) {
char[] hexChars = new char[bytes.length * 3];
StringBuffer sb = new StringBuffer();
StringBuffer sb1 = new StringBuffer();
for (int j = 0; j < bytes.length; j++) {
int v = bytes[j] & 0xFF;
hexChars[j * 3] = hexArray[v >>> 4];
hexChars[j * 3 + 1] = hexArray[v & 0x0F];
hexChars[j * 3 + 2] = ' ';
char c = (char) bytes[j];
sb.append(" ").append(Character.isLetterOrDigit(c) ? c : '!')
.append(" ");
sb1.append(" ").append(j).append(j > 9 ? "" : " ");
}
return new String(hexChars) + "\n" + sb + "\n" + sb1;
}
}
host.txt内容:
127.0.0.1 7.7.7.4
106.186.124.201 8.8.8.8
101.231.192.210 162.243.157.169
这样我们可以通过:
http://server:port/[dump/set/save]
来管理dns服务器,并且服务器还能每隔10分钟,自动save一次。
代码同步在Github上,以后如有更新,以Github为准: https://github.com/100apps/DNSd