⛏️ index : haiku.git

/*
 * Copyright 2022 Haiku Inc. All rights reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *		Niels Sascha Reedijk, niels.reedijk@gmail.com
 */

#include <HttpRequest.h>

#include <algorithm>
#include <ctype.h>
#include <sstream>
#include <utility>

#include <DataIO.h>
#include <HttpFields.h>
#include <MimeType.h>
#include <NetServicesDefs.h>
#include <Url.h>

#include "HttpBuffer.h"
#include "HttpPrivate.h"

using namespace std::literals;
using namespace BPrivate::Network;


// #pragma mark -- BHttpMethod::InvalidMethod


BHttpMethod::InvalidMethod::InvalidMethod(const char* origin, BString input)
	:
	BError(origin),
	input(std::move(input))
{
}


const char*
BHttpMethod::InvalidMethod::Message() const noexcept
{
	if (input.IsEmpty())
		return "The HTTP method cannot be empty";
	else
		return "Unsupported characters in the HTTP method";
}


BString
BHttpMethod::InvalidMethod::DebugMessage() const
{
	BString output = BError::DebugMessage();
	if (!input.IsEmpty())
		output << ":\t " << input << "\n";
	return output;
}


// #pragma mark -- BHttpMethod


BHttpMethod::BHttpMethod(Verb verb) noexcept
	:
	fMethod(verb)
{
}


BHttpMethod::BHttpMethod(const std::string_view& verb)
	:
	fMethod(BString(verb.data(), verb.length()))
{
	if (verb.size() == 0 || !validate_http_token_string(verb))
		throw BHttpMethod::InvalidMethod(
			__PRETTY_FUNCTION__, std::move(std::get<BString>(fMethod)));
}


BHttpMethod::BHttpMethod(const BHttpMethod& other) = default;


BHttpMethod::BHttpMethod(BHttpMethod&& other) noexcept
	:
	fMethod(std::move(other.fMethod))
{
	other.fMethod = Get;
}


BHttpMethod::~BHttpMethod() = default;


BHttpMethod& BHttpMethod::operator=(const BHttpMethod& other) = default;


BHttpMethod&
BHttpMethod::operator=(BHttpMethod&& other) noexcept
{
	fMethod = std::move(other.fMethod);
	other.fMethod = Get;
	return *this;
}


bool
BHttpMethod::operator==(const BHttpMethod::Verb& other) const noexcept
{
	if (std::holds_alternative<Verb>(fMethod)) {
		return std::get<Verb>(fMethod) == other;
	} else {
		BHttpMethod otherMethod(other);
		auto otherMethodSv = otherMethod.Method();
		return std::get<BString>(fMethod).Compare(otherMethodSv.data(), otherMethodSv.size()) == 0;
	}
}


bool
BHttpMethod::operator!=(const BHttpMethod::Verb& other) const noexcept
{
	return !operator==(other);
}


const std::string_view
BHttpMethod::Method() const noexcept
{
	if (std::holds_alternative<Verb>(fMethod)) {
		switch (std::get<Verb>(fMethod)) {
			case Get:
				return "GET"sv;
			case Head:
				return "HEAD"sv;
			case Post:
				return "POST"sv;
			case Put:
				return "PUT"sv;
			case Delete:
				return "DELETE"sv;
			case Connect:
				return "CONNECT"sv;
			case Options:
				return "OPTIONS"sv;
			case Trace:
				return "TRACE"sv;
			default:
				// should never be reached
				std::abort();
		}
	} else {
		const auto& methodString = std::get<BString>(fMethod);
		// the following constructor is not noexcept, but we know we pass in valid data
		return std::string_view(methodString.String());
	}
}


// #pragma mark -- BHttpRequest::Data
static const BUrl kDefaultUrl = BUrl();
static const BHttpMethod kDefaultMethod = BHttpMethod::Get;
static const BHttpFields kDefaultOptionalFields = BHttpFields();

struct BHttpRequest::Data {
	BUrl url = kDefaultUrl;
	BHttpMethod method = kDefaultMethod;
	uint8 maxRedirections = 8;
	BHttpFields optionalFields;
	std::optional<BHttpAuthentication> authentication;
	bool stopOnError = false;
	bigtime_t timeout = B_INFINITE_TIMEOUT;
	std::optional<Body> requestBody;
};


// #pragma mark -- BHttpRequest helper functions


/*!
	\brief Build basic authentication header
*/
static inline BString
build_basic_http_header(const BString& username, const BString& password)
{
	BString basicEncode, result;
	basicEncode << username << ":" << password;
	result << "Basic " << encode_to_base64(basicEncode);
	return result;
}


