⛏️ index : haiku.git

/*
 * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de. All rights reserved.
 * Distributed under the terms of the MIT License.
 */

#include <grp.h>

#include <errno.h>
#include <string.h>
#include <unistd.h>

#include <new>

#include <OS.h>

#include <errno_private.h>
#include <libroot_private.h>
#include <RegistrarDefs.h>
#include <user_group.h>

#include <util/KMessage.h>


using BPrivate::UserGroupLocker;
using BPrivate::relocate_pointer;


static KMessage sGroupDBReply;
static group** sGroupEntries = NULL;
static size_t sGroupEntryCount = 0;
static size_t sIterationIndex = 0;

static struct group sGroupBuffer;
static char sGroupStringBuffer[MAX_GROUP_BUFFER_SIZE];


static status_t
query_group_entry(const char* name, gid_t _gid, struct group *group,
	char *buffer, size_t bufferSize, struct group **_result)
{
	*_result = NULL;

	KMessage message(BPrivate::B_REG_GET_GROUP);
	if (name)
		message.AddString("name", name);
	else
		message.AddInt32("gid", _gid);

	KMessage reply;
	status_t error = BPrivate::send_authentication_request_to_registrar(message,
		reply);
	if (error != B_OK) {
		return error == ENOENT ? B_OK : error;
	}

	int32 gid;
	const char* password;

	if ((error = reply.FindInt32("gid", &gid)) != B_OK
		|| (error = reply.FindString("name", &name)) != B_OK
		|| (error = reply.FindString("password", &password)) != B_OK) {
		return error;
	}

	const char* members[MAX_GROUP_MEMBER_COUNT];
	int memberCount = 0;
	for (int32 index = 0; memberCount < MAX_GROUP_MEMBER_COUNT; index++) {
		if (reply.FindString("members", index, members + memberCount) != B_OK)
			break;
		memberCount++;
	}

	error = BPrivate::copy_group_to_buffer(name, password, gid, members,
		memberCount, group, buffer, bufferSize);
	if (error == B_OK)
		*_result = group;

	return error;
}


static status_t
init_group_db()
{
	if (sGroupEntries != NULL)
		return B_OK;

	// ask the registrar
	KMessage message(BPrivate::B_REG_GET_GROUP_DB);
	status_t error = BPrivate::send_authentication_request_to_registrar(message,
		sGroupDBReply);
	if (error != B_OK)
		return error;

	// unpack the reply
	int32 count;
	group** entries;
	int32 numBytes;
	if ((error = sGroupDBReply.FindInt32("count", &count)) != B_OK
		|| (error = sGroupDBReply.FindData("entries", B_RAW_TYPE,
				(const void**)&entries, &numBytes)) != B_OK) {
		return error;
	}

	// relocate the entries
	addr_t baseAddress = (addr_t)entries;
	for (int32 i = 0; i < count; i++) {
		group* entry = relocate_pointer(baseAddress, entries[i]);
		relocate_pointer(baseAddress, entry->gr_name);
		relocate_pointer(baseAddress, entry->gr_passwd);
		relocate_pointer(baseAddress, entry->gr_mem);
		int32 k = 0;
		for (; entry->gr_mem[k] != (void*)-1; k++)
			relocate_pointer(baseAddress, entry->gr_mem[k]);
		entry->gr_mem[k] = NULL;
	}

	sGroupEntries = entries;
	sGroupEntryCount = count;

	return B_OK;
}


// #pragma mark -


struct group*
getgrent(void)
{
	struct group* result = NULL;
	int status = getgrent_r(&sGroupBuffer, sGroupStringBuffer,
		sizeof(sGroupStringBuffer), &result);
	if (status != 0)
		__set_errno(status);
	return result;
}


int
getgrent_r(struct group* group, char* buffer, size_t bufferSize,
	struct group** _result)
{
	UserGroupLocker _;

	int status = B_NO_MEMORY;

	*_result = NULL;

	if ((status = init_group_db()) == B_OK) {
		if (sIterationIndex >= sGroupEntryCount)
			return ENOENT;

		status = BPrivate::copy_group_to_buffer(
			sGroupEntries[sIterationIndex], group, buffer, bufferSize);

		if (status == B_OK) {
			sIterationIndex++;
			*_result = group;
		}
	}

	return status;
}


void
setgrent(void)
{
	UserGroupLocker _;

	sIterationIndex = 0;
}


void
endgrent(void)
{
	UserGroupLocker locker;

	sGroupDBReply.Unset();
	sGroupEntries = NULL;
	sGroupEntryCount = 0;
	sIterationIndex = 0;
}


struct group *
getgrnam(const char *name)
{
	struct group* result = NULL;
	int status = getgrnam_r(name, &sGroupBuffer, sGroupStringBuffer,
		sizeof(sGroupStringBuffer), &result);
	if (status != 0)
		__set_errno(status);
	return result;
}


int
getgrnam_r(const char *name, struct group *group, char *buffer,
	size_t bufferSize, struct group **_result)
{
	return query_group_entry(name, 0, group, buffer, bufferSize, _result);
}


struct group *
getgrgid(gid_t gid)
{
	struct group* result = NULL;
	int status = getgrgid_r(gid, &sGroupBuffer, sGroupStringBuffer,
		sizeof(sGroupStringBuffer), &result);
	if (status != 0)
		__set_errno(status);
	return result;
}


int
getgrgid_r(gid_t gid, struct group *group, char *buffer,
	size_t bufferSize, struct group **_result)
{
	return query_group_entry(NULL, gid, group, buffer, bufferSize, _result);
}


int
getgrouplist(const char* user, gid_t baseGroup, gid_t* groupList,
	int* groupCount)
{
	int maxGroupCount = *groupCount;
	*groupCount = 0;

	status_t error = B_OK;

	// prepare request
	KMessage message(BPrivate::B_REG_GET_USER_GROUPS);
	if (message.AddString("name", user) != B_OK
		|| message.AddInt32("max count", maxGroupCount) != B_OK) {
		return -1;
	}

	// send request
	KMessage reply;
	error = BPrivate::send_authentication_request_to_registrar(message, reply);
	if (error != B_OK)
		return -1;

	// unpack reply
	int32 count;
	const int32* groups;
	int32 groupsSize;
	if (reply.FindInt32("count", &count) != B_OK
		|| reply.FindData("groups", B_INT32_TYPE, (const void**)&groups,
				&groupsSize) != B_OK) {
		return -1;
	}

	memcpy(groupList, groups, groupsSize);
	*groupCount = count;

	// add the base group
	if (*groupCount < maxGroupCount)
		groupList[*groupCount] = baseGroup;
	++*groupCount;

	return *groupCount <= maxGroupCount ? *groupCount : -1;
}


int
initgroups(const char* user, gid_t baseGroup)
{
	gid_t groups[NGROUPS_MAX + 1];
	int groupCount = NGROUPS_MAX + 1;
	if (getgrouplist(user, baseGroup, groups, &groupCount) < 0)
		return -1;

	return setgroups(groupCount, groups);
}