⛏️ index : haiku.git

/*
 * Copyright 2006-2023, Haiku, Inc. All Rights Reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *		Axel Dörfler, axeld@pinc-software.de
 */


#include "argv.h"
#include "tcp.h"
#include "pcap.h"
#include "utility.h"

#include <AutoDeleter.h>
#include <NetBufferUtilities.h>
#include <Referenceable.h>
#include <net_buffer.h>
#include <net_datalink.h>
#include <net_protocol.h>
#include <net_socket.h>
#include <net_stack.h>
#include <slab/Slab.h>
#include <util/AutoLock.h>
#include <util/DoublyLinkedList.h>

#include <KernelExport.h>
#include <Select.h>
#include <module.h>
#include <Locker.h>

#include <netinet/in.h>
#include <netinet/ip.h>

#include <ctype.h>
#include <errno.h>
#include <new>
#include <set>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>


struct context {
	BLocker		lock;
	sem_id		wait_sem;
	struct list list;
	net_route	route;
	bool		server;
	thread_id	thread;
};

struct cmd_entry {
	const char*	name;
	void	(*func)(int argc, char **argv);
	const char*	help;
};


struct net_socket_private;
typedef DoublyLinkedList<net_socket_private> SocketList;

struct net_socket_private : net_socket,
		DoublyLinkedListLinkImpl<net_socket_private> {
	struct net_socket		*parent;

	team_id					owner;
	uint32					max_backlog;
	uint32					child_count;
	SocketList				pending_children;
	SocketList				connected_children;

	struct select_sync_pool	*select_pool;
	mutex					lock;
};


struct InterfaceAddress : DoublyLinkedListLinkImpl<InterfaceAddress>,
		net_interface_address, BReferenceable {
};


extern "C" status_t _add_builtin_module(module_info *info);
extern "C" status_t _get_builtin_dependencies(void);
extern bool gDebugOutputEnabled;
	// from libkernelland_emu.so

extern struct net_buffer_module_info gNetBufferModule;
	// from net_buffer.cpp
extern net_address_module_info gIPv4AddressModule;
	// from ipv4_address.cpp
extern module_info *modules[];
	// from tcp.cpp


extern struct net_protocol_module_info gDomainModule;
struct InterfaceAddress gInterfaceAddress;
extern struct net_socket_module_info gNetSocketModule;
struct net_protocol_module_info *gTCPModule;
struct net_socket *gServerSocket, *gClientSocket;
static struct context sClientContext, sServerContext;

static int32 sPacketNumber = 1;
static double sRandomDrop = 0.0;
static std::set<uint32> sDropList;
static bigtime_t sRoundTripTime = 0;
static bool sIncreasingRoundTrip = false;
static bool sRandomRoundTrip = false;
static void (*sPacketMonitor)(net_buffer *, int32, bool) = NULL;
static int sPcapFD = -1;
static bigtime_t sStartTime;
static double sRandomReorder = 0.0;
static std::set<uint32> sReorderList;
static bool sSimultaneousConnect = false;
static bool sSimultaneousClose = false;
static bool sServerActiveClose = false;

static struct net_domain sDomain = {
	"ipv4",
	AF_INET,

	&gDomainModule,
	&gIPv4AddressModule
};


static bool
is_server(const sockaddr* addr)
{
	return ((sockaddr_in*)addr)->sin_port == htons(1024);
}


static uint8
tcp_segment_flags(net_buffer* buffer)
{
	NetBufferHeaderReader<tcp_header> bufferHeader(buffer);
	if (bufferHeader.Status() < B_OK)
		return bufferHeader.Status();

	tcp_header &header = bufferHeader.Data();
	return header.flags;
}


static bool
is_syn(net_buffer* buffer)
{
	return (tcp_segment_flags(buffer) & TCP_FLAG_SYNCHRONIZE) != 0;
}


static bool
is_fin(net_buffer* buffer)
{
	return (tcp_segment_flags(buffer) & TCP_FLAG_FINISH) != 0;
}


//	#pragma mark - stack


status_t
std_ops(int32, ...)
{
	return B_OK;
}


net_domain *
get_domain(int family)
{
	return &sDomain;
}


status_t
register_domain_protocols(int family, int type, int protocol, ...)
{
	return B_OK;
}


status_t
register_domain_datalink_protocols(int family, int type, ...)
{
	return B_OK;
}


static status_t
register_domain_receiving_protocol(int family, int type, const char *moduleName)
{
	return B_OK;
}


static bool
dummy_is_syscall(void)
{
	return false;
}


static bool
dummy_is_restarted_syscall(void)
{
	return false;
}


static void
dummy_store_syscall_restart_timeout(bigtime_t timeout)
{
}


static net_stack_module_info gNetStackModule = {
	{
		NET_STACK_MODULE_NAME,
		0,
		std_ops
	},
	NULL, // register_domain,
	NULL, // unregister_domain,
	get_domain,

	register_domain_protocols,
	register_domain_datalink_protocols,
	register_domain_receiving_protocol,

	NULL, // get_domain_receiving_protocol,
	NULL, // put_domain_receiving_protocol,

	NULL, // register_device_deframer,
	NULL, // unregister_device_deframer,
	NULL, // register_domain_device_handler,
	NULL, // register_device_handler,
	NULL, // unregister_device_handler,
	NULL, // register_device_monitor,
	NULL, // unregister_device_monitor,
	NULL, // device_link_changed,
	NULL, // device_removed,
	NULL, // device_enqueue_buffer,

	notify_socket,

	checksum,

	init_fifo,
	uninit_fifo,
	fifo_enqueue_buffer,
	fifo_dequeue_buffer,
	clear_fifo,
	fifo_socket_enqueue_buffer,

	init_timer,
	set_timer,
	cancel_timer,
	wait_for_timer,
	is_timer_active,
	is_timer_running,

	dummy_is_syscall,
	dummy_is_restarted_syscall,
	dummy_store_syscall_restart_timeout,
	NULL, // restore_syscall_restart_timeout

	// ancillary data is not used by TCP
};


//	#pragma mark - socket


