#include <new>
#include <OS.h>
#include "Channel.h"
#include "Connection.h"
#include "RequestChannel.h"
#include "RequestConnection.h"
#include "RequestHandler.h"
class RequestConnection::DownStreamThread {
public:
DownStreamThread()
: fThread(-1),
fConnection(NULL),
fChannel(NULL),
fHandler(NULL),
fTerminating(false)
{
}
~DownStreamThread()
{
Terminate();
delete fChannel;
}
status_t Init(RequestConnection* connection, Channel* channel,
RequestHandler* handler)
{
if (!connection || !channel || !handler)
return B_BAD_VALUE;
fConnection = connection;
fHandler = handler;
fChannel = new(std::nothrow) RequestChannel(channel);
if (!fChannel)
return B_NO_MEMORY;
fThread = spawn_thread(_LoopEntry, "down stream thread",
B_NORMAL_PRIORITY, this);
if (fThread < 0)
return fThread;
return B_OK;
}
void Run()
{
resume_thread(fThread);
}
void Terminate()
{
fTerminating = true;
if (fThread > 0 && find_thread(NULL) != fThread) {
int32 result;
wait_for_thread(fThread, &result);
}
}
RequestChannel* GetRequestChannel()
{
return fChannel;
}
private:
static int32 _LoopEntry(void* data)
{
return ((DownStreamThread*)data)->_Loop();
}
int32 _Loop()
{
while (!fTerminating) {
Request* request;
status_t error = fChannel->ReceiveRequest(&request);
if (error == B_OK) {
error = fHandler->HandleRequest(request, fChannel);
delete request;
}
if (error != B_OK)
fTerminating = fConnection->DownStreamChannelError(this, error);
}
return 0;
}
private:
thread_id fThread;
RequestConnection* fConnection;
RequestChannel* fChannel;
RequestHandler* fHandler;
volatile bool fTerminating;
};
RequestConnection::RequestConnection(Connection* connection,
RequestHandler* requestHandler, bool ownsRequestHandler)
: fConnection(connection),
fRequestHandler(requestHandler),
fOwnsRequestHandler(ownsRequestHandler),
fThreads(NULL),
fThreadCount(0),
fTerminationCount(0)
{
}
RequestConnection::~RequestConnection()
{
Close();
delete[] fThreads;
delete fConnection;
if (fOwnsRequestHandler)
delete fRequestHandler;
}
status_t
RequestConnection::Init()
{
if (!fConnection || !fRequestHandler)
return B_BAD_VALUE;
if (fConnection->CountDownStreamChannels() < 1)
return B_ERROR;
fThreadCount = fConnection->CountDownStreamChannels();
fThreads = new(std::nothrow) DownStreamThread[fThreadCount];
if (!fThreads)
return B_NO_MEMORY;
for (int32 i = 0; i < fThreadCount; i++) {
status_t error = fThreads[i].Init(this,
fConnection->DownStreamChannelAt(i), fRequestHandler);
if (error != B_OK)
return error;
}
for (int32 i = 0; i < fThreadCount; i++)
fThreads[i].Run();
return B_OK;
}
void
RequestConnection::Close()
{
atomic_add(&fTerminationCount, 1);
if (fConnection)
fConnection->Close();
if (fThreads) {
for (int32 i = 0; i < fThreadCount; i++)
fThreads[i].Terminate();
}
}
status_t
RequestConnection::SendRequest(Request* request, Request** reply)
{
return _SendRequest(request, reply, NULL);
}
status_t
RequestConnection::SendRequest(Request* request, RequestHandler* replyHandler)
{
if (!replyHandler)
return B_BAD_VALUE;
return _SendRequest(request, NULL, replyHandler);
}
bool
RequestConnection::DownStreamChannelError(DownStreamThread* thread,
status_t error)
{
if (atomic_add(&fTerminationCount, 1) == 0 && fRequestHandler) {
ConnectionBrokenRequest request;
request.error = error;
fRequestHandler->HandleRequest(&request, thread->GetRequestChannel());
}
return true;
}
status_t
RequestConnection::_SendRequest(Request* request, Request** _reply,
RequestHandler* replyHandler)
{
if (!request)
return B_BAD_VALUE;
Channel* channel = NULL;
status_t error = fConnection->GetUpStreamChannel(&channel);
if (error != B_OK)
return error;
RequestChannel requestChannel(channel);
error = requestChannel.SendRequest(request);
Request* reply = NULL;
if (error == B_OK && (_reply || replyHandler)) {
error = requestChannel.ReceiveRequest(&reply);
if (error == B_OK) {
if (replyHandler)
error = replyHandler->HandleRequest(reply, &requestChannel);
if (error == B_OK && _reply)
*_reply = reply;
else
delete reply;
}
}
if (fConnection->PutUpStreamChannel(channel) != B_OK) {
delete channel;
}
return error;
}