#include "RequestChannel.h"
#include <new>
#include <typeinfo>
#include <stdlib.h>
#include <string.h>
#include <AutoDeleter.h>
#include <ByteOrder.h>
#include "Channel.h"
#include "Compatibility.h"
#include "DebugSupport.h"
#include "Request.h"
#include "RequestDumper.h"
#include "RequestFactory.h"
#include "RequestFlattener.h"
#include "RequestUnflattener.h"
static const int32 kMaxSaneRequestSize = 128 * 1024;
static const int32 kDefaultBufferSize = 4096;
class RequestChannel::ChannelWriter : public Writer {
public:
ChannelWriter(Channel* channel, void* buffer, int32 bufferSize)
: fChannel(channel),
fBuffer(buffer),
fBufferSize(bufferSize),
fBytesWritten(0)
{
}
virtual status_t Write(const void* buffer, int32 size)
{
status_t error = B_OK;
if (fBytesWritten + size > fBufferSize) {
error = Flush();
if (error != B_OK)
return error;
}
if (size > fBufferSize) {
error = fChannel->Send(buffer, size);
if (error != B_OK)
return error;
} else {
memcpy((uint8*)fBuffer + fBytesWritten, buffer, size);
fBytesWritten += size;
}
return error;
}
status_t Flush()
{
if (fBytesWritten == 0)
return B_OK;
status_t error = fChannel->Send(fBuffer, fBytesWritten);
if (error != B_OK)
return error;
fBytesWritten = 0;
return B_OK;
}
private:
Channel* fChannel;
void* fBuffer;
int32 fBufferSize;
int32 fBytesWritten;
};
class RequestChannel::MemoryReader : public Reader {
public:
MemoryReader(void* buffer, int32 bufferSize)
: Reader(),
fBuffer(buffer),
fBufferSize(bufferSize),
fBytesRead(0)
{
}
virtual status_t Read(void* buffer, int32 size)
{
if (!buffer || size < 0)
return B_BAD_VALUE;
if (size == 0)
return B_OK;
void* localBuffer;
bool mustFree;
status_t error = Read(size, &localBuffer, &mustFree);
if (error != B_OK)
return error;
memcpy(buffer, localBuffer, size);
return B_OK;
}
virtual status_t Read(int32 size, void** buffer, bool* mustFree)
{
if (size < 0 || !buffer || !mustFree)
return B_BAD_VALUE;
if (fBytesRead + size > fBufferSize)
return B_BAD_VALUE;
*buffer = (uint8*)fBuffer + fBytesRead;
*mustFree = false;
fBytesRead += size;
return B_OK;
}
bool AllBytesRead() const
{
return (fBytesRead == fBufferSize);
}
private:
void* fBuffer;
int32 fBufferSize;
int32 fBytesRead;
};
struct RequestChannel::RequestHeader {
uint32 type;
int32 size;
};
RequestChannel::RequestChannel(Channel* channel)
: fChannel(channel),
fBuffer(NULL),
fBufferSize(0)
{
fBuffer = malloc(kDefaultBufferSize);
if (fBuffer)
fBufferSize = kDefaultBufferSize;
}
RequestChannel::~RequestChannel()
{
free(fBuffer);
}
status_t
RequestChannel::SendRequest(Request* request)
{
if (!request)
RETURN_ERROR(B_BAD_VALUE);
PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
int32 size;
status_t error = _GetRequestSize(request, &size);
if (error != B_OK)
RETURN_ERROR(error);
if (size < 0 || size > kMaxSaneRequestSize) {
ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: "
"%" B_PRId32 "\n", size);
RETURN_ERROR(B_BAD_DATA);
}
RequestHeader header;
header.type = B_HOST_TO_BENDIAN_INT32(request->GetType());
header.size = B_HOST_TO_BENDIAN_INT32(size);
ChannelWriter writer(fChannel, fBuffer, fBufferSize);
error = writer.Write(&header, sizeof(RequestHeader));
if (error != B_OK)
RETURN_ERROR(error);
RequestFlattener flattener(&writer);
request->Flatten(&flattener);
error = flattener.GetStatus();
if (error != B_OK)
RETURN_ERROR(error);
error = writer.Flush();
RETURN_ERROR(error);
}
status_t
RequestChannel::ReceiveRequest(Request** _request)
{
if (!_request)
RETURN_ERROR(B_BAD_VALUE);
RequestHeader header;
status_t error = fChannel->Receive(&header, sizeof(RequestHeader));
if (error != B_OK)
RETURN_ERROR(error);
header.type = B_HOST_TO_BENDIAN_INT32(header.type);
header.size = B_HOST_TO_BENDIAN_INT32(header.size);
if (header.size < 0 || header.size > kMaxSaneRequestSize) {
ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: "
"%" B_PRId32 "\n", header.size);
RETURN_ERROR(B_BAD_DATA);
}
Request* request;
error = RequestFactory::CreateRequest(header.type, &request);
if (error != B_OK)
RETURN_ERROR(error);
ObjectDeleter<Request> requestDeleter(request);
if (header.size > 0) {
RequestBuffer* requestBuffer = RequestBuffer::Create(header.size);
if (!requestBuffer)
RETURN_ERROR(B_NO_MEMORY);
request->AttachBuffer(requestBuffer);
error = fChannel->Receive(requestBuffer->GetData(), header.size);
if (error != B_OK)
RETURN_ERROR(error);
MemoryReader reader(requestBuffer->GetData(), header.size);
RequestUnflattener unflattener(&reader);
request->Unflatten(&unflattener);
error = unflattener.GetStatus();
if (error != B_OK)
RETURN_ERROR(error);
if (!reader.AllBytesRead())
RETURN_ERROR(B_BAD_DATA);
}
requestDeleter.Detach();
*_request = request;
PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
return B_OK;
}
status_t
RequestChannel::_GetRequestSize(Request* request, int32* size)
{
DummyWriter dummyWriter;
RequestFlattener flattener(&dummyWriter);
request->ShowAround(&flattener);
status_t error = flattener.GetStatus();
if (error != B_OK)
return error;
*size = flattener.GetBytesWritten();
return error;
}