* Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
* All rights reserved. Distributed under the terms of the MIT License.
*/
#include <boot/net/IP.h>
#include <stdio.h>
#include <KernelExport.h>
#include <boot/net/ARP.h>
#include <boot/net/ChainBuffer.h>
#ifdef TRACE_IP
# define TRACE(x) dprintf x
#else
# define TRACE(x) ;
#endif
IPSubService::IPSubService(const char *serviceName)
: NetService(serviceName)
{
}
IPSubService::~IPSubService()
{
}
IPService::IPService(EthernetService *ethernet, ARPService *arpService)
: EthernetSubService(kIPServiceName),
fEthernet(ethernet),
fARPService(arpService)
{
}
IPService::~IPService()
{
if (fEthernet)
fEthernet->UnregisterEthernetSubService(this);
}
status_t
IPService::Init()
{
if (!fEthernet)
return B_BAD_VALUE;
if (!fEthernet->RegisterEthernetSubService(this))
return B_NO_MEMORY;
return B_OK;
}
ip_addr_t
IPService::IPAddress() const
{
return (fEthernet ? fEthernet->IPAddress() : INADDR_ANY);
}
uint16
IPService::EthernetProtocol() const
{
return ETHERTYPE_IP;
}
void
IPService::HandleEthernetPacket(EthernetService *ethernet,
const mac_addr_t &targetAddress, const void *data, size_t size)
{
TRACE(("IPService::HandleEthernetPacket(): %lu - %lu bytes\n", size,
sizeof(ip_header)));
if (!data || size < sizeof(ip_header))
return;
const ip_header *header = (const ip_header*)data;
int headerLength = header->header_length * 4;
if (headerLength < 20 || headerLength > (int)size
|| header->version != IP_PROTOCOL_VERSION_4
|| ntohs(header->total_length) > size
|| (header->destination != htonl(INADDR_BROADCAST)
&& (fEthernet->IPAddress() == INADDR_ANY
|| header->destination != htonl(fEthernet->IPAddress())))
|| _Checksum(*header) != 0) {
return;
}
int serviceCount = fServices.Count();
for (int i = 0; i < serviceCount; i++) {
IPSubService *service = fServices.ElementAt(i);
if (service->IPProtocol() == header->protocol) {
service->HandleIPPacket(this, ntohl(header->source),
ntohl(header->destination),
(uint8*)data + headerLength,
ntohs(header->total_length) - headerLength);
break;
}
}
}
status_t
IPService::Send(ip_addr_t destination, uint8 protocol, ChainBuffer *buffer)
{
TRACE(("IPService::Send(to: %08lx, proto: %lu, %lu bytes)\n", destination,
(uint32)protocol, (buffer ? buffer->TotalSize() : 0)));
if (!buffer)
return B_BAD_VALUE;
if (!fEthernet || !fARPService)
return B_NO_INIT;
ip_header header;
ChainBuffer headerBuffer(&header, sizeof(header), buffer);
header.header_length = 5;
header.version = IP_PROTOCOL_VERSION_4;
header.type_of_service = 0;
header.total_length = htons(headerBuffer.TotalSize());
header.identifier = 0;
header.fragment_offset = htons(IP_DONT_FRAGMENT);
header.time_to_live = IP_DEFAULT_TIME_TO_LIVE;
header.protocol = protocol;
header.checksum = 0;
header.source = htonl(fEthernet->IPAddress());
header.destination = htonl(destination);
header.checksum = htons(_Checksum(header));
mac_addr_t targetMAC;
status_t error = fARPService->GetMACForIP(destination, targetMAC);
if (error != B_OK)
return error;
return fEthernet->Send(targetMAC, ETHERTYPE_IP, &headerBuffer);
}
void
IPService::ProcessIncomingPackets()
{
if (fEthernet)
fEthernet->ProcessIncomingPackets();
}
bool
IPService::RegisterIPSubService(IPSubService *service)
{
return (service && fServices.Add(service) == B_OK);
}
bool
IPService::UnregisterIPSubService(IPSubService *service)
{
return (service && fServices.Remove(service) >= 0);
}
int
IPService::CountSubNetServices() const
{
return fServices.Count();
}
NetService *
IPService::SubNetServiceAt(int index) const
{
return fServices.ElementAt(index);
}
uint16
IPService::_Checksum(const ip_header &header)
{
ChainBuffer buffer((void*)&header, header.header_length * 4);
return ip_checksum(&buffer);
}
uint16
ip_checksum(ChainBuffer *buffer)
{
struct Iterator {
Iterator(ChainBuffer *buffer)
: fBuffer(buffer),
fOffset(-1)
{
_Next();
}
bool HasNext() const
{
return fBuffer;
}
uint16 Next()
{
uint16 byte = _NextByte();
return (byte << 8) | _NextByte();
}
private:
void _Next()
{
while (fBuffer) {
fOffset++;
if (fOffset < (int)fBuffer->Size())
break;
fOffset = -1;
fBuffer = fBuffer->Next();
}
}
uint8 _NextByte()
{
uint8 byte = (fBuffer ? ((uint8*)fBuffer->Data())[fOffset] : 0);
_Next();
return byte;
}
ChainBuffer *fBuffer;
int fOffset;
};
Iterator it(buffer);
uint32 checksum = 0;
while (it.HasNext()) {
checksum += it.Next();
while (checksum >> 16)
checksum = (checksum & 0xffff) + (checksum >> 16);
}
return ~checksum;
}
ip_addr_t
ip_parse_address(const char *string)
{
ip_addr_t address = 0;
int components = 0;
while (components < 4) {
address |= strtol(string, NULL, 0) << ((4 - components - 1) * 8);
const char *dot = strchr(string, '.');
if (dot == NULL)
break;
string = dot + 1;
components++;
}
return address;
}