status_t
socket_create(int family, int type, int protocol, net_socket **_socket)
{
	net_protocol* domainProtocol;

	struct net_socket_private *socket = new (std::nothrow) net_socket_private;
	if (socket == NULL)
		return B_NO_MEMORY;

	memset(socket, 0, sizeof(net_socket));
	socket->family = family;
	socket->type = type;
	socket->protocol = protocol;

	mutex_init(&socket->lock, "socket");

	// set defaults (may be overridden by the protocols)
	socket->send.buffer_size = 65535;
	socket->send.low_water_mark = 1;
	socket->send.timeout = B_INFINITE_TIMEOUT;
	socket->receive.buffer_size = 65535;
	socket->receive.low_water_mark = 1;
	socket->receive.timeout = B_INFINITE_TIMEOUT;

	socket->first_protocol = gTCPModule->init_protocol(socket);
	if (socket->first_protocol == NULL) {
		fprintf(stderr, "tcp_tester: cannot create protocol\n");
		mutex_destroy(&socket->lock);
		delete socket;
		return B_ERROR;
	}

	socket->first_info = gTCPModule;

	domainProtocol = new net_protocol;
	domainProtocol->module = &gDomainModule;
	domainProtocol->socket = socket;

	socket->first_protocol->next = domainProtocol;
	socket->first_protocol->module = gTCPModule;
	socket->first_protocol->socket = socket;

	*_socket = socket;
	return B_OK;
}


void
socket_delete(net_socket *_socket)
{
	net_socket_private *socket = (net_socket_private *)_socket;

	if (socket->parent != NULL)
		panic("socket still has a parent!");

	socket->first_info->uninit_protocol(socket->first_protocol);
	mutex_destroy(&socket->lock);
	delete socket;
}


int
socket_accept(net_socket *socket, struct sockaddr *address,
	socklen_t *_addressLength, net_socket **_acceptedSocket)
{
	net_socket *accepted;
	status_t status = socket->first_info->accept(socket->first_protocol,
		&accepted);
	if (status < B_OK)
		return status;

	if (address && *_addressLength > 0) {
		memcpy(address, &accepted->peer, min_c(*_addressLength,
			accepted->peer.ss_len));
		*_addressLength = accepted->peer.ss_len;
	}

	*_acceptedSocket = accepted;
	return B_OK;
}


int
socket_bind(net_socket *socket, const struct sockaddr *address, socklen_t addressLength)
{
	sockaddr empty;
	if (address == NULL) {
		// special - try to bind to an empty address, like INADDR_ANY
		memset(&empty, 0, sizeof(sockaddr));
		empty.sa_len = sizeof(sockaddr);
		empty.sa_family = socket->family;

		address = &empty;
		addressLength = sizeof(sockaddr);
	}

	if (socket->address.ss_len != 0) {
		status_t status = socket->first_info->unbind(socket->first_protocol,
			(sockaddr *)&socket->address);
		if (status < B_OK)
			return status;
	}

	memcpy(&socket->address, address, sizeof(sockaddr));

	status_t status = socket->first_info->bind(socket->first_protocol,
		(sockaddr *)address);
	if (status < B_OK) {
		// clear address again, as binding failed
		socket->address.ss_len = 0;
	}

	return status;
}


int
socket_connect(net_socket *socket, const struct sockaddr *address, socklen_t addressLength)
{
	if (address == NULL || addressLength == 0)
		return ENETUNREACH;

	if (socket->address.ss_len == 0) {
		// try to bind first
		status_t status = socket_bind(socket, NULL, 0);
		if (status < B_OK)
			return status;
	}

	return socket->first_info->connect(socket->first_protocol, address);
}


int
socket_listen(net_socket *socket, int backlog)
{
	status_t status = socket->first_info->listen(socket->first_protocol, backlog);
	if (status == B_OK)
		socket->options |= SO_ACCEPTCONN;

	return status;
}


ssize_t
socket_send(net_socket *socket, const void *data, size_t length, int flags)
{
	if (socket->peer.ss_len == 0)
		return EDESTADDRREQ;

	if (socket->address.ss_len == 0) {
		// try to bind first
		status_t status = socket_bind(socket, NULL, 0);
		if (status < B_OK)
			return status;
	}

	// TODO: useful, maybe even computed header space!
	net_buffer *buffer = gNetBufferModule.create(256);
	if (buffer == NULL)
		return ENOBUFS;

	// copy data into buffer
	if (gNetBufferModule.append(buffer, data, length) < B_OK) {
		gNetBufferModule.free(buffer);
		return ENOBUFS;
	}

	buffer->msg_flags = flags;
	memcpy(buffer->source, &socket->address, socket->address.ss_len);
	memcpy(buffer->destination, &socket->peer, socket->peer.ss_len);

	status_t status = socket->first_info->send_data(socket->first_protocol, buffer);
	if (status < B_OK) {
		gNetBufferModule.free(buffer);
		return status;
	}

	return length;
}


ssize_t
socket_recv(net_socket *socket, void *data, size_t length, int flags)
{
	net_buffer *buffer;
	ssize_t status = socket->first_info->read_data(
		socket->first_protocol, length, flags, &buffer);
	if (status < B_OK)
		return status;

	// if 0 bytes we're received, no buffer will be created
	if (buffer == NULL)
		return 0;

	ssize_t bytesReceived = buffer->size;
	gNetBufferModule.read(buffer, 0, data, bytesReceived);
	gNetBufferModule.free(buffer);

	return bytesReceived;
}


bool
socket_acquire(net_socket* _socket)
{
	return true;
}


bool
socket_release(net_socket* _socket)
{
	return true;
}


