#include "AuthenticationServer.h"
#include <new>
#include <HashMap.h>
#include <HashString.h>
#include <util/KMessage.h>
#include "AuthenticationPanel.h"
#include "AuthenticationServerDefs.h"
#include "DebugSupport.h"
#include "TaskManager.h"
class AuthenticationServer::Authentication {
public:
Authentication()
: fUser(),
fPassword()
{
}
Authentication(const char* user, const char* password)
: fUser(user),
fPassword(password)
{
}
status_t SetTo(const char* user, const char* password)
{
if (fUser.SetTo(user) && fPassword.SetTo(password))
return B_OK;
return B_NO_MEMORY;
}
bool IsValid() const
{
return (fUser.GetLength() > 0);
}
const char* GetUser() const
{
return fUser.GetString();
}
const char* GetPassword() const
{
return fPassword.GetString();
}
private:
HashString fUser;
HashString fPassword;
};
class AuthenticationServer::ServerKey {
public:
ServerKey()
: fContext(),
fServer()
{
}
ServerKey(const char* context, const char* server)
: fContext(context),
fServer(server)
{
}
ServerKey(const ServerKey& other)
: fContext(other.fContext),
fServer(other.fServer)
{
}
uint32 GetHashCode() const
{
return fContext.GetHashCode() * 17 + fServer.GetHashCode();
}
ServerKey& operator=(const ServerKey& other)
{
fContext = other.fContext;
fServer = other.fServer;
return *this;
}
bool operator==(const ServerKey& other) const
{
return (fContext == other.fContext && fServer == other.fServer);
}
bool operator!=(const ServerKey& other) const
{
return !(*this == other);
}
private:
HashString fContext;
HashString fServer;
};
class AuthenticationServer::ServerEntry {
public:
ServerEntry()
: fDefaultAuthentication(),
fUseDefaultAuthentication(false)
{
}
~ServerEntry()
{
for (AuthenticationMap::Iterator it = fAuthentications.GetIterator();
it.HasNext();) {
delete it.Next().value;
}
}
void SetUseDefaultAuthentication(bool useDefaultAuthentication)
{
fUseDefaultAuthentication = useDefaultAuthentication;
}
bool UseDefaultAuthentication() const
{
return fUseDefaultAuthentication;
}
status_t SetDefaultAuthentication(const char* user, const char* password)
{
return fDefaultAuthentication.SetTo(user, password);
}
const Authentication& GetDefaultAuthentication() const
{
return fDefaultAuthentication;
}
status_t SetAuthentication(const char* share, const char* user,
const char* password)
{
Authentication* authentication = fAuthentications.Get(share);
if (authentication)
return authentication->SetTo(user, password);
authentication = new(std::nothrow) Authentication;
if (!authentication)
return B_NO_MEMORY;
status_t error = authentication->SetTo(user, password);
if (error == B_OK)
error = fAuthentications.Put(share, authentication);
if (error != B_OK)
delete authentication;
return error;
}
Authentication* GetAuthentication(const char* share) const
{
return fAuthentications.Get(share);
}
private:
typedef HashMap<HashString, Authentication*> AuthenticationMap;
Authentication fDefaultAuthentication;
bool fUseDefaultAuthentication;
AuthenticationMap fAuthentications;
};
struct AuthenticationServer::ServerEntryMap
: HashMap<ServerKey, ServerEntry*> {
};
class AuthenticationServer::UserDialogTask : public Task {
public:
UserDialogTask(AuthenticationServer* authenticationServer,
const char* context, const char* server, const char* share,
bool badPassword, port_id replyPort,
int32 replyToken)
: Task("user dialog task"),
fAuthenticationServer(authenticationServer),
fContext(context),
fServer(server),
fShare(share),
fBadPassword(badPassword),
fReplyPort(replyPort),
fReplyToken(replyToken),
fPanel(NULL)
{
}
virtual status_t Execute()
{
char user[B_OS_NAME_LENGTH];
char password[B_OS_NAME_LENGTH];
bool keep = true;
fPanel = new(std::nothrow) AuthenticationPanel();
status_t error = (fPanel ? B_OK : B_NO_MEMORY);
bool cancelled = false;
HashString defaultUser;
HashString defaultPassword;
fAuthenticationServer->_GetAuthentication(fContext.GetString(),
fServer.GetString(), NULL, &defaultUser, &defaultPassword);
if (error == B_OK) {
cancelled = fPanel->GetAuthentication(fServer.GetString(),
fShare.GetString(), defaultUser.GetString(),
defaultPassword.GetString(), false, fBadPassword, user,
password, &keep);
}
fPanel = NULL;
if (error != B_OK) {
fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
error, true, NULL, NULL);
} else if (cancelled) {
fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
B_OK, true, NULL, NULL);
} else {
fAuthenticationServer->_AddAuthentication(fContext.GetString(),
fServer.GetString(), fShare.GetString(), user, password,
keep);
fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
B_OK, false, user, password);
}
return error;
}
virtual void Stop()
{
if (fPanel)
fPanel->Cancel();
}
private:
AuthenticationServer* fAuthenticationServer;
HashString fContext;
HashString fServer;
HashString fShare;
bool fBadPassword;
port_id fReplyPort;
int32 fReplyToken;
AuthenticationPanel* fPanel;
};
AuthenticationServer::AuthenticationServer()
:
BApplication("application/x-vnd.haiku-authentication_server"),
fLock(),
fRequestPort(-1),
fRequestThread(-1),
fServerEntries(NULL),
fTerminating(false)
{
}
AuthenticationServer::~AuthenticationServer()
{
fTerminating = true;
if (fRequestPort >= 0)
delete_port(fRequestPort);
if (fRequestThread >= 0) {
int32 result;
wait_for_thread(fRequestPort, &result);
}
for (ServerEntryMap::Iterator it = fServerEntries->GetIterator();
it.HasNext();) {
delete it.Next().value;
}
}
status_t
AuthenticationServer::Init()
{
fServerEntries = new(std::nothrow) ServerEntryMap;
if (!fServerEntries)
return B_NO_MEMORY;
status_t error = fServerEntries->InitCheck();
if (error != B_OK)
return error;
fRequestPort = create_port(10, kAuthenticationServerPortName);
if (fRequestPort < 0)
return fRequestPort;
fRequestThread = spawn_thread(&_RequestThreadEntry, "request thread",
B_NORMAL_PRIORITY, this);
if (fRequestThread < 0)
return fRequestThread;
resume_thread(fRequestThread);
return B_OK;
}
int32
AuthenticationServer::_RequestThreadEntry(void* data)
{
return ((AuthenticationServer*)data)->_RequestThread();
}
int32
AuthenticationServer::_RequestThread()
{
TaskManager taskManager;
while (!fTerminating) {
taskManager.RemoveDoneTasks();
KMessage request;
status_t error = request.ReceiveFrom(fRequestPort);
if (error != B_OK)
continue;
const char* context = NULL;
const char* server = NULL;
const char* share = NULL;
bool badPassword = true;
request.FindString("context", &context);
request.FindString("server", &server);
request.FindString("share", &share);
request.FindBool("badPassword", &badPassword);
if (!context || !server || !share)
continue;
HashString foundUser;
HashString foundPassword;
if (!badPassword && _GetAuthentication(context, server, share,
&foundUser, &foundPassword)) {
_SendRequestReply(request.ReplyPort(), request.ReplyToken(),
error, false, foundUser.GetString(), foundPassword.GetString());
} else {
UserDialogTask* task = new(std::nothrow) UserDialogTask(this,
context, server, share, badPassword, request.ReplyPort(),
request.ReplyToken());
if (!task) {
ERROR("AuthenticationServer::_RequestThread(): ERROR: "
"failed to allocate ");
continue;
}
status_t error = taskManager.RunTask(task);
if (error != B_OK) {
ERROR("AuthenticationServer::_RequestThread(): Failed to "
"start server info task: %s\n", strerror(error));
continue;
}
}
}
return 0;
}
If share is NULL, the default authentication for the server is returned.
*/
bool
AuthenticationServer::_GetAuthentication(const char* context,
const char* server, const char* share, HashString* user,
HashString* password)
{
if (!context || !server || !user || !password)
return B_BAD_VALUE;
AutoLocker<BLocker> _(fLock);
ServerKey key(context, server);
ServerEntry* serverEntry = fServerEntries->Get(key);
if (!serverEntry)
return false;
const Authentication* authentication = NULL;
if (share) {
serverEntry->GetAuthentication(share);
if (!authentication && serverEntry->UseDefaultAuthentication())
authentication = &serverEntry->GetDefaultAuthentication();
} else
authentication = &serverEntry->GetDefaultAuthentication();
if (!authentication || !authentication->IsValid())
return false;
return (user->SetTo(authentication->GetUser())
&& password->SetTo(authentication->GetPassword()));
}
status_t
AuthenticationServer::_AddAuthentication(const char* context,
const char* server, const char* share, const char* user,
const char* password, bool makeDefault)
{
AutoLocker<BLocker> _(fLock);
ServerKey key(context, server);
ServerEntry* serverEntry = fServerEntries->Get(key);
if (!serverEntry) {
serverEntry = new(std::nothrow) ServerEntry;
if (!serverEntry)
return B_NO_MEMORY;
status_t error = fServerEntries->Put(key, serverEntry);
if (error != B_OK) {
delete serverEntry;
return error;
}
}
status_t error = serverEntry->SetAuthentication(share, user, password);
if (error == B_OK) {
if (makeDefault || !serverEntry->UseDefaultAuthentication())
serverEntry->SetDefaultAuthentication(user, password);
if (makeDefault)
serverEntry->SetUseDefaultAuthentication(true);
}
return error;
}
status_t
AuthenticationServer::_SendRequestReply(port_id port, int32 token,
status_t error, bool cancelled, const char* user, const char* password)
{
KMessage reply;
reply.AddInt32("error", error);
if (error == B_OK) {
reply.AddBool("cancelled", cancelled);
if (!cancelled) {
reply.AddString("user", user);
reply.AddString("password", password);
}
}
return reply.SendTo(port, token);
}
int
main()
{
AuthenticationServer app;
status_t error = app.Init();
if (error != B_OK)
return 1;
app.Run();
return 0;
}