* Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de.
* Distributed under the terms of the MIT License.
*/
#include <stdio.h>
#include <sys/un.h>
#include <new>
#include <AutoDeleter.h>
#include <StackOrHeapArray.h>
#include <fs/fd.h>
#include <lock.h>
#include <util/AutoLock.h>
#include <vfs.h>
#include <net_buffer.h>
#include <net_protocol.h>
#include <net_socket.h>
#include <net_stack.h>
#include "unix.h"
#include "UnixAddressManager.h"
#include "UnixEndpoint.h"
#define UNIX_MODULE_DEBUG_LEVEL 0
#define UNIX_DEBUG_LEVEL UNIX_MODULE_DEBUG_LEVEL
#include "UnixDebug.h"
extern net_protocol_module_info gUnixModule;
net_stack_module_info *gStackModule;
net_socket_module_info *gSocketModule;
net_buffer_module_info *gBufferModule;
UnixAddressManager gAddressManager;
static struct net_domain *sDomain;
void
destroy_scm_rights_descriptors(const ancillary_data_header* header,
void* data)
{
int count = header->len / sizeof(file_descriptor*);
file_descriptor** descriptors = (file_descriptor**)data;
io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
for (int i = 0; i < count; i++) {
if (descriptors[i] != NULL) {
close_fd(ioContext, descriptors[i]);
put_fd(descriptors[i]);
}
}
}
void
clone_scm_rights_descriptors(const ancillary_data_header* header, void* data)
{
int count = header->len / sizeof(file_descriptor*);
file_descriptor** descriptors = (file_descriptor**)data;
for (int i = 0; i < count; i++) {
if (descriptors[i] != NULL) {
inc_fd_ref_count(descriptors[i]);
inc_fd_open_count(descriptors[i]);
}
}
}
net_protocol *
unix_init_protocol(net_socket *socket)
{
TRACE("[%" B_PRId32 "] unix_init_protocol(%p)\n", find_thread(NULL),
socket);
UnixEndpoint* endpoint;
status_t error = UnixEndpoint::Create(socket, &endpoint);
if (error != B_OK)
return NULL;
error = endpoint->Init();
if (error != B_OK) {
delete endpoint;
return NULL;
}
return endpoint;
}
status_t
unix_uninit_protocol(net_protocol *_protocol)
{
TRACE("[%" B_PRId32 "] unix_uninit_protocol(%p)\n", find_thread(NULL),
_protocol);
((UnixEndpoint*)_protocol)->Uninit();
return B_OK;
}
status_t
unix_open(net_protocol *_protocol)
{
return ((UnixEndpoint*)_protocol)->Open();
}
status_t
unix_close(net_protocol *_protocol)
{
return ((UnixEndpoint*)_protocol)->Close();
}
status_t
unix_free(net_protocol *_protocol)
{
return ((UnixEndpoint*)_protocol)->Free();
}
status_t
unix_connect(net_protocol *_protocol, const struct sockaddr *address)
{
return ((UnixEndpoint*)_protocol)->Connect(address);
}
status_t
unix_accept(net_protocol *_protocol, struct net_socket **_acceptedSocket)
{
return ((UnixEndpoint*)_protocol)->Accept(_acceptedSocket);
}
status_t
unix_control(net_protocol *protocol, int level, int option, void *value,
size_t *_length)
{
return B_BAD_VALUE;
}
status_t
unix_getsockopt(net_protocol *protocol, int level, int option, void *value,
int *_length)
{
UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
if (level == SOL_SOCKET && option == SO_PEERCRED) {
if (*_length < (int)sizeof(ucred))
return B_BAD_VALUE;
*_length = sizeof(ucred);
return endpoint->GetPeerCredentials((ucred*)value);
}
return gSocketModule->get_option(protocol->socket, level, option, value,
_length);
}
status_t
unix_setsockopt(net_protocol *protocol, int level, int option,
const void *_value, int length)
{
UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
if (level == SOL_SOCKET) {
if (option == SO_RCVBUF) {
if (length != sizeof(int))
return B_BAD_VALUE;
status_t error = endpoint->SetReceiveBufferSize(*(int*)_value);
if (error != B_OK)
return error;
} else if (option == SO_SNDBUF) {
}
}
return gSocketModule->set_option(protocol->socket, level, option,
_value, length);
}
status_t
unix_bind(net_protocol *_protocol, const struct sockaddr *_address)
{
return ((UnixEndpoint*)_protocol)->Bind(_address);
}
status_t
unix_unbind(net_protocol *_protocol, struct sockaddr *_address)
{
return ((UnixEndpoint*)_protocol)->Unbind();
}
status_t
unix_listen(net_protocol *_protocol, int count)
{
return ((UnixEndpoint*)_protocol)->Listen(count);
}
status_t
unix_shutdown(net_protocol *_protocol, int direction)
{
return ((UnixEndpoint*)_protocol)->Shutdown(direction);
}
status_t
unix_send_routed_data(net_protocol *_protocol, struct net_route *route,
net_buffer *buffer)
{
return B_ERROR;
}
status_t
unix_send_data(net_protocol *_protocol, net_buffer *buffer)
{
return B_ERROR;
}
ssize_t
unix_send_avail(net_protocol *_protocol)
{
return ((UnixEndpoint*)_protocol)->Sendable();
}
status_t
unix_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
net_buffer **_buffer)
{
return B_ERROR;
}
ssize_t
unix_read_avail(net_protocol *_protocol)
{
return ((UnixEndpoint*)_protocol)->Receivable();
}
struct net_domain *
unix_get_domain(net_protocol *protocol)
{
return sDomain;
}
size_t
unix_get_mtu(net_protocol *protocol, const struct sockaddr *address)
{
return UNIX_MAX_TRANSFER_UNIT;
}
status_t
unix_receive_data(net_buffer *buffer)
{
return B_ERROR;
}
status_t
unix_deliver_data(net_protocol *_protocol, net_buffer *buffer)
{
return B_ERROR;
}
status_t
unix_error_received(net_error error, net_error_data* errorData, net_buffer *data)
{
return B_ERROR;
}
status_t
unix_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
net_error_data *errorData)
{
return B_ERROR;
}
status_t
unix_add_ancillary_data(net_protocol *self, ancillary_data_container *container,
const cmsghdr *header)
{
TRACE("[%" B_PRId32 "] unix_add_ancillary_data(%p, %p, %p (level: %d, type: %d, "
"len: %" B_PRId32 "))\n", find_thread(NULL), self, container, header,
header->cmsg_level, header->cmsg_type, header->cmsg_len);
if (header->cmsg_level != SOL_SOCKET || header->cmsg_type != SCM_RIGHTS)
return B_BAD_VALUE;
int* fds = (int*)CMSG_DATA(header);
int count = (header->cmsg_len - CMSG_LEN(0)) / sizeof(int);
if (count == 0)
return B_BAD_VALUE;
BStackOrHeapArray<file_descriptor*, 8> descriptors(count);
if (!descriptors.IsValid())
return ENOBUFS;
memset(descriptors, 0, sizeof(file_descriptor*) * count);
io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
status_t error = B_OK;
for (int i = 0; i < count; i++) {
descriptors[i] = get_open_fd(ioContext, fds[i]);
if (descriptors[i] == NULL) {
error = EBADF;
break;
}
}
if (error == B_OK) {
ancillary_data_header header;
header.level = SOL_SOCKET;
header.type = SCM_RIGHTS;
header.len = count * sizeof(file_descriptor*);
TRACE("[%" B_PRId32 "] unix_add_ancillary_data(): adding %d FDs to "
"container\n", find_thread(NULL), count);
error = gStackModule->add_ancillary_data(container, &header,
descriptors, destroy_scm_rights_descriptors, clone_scm_rights_descriptors, NULL);
}
if (error != B_OK) {
for (int i = 0; i < count; i++) {
if (descriptors[i] != NULL) {
close_fd(ioContext, descriptors[i]);
put_fd(descriptors[i]);
}
}
}
return error;
}
ssize_t
unix_process_ancillary_data(net_protocol *self,
const ancillary_data_container *container, void *buffer,
size_t bufferSize, int flags)
{
TRACE("[%" B_PRId32 "] unix_process_ancillary_data(%p, %p, %p, %p, %lu)\n",
find_thread(NULL), self, container, buffer, bufferSize);
int totalCount = 0;
ancillary_data_header header;
void* data = NULL;
while ((data = gStackModule->next_ancillary_data(container, data, &header)) != NULL) {
if (header.level != SOL_SOCKET || header.type != SCM_RIGHTS)
return B_BAD_VALUE;
totalCount += header.len / sizeof(file_descriptor*);
}
size_t neededBufferSpace = CMSG_SPACE(sizeof(int) * totalCount);
if (bufferSize < neededBufferSpace)
return B_BAD_VALUE;
cmsghdr* messageHeader = (cmsghdr*)buffer;
messageHeader->cmsg_level = SOL_SOCKET;
messageHeader->cmsg_type = SCM_RIGHTS;
messageHeader->cmsg_len = CMSG_LEN(sizeof(int) * totalCount);
int* fds = (int*)CMSG_DATA(messageHeader);
io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
status_t error = B_OK;
int i = 0;
data = NULL;
while ((data = gStackModule->next_ancillary_data(container, data, &header)) != NULL) {
int count = header.len / sizeof(file_descriptor*);
file_descriptor** descriptors = (file_descriptor**)data;
for (int k = 0; k < count; k++, i++) {
inc_fd_ref_count(descriptors[k]);
fds[i] = new_fd(ioContext, descriptors[k]);
if (fds[i] < 0) {
error = fds[i];
put_fd(descriptors[k]);
for (int j = i - 1; j >= 0; j--)
close_fd_index(ioContext, fds[j]);
break;
}
WriteLocker locker(ioContext->lock);
if ((flags & MSG_CMSG_CLOEXEC) != 0)
fd_set_close_on_exec(ioContext, fds[i], true);
if ((flags & MSG_CMSG_CLOFORK) != 0)
fd_set_close_on_fork(ioContext, fds[i], true);
}
if (error != B_OK)
break;
}
return error == B_OK ? neededBufferSpace : error;
}
ssize_t
unix_send_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
size_t vecCount, ancillary_data_container *ancillaryData,
const struct sockaddr *address, socklen_t addressLength, int flags)
{
return ((UnixEndpoint*)_protocol)->Send(vecs, vecCount, ancillaryData,
address, addressLength, flags);
}
ssize_t
unix_read_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
size_t vecCount, ancillary_data_container **_ancillaryData,
struct sockaddr *_address, socklen_t *_addressLength, int flags)
{
return ((UnixEndpoint*)_protocol)->Receive(vecs, vecCount, _ancillaryData,
_address, _addressLength, flags);
}
status_t
init_unix()
{
new(&gAddressManager) UnixAddressManager;
status_t error = gAddressManager.Init();
if (error != B_OK)
return error;
error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_STREAM, 0,
"network/protocols/unix/v1", NULL);
if (error == B_OK) {
error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_DGRAM, 0,
"network/protocols/unix/v1", NULL);
}
if (error == B_OK) {
error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_SEQPACKET, 0,
"network/protocols/unix/v1", NULL);
}
if (error != B_OK) {
gAddressManager.~UnixAddressManager();
return error;
}
error = gStackModule->register_domain(AF_UNIX, "unix", &gUnixModule,
&gAddressModule, &sDomain);
if (error != B_OK) {
gAddressManager.~UnixAddressManager();
return error;
}
return B_OK;
}
status_t
uninit_unix()
{
gStackModule->unregister_domain(sDomain);
gAddressManager.~UnixAddressManager();
return B_OK;
}
static status_t
unix_std_ops(int32 op, ...)
{
switch (op) {
case B_MODULE_INIT:
return init_unix();
case B_MODULE_UNINIT:
return uninit_unix();
default:
return B_ERROR;
}
}
net_protocol_module_info gUnixModule = {
{
"network/protocols/unix/v1",
0,
unix_std_ops
},
0,
unix_init_protocol,
unix_uninit_protocol,
unix_open,
unix_close,
unix_free,
unix_connect,
unix_accept,
unix_control,
unix_getsockopt,
unix_setsockopt,
unix_bind,
unix_unbind,
unix_listen,
unix_shutdown,
unix_send_data,
unix_send_routed_data,
unix_send_avail,
unix_read_data,
unix_read_avail,
unix_get_domain,
unix_get_mtu,
unix_receive_data,
unix_deliver_data,
unix_error_received,
unix_error_reply,
unix_add_ancillary_data,
unix_process_ancillary_data,
NULL,
unix_send_data_no_buffer,
unix_read_data_no_buffer
};
module_dependency module_dependencies[] = {
{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
{}
};
module_info *modules[] = {
(module_info *)&gUnixModule,
NULL
};