status_t
socket_spawn_pending(net_socket *_parent, net_socket **_socket)
{
	net_socket_private *parent = (net_socket_private *)_parent;

	MutexLocker locker(parent->lock);

	// We actually accept more pending connections to compensate for those
	// that never complete, and also make sure at least a single connection
	// can always be accepted
	if (parent->child_count > 3 * parent->max_backlog / 2)
		return ENOBUFS;

	net_socket_private *socket;
	status_t status = socket_create(parent->family, parent->type, parent->protocol,
		(net_socket **)&socket);
	if (status < B_OK)
		return status;

	// inherit parent's properties
	socket->send = parent->send;
	socket->receive = parent->receive;
	socket->options = parent->options & ~SO_ACCEPTCONN;
	socket->linger = parent->linger;
	memcpy(&socket->address, &parent->address, parent->address.ss_len);
	memcpy(&socket->peer, &parent->peer, parent->peer.ss_len);

	// add to the parent's list of pending connections
	parent->pending_children.Add(socket);
	socket->parent = parent;
	parent->child_count++;

	*_socket = socket;
	return B_OK;
}


status_t
socket_dequeue_connected(net_socket *_parent, net_socket **_socket)
{
	net_socket_private *parent = (net_socket_private *)_parent;

	mutex_lock(&parent->lock);

	net_socket_private *socket = parent->connected_children.RemoveHead();
	if (socket != NULL) {
		socket->parent = NULL;
		parent->child_count--;
		*_socket = socket;
	}

	mutex_unlock(&parent->lock);
	return socket != NULL ? B_OK : B_ENTRY_NOT_FOUND;
}


ssize_t
socket_count_connected(net_socket *_parent)
{
	net_socket_private *parent = (net_socket_private *)_parent;

	MutexLocker _(parent->lock);

	return parent->connected_children.Count();
}


status_t
socket_set_max_backlog(net_socket *_socket, uint32 backlog)
{
	net_socket_private *socket = (net_socket_private *)_socket;

	// we enforce an upper limit of connections waiting to be accepted
	if (backlog > 256)
		backlog = 256;

	mutex_lock(&socket->lock);

	// first remove the pending connections, then the already connected ones as needed
	net_socket_private *child;
	while (socket->child_count > backlog
		&& (child = socket->pending_children.RemoveTail()) != NULL) {
		child->parent = NULL;
		socket->child_count--;
	}
	while (socket->child_count > backlog
		&& (child = socket->connected_children.RemoveTail()) != NULL) {
		child->parent = NULL;
		socket_delete(child);
		socket->child_count--;
	}

	socket->max_backlog = backlog;
	mutex_unlock(&socket->lock);
	return B_OK;
}


bool
socket_has_parent(net_socket* _socket)
{
	net_socket_private* socket = (net_socket_private*)_socket;
	return socket->parent != NULL;
}


status_t
socket_connected(net_socket *socket)
{
	net_socket_private *socket_private = (net_socket_private *)socket;
	net_socket_private *parent = (net_socket_private *)socket_private->parent;
	if (parent == NULL)
		return B_BAD_VALUE;

	mutex_lock(&parent->lock);

	parent->pending_children.Remove(socket_private);
	parent->connected_children.Add(socket_private);

	mutex_unlock(&parent->lock);
	return B_OK;
}


status_t
socket_notify(net_socket *_socket, uint8 event, int32 value)
{
	net_socket_private *socket = (net_socket_private *)_socket;

	mutex_lock(&socket->lock);

	bool notify = true;

	switch (event) {
		case B_SELECT_READ:
		{
			if ((ssize_t)socket->receive.low_water_mark > value && value >= B_OK)
				notify = false;
			break;
		}
		case B_SELECT_WRITE:
		{
			if ((ssize_t)socket->send.low_water_mark > value && value >= B_OK)
				notify = false;
			break;
		}
		case B_SELECT_ERROR:
			socket->error = value;
			break;
	}

	if (notify)
		;

	mutex_unlock(&socket->lock);
	return B_OK;
}


net_socket_module_info gNetSocketModule = {
	{
		NET_SOCKET_MODULE_NAME,
		0,
		std_ops
	},
	NULL, // open,
	NULL, // close,
	NULL, // free,

	NULL, // control,

	NULL, // read_avail,
	NULL, // send_avail,

	NULL, // send_data,
	NULL, // receive_data,

	NULL, // get_option,
	NULL, // set_option,
	NULL, // get_next_stat,

	socket_acquire,
	socket_release,

	// connections
	socket_spawn_pending,
	socket_dequeue_connected,
	socket_count_connected,
	socket_set_max_backlog,
	socket_has_parent,
	socket_connected,
	NULL, // set_aborted

	// notifications
	NULL, // request_notification,
	NULL, // cancel_notification,
	socket_notify,

	// standard socket API
	NULL, // accept,
	NULL, // bind,
	NULL, // connect,
	NULL, // getpeername,
	NULL, // getsockname,
	NULL, // getsockopt,
	NULL, // listen,
	NULL, // receive,
	NULL, // send,
	NULL, // setsockopt,
	NULL, // shutdown,
	NULL, // socketpair
};


//	#pragma mark - protocol


net_protocol*
init_protocol(net_socket** _socket)
{
	net_socket *socket;
	status_t status = socket_create(AF_INET, SOCK_STREAM, IPPROTO_TCP, &socket);
	if (status < B_OK)
		return NULL;

	status = socket->first_info->open(socket->first_protocol);
	if (status < B_OK) {
		fprintf(stderr, "tcp_tester: cannot open client: %s\n", strerror(status));
		socket_delete(socket);
		return NULL;
	}

	*_socket = socket;
	return socket->first_protocol;
}


void
close_protocol(net_protocol* protocol)
{
	gTCPModule->close(protocol);
	if (gTCPModule->free(protocol) == B_OK)
		gTCPModule->uninit_protocol(protocol);
		//socket_delete(protocol->socket);
}


//	#pragma mark - datalink


status_t
datalink_send_data(struct net_route *route, net_buffer *buffer)
{
	struct context* context = (struct context*)route->gateway;

	buffer->interface_address = &gInterfaceAddress;
	gInterfaceAddress.AcquireReference();

	context->lock.Lock();
	list_add_item(&context->list, buffer);
	context->lock.Unlock();

	release_sem(context->wait_sem);
	return B_OK;
}


status_t
datalink_send_datagram(net_protocol *protocol, net_domain *domain,
	net_buffer *buffer)
{
	panic("called");
	return B_ERROR;
}