// #pragma mark -- BHttpRequest


BHttpRequest::BHttpRequest()
	:
	fData(std::make_unique<Data>())
{
}


BHttpRequest::BHttpRequest(const BUrl& url)
	:
	fData(std::make_unique<Data>())
{
	SetUrl(url);
}


BHttpRequest::BHttpRequest(BHttpRequest&& other) noexcept = default;


BHttpRequest::~BHttpRequest() = default;


BHttpRequest& BHttpRequest::operator=(BHttpRequest&&) noexcept = default;


bool
BHttpRequest::IsEmpty() const noexcept
{
	return (!fData || !fData->url.IsValid());
}


const BHttpAuthentication*
BHttpRequest::Authentication() const noexcept
{
	if (fData && fData->authentication)
		return std::addressof(*fData->authentication);
	return nullptr;
}


const BHttpFields&
BHttpRequest::Fields() const noexcept
{
	if (!fData)
		return kDefaultOptionalFields;
	return fData->optionalFields;
}


uint8
BHttpRequest::MaxRedirections() const noexcept
{
	if (!fData)
		return 8;
	return fData->maxRedirections;
}


const BHttpMethod&
BHttpRequest::Method() const noexcept
{
	if (!fData)
		return kDefaultMethod;
	return fData->method;
}


const BHttpRequest::Body*
BHttpRequest::RequestBody() const noexcept
{
	if (fData && fData->requestBody)
		return std::addressof(*fData->requestBody);
	return nullptr;
}


bool
BHttpRequest::StopOnError() const noexcept
{
	if (!fData)
		return false;
	return fData->stopOnError;
}


bigtime_t
BHttpRequest::Timeout() const noexcept
{
	if (!fData)
		return B_INFINITE_TIMEOUT;
	return fData->timeout;
}


const BUrl&
BHttpRequest::Url() const noexcept
{
	if (!fData)
		return kDefaultUrl;
	return fData->url;
}


void
BHttpRequest::SetAuthentication(const BHttpAuthentication& authentication)
{
	if (!fData)
		fData = std::make_unique<Data>();

	fData->authentication = authentication;
}


static constexpr std::array<std::string_view, 6> fReservedOptionalFieldNames
	= {"Host"sv, "Accept-Encoding"sv, "Connection"sv, "Content-Type"sv, "Content-Length"sv};


void
BHttpRequest::SetFields(const BHttpFields& fields)
{
	if (!fData)
		fData = std::make_unique<Data>();

	for (auto& field: fields) {
		if (std::find(fReservedOptionalFieldNames.begin(), fReservedOptionalFieldNames.end(),
				field.Name())
			!= fReservedOptionalFieldNames.end()) {
			std::string_view fieldName = field.Name();
			throw BHttpFields::InvalidInput(
				__PRETTY_FUNCTION__, BString(fieldName.data(), fieldName.size()));
		}
	}
	fData->optionalFields = fields;
}


void
BHttpRequest::SetMaxRedirections(uint8 maxRedirections)
{
	if (!fData)
		fData = std::make_unique<Data>();
	fData->maxRedirections = maxRedirections;
}


void
BHttpRequest::SetMethod(const BHttpMethod& method)
{
	if (!fData)
		fData = std::make_unique<Data>();
	fData->method = method;
}


void
BHttpRequest::SetRequestBody(
	std::unique_ptr<BDataIO> input, BString mimeType, std::optional<off_t> size)
{
	if (input == nullptr)
		throw std::invalid_argument("input cannot be null");

	// TODO: support optional mimetype arguments like type/subtype;parameter=value
	if (!BMimeType::IsValid(mimeType.String()))
		throw std::invalid_argument("mimeType must be a valid mimetype");

	// TODO: review if there should be complex validation between the method and whether or not
	// there is a request body. The current implementation does the validation at the request
	// generation stage, where GET, HEAD, OPTIONS, CONNECT and TRACE will not submit a body.

	if (!fData)
		fData = std::make_unique<Data>();
	fData->requestBody = {std::move(input), std::move(mimeType), size};

	// Check if the input is a BPositionIO, and if so, store the current position, so that it can
	// be rewinded in case of a redirect.
	auto inputPositionIO = dynamic_cast<BPositionIO*>(fData->requestBody->input.get());
	if (inputPositionIO != nullptr)
		fData->requestBody->startPosition = inputPositionIO->Position();
}


