#include <new>
#include <errno.h>
#include <netdb.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <ByteOrder.h>
#include "Compatibility.h"
#include "DebugSupport.h"
#include "InsecureChannel.h"
#include "InsecureConnection.h"
#include "NetAddress.h"
#include "NetFSDefs.h"
namespace InsecureConnectionDefs {
const int32 kProtocolVersion = 1;
const bigtime_t kAcceptingTimeout = 10000000;
const int32 kMinUpStreamChannels = 1;
const int32 kMaxUpStreamChannels = 10;
const int32 kDefaultUpStreamChannels = 5;
const int32 kMinDownStreamChannels = 1;
const int32 kMaxDownStreamChannels = 5;
const int32 kDefaultDownStreamChannels = 1;
}
using namespace InsecureConnectionDefs;
struct SocketCloser {
SocketCloser(int fd) : fFD(fd) {}
~SocketCloser()
{
if (fFD >= 0)
closesocket(fFD);
}
int Detach()
{
int fd = fFD;
fFD = -1;
return fd;
}
private:
int fFD;
};
InsecureConnection::InsecureConnection()
{
}
InsecureConnection::~InsecureConnection()
{
}
status_t
InsecureConnection::Init(int fd)
{
status_t error = AbstractConnection::Init();
if (error != B_OK) {
closesocket(fd);
return error;
}
Channel* channel = new(std::nothrow) InsecureChannel(fd);
if (!channel) {
closesocket(fd);
return B_NO_MEMORY;
}
error = AddDownStreamChannel(channel);
if (error != B_OK) {
delete channel;
return error;
}
return B_OK;
}
status_t
InsecureConnection::Init(const char* parameters)
{
PRINT(("InsecureConnection::Init\n"));
if (!parameters)
return B_BAD_VALUE;
status_t error = AbstractConnection::Init();
if (error != B_OK)
return error;
char server[256];
uint16 port = kDefaultInsecureConnectionPort;
int upStreamChannels = kDefaultUpStreamChannels;
int downStreamChannels = kDefaultDownStreamChannels;
if (strchr(parameters, ':')) {
int result = sscanf(parameters, "%255[^:]:%hu %d %d", server, &port,
&upStreamChannels, &downStreamChannels);
if (result < 2)
return B_BAD_VALUE;
} else {
int result = sscanf(parameters, "%255[^:] %d %d", server,
&upStreamChannels, &downStreamChannels);
if (result < 1)
return B_BAD_VALUE;
}
NetAddress netAddress;
error = NetAddressResolver().GetHostAddress(server, &netAddress);
if (error != B_OK)
return error;
in_addr serverAddr = netAddress.GetAddress().sin_addr;
Channel* channel;
error = _OpenClientChannel(serverAddr, port, &channel);
if (error != B_OK)
return error;
error = AddUpStreamChannel(channel);
if (error != B_OK) {
delete channel;
return error;
}
ConnectRequest request;
request.protocolVersion = B_HOST_TO_BENDIAN_INT32(kProtocolVersion);
request.serverAddress = serverAddr.s_addr;
request.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
request.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
error = channel->Send(&request, sizeof(ConnectRequest));
if (error != B_OK)
return error;
ConnectReply reply;
error = channel->Receive(&reply, sizeof(ConnectReply));
if (error != B_OK)
return error;
error = B_BENDIAN_TO_HOST_INT32(reply.error);
if (error != B_OK)
return error;
upStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.upStreamChannels);
downStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.downStreamChannels);
port = B_BENDIAN_TO_HOST_INT16(reply.port);
int32 allChannels = upStreamChannels + downStreamChannels;
for (int32 i = 1; i < allChannels; i++) {
PRINT(" creating channel %" B_PRId32 "\n", i);
error = _OpenClientChannel(serverAddr, port, &channel);
if (error != B_OK)
RETURN_ERROR(error);
if (i < upStreamChannels)
error = AddUpStreamChannel(channel);
else
error = AddDownStreamChannel(channel);
if (error != B_OK) {
delete channel;
return error;
}
}
return B_OK;
}
status_t
InsecureConnection::FinishInitialization()
{
PRINT(("InsecureConnection::FinishInitialization()\n"));
InsecureChannel* channel
= dynamic_cast<InsecureChannel*>(DownStreamChannelAt(0));
if (!channel)
return B_BAD_VALUE;
ConnectRequest request;
status_t error = channel->Receive(&request, sizeof(ConnectRequest));
if (error != B_OK)
return error;
int32 protocolVersion = B_BENDIAN_TO_HOST_INT32(request.protocolVersion);
if (protocolVersion != kProtocolVersion) {
_SendErrorReply(channel, B_ERROR);
return B_ERROR;
}
in_addr serverAddr;
serverAddr.s_addr = request.serverAddress;
int32 upStreamChannels = B_BENDIAN_TO_HOST_INT32(request.upStreamChannels);
int32 downStreamChannels = B_BENDIAN_TO_HOST_INT32(
request.downStreamChannels);
if (upStreamChannels < kMinUpStreamChannels)
upStreamChannels = kMinUpStreamChannels;
else if (upStreamChannels > kMaxUpStreamChannels)
upStreamChannels = kMaxUpStreamChannels;
if (downStreamChannels < kMinDownStreamChannels)
downStreamChannels = kMinDownStreamChannels;
else if (downStreamChannels > kMaxDownStreamChannels)
downStreamChannels = kMaxDownStreamChannels;
NetAddress peerAddress;
if (channel->GetPeerAddress(&peerAddress) == B_OK
&& peerAddress.IsLocal()) {
upStreamChannels = 1;
if (downStreamChannels > 2)
downStreamChannels = 2;
}
int32 allChannels = upStreamChannels + downStreamChannels;
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
error = errno;
_SendErrorReply(channel, error);
return error;
}
SocketCloser _(fd);
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = 0;
addr.sin_addr = serverAddr;
if (bind(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
error = errno;
_SendErrorReply(channel, error);
return error;
}
socklen_t addrSize = sizeof(addr);
if (getsockname(fd, (sockaddr*)&addr, &addrSize) < 0) {
error = errno;
_SendErrorReply(channel, error);
return error;
}
int dontBlock = 1;
if (setsockopt(fd, SOL_SOCKET, SO_NONBLOCK, &dontBlock, sizeof(int)) < 0) {
error = errno;
_SendErrorReply(channel, error);
return error;
}
if (listen(fd, allChannels - 1) < 0) {
error = errno;
_SendErrorReply(channel, error);
return error;
}
ConnectReply reply;
reply.error = B_HOST_TO_BENDIAN_INT32(B_OK);
reply.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
reply.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
reply.port = addr.sin_port;
error = channel->Send(&reply, sizeof(ConnectReply));
if (error != B_OK)
return error;
bigtime_t startAccepting = system_time();
for (int32 i = 1; i < allChannels; ) {
int channelFD = accept(fd, NULL, 0);
if (channelFD < 0) {
error = errno;
if (error == B_INTERRUPTED) {
error = B_OK;
continue;
}
if (error == B_WOULD_BLOCK) {
bigtime_t now = system_time();
if (now - startAccepting > kAcceptingTimeout)
RETURN_ERROR(B_TIMED_OUT);
snooze(10000);
continue;
}
RETURN_ERROR(error);
}
PRINT(" accepting channel %" B_PRId32 "\n", i);
channel = new(std::nothrow) InsecureChannel(channelFD);
if (!channel) {
closesocket(channelFD);
return B_NO_MEMORY;
}
if (i < upStreamChannels)
error = AddDownStreamChannel(channel);
else
error = AddUpStreamChannel(channel);
if (error != B_OK) {
delete channel;
return error;
}
i++;
startAccepting = system_time();
}
return B_OK;
}
status_t
InsecureConnection::_OpenClientChannel(in_addr serverAddr, uint16 port,
Channel** _channel)
{
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0)
return errno;
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr = serverAddr;
if (connect(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
status_t error = errno;
closesocket(fd);
RETURN_ERROR(error);
}
Channel* channel = new(std::nothrow) InsecureChannel(fd);
if (!channel) {
closesocket(fd);
return B_NO_MEMORY;
}
*_channel = channel;
return B_OK;
}
status_t
InsecureConnection::_SendErrorReply(Channel* channel, status_t error)
{
ConnectReply reply;
reply.error = B_HOST_TO_BENDIAN_INT32(error);
return channel->Send(&reply, sizeof(ConnectReply));
}