struct net_route *
get_route(struct net_domain *_domain, const struct sockaddr *address)
{
	if (is_server(address)) {
		// to the server
		return &sServerContext.route;
	}

	return &sClientContext.route;
}


static void
put_route(struct net_domain* _domain, net_route* route)
{
}


net_datalink_module_info gNetDatalinkModule = {
	{
		NET_DATALINK_MODULE_NAME,
		0,
		std_ops
	},

	NULL, // control
	datalink_send_data,
	datalink_send_datagram,

	NULL, // is_local_address
	NULL, // is_local_link_address

	NULL, // get_interface
	NULL, // get_interface_with_address
	NULL, // put_interface

	NULL, // get_interface_address
	NULL, // get_next_interface_address,
	NULL, // put_interface_address

	NULL, // join_multicast
	NULL, // leave_multicast

	NULL, // add_route,
	NULL, // remove_route,
	get_route,
	NULL, // get_buffer_route
	put_route,
	NULL, // register_route_info,
	NULL, // unregister_route_info,
	NULL, // update_route_info
};


//	#pragma mark - domain


status_t
domain_open(net_protocol *protocol)
{
	return B_OK;
}


status_t
domain_close(net_protocol *protocol)
{
	return B_OK;
}


status_t
domain_free(net_protocol *protocol)
{
	return B_OK;
}


status_t
domain_connect(net_protocol *protocol, const struct sockaddr *address)
{
	return B_ERROR;
}


status_t
domain_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
{
	return EOPNOTSUPP;
}


status_t
domain_control(net_protocol *protocol, int level, int option, void *value,
	size_t *_length)
{
	return B_ERROR;
}


status_t
domain_bind(net_protocol *protocol, const struct sockaddr *address)
{
	memcpy(&protocol->socket->address, address, sizeof(struct sockaddr_in));
	protocol->socket->address.ss_len = sizeof(struct sockaddr_in);
		// explicitly set length, as our callers can't be trusted to
		// always provide the correct length!
	return B_OK;
}


status_t
domain_unbind(net_protocol *protocol, struct sockaddr *address)
{
	return B_OK;
}


status_t
domain_listen(net_protocol *protocol, int count)
{
	return EOPNOTSUPP;
}


status_t
domain_shutdown(net_protocol *protocol, int direction)
{
	return EOPNOTSUPP;
}


status_t
domain_send_data(net_protocol *protocol, net_buffer *buffer)
{
	// find route
	struct net_route *route = get_route(&sDomain, (sockaddr *)&buffer->destination);
	if (route == NULL)
		return ENETUNREACH;

	return datalink_send_data(route, buffer);
}


status_t
domain_send_routed_data(net_protocol *protocol, struct net_route *route,
	net_buffer *buffer)
{
	return datalink_send_data(route, buffer);
}


ssize_t
domain_send_avail(net_protocol *protocol)
{
	return B_ERROR;
}


status_t
domain_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
	net_buffer **_buffer)
{
	return B_ERROR;
}


ssize_t
domain_read_avail(net_protocol *protocol)
{
	return B_ERROR;
}


struct net_domain *
domain_get_domain(net_protocol *protocol)
{
	return &sDomain;
}


size_t
domain_get_mtu(net_protocol *protocol, const struct sockaddr *address)
{
	return 1480;
		// 1500 ethernet - IPv4 header
}


status_t
domain_receive_data(net_buffer *buffer)
{
	uint32 packetNumber = atomic_add(&sPacketNumber, 1);

	bool drop = false;
	if (sDropList.find(packetNumber) != sDropList.end()
		|| (sRandomDrop > 0.0 && (1.0 * rand() / RAND_MAX) > sRandomDrop))
		drop = true;

	if (!drop && (sRoundTripTime > 0 || sRandomRoundTrip || sIncreasingRoundTrip)) {
		bigtime_t add = 0;
		if (sRandomRoundTrip)
			add = (bigtime_t)(1.0 * rand() / RAND_MAX * 500000) - 250000;
		if (sIncreasingRoundTrip)
			sRoundTripTime += (bigtime_t)(1.0 * rand() / RAND_MAX * 150000);

		snooze(sRoundTripTime / 2 + add);
	}

	if (sPacketMonitor != NULL) {
		sPacketMonitor(buffer, packetNumber, drop);
	} else if (drop)
		printf("<**** DROPPED %ld ****>\n", packetNumber);

	if (drop) {
		gNetBufferModule.free(buffer);
		return B_OK;
	}

	return gTCPModule->receive_data(buffer);
}


status_t
domain_error(net_error error, net_buffer* data)
{
	return B_ERROR;
}


status_t
domain_error_reply(net_protocol* self, net_buffer* cause,
	net_error error, net_error_data* errorData)
{
	return B_ERROR;
}


net_protocol_module_info gDomainModule = {
	{
		NULL,
		0,
		std_ops
	},
	NET_PROTOCOL_ATOMIC_MESSAGES,

	NULL, // init
	NULL, // uninit
	domain_open,
	domain_close,
	domain_free,
	domain_connect,
	domain_accept,
	domain_control,
	NULL, // getsockopt
	NULL, // setsockopt
	domain_bind,
	domain_unbind,
	domain_listen,
	domain_shutdown,
	domain_send_data,
	domain_send_routed_data,
	domain_send_avail,
	domain_read_data,
	domain_read_avail,
	domain_get_domain,
	domain_get_mtu,
	domain_receive_data,
	NULL, // deliver_data
	domain_error,
	domain_error_reply,
	NULL, // attach_ancillary_data
	NULL, // process_ancillary_data
};


//	#pragma mark - packet capture


