* Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Axel Dörfler, axeld@pinc-software.de
*/
#include "BufferQueue.h"
#include <KernelExport.h>
#include <arpa/inet.h>
#ifdef TRACE_BUFFER_QUEUE
# define TRACE(x) dprintf x
#else
# define TRACE(x)
#endif
#if DEBUG_TCP_BUFFER_QUEUE
# define VERIFY() Verify();
#else
# define VERIFY() ;
#endif
BufferQueue::BufferQueue(size_t maxBytes)
:
fMaxBytes(maxBytes),
fNumBytes(0),
fContiguousBytes(0),
fFirstSequence(0),
fLastSequence(0),
fPushPointer(0)
{
}
BufferQueue::~BufferQueue()
{
net_buffer *buffer;
while ((buffer = fList.RemoveHead()) != NULL) {
gBufferModule->free(buffer);
}
}
void
BufferQueue::SetMaxBytes(size_t maxBytes)
{
fMaxBytes = maxBytes;
}
void
BufferQueue::SetInitialSequence(tcp_sequence sequence)
{
TRACE(("BufferQueue@%p::SetInitialSequence(%" B_PRIu32 ")\n", this,
sequence.Number()));
fFirstSequence = fLastSequence = sequence;
}
void
BufferQueue::Add(net_buffer *buffer)
{
Add(buffer, fLastSequence);
}
void
BufferQueue::Add(net_buffer *buffer, tcp_sequence sequence)
{
TRACE(("BufferQueue@%p::Add(buffer %p, size %" B_PRIu32 ", sequence %"
B_PRIu32 ")\n", this, buffer, buffer->size, sequence.Number()));
TRACE((" in: first: %" B_PRIu32 ", last: %" B_PRIu32 ", num: %lu, cont: "
"%lu\n", fFirstSequence.Number(), fLastSequence.Number(), fNumBytes,
fContiguousBytes));
VERIFY();
if (tcp_sequence(sequence + buffer->size) <= fFirstSequence
|| buffer->size == 0) {
gBufferModule->free(buffer);
return;
}
if (sequence < fFirstSequence) {
gBufferModule->remove_header(buffer,
(fFirstSequence - sequence).Number());
sequence = fFirstSequence;
}
if (fList.IsEmpty() || sequence >= fLastSequence) {
fList.Add(buffer);
buffer->sequence = sequence.Number();
if (sequence == fLastSequence
&& fLastSequence - fFirstSequence == fNumBytes) {
fContiguousBytes += buffer->size;
}
fLastSequence = sequence + buffer->size;
fNumBytes += buffer->size;
TRACE((" out0: first: %" B_PRIu32 ", last: %" B_PRIu32 ", num: %"
B_PRIuSIZE ", cont: %" B_PRIuSIZE "\n", fFirstSequence.Number(),
fLastSequence.Number(), fNumBytes, fContiguousBytes));
VERIFY();
return;
}
if (fLastSequence < sequence + buffer->size)
fLastSequence = sequence + buffer->size;
SegmentList::ReverseIterator iterator = fList.GetReverseIterator();
net_buffer *previous = NULL;
net_buffer *next = NULL;
while ((previous = iterator.Next()) != NULL) {
if (sequence >= previous->sequence) {
break;
}
next = previous;
}
if (previous != NULL) {
if (sequence == previous->sequence) {
if (previous->size >= buffer->size) {
gBufferModule->free(buffer);
buffer = NULL;
} else {
fList.Remove(previous);
fNumBytes -= previous->size;
gBufferModule->free(previous);
}
} else if (tcp_sequence(previous->sequence + previous->size)
>= sequence + buffer->size) {
gBufferModule->free(buffer);
buffer = NULL;
} else if (tcp_sequence(previous->sequence + previous->size)
> sequence) {
gBufferModule->remove_header(buffer,
(previous->sequence + previous->size - sequence).Number());
sequence = previous->sequence + previous->size;
}
}
ASSERT(next == NULL || buffer == NULL || next->sequence >= sequence);
while (buffer != NULL && next != NULL
&& tcp_sequence(sequence + buffer->size) > next->sequence) {
if (tcp_sequence(next->sequence + next->size)
<= sequence + buffer->size) {
net_buffer *remove = next;
next = (net_buffer *)next->link.next;
fList.Remove(remove);
fNumBytes -= remove->size;
gBufferModule->free(remove);
} else if (tcp_sequence(next->sequence) > sequence) {
gBufferModule->remove_trailer(buffer,
(sequence + buffer->size - next->sequence).Number());
} else {
gBufferModule->free(buffer);
buffer = NULL;
}
}
if (buffer == NULL) {
TRACE((" out1: first: %" B_PRIu32 ", last: %" B_PRIu32 ", num: %"
B_PRIuSIZE ", cont: %" B_PRIuSIZE "\n", fFirstSequence.Number(),
fLastSequence.Number(), fNumBytes, fContiguousBytes));
VERIFY();
return;
}
fList.InsertBefore(next, buffer);
buffer->sequence = sequence.Number();
fNumBytes += buffer->size;
if (fLastSequence - fFirstSequence == fNumBytes)
fContiguousBytes = fNumBytes;
else if (fFirstSequence + fContiguousBytes == sequence) {
do {
fContiguousBytes += buffer->size;
buffer = (struct net_buffer *)buffer->link.next;
} while (buffer != NULL
&& fFirstSequence + fContiguousBytes == buffer->sequence);
}
TRACE((" out2: first: %" B_PRIu32 ", last: %" B_PRIu32 ", num: %lu, cont: "
"%lu\n", fFirstSequence.Number(), fLastSequence.Number(), fNumBytes,
fContiguousBytes));
VERIFY();
}
NOTE: If there are missing segments in the buffers to be removed,
fContiguousBytes is not maintained correctly!
*/
status_t
BufferQueue::RemoveUntil(tcp_sequence sequence)
{
TRACE(("BufferQueue@%p::RemoveUntil(sequence %" B_PRIu32 ")\n", this,
sequence.Number()));
VERIFY();
if (sequence < fFirstSequence)
return B_OK;
SegmentList::Iterator iterator = fList.GetIterator();
tcp_sequence lastRemoved = fFirstSequence;
net_buffer *buffer = NULL;
while ((buffer = iterator.Next()) != NULL && buffer->sequence < sequence) {
ASSERT(lastRemoved == buffer->sequence);
if (sequence >= buffer->sequence + buffer->size) {
iterator.Remove();
fNumBytes -= buffer->size;
fContiguousBytes -= buffer->size;
lastRemoved = buffer->sequence + buffer->size;
gBufferModule->free(buffer);
} else {
size_t size = (sequence - buffer->sequence).Number();
gBufferModule->remove_header(buffer, size);
buffer->sequence += size;
fNumBytes -= size;
fContiguousBytes -= size;
break;
}
}
if (fList.IsEmpty())
fFirstSequence = fLastSequence;
else
fFirstSequence = fList.Head()->sequence;
VERIFY();
return B_OK;
}
*/
status_t
BufferQueue::Get(net_buffer *buffer, tcp_sequence sequence, size_t bytes)
{
TRACE(("BufferQueue@%p::Get(sequence %" B_PRIu32 ", bytes %lu)\n", this,
sequence.Number(), bytes));
VERIFY();
if (bytes == 0)
return B_OK;
if (sequence >= fLastSequence || sequence < fFirstSequence) {
return B_BAD_VALUE;
}
if (tcp_sequence(sequence + bytes) > fLastSequence)
bytes = (fLastSequence - sequence).Number();
size_t bytesLeft = bytes;
SegmentList::Iterator iterator = fList.GetIterator();
net_buffer *source = NULL;
while ((source = iterator.Next()) != NULL) {
if (sequence < source->sequence + source->size)
break;
}
if (source == NULL)
panic("we should have had that data...");
if (tcp_sequence(source->sequence) > sequence) {
panic("source %p, sequence = %" B_PRIu32 " (%" B_PRIu32 ")\n", source,
source->sequence, sequence.Number());
}
uint32 offset = (sequence - source->sequence).Number();
while (source != NULL && bytesLeft > 0) {
size_t size = min_c(source->size - offset, bytesLeft);
status_t status = gBufferModule->append_cloned(buffer, source, offset,
size);
if (status < B_OK)
return status;
bytesLeft -= size;
offset = 0;
source = iterator.Next();
}
VERIFY();
return B_OK;
}
buffer queue. If \a remove is \c true, the data is removed from the
queue, if not, the data is cloned from the queue.
*/
status_t
BufferQueue::Get(size_t bytes, bool remove, net_buffer **_buffer)
{
if (bytes > Available())
bytes = Available();
if (bytes == 0) {
*_buffer = NULL;
return B_OK;
}
net_buffer *buffer = fList.First();
size_t bytesLeft = bytes;
ASSERT(buffer != NULL);
if (!remove || buffer->size > bytes) {
buffer = gBufferModule->create(256);
if (buffer == NULL)
return B_NO_MEMORY;
} else {
bytesLeft -= buffer->size;
fFirstSequence += buffer->size;
fList.Remove(buffer);
}
SegmentList::Iterator iterator = fList.GetIterator();
net_buffer *source = NULL;
status_t status = B_OK;
while (bytesLeft > 0 && (source = iterator.Next()) != NULL) {
size_t size = min_c(source->size, bytesLeft);
status = gBufferModule->append_cloned(buffer, source, 0, size);
if (status < B_OK)
break;
bytesLeft -= size;
if (!remove)
continue;
fFirstSequence += size;
if (size == source->size) {
iterator.Remove();
gBufferModule->free(source);
} else {
gBufferModule->remove_header(source, size);
source->sequence += size;
}
}
if (remove && buffer->size) {
fNumBytes -= buffer->size;
fContiguousBytes -= buffer->size;
}
if (status < B_OK && buffer->size == 0) {
gBufferModule->free(buffer);
VERIFY();
return status;
}
*_buffer = buffer;
VERIFY();
return B_OK;
}
size_t
BufferQueue::Available(tcp_sequence sequence) const
{
if (sequence > (fFirstSequence + fContiguousBytes).Number())
return 0;
return (fContiguousBytes + fFirstSequence - sequence).Number();
}
void
BufferQueue::SetPushPointer()
{
if (fList.IsEmpty())
fPushPointer = 0;
else
fPushPointer = fList.Tail()->sequence + fList.Tail()->size;
}
int
BufferQueue::PopulateSackInfo(tcp_sequence sequence, int maxSackCount,
tcp_sack* sacks)
{
SegmentList::ReverseIterator iterator = fList.GetReverseIterator();
net_buffer* buffer = iterator.Next();
int sackCount = 0;
TRACE(("BufferQueue::PopulateSackInfo() %" B_PRIu32 "\n",
sequence.Number()));
while (buffer != NULL && buffer->sequence > sequence) {
if (buffer->sequence + buffer->size < sacks[sackCount].left_edge) {
if (sackCount + 1 == maxSackCount)
break;
++sackCount;
sacks[sackCount].left_edge = buffer->sequence;
sacks[sackCount].right_edge = buffer->sequence + buffer->size;
} else {
sacks[sackCount].left_edge = buffer->sequence;
if (sacks[sackCount].right_edge == 0)
sacks[sackCount].right_edge = buffer->sequence + buffer->size;
}
buffer = iterator.Next();
}
if (sacks[0].left_edge != 0) {
for (int i = 0; i <= sackCount; ++i) {
sacks[i].left_edge = htonl(sacks[i].left_edge);
sacks[i].right_edge = htonl(sacks[i].right_edge);
}
++sackCount;
}
return sackCount;
}
#if DEBUG_TCP_BUFFER_QUEUE
*/
void
BufferQueue::Verify() const
{
ASSERT(Available() == 0 || fList.First() != NULL);
if (fList.First() == NULL) {
ASSERT(fNumBytes == 0);
return;
}
SegmentList::ConstIterator iterator = fList.GetIterator();
size_t numBytes = 0;
size_t contiguousBytes = 0;
bool contiguous = true;
tcp_sequence last = fFirstSequence;
while (net_buffer* buffer = iterator.Next()) {
if (contiguous && buffer->sequence == last)
contiguousBytes += buffer->size;
else
contiguous = false;
ASSERT(last <= buffer->sequence);
ASSERT(buffer->size > 0);
numBytes += buffer->size;
last = buffer->sequence + buffer->size;
}
ASSERT(last == fLastSequence);
ASSERT(contiguousBytes == fContiguousBytes);
ASSERT(numBytes == fNumBytes);
}
void
BufferQueue::Dump() const
{
SegmentList::ConstIterator iterator = fList.GetIterator();
int32 number = 0;
while (net_buffer* buffer = iterator.Next()) {
kprintf(" %" B_PRId32 ". buffer %p, sequence %" B_PRIu32
", size %" B_PRIu32 "\n", ++number, buffer, buffer->sequence,
buffer->size);
}
}
#endif