void
BHttpRequest::SetStopOnError(bool stopOnError)
{
	if (!fData)
		fData = std::make_unique<Data>();
	fData->stopOnError = stopOnError;
}


void
BHttpRequest::SetTimeout(bigtime_t timeout)
{
	if (!fData)
		fData = std::make_unique<Data>();
	fData->timeout = timeout;
}


void
BHttpRequest::SetUrl(const BUrl& url)
{
	if (!fData)
		fData = std::make_unique<Data>();

	if (!url.IsValid())
		throw BInvalidUrl(__PRETTY_FUNCTION__, BUrl(url));
	if (url.Protocol() != "http" && url.Protocol() != "https") {
		// TODO: optimize BStringList with modern language features
		BStringList list;
		list.Add("http");
		list.Add("https");
		throw BUnsupportedProtocol(__PRETTY_FUNCTION__, BUrl(url), list);
	}
	fData->url = url;
}


void
BHttpRequest::ClearAuthentication() noexcept
{
	if (fData)
		fData->authentication = std::nullopt;
}


std::unique_ptr<BDataIO>
BHttpRequest::ClearRequestBody() noexcept
{
	if (fData && fData->requestBody) {
		auto body = std::move(fData->requestBody->input);
		fData->requestBody = std::nullopt;
		return body;
	}
	return nullptr;
}


BString
BHttpRequest::HeaderToString() const
{
	HttpBuffer buffer;
	SerializeHeaderTo(buffer);

	return BString(static_cast<const char*>(buffer.Data().data()), buffer.RemainingBytes());
}


/*!
	\brief Private method used by BHttpSession::Request to rewind the content in case of redirect

	\retval true Content was rewinded successfully. Also the case if there is no content
	\retval false Cannot/could not rewind content.
*/
bool
BHttpRequest::RewindBody() noexcept
{
	if (fData && fData->requestBody && fData->requestBody->startPosition) {
		auto inputData = dynamic_cast<BPositionIO*>(fData->requestBody->input.get());
		return *fData->requestBody->startPosition
			== inputData->Seek(*fData->requestBody->startPosition, SEEK_SET);
	}
	return true;
}


/*!
	\brief Private method used by HttpSerializer::SetTo() to serialize the header data into a
		buffer.
*/
void
BHttpRequest::SerializeHeaderTo(HttpBuffer& buffer) const
{
	// Method & URL
	//	TODO: proxy
	buffer << fData->method.Method() << " "sv;
	if (fData->url.HasPath() && fData->url.Path().Length() > 0)
		buffer << std::string_view(fData->url.Path().String());
	else
		buffer << "/"sv;

	if (fData->url.HasRequest())
		buffer << "?"sv << Url().Request().String();

	// TODO: switch between HTTP 1.0 and 1.1 based on configuration
	buffer << " HTTP/1.1\r\n"sv;

	BHttpFields outputFields;
	if (true /* http == 1.1 */) {
		BString host = fData->url.Host();
		int defaultPort = fData->url.Protocol() == "http" ? 80 : 443;
		if (fData->url.HasPort() && fData->url.Port() != defaultPort)
			host << ':' << fData->url.Port();

		outputFields.AddFields({
			{"Host"sv, std::string_view(host.String())}, {"Accept-Encoding"sv, "gzip"sv},
			// Allows the server to compress data using the "gzip" format.
			// "deflate" is not supported, because there are two interpretations
			// of what it means (the RFC and Microsoft products), and we don't
			// want to handle this. Very few websites support only deflate,
			// and most of them will send gzip, or at worst, uncompressed data.
			{"Connection"sv, "close"sv}
			// Let the remote server close the connection after response since
			// we don't handle multiple request on a single connection
		});
	}

	if (fData->authentication) {
		// This request will add a Basic authorization header
		BString authorization = build_basic_http_header(
			fData->authentication->username, fData->authentication->password);
		outputFields.AddField("Authorization"sv, std::string_view(authorization.String()));
	}

	if (fData->requestBody) {
		outputFields.AddField(
			"Content-Type"sv, std::string_view(fData->requestBody->mimeType.String()));
		if (fData->requestBody->size)
			outputFields.AddField("Content-Length"sv, std::to_string(*fData->requestBody->size));
		else
			throw BRuntimeError(__PRETTY_FUNCTION__,
				"Transfer body with unknown content length; chunked transfer not supported");
	}

	for (const auto& field: outputFields)
		buffer << field.RawField() << "\r\n"sv;

	for (const auto& field: fData->optionalFields)
		buffer << field.RawField() << "\r\n"sv;

	buffer << "\r\n"sv;
}