static void
dump_printf(net_buffer* buffer, int32 packetNumber, bool willBeDropped)
{
	static bigtime_t lastTime = 0;

	NetBufferHeaderReader<tcp_header> bufferHeader(buffer);
	if (bufferHeader.Status() < B_OK)
		return;

	tcp_header &header = bufferHeader.Data();

	bigtime_t now = system_time();
	if (lastTime == 0)
		lastTime = now;

	printf("\33[0m% 3ld %8.6f (%8.6f) ", packetNumber, (now - sStartTime) / 1000000.0,
		(now - lastTime) / 1000000.0);
	lastTime = now;

	if (is_server((sockaddr *)buffer->source))
		printf("\33[31mserver > client: ");
	else
		printf("client > server: ");

	int32 length = buffer->size - header.HeaderLength();

	if ((header.flags & TCP_FLAG_PUSH) != 0)
		putchar('P');
	if ((header.flags & TCP_FLAG_SYNCHRONIZE) != 0)
		putchar('S');
	if ((header.flags & TCP_FLAG_FINISH) != 0)
		putchar('F');
	if ((header.flags & TCP_FLAG_RESET) != 0)
		putchar('R');
	if ((header.flags
		& (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH | TCP_FLAG_PUSH | TCP_FLAG_RESET)) == 0)
		putchar('.');

	printf(" %lu:%lu (%lu)", header.Sequence(), header.Sequence() + length, length);
	if ((header.flags & TCP_FLAG_ACKNOWLEDGE) != 0)
		printf(" ack %lu", header.Acknowledge());

	printf(" win %u", header.AdvertisedWindow());

	if (header.HeaderLength() > sizeof(tcp_header)) {
		int32 size = header.HeaderLength() - sizeof(tcp_header);

		tcp_option *option;
		uint8 optionsBuffer[1024];
		if (gBufferModule->direct_access(buffer, sizeof(tcp_header),
				size, (void **)&option) != B_OK) {
			if (size > 1024) {
				printf("options too large to take into account (%ld bytes)\n", size);
				size = 1024;
			}

			gBufferModule->read(buffer, sizeof(tcp_header), optionsBuffer, size);
			option = (tcp_option *)optionsBuffer;
		}

		while (size > 0) {
			uint32 length = 1;
			switch (option->kind) {
				case TCP_OPTION_END:
				case TCP_OPTION_NOP:
					break;
				case TCP_OPTION_MAX_SEGMENT_SIZE:
					printf(" <mss %u>", ntohs(option->max_segment_size));
					length = 4;
					break;
				case TCP_OPTION_WINDOW_SHIFT:
					printf(" <ws %u>", option->window_shift);
					length = 3;
					break;
				case TCP_OPTION_TIMESTAMP:
					printf(" <ts %lu:%lu>", option->timestamp.value, option->timestamp.reply);
					length = 10;
					break;

				default:
					length = option->length;
					// make sure we don't end up in an endless loop
					if (length == 0)
						size = 0;
					break;
			}

			size -= length;
			option = (tcp_option *)((uint8 *)option + length);
		}
	}

	if (willBeDropped)
		printf(" <DROPPED>");
	printf("\33[0m\n");
}


static void
dump_pcap(net_buffer* buffer, int32 packetNumber, bool willBeDropped)
{
	if (willBeDropped) {
		printf(" <DROPPED>\n");
		return;
	}

	const bigtime_t time = real_time_clock_usecs();

	struct pcap_packet_header pcap_header;
	pcap_header.ts_sec = time / 1000000;
	pcap_header.ts_usec = time % 1000000;
	pcap_header.included_len = sizeof(struct ip) + buffer->size;
	pcap_header.original_len = pcap_header.included_len;

	struct ip ip_header;
	ip_header.ip_v = IPVERSION;
	ip_header.ip_hl = sizeof(struct ip) >> 2;
	ip_header.ip_tos = 0;
	ip_header.ip_len = htons(sizeof(struct ip) + buffer->size);
	ip_header.ip_id = htons(packetNumber);
	ip_header.ip_off = 0;
	ip_header.ip_ttl = 254;
	ip_header.ip_p = IPPROTO_TCP;
	ip_header.ip_sum = 0;
	ip_header.ip_src.s_addr = ((sockaddr_in*)buffer->source)->sin_addr.s_addr;
	ip_header.ip_dst.s_addr = ((sockaddr_in*)buffer->destination)->sin_addr.s_addr;

	size_t count = 16, used = 0;
	iovec vecs[count];

	vecs[used].iov_base = &pcap_header;
	vecs[used].iov_len = sizeof(pcap_header);
	used++;

	vecs[used].iov_base = &ip_header;
	vecs[used].iov_len = sizeof(ip_header);
	used++;

	used += gNetBufferModule.get_iovecs(buffer, vecs + used, count - used);

	static mutex writesLock = MUTEX_INITIALIZER("pcap writes");
	MutexLocker _(writesLock);
	ssize_t written = writev(sPcapFD, vecs, used);
	if (written != (pcap_header.included_len + sizeof(pcap_packet_header))) {
		fprintf(stderr, "writing to pcap file failed\n");
		exit(1);
	}
}


static bool
setup_dump_pcap(const char* file)
{
	sPcapFD = open(file, O_CREAT | O_WRONLY | O_TRUNC);
	if (sPcapFD < 0) {
		fprintf(stderr, "tcp_shell: Failed to open output pcap file: %d\n",
			errno);
		return false;
	}

	struct pcap_header header;
	header.magic = PCAP_MAGIC;
	header.version_major = 2;
	header.version_minor = 4;
	header.timezone = 0;
	header.timestamp_accuracy = 0;
	header.max_packet_length = 65535;
	header.linktype = PCAP_LINKTYPE_IPV4;
	if (write(sPcapFD, &header, sizeof(header)) != sizeof(header)) {
		fprintf(stderr, "tcp_shell: Failed to write pcap file header: %d\n",
			errno);
		return false;
	}

	sPacketMonitor = dump_pcap;
	return true;
}


//	#pragma mark - test


