⛏️ index : haiku.git

// InsecureConnection.cpp

// TODO: Asynchronous connecting on client side?

#include <new>

#include <errno.h>
#include <netdb.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>

#include <ByteOrder.h>

#include "Compatibility.h"
#include "DebugSupport.h"
#include "InsecureChannel.h"
#include "InsecureConnection.h"
#include "NetAddress.h"
#include "NetFSDefs.h"

namespace InsecureConnectionDefs {

const int32 kProtocolVersion = 1;

const bigtime_t kAcceptingTimeout = 10000000;	// 10 s

// number of client up/down stream channels
const int32 kMinUpStreamChannels		= 1;
const int32 kMaxUpStreamChannels		= 10;
const int32 kDefaultUpStreamChannels	= 5;
const int32 kMinDownStreamChannels		= 1;
const int32 kMaxDownStreamChannels		= 5;
const int32 kDefaultDownStreamChannels	= 1;

} // namespace InsecureConnectionDefs

using namespace InsecureConnectionDefs;

// SocketCloser
struct SocketCloser {
	SocketCloser(int fd) : fFD(fd) {}
	~SocketCloser()
	{
		if (fFD >= 0)
			closesocket(fFD);
	}

	int Detach()
	{
		int fd = fFD;
		fFD = -1;
		return fd;
	}

private:
	int	fFD;
};


// #pragma mark -
// #pragma mark ----- InsecureConnection -----

// constructor
InsecureConnection::InsecureConnection()
{
}

// destructor
InsecureConnection::~InsecureConnection()
{
}

// Init (server side)
status_t
InsecureConnection::Init(int fd)
{
	status_t error = AbstractConnection::Init();
	if (error != B_OK) {
		closesocket(fd);
		return error;
	}
	// create the initial channel
	Channel* channel = new(std::nothrow) InsecureChannel(fd);
	if (!channel) {
		closesocket(fd);
		return B_NO_MEMORY;
	}
	// add it
	error = AddDownStreamChannel(channel);
	if (error != B_OK) {
		delete channel;
		return error;
	}
	return B_OK;
}

// Init (client side)
status_t
InsecureConnection::Init(const char* parameters)
{
PRINT(("InsecureConnection::Init\n"));
	if (!parameters)
		return B_BAD_VALUE;
	status_t error = AbstractConnection::Init();
	if (error != B_OK)
		return error;
	// parse the parameters to get a server name and a port we shall connect to
	// parameter format is "<server>[:port] [ <up> [ <down> ] ]"
	char server[256];
	uint16 port = kDefaultInsecureConnectionPort;
	int upStreamChannels = kDefaultUpStreamChannels;
	int downStreamChannels = kDefaultDownStreamChannels;
	if (strchr(parameters, ':')) {
		int result = sscanf(parameters, "%255[^:]:%hu %d %d", server, &port,
			&upStreamChannels, &downStreamChannels);
		if (result < 2)
			return B_BAD_VALUE;
	} else {
		int result = sscanf(parameters, "%255[^:] %d %d", server,
			&upStreamChannels, &downStreamChannels);
		if (result < 1)
			return B_BAD_VALUE;
	}
	// resolve server address
	NetAddress netAddress;
	error = NetAddressResolver().GetHostAddress(server, &netAddress);
	if (error != B_OK)
		return error;
	in_addr serverAddr = netAddress.GetAddress().sin_addr;
	// open the initial channel
	Channel* channel;
	error = _OpenClientChannel(serverAddr, port, &channel);
	if (error != B_OK)
		return error;
	error = AddUpStreamChannel(channel);
	if (error != B_OK) {
		delete channel;
		return error;
	}
	// send the server a connect request
	ConnectRequest request;
	request.protocolVersion = B_HOST_TO_BENDIAN_INT32(kProtocolVersion);
	request.serverAddress = serverAddr.s_addr;
	request.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
	request.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
	error = channel->Send(&request, sizeof(ConnectRequest));
	if (error != B_OK)
		return error;
	// get the server reply
	ConnectReply reply;
	error = channel->Receive(&reply, sizeof(ConnectReply));
	if (error != B_OK)
		return error;
	error = B_BENDIAN_TO_HOST_INT32(reply.error);
	if (error != B_OK)
		return error;
	upStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.upStreamChannels);
	downStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.downStreamChannels);
	port = B_BENDIAN_TO_HOST_INT16(reply.port);
	// open the remaining channels
	int32 allChannels = upStreamChannels + downStreamChannels;
	for (int32 i = 1; i < allChannels; i++) {
		PRINT("  creating channel %" B_PRId32 "\n", i);
		// open the channel
		error = _OpenClientChannel(serverAddr, port, &channel);
		if (error != B_OK)
			RETURN_ERROR(error);
		// add it
		if (i < upStreamChannels)
			error = AddUpStreamChannel(channel);
		else
			error = AddDownStreamChannel(channel);
		if (error != B_OK) {
			delete channel;
			return error;
		}
	}
	return B_OK;
}

