#include "DNSQuery.h"
#include <errno.h>
#include <stdio.h>
#include <ByteOrder.h>
#include <FindDirectory.h>
#include <NetAddress.h>
#include <NetEndpoint.h>
#include <Path.h>
#undef PRINT
#ifdef DEBUG
#define PRINT(a...) printf(a)
#else
#define PRINT(a...)
#endif
static int32 gID = 1;
BRawNetBuffer::BRawNetBuffer()
{
_Init(NULL, 0);
}
BRawNetBuffer::BRawNetBuffer(off_t size)
{
_Init(NULL, 0);
fBuffer.SetSize(size);
}
BRawNetBuffer::BRawNetBuffer(const void* buf, size_t size)
{
_Init(buf, size);
}
status_t
BRawNetBuffer::AppendUint16(uint16 value)
{
uint16 netVal = B_HOST_TO_BENDIAN_INT16(value);
ssize_t sizeW = fBuffer.WriteAt(fWritePosition, &netVal, sizeof(uint16));
if (sizeW == B_NO_MEMORY)
return B_NO_MEMORY;
fWritePosition += sizeof(uint16);
return B_OK;
}
status_t
BRawNetBuffer::AppendString(const char* string)
{
size_t length = strlen(string) + 1;
ssize_t sizeW = fBuffer.WriteAt(fWritePosition, string, length);
if (sizeW == B_NO_MEMORY)
return B_NO_MEMORY;
fWritePosition += length;
return B_OK;
}
status_t
BRawNetBuffer::ReadUint16(uint16& value)
{
uint16 netVal;
ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint16));
if (sizeW == 0)
return B_ERROR;
value= B_BENDIAN_TO_HOST_INT16(netVal);
fReadPosition += sizeof(uint16);
return B_OK;
}
status_t
BRawNetBuffer::ReadUint32(uint32& value)
{
uint32 netVal;
ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint32));
if (sizeW == 0)
return B_ERROR;
value= B_BENDIAN_TO_HOST_INT32(netVal);
fReadPosition += sizeof(uint32);
return B_OK;
}
status_t
BRawNetBuffer::ReadString(BString& string)
{
string = "";
ssize_t bytesRead = _ReadStringAt(string, fReadPosition);
if (bytesRead < 0)
return B_ERROR;
fReadPosition += bytesRead;
return B_OK;
}
status_t
BRawNetBuffer::SkipReading(off_t skip)
{
if (fReadPosition + skip > (off_t)fBuffer.BufferLength())
return B_ERROR;
fReadPosition += skip;
return B_OK;
}
void
BRawNetBuffer::_Init(const void* buf, size_t size)
{
fWritePosition = 0;
fReadPosition = 0;
fBuffer.WriteAt(fWritePosition, buf, size);
}
ssize_t
BRawNetBuffer::_ReadStringAt(BString& string, off_t pos)
{
if (pos >= (off_t)fBuffer.BufferLength())
return -1;
ssize_t bytesRead = 0;
char* buffer = (char*)fBuffer.Buffer();
buffer = &buffer[pos];
while (pos < (off_t)fBuffer.BufferLength() && *buffer != 0) {
if (uint8(*buffer) == 192) {
buffer++;
bytesRead++;
off_t subPos = uint8(*buffer);
_ReadStringAt(string, subPos);
break;
}
string.Append(buffer, 1);
buffer++;
bytesRead++;
}
bytesRead++;
return bytesRead;
}
status_t
DNSTools::GetDNSServers(BObjectList<BString, true>* serverList)
{
#define MATCH(line, name) \
(!strncmp(line, name, sizeof(name) - 1) && \
(line[sizeof(name) - 1] == ' ' || \
line[sizeof(name) - 1] == '\t'))
BPath path;
if (find_directory(B_SYSTEM_SETTINGS_DIRECTORY, &path) != B_OK)
return B_ENTRY_NOT_FOUND;
path.Append("network/resolv.conf");
FILE* fp = fopen(path.Path(), "r");
if (fp == NULL) {
fprintf(stderr, "failed to open '%s' to read nameservers: %s\n",
path.Path(), strerror(errno));
return B_ENTRY_NOT_FOUND;
}
int nserv = 0;
char buf[1024];
char *cp;
int MAXNS = 2;
while (fgets(buf, sizeof(buf), fp) != NULL) {
if (*buf == ';' || *buf == '#')
continue;
if (MATCH(buf, "nameserver") && nserv < MAXNS) {
cp = buf + sizeof("nameserver") - 1;
while (*cp == ' ' || *cp == '\t')
cp++;
cp[strcspn(cp, ";# \t\n")] = '\0';
if ((*cp != '\0') && (*cp != '\n')) {
serverList->AddItem(new BString(cp));
nserv++;
}
}
continue;
}
fclose(fp);
return B_OK;
}
BString
DNSTools::ConvertToDNSName(const BString& string)
{
BString outString = string;
int32 dot, lastDot, diff;
dot = string.FindFirst(".");
if (dot != B_ERROR) {
outString.Prepend((char*)&dot, 1);
lastDot = dot + 1;
while (true) {
dot = outString.FindFirst(".", lastDot + 1);
if (dot == B_ERROR)
break;
diff = dot - 1 - lastDot;
outString.SetByteAt(lastDot, (char)diff);
lastDot = dot;
}
} else
lastDot = 0;
diff = outString.CountChars() - 1 - lastDot;
outString.SetByteAt(lastDot, (char)diff);
return outString;
}
BString
DNSTools::ConvertFromDNSName(const BString& string)
{
if (string.Length() == 0)
return string;
BString outString = string;
int32 dot = string[0];
int32 nextDot = dot;
outString.Remove(0, sizeof(char));
while (true) {
if (nextDot >= outString.Length())
break;
dot = outString[nextDot];
if (dot == 0)
break;
outString.SetByteAt(nextDot, '.');
nextDot+= dot + 1;
}
return outString;
}
DNSQuery::DNSQuery()
{
}
DNSQuery::~DNSQuery()
{
}
status_t
DNSQuery::ReadDNSServer(in_addr* add)
{
BObjectList<BString, true> dnsServerList(5);
status_t status = DNSTools::GetDNSServers(&dnsServerList);
if (status != B_OK)
return status;
BString* firstDNS = dnsServerList.ItemAt(0);
if (firstDNS == NULL || inet_aton(firstDNS->String(), add) != 1)
return B_ERROR;
PRINT("dns server found: %s \n", firstDNS->String());
return B_OK;
}
status_t
DNSQuery::GetMXRecords(const BString& serverName,
BObjectList<mx_record, true>* mxList, bigtime_t timeout)
{
in_addr dnsAddress;
if (ReadDNSServer(&dnsAddress) != B_OK)
return B_ERROR;
BRawNetBuffer buffer;
dns_header header;
_SetMXHeader(&header);
_AppendQueryHeader(buffer, &header);
BString serverNameConv = DNSTools::ConvertToDNSName(serverName);
buffer.AppendString(serverNameConv);
buffer.AppendUint16(uint16(MX_RECORD));
buffer.AppendUint16(uint16(1));
PRINT("send buffer\n");
BNetAddress netAddress(dnsAddress, 53);
BNetEndpoint netEndpoint(SOCK_DGRAM);
if (netEndpoint.InitCheck() != B_OK)
return B_ERROR;
if (netEndpoint.Connect(netAddress) != B_OK)
return B_ERROR;
PRINT("Connected\n");
int32 bytesSend = netEndpoint.Send(buffer.Data(), buffer.Size());
if (bytesSend == B_ERROR)
return B_ERROR;
PRINT("bytes send %i\n", int(bytesSend));
BRawNetBuffer receiBuffer(512);
netEndpoint.SetTimeout(timeout);
int32 bytesRecei = netEndpoint.ReceiveFrom(receiBuffer.Data(), 512,
netAddress);
if (bytesRecei == B_ERROR)
return B_ERROR;
PRINT("bytes received %i\n", int(bytesRecei));
dns_header receiHeader;
_ReadQueryHeader(receiBuffer, &receiHeader);
PRINT("Package contains :");
PRINT("%d Questions, ", receiHeader.q_count);
PRINT("%d Answers, ", receiHeader.ans_count);
PRINT("%d Authoritative Servers, ", receiHeader.auth_count);
PRINT("%d Additional records\n", receiHeader.add_count);
BString dummyS;
uint16 dummy;
receiBuffer.ReadString(dummyS);
receiBuffer.ReadUint16(dummy);
receiBuffer.ReadUint16(dummy);
bool mxRecordFound = false;
for (int i = 0; i < receiHeader.ans_count; i++) {
resource_record_head rrHead;
_ReadResourceRecord(receiBuffer, &rrHead);
if (rrHead.type == MX_RECORD) {
mx_record* mxRec = new mx_record;
_ReadMXRecord(receiBuffer, mxRec);
PRINT("MX record found pri %i, name %s\n",
mxRec->priority, mxRec->serverName.String());
mxList->AddItem(mxRec);
mxRecordFound = true;
} else {
buffer.SkipReading(rrHead.dataLength);
}
}
if (!mxRecordFound)
return B_ERROR;
return B_OK;
}
uint16
DNSQuery::_GetUniqueID()
{
int32 nextId= atomic_add(&gID, 1);
if (nextId > 65529)
nextId = 0;
return nextId;
}
void
DNSQuery::_SetMXHeader(dns_header* header)
{
header->id = _GetUniqueID();
header->qr = 0;
header->opcode = 0;
header->aa = 0;
header->tc = 0;
header->rd = 1;
header->ra = 0;
header->z = 0;
header->rcode = 0;
header->q_count = 1;
header->ans_count = 0;
header->auth_count = 0;
header->add_count = 0;
}
void
DNSQuery::_AppendQueryHeader(BRawNetBuffer& buffer, const dns_header* header)
{
buffer.AppendUint16(header->id);
uint16 data = 0;
data |= header->rcode;
data |= header->z << 4;
data |= header->ra << 7;
data |= header->rd << 8;
data |= header->tc << 9;
data |= header->aa << 10;
data |= header->opcode << 11;
data |= header->qr << 15;
buffer.AppendUint16(data);
buffer.AppendUint16(header->q_count);
buffer.AppendUint16(header->ans_count);
buffer.AppendUint16(header->auth_count);
buffer.AppendUint16(header->add_count);
}
void
DNSQuery::_ReadQueryHeader(BRawNetBuffer& buffer, dns_header* header)
{
buffer.ReadUint16(header->id);
uint16 data = 0;
buffer.ReadUint16(data);
header->rcode = data & 0x0F;
header->z = (data >> 4) & 0x07;
header->ra = (data >> 7) & 0x01;
header->rd = (data >> 8) & 0x01;
header->tc = (data >> 9) & 0x01;
header->aa = (data >> 10) & 0x01;
header->opcode = (data >> 11) & 0x0F;
header->qr = (data >> 15) & 0x01;
buffer.ReadUint16(header->q_count);
buffer.ReadUint16(header->ans_count);
buffer.ReadUint16(header->auth_count);
buffer.ReadUint16(header->add_count);
}
void
DNSQuery::_ReadMXRecord(BRawNetBuffer& buffer, mx_record* mxRecord)
{
buffer.ReadUint16(mxRecord->priority);
buffer.ReadString(mxRecord->serverName);
mxRecord->serverName = DNSTools::ConvertFromDNSName(mxRecord->serverName);
}
void
DNSQuery::_ReadResourceRecord(BRawNetBuffer& buffer,
resource_record_head *rrHead)
{
buffer.ReadString(rrHead->name);
buffer.ReadUint16(rrHead->type);
buffer.ReadUint16(rrHead->dataClass);
buffer.ReadUint32(rrHead->ttl);
buffer.ReadUint16(rrHead->dataLength);
}