int32
receiving_thread(void* _data)
{
	struct context* context = (struct context*)_data;
	struct net_buffer* reorderBuffer = NULL;

	while (true) {
		status_t status;
		do {
			status = acquire_sem(context->wait_sem);
		} while (status == B_INTERRUPTED);

		if (status < B_OK)
			break;

		while (true) {
			context->lock.Lock();
			net_buffer* buffer = (net_buffer*)list_remove_head_item(
				&context->list);
			context->lock.Unlock();

			if (buffer == NULL)
				break;

			if (sSimultaneousConnect && context->server && is_syn(buffer)) {
				// delay getting the SYN request, and connect as well
				sockaddr_in address;
				memset(&address, 0, sizeof(address));
				address.sin_family = AF_INET;
				address.sin_port = htons(1023);
				address.sin_addr.s_addr = htonl(0xc0a80001);

				status_t status = socket_connect(gServerSocket,
					(struct sockaddr *)&address, sizeof(struct sockaddr));
				if (status < B_OK)
					fprintf(stderr, "tcp_tester: simultaneous connect failed: %s\n", strerror(status));

				sSimultaneousConnect = false;
			}
			if (sSimultaneousClose && !context->server && is_fin(buffer)) {
				close_protocol(gClientSocket->first_protocol);
				sSimultaneousClose = false;
			}
			if ((sRandomReorder > 0.0
					|| sReorderList.find(sPacketNumber) != sReorderList.end())
				&& reorderBuffer == NULL
				&& (1.0 * rand() / RAND_MAX) > sRandomReorder) {
				reorderBuffer = buffer;
			} else {
				if (sDomain.module->receive_data(buffer) < B_OK)
					gNetBufferModule.free(buffer);

				if (reorderBuffer != NULL) {
					if (sDomain.module->receive_data(reorderBuffer) < B_OK)
						gNetBufferModule.free(reorderBuffer);
					reorderBuffer = NULL;
				}
			}
		}
	}

	return 0;
}


int32
server_thread(void*)
{
	while (true) {
		// main accept() loop
		net_socket* connectionSocket;
		sockaddr_in address;
		socklen_t size = sizeof(struct sockaddr_in);
		status_t status = socket_accept(gServerSocket,
			(struct sockaddr *)&address, &size, &connectionSocket);
		if (status < B_OK) {
			fprintf(stderr, "SERVER: accepting failed: %s\n", strerror(status));
			break;
		}

		printf("server: got connection from %08x\n", address.sin_addr.s_addr);

		char buffer[1024];
		ssize_t bytesRead;
		while ((bytesRead = socket_recv(connectionSocket, buffer,
				sizeof(buffer), 0)) > 0) {
			printf("server: received %ld bytes\n", bytesRead);

			if (sServerActiveClose) {
				printf("server: active close\n");
				close_protocol(connectionSocket->first_protocol);
				return 0;
			}
		}
		if (bytesRead < 0)
			printf("server: receiving failed: %s\n", strerror(bytesRead));
		else
			printf("server: peer closed connection.\n");

		snooze(1000000);
		close_protocol(connectionSocket->first_protocol);
	}

	return 0;
}


void
setup_server()
{
	sockaddr_in address;
	memset(&address, 0, sizeof(address));
	address.sin_len = sizeof(sockaddr_in);
	address.sin_family = AF_INET;
	address.sin_port = htons(1024);
	address.sin_addr.s_addr = INADDR_ANY;

	status_t status = socket_bind(gServerSocket, (struct sockaddr*)&address,
		sizeof(struct sockaddr));
	if (status < B_OK) {
		fprintf(stderr, "tcp_tester: cannot bind server: %s\n", strerror(status));
		exit(1);
	}
	status = socket_listen(gServerSocket, 40);
	if (status < B_OK) {
		fprintf(stderr, "tcp_tester: server cannot listen: %s\n",
			strerror(status));
		exit(1);
	}

	thread_id serverThread = spawn_thread(server_thread, "server",
		B_NORMAL_PRIORITY, NULL);
	if (serverThread < B_OK) {
		fprintf(stderr, "tcp_tester: cannot start server: %s\n",
			strerror(serverThread));
		exit(1);
	}

	resume_thread(serverThread);
}


void
setup_context(struct context& context, bool server)
{
	list_init(&context.list);
	context.route.interface_address = &gInterfaceAddress;
	context.route.gateway = (sockaddr *)&context;
		// backpointer to the context
	context.route.mtu = 1500;
	context.server = server;
	context.wait_sem = create_sem(0, "receive wait");

	context.thread = spawn_thread(receiving_thread,
		server ? "server receiver" : "client receiver", B_NORMAL_PRIORITY,
		&context);
	resume_thread(context.thread);
}


void
cleanup_context(struct context& context)
{
	delete_sem(context.wait_sem);

	status_t status;
	wait_for_thread(context.thread, &status);
}


//  #pragma mark -


static void do_help(int argc, char** argv);


static void
do_connect(int argc, char** argv)
{
	sSimultaneousConnect = false;

	int port = 1024;
	if (argc > 1) {
		if (!strcmp(argv[1], "-s"))
			sSimultaneousConnect = true;
		else if (isdigit(argv[1][0]))
			port = atoi(argv[1]);
		else {
			fprintf(stderr, "usage: connect [-s|<port-number>]\n");
			return;
		}
	}

	if (sSimultaneousConnect) {
		// bind to a port first, so the other end can find us
		sockaddr_in address;
		memset(&address, 0, sizeof(address));
		address.sin_family = AF_INET;
		address.sin_port = htons(1023);
		address.sin_addr.s_addr = htonl(0xc0a80001);

		status_t status = socket_bind(gClientSocket, (struct sockaddr *)&address,
			sizeof(struct sockaddr));
		if (status < B_OK) {
			fprintf(stderr, "Could not bind to port 1023: %s\n", strerror(status));
			sSimultaneousConnect = false;
			return;
		}
	}

	sStartTime = system_time();

	sockaddr_in address;
	memset(&address, 0, sizeof(address));
	address.sin_family = AF_INET;
	address.sin_port = htons(port);
	address.sin_addr.s_addr = htonl(0xc0a80001);

	status_t status = socket_connect(gClientSocket, (struct sockaddr *)&address,
		sizeof(struct sockaddr));
	if (status < B_OK)
		fprintf(stderr, "tcp_tester: could not connect: %s\n", strerror(status));
}