// FinishInitialization
status_t
InsecureConnection::FinishInitialization()
{
PRINT(("InsecureConnection::FinishInitialization()\n"));
	// get the down stream channel
	InsecureChannel* channel
		= dynamic_cast<InsecureChannel*>(DownStreamChannelAt(0));
	if (!channel)
		return B_BAD_VALUE;
	// receive the connect request
	ConnectRequest request;
	status_t error = channel->Receive(&request, sizeof(ConnectRequest));
	if (error != B_OK)
		return error;
	// check the protocol version
	int32 protocolVersion = B_BENDIAN_TO_HOST_INT32(request.protocolVersion);
	if (protocolVersion != kProtocolVersion) {
		_SendErrorReply(channel, B_ERROR);
		return B_ERROR;
	}
	// get our address (we need it for binding)
	in_addr serverAddr;
	serverAddr.s_addr = request.serverAddress;
	// check number of up and down stream channels
	int32 upStreamChannels = B_BENDIAN_TO_HOST_INT32(request.upStreamChannels);
	int32 downStreamChannels = B_BENDIAN_TO_HOST_INT32(
		request.downStreamChannels);
	if (upStreamChannels < kMinUpStreamChannels)
		upStreamChannels = kMinUpStreamChannels;
	else if (upStreamChannels > kMaxUpStreamChannels)
		upStreamChannels = kMaxUpStreamChannels;
	if (downStreamChannels < kMinDownStreamChannels)
		downStreamChannels = kMinDownStreamChannels;
	else if (downStreamChannels > kMaxDownStreamChannels)
		downStreamChannels = kMaxDownStreamChannels;
	// due to a bug on BONE we have a maximum of 2 working connections
	// accepted on one listener socket.
	NetAddress peerAddress;
	if (channel->GetPeerAddress(&peerAddress) == B_OK
		&& peerAddress.IsLocal()) {
		upStreamChannels = 1;
		if (downStreamChannels > 2)
			downStreamChannels = 2;
	}
	int32 allChannels = upStreamChannels + downStreamChannels;
	// create a listener socket
	int fd = socket(AF_INET, SOCK_STREAM, 0);
	if (fd < 0) {
		error = errno;
		_SendErrorReply(channel, error);
		return error;
	}
	SocketCloser _(fd);
	// bind it to some port
	sockaddr_in addr;
	addr.sin_family = AF_INET;
	addr.sin_port = 0;
	addr.sin_addr = serverAddr;
	if (bind(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
		error = errno;
		_SendErrorReply(channel, error);
		return error;
	}
	// get the port
	socklen_t addrSize = sizeof(addr);
	if (getsockname(fd, (sockaddr*)&addr, &addrSize) < 0) {
		error = errno;
		_SendErrorReply(channel, error);
		return error;
	}
	// set socket to non-blocking
	int dontBlock = 1;
	if (setsockopt(fd, SOL_SOCKET, SO_NONBLOCK, &dontBlock, sizeof(int)) < 0) {
		error = errno;
		_SendErrorReply(channel, error);
		return error;
	}
	// start listening
	if (listen(fd, allChannels - 1) < 0) {
		error = errno;
		_SendErrorReply(channel, error);
		return error;
	}
	// send the reply
	ConnectReply reply;
	reply.error = B_HOST_TO_BENDIAN_INT32(B_OK);
	reply.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
	reply.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
	reply.port = addr.sin_port;
	error = channel->Send(&reply, sizeof(ConnectReply));
	if (error != B_OK)
		return error;
	// start accepting
	bigtime_t startAccepting = system_time();
	for (int32 i = 1; i < allChannels; ) {
		// accept a connection
		int channelFD = accept(fd, NULL, 0);
		if (channelFD < 0) {
			error = errno;
			if (error == B_INTERRUPTED) {
				error = B_OK;
				continue;
			}
			if (error == B_WOULD_BLOCK) {
				bigtime_t now = system_time();
				if (now - startAccepting > kAcceptingTimeout)
					RETURN_ERROR(B_TIMED_OUT);
				snooze(10000);
				continue;
			}
			RETURN_ERROR(error);
		}
		PRINT("  accepting channel %" B_PRId32 "\n", i);
		// create a channel
		channel = new(std::nothrow) InsecureChannel(channelFD);
		if (!channel) {
			closesocket(channelFD);
			return B_NO_MEMORY;
		}
		// add it
		if (i < upStreamChannels)	// inverse, since we are on server side
			error = AddDownStreamChannel(channel);
		else
			error = AddUpStreamChannel(channel);
		if (error != B_OK) {
			delete channel;
			return error;
		}
		i++;
		startAccepting = system_time();
	}
	return B_OK;
}

// _OpenClientChannel
status_t
InsecureConnection::_OpenClientChannel(in_addr serverAddr, uint16 port,
	Channel** _channel)
{
	// create a socket
	int fd = socket(AF_INET, SOCK_STREAM, 0);
	if (fd < 0)
		return errno;
	// connect
	sockaddr_in addr;
	addr.sin_family = AF_INET;
	addr.sin_port = htons(port);
	addr.sin_addr = serverAddr;
	if (connect(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
		status_t error = errno;
		closesocket(fd);
		RETURN_ERROR(error);
	}
	// create the channel
	Channel* channel = new(std::nothrow) InsecureChannel(fd);
	if (!channel) {
		closesocket(fd);
		return B_NO_MEMORY;
	}
	*_channel = channel;
	return B_OK;
}

// _SendErrorReply
status_t
InsecureConnection::_SendErrorReply(Channel* channel, status_t error)
{
	ConnectReply reply;
	reply.error = B_HOST_TO_BENDIAN_INT32(error);
	return channel->Send(&reply, sizeof(ConnectReply));
}