* Copyright 2006-2023, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Axel Dörfler, axeld@pinc-software.de
*/
#include "argv.h"
#include "tcp.h"
#include "pcap.h"
#include "utility.h"
#include <AutoDeleter.h>
#include <NetBufferUtilities.h>
#include <Referenceable.h>
#include <net_buffer.h>
#include <net_datalink.h>
#include <net_protocol.h>
#include <net_socket.h>
#include <net_stack.h>
#include <slab/Slab.h>
#include <util/AutoLock.h>
#include <util/DoublyLinkedList.h>
#include <KernelExport.h>
#include <Select.h>
#include <module.h>
#include <Locker.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <ctype.h>
#include <errno.h>
#include <new>
#include <set>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
struct context {
BLocker lock;
sem_id wait_sem;
struct list list;
net_route route;
bool server;
thread_id thread;
};
struct cmd_entry {
const char* name;
void (*func)(int argc, char **argv);
const char* help;
};
struct net_socket_private;
typedef DoublyLinkedList<net_socket_private> SocketList;
struct net_socket_private : net_socket,
DoublyLinkedListLinkImpl<net_socket_private> {
struct net_socket *parent;
team_id owner;
uint32 max_backlog;
uint32 child_count;
SocketList pending_children;
SocketList connected_children;
struct select_sync_pool *select_pool;
mutex lock;
};
struct InterfaceAddress : DoublyLinkedListLinkImpl<InterfaceAddress>,
net_interface_address, BReferenceable {
};
extern "C" status_t _add_builtin_module(module_info *info);
extern "C" status_t _get_builtin_dependencies(void);
extern bool gDebugOutputEnabled;
extern struct net_buffer_module_info gNetBufferModule;
extern net_address_module_info gIPv4AddressModule;
extern module_info *modules[];
extern struct net_protocol_module_info gDomainModule;
struct InterfaceAddress gInterfaceAddress;
extern struct net_socket_module_info gNetSocketModule;
struct net_protocol_module_info *gTCPModule;
struct net_socket *gServerSocket, *gClientSocket;
static struct context sClientContext, sServerContext;
static int32 sPacketNumber = 1;
static double sRandomDrop = 0.0;
static std::set<uint32> sDropList;
static bigtime_t sRoundTripTime = 0;
static bool sIncreasingRoundTrip = false;
static bool sRandomRoundTrip = false;
static void (*sPacketMonitor)(net_buffer *, int32, bool) = NULL;
static int sPcapFD = -1;
static bigtime_t sStartTime;
static double sRandomReorder = 0.0;
static std::set<uint32> sReorderList;
static bool sSimultaneousConnect = false;
static bool sSimultaneousClose = false;
static bool sServerActiveClose = false;
static struct net_domain sDomain = {
"ipv4",
AF_INET,
&gDomainModule,
&gIPv4AddressModule
};
static bool
is_server(const sockaddr* addr)
{
return ((sockaddr_in*)addr)->sin_port == htons(1024);
}
static uint8
tcp_segment_flags(net_buffer* buffer)
{
NetBufferHeaderReader<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() < B_OK)
return bufferHeader.Status();
tcp_header &header = bufferHeader.Data();
return header.flags;
}
static bool
is_syn(net_buffer* buffer)
{
return (tcp_segment_flags(buffer) & TCP_FLAG_SYNCHRONIZE) != 0;
}
static bool
is_fin(net_buffer* buffer)
{
return (tcp_segment_flags(buffer) & TCP_FLAG_FINISH) != 0;
}
status_t
std_ops(int32, ...)
{
return B_OK;
}
net_domain *
get_domain(int family)
{
return &sDomain;
}
status_t
register_domain_protocols(int family, int type, int protocol, ...)
{
return B_OK;
}
status_t
register_domain_datalink_protocols(int family, int type, ...)
{
return B_OK;
}
static status_t
register_domain_receiving_protocol(int family, int type, const char *moduleName)
{
return B_OK;
}
static bool
dummy_is_syscall(void)
{
return false;
}
static bool
dummy_is_restarted_syscall(void)
{
return false;
}
static void
dummy_store_syscall_restart_timeout(bigtime_t timeout)
{
}
static net_stack_module_info gNetStackModule = {
{
NET_STACK_MODULE_NAME,
0,
std_ops
},
NULL,
NULL,
get_domain,
register_domain_protocols,
register_domain_datalink_protocols,
register_domain_receiving_protocol,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
notify_socket,
checksum,
init_fifo,
uninit_fifo,
fifo_enqueue_buffer,
fifo_dequeue_buffer,
clear_fifo,
fifo_socket_enqueue_buffer,
init_timer,
set_timer,
cancel_timer,
wait_for_timer,
is_timer_active,
is_timer_running,
dummy_is_syscall,
dummy_is_restarted_syscall,
dummy_store_syscall_restart_timeout,
NULL,
};
status_t
socket_create(int family, int type, int protocol, net_socket **_socket)
{
net_protocol* domainProtocol;
struct net_socket_private *socket = new (std::nothrow) net_socket_private;
if (socket == NULL)
return B_NO_MEMORY;
memset(socket, 0, sizeof(net_socket));
socket->family = family;
socket->type = type;
socket->protocol = protocol;
mutex_init(&socket->lock, "socket");
socket->send.buffer_size = 65535;
socket->send.low_water_mark = 1;
socket->send.timeout = B_INFINITE_TIMEOUT;
socket->receive.buffer_size = 65535;
socket->receive.low_water_mark = 1;
socket->receive.timeout = B_INFINITE_TIMEOUT;
socket->first_protocol = gTCPModule->init_protocol(socket);
if (socket->first_protocol == NULL) {
fprintf(stderr, "tcp_tester: cannot create protocol\n");
mutex_destroy(&socket->lock);
delete socket;
return B_ERROR;
}
socket->first_info = gTCPModule;
domainProtocol = new net_protocol;
domainProtocol->module = &gDomainModule;
domainProtocol->socket = socket;
socket->first_protocol->next = domainProtocol;
socket->first_protocol->module = gTCPModule;
socket->first_protocol->socket = socket;
*_socket = socket;
return B_OK;
}
void
socket_delete(net_socket *_socket)
{
net_socket_private *socket = (net_socket_private *)_socket;
if (socket->parent != NULL)
panic("socket still has a parent!");
socket->first_info->uninit_protocol(socket->first_protocol);
mutex_destroy(&socket->lock);
delete socket;
}
int
socket_accept(net_socket *socket, struct sockaddr *address,
socklen_t *_addressLength, net_socket **_acceptedSocket)
{
net_socket *accepted;
status_t status = socket->first_info->accept(socket->first_protocol,
&accepted);
if (status < B_OK)
return status;
if (address && *_addressLength > 0) {
memcpy(address, &accepted->peer, min_c(*_addressLength,
accepted->peer.ss_len));
*_addressLength = accepted->peer.ss_len;
}
*_acceptedSocket = accepted;
return B_OK;
}
int
socket_bind(net_socket *socket, const struct sockaddr *address, socklen_t addressLength)
{
sockaddr empty;
if (address == NULL) {
memset(&empty, 0, sizeof(sockaddr));
empty.sa_len = sizeof(sockaddr);
empty.sa_family = socket->family;
address = ∅
addressLength = sizeof(sockaddr);
}
if (socket->address.ss_len != 0) {
status_t status = socket->first_info->unbind(socket->first_protocol,
(sockaddr *)&socket->address);
if (status < B_OK)
return status;
}
memcpy(&socket->address, address, sizeof(sockaddr));
status_t status = socket->first_info->bind(socket->first_protocol,
(sockaddr *)address);
if (status < B_OK) {
socket->address.ss_len = 0;
}
return status;
}
int
socket_connect(net_socket *socket, const struct sockaddr *address, socklen_t addressLength)
{
if (address == NULL || addressLength == 0)
return ENETUNREACH;
if (socket->address.ss_len == 0) {
status_t status = socket_bind(socket, NULL, 0);
if (status < B_OK)
return status;
}
return socket->first_info->connect(socket->first_protocol, address);
}
int
socket_listen(net_socket *socket, int backlog)
{
status_t status = socket->first_info->listen(socket->first_protocol, backlog);
if (status == B_OK)
socket->options |= SO_ACCEPTCONN;
return status;
}
ssize_t
socket_send(net_socket *socket, const void *data, size_t length, int flags)
{
if (socket->peer.ss_len == 0)
return EDESTADDRREQ;
if (socket->address.ss_len == 0) {
status_t status = socket_bind(socket, NULL, 0);
if (status < B_OK)
return status;
}
net_buffer *buffer = gNetBufferModule.create(256);
if (buffer == NULL)
return ENOBUFS;
if (gNetBufferModule.append(buffer, data, length) < B_OK) {
gNetBufferModule.free(buffer);
return ENOBUFS;
}
buffer->msg_flags = flags;
memcpy(buffer->source, &socket->address, socket->address.ss_len);
memcpy(buffer->destination, &socket->peer, socket->peer.ss_len);
status_t status = socket->first_info->send_data(socket->first_protocol, buffer);
if (status < B_OK) {
gNetBufferModule.free(buffer);
return status;
}
return length;
}
ssize_t
socket_recv(net_socket *socket, void *data, size_t length, int flags)
{
net_buffer *buffer;
ssize_t status = socket->first_info->read_data(
socket->first_protocol, length, flags, &buffer);
if (status < B_OK)
return status;
if (buffer == NULL)
return 0;
ssize_t bytesReceived = buffer->size;
gNetBufferModule.read(buffer, 0, data, bytesReceived);
gNetBufferModule.free(buffer);
return bytesReceived;
}
bool
socket_acquire(net_socket* _socket)
{
return true;
}
bool
socket_release(net_socket* _socket)
{
return true;
}
status_t
socket_spawn_pending(net_socket *_parent, net_socket **_socket)
{
net_socket_private *parent = (net_socket_private *)_parent;
MutexLocker locker(parent->lock);
if (parent->child_count > 3 * parent->max_backlog / 2)
return ENOBUFS;
net_socket_private *socket;
status_t status = socket_create(parent->family, parent->type, parent->protocol,
(net_socket **)&socket);
if (status < B_OK)
return status;
socket->send = parent->send;
socket->receive = parent->receive;
socket->options = parent->options & ~SO_ACCEPTCONN;
socket->linger = parent->linger;
memcpy(&socket->address, &parent->address, parent->address.ss_len);
memcpy(&socket->peer, &parent->peer, parent->peer.ss_len);
parent->pending_children.Add(socket);
socket->parent = parent;
parent->child_count++;
*_socket = socket;
return B_OK;
}
status_t
socket_dequeue_connected(net_socket *_parent, net_socket **_socket)
{
net_socket_private *parent = (net_socket_private *)_parent;
mutex_lock(&parent->lock);
net_socket_private *socket = parent->connected_children.RemoveHead();
if (socket != NULL) {
socket->parent = NULL;
parent->child_count--;
*_socket = socket;
}
mutex_unlock(&parent->lock);
return socket != NULL ? B_OK : B_ENTRY_NOT_FOUND;
}
ssize_t
socket_count_connected(net_socket *_parent)
{
net_socket_private *parent = (net_socket_private *)_parent;
MutexLocker _(parent->lock);
return parent->connected_children.Count();
}
status_t
socket_set_max_backlog(net_socket *_socket, uint32 backlog)
{
net_socket_private *socket = (net_socket_private *)_socket;
if (backlog > 256)
backlog = 256;
mutex_lock(&socket->lock);
net_socket_private *child;
while (socket->child_count > backlog
&& (child = socket->pending_children.RemoveTail()) != NULL) {
child->parent = NULL;
socket->child_count--;
}
while (socket->child_count > backlog
&& (child = socket->connected_children.RemoveTail()) != NULL) {
child->parent = NULL;
socket_delete(child);
socket->child_count--;
}
socket->max_backlog = backlog;
mutex_unlock(&socket->lock);
return B_OK;
}
bool
socket_has_parent(net_socket* _socket)
{
net_socket_private* socket = (net_socket_private*)_socket;
return socket->parent != NULL;
}
status_t
socket_connected(net_socket *socket)
{
net_socket_private *socket_private = (net_socket_private *)socket;
net_socket_private *parent = (net_socket_private *)socket_private->parent;
if (parent == NULL)
return B_BAD_VALUE;
mutex_lock(&parent->lock);
parent->pending_children.Remove(socket_private);
parent->connected_children.Add(socket_private);
mutex_unlock(&parent->lock);
return B_OK;
}
status_t
socket_notify(net_socket *_socket, uint8 event, int32 value)
{
net_socket_private *socket = (net_socket_private *)_socket;
mutex_lock(&socket->lock);
bool notify = true;
switch (event) {
case B_SELECT_READ:
{
if ((ssize_t)socket->receive.low_water_mark > value && value >= B_OK)
notify = false;
break;
}
case B_SELECT_WRITE:
{
if ((ssize_t)socket->send.low_water_mark > value && value >= B_OK)
notify = false;
break;
}
case B_SELECT_ERROR:
socket->error = value;
break;
}
if (notify)
;
mutex_unlock(&socket->lock);
return B_OK;
}
net_socket_module_info gNetSocketModule = {
{
NET_SOCKET_MODULE_NAME,
0,
std_ops
},
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
socket_acquire,
socket_release,
socket_spawn_pending,
socket_dequeue_connected,
socket_count_connected,
socket_set_max_backlog,
socket_has_parent,
socket_connected,
NULL,
NULL,
NULL,
socket_notify,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
};
net_protocol*
init_protocol(net_socket** _socket)
{
net_socket *socket;
status_t status = socket_create(AF_INET, SOCK_STREAM, IPPROTO_TCP, &socket);
if (status < B_OK)
return NULL;
status = socket->first_info->open(socket->first_protocol);
if (status < B_OK) {
fprintf(stderr, "tcp_tester: cannot open client: %s\n", strerror(status));
socket_delete(socket);
return NULL;
}
*_socket = socket;
return socket->first_protocol;
}
void
close_protocol(net_protocol* protocol)
{
gTCPModule->close(protocol);
if (gTCPModule->free(protocol) == B_OK)
gTCPModule->uninit_protocol(protocol);
}
status_t
datalink_send_data(struct net_route *route, net_buffer *buffer)
{
struct context* context = (struct context*)route->gateway;
buffer->interface_address = &gInterfaceAddress;
gInterfaceAddress.AcquireReference();
context->lock.Lock();
list_add_item(&context->list, buffer);
context->lock.Unlock();
release_sem(context->wait_sem);
return B_OK;
}
status_t
datalink_send_datagram(net_protocol *protocol, net_domain *domain,
net_buffer *buffer)
{
panic("called");
return B_ERROR;
}
struct net_route *
get_route(struct net_domain *_domain, const struct sockaddr *address)
{
if (is_server(address)) {
return &sServerContext.route;
}
return &sClientContext.route;
}
static void
put_route(struct net_domain* _domain, net_route* route)
{
}
net_datalink_module_info gNetDatalinkModule = {
{
NET_DATALINK_MODULE_NAME,
0,
std_ops
},
NULL,
datalink_send_data,
datalink_send_datagram,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
get_route,
NULL,
put_route,
NULL,
NULL,
NULL,
};
status_t
domain_open(net_protocol *protocol)
{
return B_OK;
}
status_t
domain_close(net_protocol *protocol)
{
return B_OK;
}
status_t
domain_free(net_protocol *protocol)
{
return B_OK;
}
status_t
domain_connect(net_protocol *protocol, const struct sockaddr *address)
{
return B_ERROR;
}
status_t
domain_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
{
return EOPNOTSUPP;
}
status_t
domain_control(net_protocol *protocol, int level, int option, void *value,
size_t *_length)
{
return B_ERROR;
}
status_t
domain_bind(net_protocol *protocol, const struct sockaddr *address)
{
memcpy(&protocol->socket->address, address, sizeof(struct sockaddr_in));
protocol->socket->address.ss_len = sizeof(struct sockaddr_in);
return B_OK;
}
status_t
domain_unbind(net_protocol *protocol, struct sockaddr *address)
{
return B_OK;
}
status_t
domain_listen(net_protocol *protocol, int count)
{
return EOPNOTSUPP;
}
status_t
domain_shutdown(net_protocol *protocol, int direction)
{
return EOPNOTSUPP;
}
status_t
domain_send_data(net_protocol *protocol, net_buffer *buffer)
{
struct net_route *route = get_route(&sDomain, (sockaddr *)&buffer->destination);
if (route == NULL)
return ENETUNREACH;
return datalink_send_data(route, buffer);
}
status_t
domain_send_routed_data(net_protocol *protocol, struct net_route *route,
net_buffer *buffer)
{
return datalink_send_data(route, buffer);
}
ssize_t
domain_send_avail(net_protocol *protocol)
{
return B_ERROR;
}
status_t
domain_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
net_buffer **_buffer)
{
return B_ERROR;
}
ssize_t
domain_read_avail(net_protocol *protocol)
{
return B_ERROR;
}
struct net_domain *
domain_get_domain(net_protocol *protocol)
{
return &sDomain;
}
size_t
domain_get_mtu(net_protocol *protocol, const struct sockaddr *address)
{
return 1480;
}
status_t
domain_receive_data(net_buffer *buffer)
{
uint32 packetNumber = atomic_add(&sPacketNumber, 1);
bool drop = false;
if (sDropList.find(packetNumber) != sDropList.end()
|| (sRandomDrop > 0.0 && (1.0 * rand() / RAND_MAX) > sRandomDrop))
drop = true;
if (!drop && (sRoundTripTime > 0 || sRandomRoundTrip || sIncreasingRoundTrip)) {
bigtime_t add = 0;
if (sRandomRoundTrip)
add = (bigtime_t)(1.0 * rand() / RAND_MAX * 500000) - 250000;
if (sIncreasingRoundTrip)
sRoundTripTime += (bigtime_t)(1.0 * rand() / RAND_MAX * 150000);
snooze(sRoundTripTime / 2 + add);
}
if (sPacketMonitor != NULL) {
sPacketMonitor(buffer, packetNumber, drop);
} else if (drop)
printf("<**** DROPPED %ld ****>\n", packetNumber);
if (drop) {
gNetBufferModule.free(buffer);
return B_OK;
}
return gTCPModule->receive_data(buffer);
}
status_t
domain_error(net_error error, net_buffer* data)
{
return B_ERROR;
}
status_t
domain_error_reply(net_protocol* self, net_buffer* cause,
net_error error, net_error_data* errorData)
{
return B_ERROR;
}
net_protocol_module_info gDomainModule = {
{
NULL,
0,
std_ops
},
NET_PROTOCOL_ATOMIC_MESSAGES,
NULL,
NULL,
domain_open,
domain_close,
domain_free,
domain_connect,
domain_accept,
domain_control,
NULL,
NULL,
domain_bind,
domain_unbind,
domain_listen,
domain_shutdown,
domain_send_data,
domain_send_routed_data,
domain_send_avail,
domain_read_data,
domain_read_avail,
domain_get_domain,
domain_get_mtu,
domain_receive_data,
NULL,
domain_error,
domain_error_reply,
NULL,
NULL,
};
static void
dump_printf(net_buffer* buffer, int32 packetNumber, bool willBeDropped)
{
static bigtime_t lastTime = 0;
NetBufferHeaderReader<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() < B_OK)
return;
tcp_header &header = bufferHeader.Data();
bigtime_t now = system_time();
if (lastTime == 0)
lastTime = now;
printf("\33[0m% 3ld %8.6f (%8.6f) ", packetNumber, (now - sStartTime) / 1000000.0,
(now - lastTime) / 1000000.0);
lastTime = now;
if (is_server((sockaddr *)buffer->source))
printf("\33[31mserver > client: ");
else
printf("client > server: ");
int32 length = buffer->size - header.HeaderLength();
if ((header.flags & TCP_FLAG_PUSH) != 0)
putchar('P');
if ((header.flags & TCP_FLAG_SYNCHRONIZE) != 0)
putchar('S');
if ((header.flags & TCP_FLAG_FINISH) != 0)
putchar('F');
if ((header.flags & TCP_FLAG_RESET) != 0)
putchar('R');
if ((header.flags
& (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH | TCP_FLAG_PUSH | TCP_FLAG_RESET)) == 0)
putchar('.');
printf(" %lu:%lu (%lu)", header.Sequence(), header.Sequence() + length, length);
if ((header.flags & TCP_FLAG_ACKNOWLEDGE) != 0)
printf(" ack %lu", header.Acknowledge());
printf(" win %u", header.AdvertisedWindow());
if (header.HeaderLength() > sizeof(tcp_header)) {
int32 size = header.HeaderLength() - sizeof(tcp_header);
tcp_option *option;
uint8 optionsBuffer[1024];
if (gBufferModule->direct_access(buffer, sizeof(tcp_header),
size, (void **)&option) != B_OK) {
if (size > 1024) {
printf("options too large to take into account (%ld bytes)\n", size);
size = 1024;
}
gBufferModule->read(buffer, sizeof(tcp_header), optionsBuffer, size);
option = (tcp_option *)optionsBuffer;
}
while (size > 0) {
uint32 length = 1;
switch (option->kind) {
case TCP_OPTION_END:
case TCP_OPTION_NOP:
break;
case TCP_OPTION_MAX_SEGMENT_SIZE:
printf(" <mss %u>", ntohs(option->max_segment_size));
length = 4;
break;
case TCP_OPTION_WINDOW_SHIFT:
printf(" <ws %u>", option->window_shift);
length = 3;
break;
case TCP_OPTION_TIMESTAMP:
printf(" <ts %lu:%lu>", option->timestamp.value, option->timestamp.reply);
length = 10;
break;
default:
length = option->length;
if (length == 0)
size = 0;
break;
}
size -= length;
option = (tcp_option *)((uint8 *)option + length);
}
}
if (willBeDropped)
printf(" <DROPPED>");
printf("\33[0m\n");
}
static void
dump_pcap(net_buffer* buffer, int32 packetNumber, bool willBeDropped)
{
if (willBeDropped) {
printf(" <DROPPED>\n");
return;
}
const bigtime_t time = real_time_clock_usecs();
struct pcap_packet_header pcap_header;
pcap_header.ts_sec = time / 1000000;
pcap_header.ts_usec = time % 1000000;
pcap_header.included_len = sizeof(struct ip) + buffer->size;
pcap_header.original_len = pcap_header.included_len;
struct ip ip_header;
ip_header.ip_v = IPVERSION;
ip_header.ip_hl = sizeof(struct ip) >> 2;
ip_header.ip_tos = 0;
ip_header.ip_len = htons(sizeof(struct ip) + buffer->size);
ip_header.ip_id = htons(packetNumber);
ip_header.ip_off = 0;
ip_header.ip_ttl = 254;
ip_header.ip_p = IPPROTO_TCP;
ip_header.ip_sum = 0;
ip_header.ip_src.s_addr = ((sockaddr_in*)buffer->source)->sin_addr.s_addr;
ip_header.ip_dst.s_addr = ((sockaddr_in*)buffer->destination)->sin_addr.s_addr;
size_t count = 16, used = 0;
iovec vecs[count];
vecs[used].iov_base = &pcap_header;
vecs[used].iov_len = sizeof(pcap_header);
used++;
vecs[used].iov_base = &ip_header;
vecs[used].iov_len = sizeof(ip_header);
used++;
used += gNetBufferModule.get_iovecs(buffer, vecs + used, count - used);
static mutex writesLock = MUTEX_INITIALIZER("pcap writes");
MutexLocker _(writesLock);
ssize_t written = writev(sPcapFD, vecs, used);
if (written != (pcap_header.included_len + sizeof(pcap_packet_header))) {
fprintf(stderr, "writing to pcap file failed\n");
exit(1);
}
}
static bool
setup_dump_pcap(const char* file)
{
sPcapFD = open(file, O_CREAT | O_WRONLY | O_TRUNC);
if (sPcapFD < 0) {
fprintf(stderr, "tcp_shell: Failed to open output pcap file: %d\n",
errno);
return false;
}
struct pcap_header header;
header.magic = PCAP_MAGIC;
header.version_major = 2;
header.version_minor = 4;
header.timezone = 0;
header.timestamp_accuracy = 0;
header.max_packet_length = 65535;
header.linktype = PCAP_LINKTYPE_IPV4;
if (write(sPcapFD, &header, sizeof(header)) != sizeof(header)) {
fprintf(stderr, "tcp_shell: Failed to write pcap file header: %d\n",
errno);
return false;
}
sPacketMonitor = dump_pcap;
return true;
}
int32
receiving_thread(void* _data)
{
struct context* context = (struct context*)_data;
struct net_buffer* reorderBuffer = NULL;
while (true) {
status_t status;
do {
status = acquire_sem(context->wait_sem);
} while (status == B_INTERRUPTED);
if (status < B_OK)
break;
while (true) {
context->lock.Lock();
net_buffer* buffer = (net_buffer*)list_remove_head_item(
&context->list);
context->lock.Unlock();
if (buffer == NULL)
break;
if (sSimultaneousConnect && context->server && is_syn(buffer)) {
sockaddr_in address;
memset(&address, 0, sizeof(address));
address.sin_family = AF_INET;
address.sin_port = htons(1023);
address.sin_addr.s_addr = htonl(0xc0a80001);
status_t status = socket_connect(gServerSocket,
(struct sockaddr *)&address, sizeof(struct sockaddr));
if (status < B_OK)
fprintf(stderr, "tcp_tester: simultaneous connect failed: %s\n", strerror(status));
sSimultaneousConnect = false;
}
if (sSimultaneousClose && !context->server && is_fin(buffer)) {
close_protocol(gClientSocket->first_protocol);
sSimultaneousClose = false;
}
if ((sRandomReorder > 0.0
|| sReorderList.find(sPacketNumber) != sReorderList.end())
&& reorderBuffer == NULL
&& (1.0 * rand() / RAND_MAX) > sRandomReorder) {
reorderBuffer = buffer;
} else {
if (sDomain.module->receive_data(buffer) < B_OK)
gNetBufferModule.free(buffer);
if (reorderBuffer != NULL) {
if (sDomain.module->receive_data(reorderBuffer) < B_OK)
gNetBufferModule.free(reorderBuffer);
reorderBuffer = NULL;
}
}
}
}
return 0;
}
int32
server_thread(void*)
{
while (true) {
net_socket* connectionSocket;
sockaddr_in address;
socklen_t size = sizeof(struct sockaddr_in);
status_t status = socket_accept(gServerSocket,
(struct sockaddr *)&address, &size, &connectionSocket);
if (status < B_OK) {
fprintf(stderr, "SERVER: accepting failed: %s\n", strerror(status));
break;
}
printf("server: got connection from %08x\n", address.sin_addr.s_addr);
char buffer[1024];
ssize_t bytesRead;
while ((bytesRead = socket_recv(connectionSocket, buffer,
sizeof(buffer), 0)) > 0) {
printf("server: received %ld bytes\n", bytesRead);
if (sServerActiveClose) {
printf("server: active close\n");
close_protocol(connectionSocket->first_protocol);
return 0;
}
}
if (bytesRead < 0)
printf("server: receiving failed: %s\n", strerror(bytesRead));
else
printf("server: peer closed connection.\n");
snooze(1000000);
close_protocol(connectionSocket->first_protocol);
}
return 0;
}
void
setup_server()
{
sockaddr_in address;
memset(&address, 0, sizeof(address));
address.sin_len = sizeof(sockaddr_in);
address.sin_family = AF_INET;
address.sin_port = htons(1024);
address.sin_addr.s_addr = INADDR_ANY;
status_t status = socket_bind(gServerSocket, (struct sockaddr*)&address,
sizeof(struct sockaddr));
if (status < B_OK) {
fprintf(stderr, "tcp_tester: cannot bind server: %s\n", strerror(status));
exit(1);
}
status = socket_listen(gServerSocket, 40);
if (status < B_OK) {
fprintf(stderr, "tcp_tester: server cannot listen: %s\n",
strerror(status));
exit(1);
}
thread_id serverThread = spawn_thread(server_thread, "server",
B_NORMAL_PRIORITY, NULL);
if (serverThread < B_OK) {
fprintf(stderr, "tcp_tester: cannot start server: %s\n",
strerror(serverThread));
exit(1);
}
resume_thread(serverThread);
}
void
setup_context(struct context& context, bool server)
{
list_init(&context.list);
context.route.interface_address = &gInterfaceAddress;
context.route.gateway = (sockaddr *)&context;
context.route.mtu = 1500;
context.server = server;
context.wait_sem = create_sem(0, "receive wait");
context.thread = spawn_thread(receiving_thread,
server ? "server receiver" : "client receiver", B_NORMAL_PRIORITY,
&context);
resume_thread(context.thread);
}
void
cleanup_context(struct context& context)
{
delete_sem(context.wait_sem);
status_t status;
wait_for_thread(context.thread, &status);
}
static void do_help(int argc, char** argv);
static void
do_connect(int argc, char** argv)
{
sSimultaneousConnect = false;
int port = 1024;
if (argc > 1) {
if (!strcmp(argv[1], "-s"))
sSimultaneousConnect = true;
else if (isdigit(argv[1][0]))
port = atoi(argv[1]);
else {
fprintf(stderr, "usage: connect [-s|<port-number>]\n");
return;
}
}
if (sSimultaneousConnect) {
sockaddr_in address;
memset(&address, 0, sizeof(address));
address.sin_family = AF_INET;
address.sin_port = htons(1023);
address.sin_addr.s_addr = htonl(0xc0a80001);
status_t status = socket_bind(gClientSocket, (struct sockaddr *)&address,
sizeof(struct sockaddr));
if (status < B_OK) {
fprintf(stderr, "Could not bind to port 1023: %s\n", strerror(status));
sSimultaneousConnect = false;
return;
}
}
sStartTime = system_time();
sockaddr_in address;
memset(&address, 0, sizeof(address));
address.sin_family = AF_INET;
address.sin_port = htons(port);
address.sin_addr.s_addr = htonl(0xc0a80001);
status_t status = socket_connect(gClientSocket, (struct sockaddr *)&address,
sizeof(struct sockaddr));
if (status < B_OK)
fprintf(stderr, "tcp_tester: could not connect: %s\n", strerror(status));
}
static ssize_t
parse_size(const char* arg)
{
char *unit;
ssize_t size = strtoul(arg, &unit, 0);
if (unit != NULL && unit[0]) {
if (unit[0] == 'k' || unit[0] == 'K')
size *= 1024;
else if (unit[0] == 'm' || unit[0] == 'M')
size *= 1024 * 1024;
else {
fprintf(stderr, "unknown unit specified!\n");
return -1;
}
}
return size;
}
static void
do_send(int argc, char** argv)
{
ssize_t size = 1024;
if (argc > 1 && isdigit(argv[1][0])) {
size = parse_size(argv[1]);
if (size < 0)
return;
} else if (argc > 1) {
fprintf(stderr, "invalid args!\n");
return;
}
if (size > 4 * 1024 * 1024) {
printf("amount to send will be limited to 4 MB\n");
size = 4 * 1024 * 1024;
}
char *buffer = (char *)malloc(size);
if (buffer == NULL) {
fprintf(stderr, "not enough memory!\n");
return;
}
MemoryDeleter bufferDeleter(buffer);
for (uint32 i = 0; i < size; i++) {
buffer[i] = (char)(i & 0xff);
}
ssize_t bytesWritten = socket_send(gClientSocket, buffer, size, 0);
if (bytesWritten < B_OK) {
fprintf(stderr, "failed sending buffer: %s\n", strerror(bytesWritten));
return;
}
}
static void
do_send_loop(int argc, char** argv)
{
if (argc != 2 || !isdigit(argv[1][0])) {
fprintf(stderr, "invalid args!\n");
return;
}
ssize_t size = parse_size(argv[1]);
if (size < 0)
return;
const size_t bufferSize = 4096;
char *buffer = (char *)malloc(bufferSize);
if (buffer == NULL) {
fprintf(stderr, "not enough memory!\n");
return;
}
MemoryDeleter bufferDeleter(buffer);
for (uint32 i = 0; i < bufferSize; i++) {
buffer[i] = (char)(i & 0xff);
}
for (ssize_t total = 0; total < size; ) {
ssize_t bytesWritten = socket_send(gClientSocket, buffer, bufferSize, 0);
if (bytesWritten < B_OK) {
fprintf(stderr, "failed sending buffer (after %" B_PRIdSSIZE "): %s\n",
total, strerror(bytesWritten));
return;
}
total += bufferSize;
}
}
static void
do_close(int argc, char** argv)
{
sSimultaneousClose = false;
sServerActiveClose = true;
if (argc > 1) {
if (!strcmp(argv[1], "-s"))
sSimultaneousClose = true;
else {
fprintf(stderr, "usage: close [-s]\n");
return;
}
}
gClientSocket->send.timeout = 0;
char buffer[32767] = {'q'};
ssize_t bytesWritten = socket_send(gClientSocket, buffer, sizeof(buffer), 0);
if (bytesWritten < B_OK) {
fprintf(stderr, "failed sending buffer: %s\n", strerror(bytesWritten));
return;
}
}
static void
do_drop(int argc, char** argv)
{
if (argc == 1) {
if (sRandomDrop > 0.0)
printf("Drop probability is %f\n", sRandomDrop);
printf("Drop pakets:\n");
std::set<uint32>::iterator iterator = sDropList.begin();
uint32 count = 0;
for (; iterator != sDropList.end(); iterator++) {
printf("%4lu\n", *iterator);
count++;
}
if (count == 0)
printf("<empty>\n");
} else if (!strcmp(argv[1], "-f")) {
sDropList.clear();
puts("drop list cleared.");
} else if (!strcmp(argv[1], "-r")) {
if (argc < 3) {
fprintf(stderr, "No drop probability specified.\n");
return;
}
sRandomDrop = atof(argv[2]);
if (sRandomDrop < 0.0)
sRandomDrop = 0;
else if (sRandomDrop > 1.0)
sRandomDrop = 1.0;
} else if (isdigit(argv[1][0])) {
for (int i = 1; i < argc; i++) {
uint32 packet = strtoul(argv[i], NULL, 0);
if (packet == 0) {
fprintf(stderr, "invalid packet number: %s\n", argv[i]);
break;
}
sDropList.insert(packet);
}
} else {
puts("usage: drop <packet-number> [...]\n"
" or: drop -r <probability>\n\n"
" or: drop [-f]\n\n"
"Specifiying -f flushes the drop list, -r sets the probability a packet\n"
"is dropped; if you called drop without any arguments, the current\n"
"drop list is dumped.");
}
}
static void
do_reorder(int argc, char** argv)
{
if (argc == 1) {
if (sRandomReorder > 0.0)
printf("Reorder probability is %f\n", sRandomReorder);
printf("Reorder packets:\n");
std::set<uint32>::iterator iterator = sReorderList.begin();
uint32 count = 0;
for (; iterator != sReorderList.end(); iterator++) {
printf("%4lu\n", *iterator);
count++;
}
if (count == 0)
printf("<empty>\n");
} else if (!strcmp(argv[1], "-f")) {
sReorderList.clear();
puts("reorder list cleared.");
} else if (!strcmp(argv[1], "-r")) {
if (argc < 3) {
fprintf(stderr, "No reorder probability specified.\n");
return;
}
sRandomReorder = atof(argv[2]);
if (sRandomReorder < 0.0)
sRandomReorder = 0;
else if (sRandomReorder > 1.0)
sRandomReorder = 1.0;
} else if (isdigit(argv[1][0])) {
for (int i = 1; i < argc; i++) {
uint32 packet = strtoul(argv[i], NULL, 0);
if (packet == 0) {
fprintf(stderr, "invalid packet number: %s\n", argv[i]);
break;
}
sReorderList.insert(packet);
}
} else {
puts("usage: reorder <packet-number> [...]\n"
" or: reorder -r <probability>\n\n"
" or: reorder [-f]\n\n"
"Specifiying -f flushes the reorder list, -r sets the probability a packet\n"
"is reordered; if you called reorder without any arguments, the current\n"
"reorder list is dumped.");
}
}
static void
do_round_trip_time(int argc, char** argv)
{
if (argc == 1) {
printf("Current round trip time: %g ms\n", sRoundTripTime / 1000.0);
} else if (!strcmp(argv[1], "-r")) {
sRandomRoundTrip = !sRandomRoundTrip;
printf("Round trip time is now %s.\n", sRandomRoundTrip ? "random" : "fixed");
} else if (!strcmp(argv[1], "-i")) {
sIncreasingRoundTrip = !sIncreasingRoundTrip;
printf("Round trip time is now %s.\n", sIncreasingRoundTrip ? "increasing" : "fixed");
} else if (isdigit(argv[1][0])) {
sRoundTripTime = 1000LL * strtoul(argv[1], NULL, 0);
} else {
puts("usage: rtt <time in ms>\n"
" or: rtt [-r|-i]\n\n"
"Specifiying -r sets random time, -i causes the times to increase over time;\n"
"witout any arguments, the current time is printed.");
}
}
static void
do_dprintf(int argc, char** argv)
{
if (argc > 1)
gDebugOutputEnabled = !strcmp(argv[1], "on");
else
gDebugOutputEnabled = !gDebugOutputEnabled;
printf("debug output turned %s.\n", gDebugOutputEnabled ? "on" : "off");
}
static cmd_entry sBuiltinCommands[] = {
{"connect", do_connect, "Connects the client"},
{"send", do_send, "Sends data from the client to the server"},
{"send_loop", do_send_loop, "Sends data in a loop"},
{"close", do_close, "Performs an active or simultaneous close"},
{"dprintf", do_dprintf, "Toggles debug output"},
{"drop", do_drop, "Lets you drop packets during transfer"},
{"reorder", do_reorder, "Lets you reorder packets during transfer"},
{"help", do_help, "prints this help text"},
{"rtt", do_round_trip_time, "Specifies the round trip time"},
{"quit", NULL, "exits the application"},
{NULL, NULL, NULL},
};
static void
do_help(int argc, char** argv)
{
printf("Available commands:\n");
for (cmd_entry* command = sBuiltinCommands; command->name != NULL; command++) {
printf("%8s - %s\n", command->name, command->help);
}
}
int
main(int argc, char* argv[])
{
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "-w") == 0 && (i + 1) < argc) {
if (!setup_dump_pcap(argv[++i]))
return 1;
}
}
if (sPacketMonitor == NULL)
sPacketMonitor = dump_printf;
status_t status = init_timers();
if (status < B_OK) {
fprintf(stderr, "tcp_tester: Could not initialize timers: %s\n",
strerror(status));
return 1;
}
_add_builtin_module((module_info*)&gNetStackModule);
_add_builtin_module((module_info*)&gNetBufferModule);
_add_builtin_module((module_info*)&gNetSocketModule);
_add_builtin_module((module_info*)&gNetDatalinkModule);
_add_builtin_module(modules[0]);
if (_get_builtin_dependencies() < B_OK) {
fprintf(stderr, "tcp_tester: Could not initialize modules: %s\n",
strerror(status));
return 1;
}
sockaddr_in interfaceAddress;
interfaceAddress.sin_len = sizeof(sockaddr_in);
interfaceAddress.sin_family = AF_INET;
interfaceAddress.sin_addr.s_addr = htonl(0xc0a80001);
gInterfaceAddress.local = (sockaddr*)&interfaceAddress;
gInterfaceAddress.domain = &sDomain;
status = get_module("network/protocols/tcp/v1", (module_info **)&gTCPModule);
if (status < B_OK) {
fprintf(stderr, "tcp_tester: Could not open TCP module: %s\n",
strerror(status));
return 1;
}
net_protocol* client = init_protocol(&gClientSocket);
if (client == NULL)
return 1;
net_protocol* server = init_protocol(&gServerSocket);
if (server == NULL)
return 1;
setup_context(sClientContext, false);
setup_context(sServerContext, true);
printf("*** Server: %p (%ld), Client: %p (%ld)\n", server,
sServerContext.thread, client, sClientContext.thread);
setup_server();
while (true) {
printf("> ");
fflush(stdout);
char line[1024];
if (fgets(line, sizeof(line), stdin) == NULL)
break;
argc = 0;
argv = build_argv(line, &argc);
if (argv == NULL || argc == 0)
continue;
int length = strlen(argv[0]);
#if 0
char *newLine = strchr(line, '\n');
if (newLine != NULL)
newLine[0] = '\0';
#endif
if (!strcmp(argv[0], "quit")
|| !strcmp(argv[0], "exit")
|| !strcmp(argv[0], "q"))
break;
bool found = false;
for (cmd_entry* command = sBuiltinCommands; command->name != NULL; command++) {
if (!strncmp(command->name, argv[0], length)) {
command->func(argc, argv);
found = true;
break;
}
}
if (!found)
fprintf(stderr, "Unknown command \"%s\". Type \"help\" for a list of commands.\n", argv[0]);
free(argv);
}
close_protocol(client);
close_protocol(server);
snooze(2000000);
cleanup_context(sClientContext);
cleanup_context(sServerContext);
put_module("network/protocols/tcp/v1");
uninit_timers();
return 0;
}