static ssize_t
parse_size(const char* arg)
{
	char *unit;
	ssize_t size = strtoul(arg, &unit, 0);
	if (unit != NULL && unit[0]) {
		if (unit[0] == 'k' || unit[0] == 'K')
			size *= 1024;
		else if (unit[0] == 'm' || unit[0] == 'M')
			size *= 1024 * 1024;
		else {
			fprintf(stderr, "unknown unit specified!\n");
			return -1;
		}
	}
	return size;
}


static void
do_send(int argc, char** argv)
{
	ssize_t size = 1024;
	if (argc > 1 && isdigit(argv[1][0])) {
		size = parse_size(argv[1]);
		if (size < 0)
			return;
	} else if (argc > 1) {
		fprintf(stderr, "invalid args!\n");
		return;
	}

	if (size > 4 * 1024 * 1024) {
		printf("amount to send will be limited to 4 MB\n");
		size = 4 * 1024 * 1024;
	}

	char *buffer = (char *)malloc(size);
	if (buffer == NULL) {
		fprintf(stderr, "not enough memory!\n");
		return;
	}
	MemoryDeleter bufferDeleter(buffer);

	// initialize buffer with some not so random data
	for (uint32 i = 0; i < size; i++) {
		buffer[i] = (char)(i & 0xff);
	}

	ssize_t bytesWritten = socket_send(gClientSocket, buffer, size, 0);
	if (bytesWritten < B_OK) {
		fprintf(stderr, "failed sending buffer: %s\n", strerror(bytesWritten));
		return;
	}
}


static void
do_send_loop(int argc, char** argv)
{
	if (argc != 2 || !isdigit(argv[1][0])) {
		fprintf(stderr, "invalid args!\n");
		return;
	}
	ssize_t size = parse_size(argv[1]);
	if (size < 0)
		return;

	const size_t bufferSize = 4096;
	char *buffer = (char *)malloc(bufferSize);
	if (buffer == NULL) {
		fprintf(stderr, "not enough memory!\n");
		return;
	}
	MemoryDeleter bufferDeleter(buffer);

	// initialize buffer with some not so random data
	for (uint32 i = 0; i < bufferSize; i++) {
		buffer[i] = (char)(i & 0xff);
	}

	for (ssize_t total = 0; total < size; ) {
		ssize_t bytesWritten = socket_send(gClientSocket, buffer, bufferSize, 0);
		if (bytesWritten < B_OK) {
			fprintf(stderr, "failed sending buffer (after %" B_PRIdSSIZE "): %s\n",
				total, strerror(bytesWritten));
			return;
		}

		total += bufferSize;
	}
}


static void
do_close(int argc, char** argv)
{
	sSimultaneousClose = false;
	sServerActiveClose = true;

	if (argc > 1) {
		if (!strcmp(argv[1], "-s"))
			sSimultaneousClose = true;
		else {
			fprintf(stderr, "usage: close [-s]\n");
			return;
		}
	}

	gClientSocket->send.timeout = 0;

	char buffer[32767] = {'q'};
	ssize_t bytesWritten = socket_send(gClientSocket, buffer, sizeof(buffer), 0);
	if (bytesWritten < B_OK) {
		fprintf(stderr, "failed sending buffer: %s\n", strerror(bytesWritten));
		return;
	}
}


static void
do_drop(int argc, char** argv)
{
	if (argc == 1) {
		// show list of dropped packets
		if (sRandomDrop > 0.0)
			printf("Drop probability is %f\n", sRandomDrop);

		printf("Drop pakets:\n");

		std::set<uint32>::iterator iterator = sDropList.begin();
		uint32 count = 0;
		for (; iterator != sDropList.end(); iterator++) {
			printf("%4lu\n", *iterator);
			count++;
		}

		if (count == 0)
			printf("<empty>\n");
	} else if (!strcmp(argv[1], "-f")) {
		// flush drop list
		sDropList.clear();
		puts("drop list cleared.");
	} else if (!strcmp(argv[1], "-r")) {
		if (argc < 3) {
			fprintf(stderr, "No drop probability specified.\n");
			return;
		}

		sRandomDrop = atof(argv[2]);
		if (sRandomDrop < 0.0)
			sRandomDrop = 0;
		else if (sRandomDrop > 1.0)
			sRandomDrop = 1.0;
	} else if (isdigit(argv[1][0])) {
		// add to drop list
		for (int i = 1; i < argc; i++) {
			uint32 packet = strtoul(argv[i], NULL, 0);
			if (packet == 0) {
				fprintf(stderr, "invalid packet number: %s\n", argv[i]);
				break;
			}

			sDropList.insert(packet);
		}
	} else {
		// print usage
		puts("usage: drop <packet-number> [...]\n"
			"   or: drop -r <probability>\n\n"
			"   or: drop [-f]\n\n"
			"Specifiying -f flushes the drop list, -r sets the probability a packet\n"
			"is dropped; if you called drop without any arguments, the current\n"
			"drop list is dumped.");
	}
}


static void
do_reorder(int argc, char** argv)
{
	if (argc == 1) {
		// show list of dropped packets
		if (sRandomReorder > 0.0)
			printf("Reorder probability is %f\n", sRandomReorder);

		printf("Reorder packets:\n");

		std::set<uint32>::iterator iterator = sReorderList.begin();
		uint32 count = 0;
		for (; iterator != sReorderList.end(); iterator++) {
			printf("%4lu\n", *iterator);
			count++;
		}

		if (count == 0)
			printf("<empty>\n");
	} else if (!strcmp(argv[1], "-f")) {
		// flush reorder list
		sReorderList.clear();
		puts("reorder list cleared.");
	} else if (!strcmp(argv[1], "-r")) {
		if (argc < 3) {
			fprintf(stderr, "No reorder probability specified.\n");
			return;
		}

		sRandomReorder = atof(argv[2]);
		if (sRandomReorder < 0.0)
			sRandomReorder = 0;
		else if (sRandomReorder > 1.0)
			sRandomReorder = 1.0;
	} else if (isdigit(argv[1][0])) {
		// add to reorder list
		for (int i = 1; i < argc; i++) {
			uint32 packet = strtoul(argv[i], NULL, 0);
			if (packet == 0) {
				fprintf(stderr, "invalid packet number: %s\n", argv[i]);
				break;
			}

			sReorderList.insert(packet);
		}
	} else {
		// print usage
		puts("usage: reorder <packet-number> [...]\n"
			"   or: reorder -r <probability>\n\n"
			"   or: reorder [-f]\n\n"
			"Specifiying -f flushes the reorder list, -r sets the probability a packet\n"
			"is reordered; if you called reorder without any arguments, the current\n"
			"reorder list is dumped.");
	}
}


static void
do_round_trip_time(int argc, char** argv)
{
	if (argc == 1) {
		// show current time
		printf("Current round trip time: %g ms\n", sRoundTripTime / 1000.0);
	} else if (!strcmp(argv[1], "-r")) {
		// toggle random time
		sRandomRoundTrip = !sRandomRoundTrip;
		printf("Round trip time is now %s.\n", sRandomRoundTrip ? "random" : "fixed");
	} else if (!strcmp(argv[1], "-i")) {
		// toggle increasing time
		sIncreasingRoundTrip = !sIncreasingRoundTrip;
		printf("Round trip time is now %s.\n", sIncreasingRoundTrip ? "increasing" : "fixed");
	} else if (isdigit(argv[1][0])) {
		// set time
		sRoundTripTime = 1000LL * strtoul(argv[1], NULL, 0);
	} else {
		// print usage
		puts("usage: rtt <time in ms>\n"
			"   or: rtt [-r|-i]\n\n"
			"Specifiying -r sets random time, -i causes the times to increase over time;\n"
			"witout any arguments, the current time is printed.");
	}
}


static void
do_dprintf(int argc, char** argv)
{
	if (argc > 1)
		gDebugOutputEnabled = !strcmp(argv[1], "on");
	else
		gDebugOutputEnabled = !gDebugOutputEnabled;

	printf("debug output turned %s.\n", gDebugOutputEnabled ? "on" : "off");
}


static cmd_entry sBuiltinCommands[] = {
	{"connect", do_connect, "Connects the client"},
	{"send", do_send, "Sends data from the client to the server"},
	{"send_loop", do_send_loop, "Sends data in a loop"},
	{"close", do_close, "Performs an active or simultaneous close"},
	{"dprintf", do_dprintf, "Toggles debug output"},
	{"drop", do_drop, "Lets you drop packets during transfer"},
	{"reorder", do_reorder, "Lets you reorder packets during transfer"},
	{"help", do_help, "prints this help text"},
	{"rtt", do_round_trip_time, "Specifies the round trip time"},
	{"quit", NULL, "exits the application"},
	{NULL, NULL, NULL},
};


static void
do_help(int argc, char** argv)
{
	printf("Available commands:\n");

	for (cmd_entry* command = sBuiltinCommands; command->name != NULL; command++) {
		printf("%8s - %s\n", command->name, command->help);
	}
}


//	#pragma mark -


int
main(int argc, char* argv[])
{
	for (int i = 1; i < argc; i++) {
		if (strcmp(argv[i], "-w") == 0 && (i + 1) < argc) {
			if (!setup_dump_pcap(argv[++i]))
				return 1;
		}
	}

	if (sPacketMonitor == NULL)
		sPacketMonitor = dump_printf;

	status_t status = init_timers();
	if (status < B_OK) {
		fprintf(stderr, "tcp_tester: Could not initialize timers: %s\n",
			strerror(status));
		return 1;
	}

	_add_builtin_module((module_info*)&gNetStackModule);
	_add_builtin_module((module_info*)&gNetBufferModule);
	_add_builtin_module((module_info*)&gNetSocketModule);
	_add_builtin_module((module_info*)&gNetDatalinkModule);
	_add_builtin_module(modules[0]);
	if (_get_builtin_dependencies() < B_OK) {
		fprintf(stderr, "tcp_tester: Could not initialize modules: %s\n",
			strerror(status));
		return 1;
	}

	sockaddr_in interfaceAddress;
	interfaceAddress.sin_len = sizeof(sockaddr_in);
	interfaceAddress.sin_family = AF_INET;
	interfaceAddress.sin_addr.s_addr = htonl(0xc0a80001);
	gInterfaceAddress.local = (sockaddr*)&interfaceAddress;
	gInterfaceAddress.domain = &sDomain;

	status = get_module("network/protocols/tcp/v1", (module_info **)&gTCPModule);
	if (status < B_OK) {
		fprintf(stderr, "tcp_tester: Could not open TCP module: %s\n",
			strerror(status));
		return 1;
	}

	net_protocol* client = init_protocol(&gClientSocket);
	if (client == NULL)
		return 1;
	net_protocol* server = init_protocol(&gServerSocket);
	if (server == NULL)
		return 1;

	setup_context(sClientContext, false);
	setup_context(sServerContext, true);

	printf("*** Server: %p (%ld), Client: %p (%ld)\n", server,
		sServerContext.thread, client, sClientContext.thread);

	setup_server();

	while (true) {
		printf("> ");
		fflush(stdout);

		char line[1024];
		if (fgets(line, sizeof(line), stdin) == NULL)
			break;

        argc = 0;
        argv = build_argv(line, &argc);
        if (argv == NULL || argc == 0)
            continue;

        int length = strlen(argv[0]);

#if 0
		char *newLine = strchr(line, '\n');
		if (newLine != NULL)
			newLine[0] = '\0';
#endif

		if (!strcmp(argv[0], "quit")
			|| !strcmp(argv[0], "exit")
			|| !strcmp(argv[0], "q"))
			break;

		bool found = false;

		for (cmd_entry* command = sBuiltinCommands; command->name != NULL; command++) {
			if (!strncmp(command->name, argv[0], length)) {
				command->func(argc, argv);
				found = true;
				break;
			}
		}

		if (!found)
			fprintf(stderr, "Unknown command \"%s\". Type \"help\" for a list of commands.\n", argv[0]);

		free(argv);
	}

	close_protocol(client);
	close_protocol(server);

	snooze(2000000);

	cleanup_context(sClientContext);
	cleanup_context(sServerContext);

	put_module("network/protocols/tcp/v1");
	uninit_timers();
	return 0;
}