Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add policies #15

Merged
merged 8 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add skeleton stream policy
  • Loading branch information
KredeGC committed Dec 14, 2023
commit 19f92572bb7f79c67c270dab8dfbd7fa93e21c1d
128 changes: 51 additions & 77 deletions include/bitstream/stream/bit_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "byte_buffer.h"
#include "serialize_traits.h"
#include "stream_traits.h"

#include <cstdint>
#include <cstring>
Expand All @@ -25,55 +26,26 @@ namespace bitstream
static constexpr bool reading = false;

/**
* @brief Default construct a writer pointing to a null buffer
* @brief Construct a writer with the parameters passed to the underlying policy
* @param ...args The arguments to pass to the policy
*/
bit_writer() noexcept :
m_Buffer(nullptr),
m_NumBitsWritten(0),
m_TotalBits(0),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}

/**
* @brief Construct a writer pointing to the given byte array with @p num_bytes size
* @param bytes The byte array to write to. Must be 4-byte aligned and the size must be a multiple of 4
* @param num_bytes The number of bytes in the array
*/
explicit bit_writer(void* bytes, uint32_t num_bytes) noexcept :
m_Buffer(static_cast<uint32_t*>(bytes)),
m_NumBitsWritten(0),
m_TotalBits(num_bytes * 8),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}

/**
* @brief Construct a writer pointing to the given @p buffer
* @param buffer The buffer to write to
*/
template<size_t Size>
explicit bit_writer(byte_buffer<Size>& buffer) noexcept :
m_Buffer(reinterpret_cast<uint32_t*>(buffer.Bytes)),
m_NumBitsWritten(0),
m_TotalBits(Size * 8),
template<typename... Ts,
typename = std::enable_if_t<std::is_constructible_v<fixed_policy, Ts...>>>
bit_writer(Ts&&... args)
noexcept(std::is_nothrow_constructible_v<fixed_policy, Ts...>) :
m_Policy(std::forward<Ts>(args) ...),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}

bit_writer(const bit_writer&) = delete;

bit_writer(bit_writer&& other) noexcept :
m_Buffer(other.m_Buffer),
m_NumBitsWritten(other.m_NumBitsWritten),
m_TotalBits(other.m_TotalBits),
m_Policy(std::move(other.m_Policy)),
m_Scratch(other.m_Scratch),
m_ScratchBits(other.m_ScratchBits),
m_WordIndex(other.m_WordIndex)
{
other.m_Buffer = nullptr;
other.m_NumBitsWritten = 0;
other.m_TotalBits = 0;
other.m_Scratch = 0;
other.m_ScratchBits = 0;
other.m_WordIndex = 0;
Expand All @@ -83,16 +55,11 @@ namespace bitstream

bit_writer& operator=(bit_writer&& rhs) noexcept
{
m_Buffer = rhs.m_Buffer;
m_NumBitsWritten = rhs.m_NumBitsWritten;
m_TotalBits = rhs.m_TotalBits;
m_Policy = std::move(rhs.m_Policy);
m_Scratch = rhs.m_Scratch;
m_ScratchBits = rhs.m_ScratchBits;
m_WordIndex = rhs.m_WordIndex;

rhs.m_Buffer = nullptr;
rhs.m_NumBitsWritten = 0;
rhs.m_TotalBits = 0;
rhs.m_Scratch = 0;
rhs.m_ScratchBits = 0;
rhs.m_WordIndex = 0;
Expand All @@ -104,39 +71,40 @@ namespace bitstream
* @brief Returns the buffer that this writer is currently serializing into
* @return The buffer
*/
[[nodiscard]] uint8_t* get_buffer() const noexcept { return reinterpret_cast<uint8_t*>(m_Buffer); }
[[nodiscard]] uint8_t* get_buffer() const noexcept { return reinterpret_cast<uint8_t*>(m_Policy.get_buffer()); }

/**
* @brief Returns the number of bits which have been written to the buffer
* @return The number of bits which have been written
*/
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsWritten; }
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_Policy.get_num_bits_serialized(); }

/**
* @brief Returns the number of bytes which have been written to the buffer
* @return The number of bytes which have been written
*/
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return m_NumBitsWritten > 0U ? ((m_NumBitsWritten - 1U) / 8U + 1U) : 0U; }
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return get_num_bits_serialized() > 0U ? ((get_num_bits_serialized() - 1U) / 8U + 1U) : 0U; }

/**
* @brief Returns whether the @p num_bits can fit in the buffer
* @param num_bits The number of bits to test
* @return Whether the number of bits can fit in the buffer
*/
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_NumBitsWritten + num_bits <= m_TotalBits; }
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_Policy.can_serialize_bits(num_bits); }

/**
* @brief Returns the number of bits which have not been written yet
* @note The same as get_total_bits() - get_num_bits_serialized()
* @return The remaining space in the buffer
*/
[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return m_TotalBits - m_NumBitsWritten; }
//[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return m_TotalBits - m_NumBitsWritten; }

// TODO: Use SFINAE to deduce whether m_Policy has get_total_bits()
/**
* @brief Returns the size of the buffer, in bits
* @return The size of the buffer, in bits
*/
[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_TotalBits; }
//[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_TotalBits; }

/**
* @brief Flushes any remaining bits into the buffer. Use this when you no longer intend to write anything to the buffer.
Expand All @@ -146,7 +114,7 @@ namespace bitstream
{
if (m_ScratchBits > 0U)
{
uint32_t* ptr = m_Buffer + m_WordIndex;
uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
uint32_t ptr_value = static_cast<uint32_t>(m_Scratch >> 32U);
*ptr = utility::to_big_endian32(ptr_value);

Expand All @@ -155,7 +123,7 @@ namespace bitstream
m_WordIndex++;
}

return m_NumBitsWritten;
return get_num_bits_serialized();
}

/**
Expand All @@ -164,13 +132,12 @@ namespace bitstream
*/
[[nodiscard]] bool prepend_checksum() noexcept
{
BS_ASSERT(m_NumBitsWritten == 0);
BS_ASSERT(get_num_bits_serialized() == 0);

BS_ASSERT(can_serialize_bits(32U));
BS_ASSERT(m_Policy.extend(32U));

// Advance the reader by the size of the checksum (32 bits / 1 word)
m_WordIndex++;
m_NumBitsWritten += 32U;

return true;
}
Expand All @@ -184,14 +151,17 @@ namespace bitstream
{
uint32_t num_bits = flush();

BS_ASSERT(num_bits > 32U);

// Copy protocol version to buffer
*m_Buffer = protocol_version;
uint32_t* buffer = m_Policy.get_buffer();
*buffer = protocol_version;

// Generate checksum of version + data
uint32_t checksum = utility::crc_uint32(reinterpret_cast<uint8_t*>(m_Buffer), get_num_bytes_serialized());
uint32_t checksum = utility::crc_uint32(reinterpret_cast<uint8_t*>(buffer), get_num_bytes_serialized());

// Put checksum at beginning
*m_Buffer = checksum;
*buffer = checksum;

return num_bits;
}
Expand All @@ -203,30 +173,33 @@ namespace bitstream
*/
[[nodiscard]] bool pad_to_size(uint32_t num_bytes) noexcept
{
BS_ASSERT(num_bytes * 8U <= m_TotalBits);
uint32_t num_bits_written = get_num_bits_serialized();

BS_ASSERT(num_bytes * 8U >= m_NumBitsWritten);

if (m_NumBitsWritten == 0)
BS_ASSERT(num_bytes * 8U >= num_bits_written);

BS_ASSERT(can_serialize_bits(num_bytes * 8U - num_bits_written));

if (num_bits_written == 0)
{
std::memset(m_Buffer, 0, num_bytes);
BS_ASSERT(m_Policy.extend(num_bytes * 8U - num_bits_written));

std::memset(m_Policy.get_buffer(), 0, num_bytes);

m_NumBitsWritten = num_bytes * 8;
m_Scratch = 0;
m_ScratchBits = 0;
m_WordIndex = num_bytes / 4;

return true;
}

uint32_t remainder = (num_bytes * 8U - m_NumBitsWritten) % 32U;
uint32_t remainder = (num_bytes * 8U - num_bits_written) % 32U;
uint32_t zero = 0;

// Align to byte
if (remainder != 0U)
BS_ASSERT(serialize_bits(zero, remainder));

uint32_t offset = m_NumBitsWritten / 32;
uint32_t offset = get_num_bits_serialized() / 32;
uint32_t max = num_bytes / 4;

// Serialize words
Expand Down Expand Up @@ -258,7 +231,7 @@ namespace bitstream
uint32_t zero = 0U;
bool status = serialize_bits(zero, 8U - remainder);

BS_ASSERT(status && m_NumBitsWritten % 8U == 0U);
BS_ASSERT(status && get_num_bits_serialized() % 8U == 0U);
}
return true;
}
Expand All @@ -273,16 +246,15 @@ namespace bitstream
{
BS_ASSERT(num_bits > 0U && num_bits <= 32U);

BS_ASSERT(can_serialize_bits(num_bits));
BS_ASSERT(m_Policy.extend(num_bits));

// Fast path
if (num_bits == 32U && m_ScratchBits == 0U)
{
uint32_t* ptr = m_Buffer + m_WordIndex;
uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;

*ptr = utility::to_big_endian32(value);

m_NumBitsWritten += num_bits;
m_WordIndex++;

return true;
Expand All @@ -293,11 +265,10 @@ namespace bitstream

m_Scratch |= ls_value;
m_ScratchBits += num_bits;
m_NumBitsWritten += num_bits;

if (m_ScratchBits >= 32U)
{
uint32_t* ptr = m_Buffer + m_WordIndex;
uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
uint32_t ptr_value = static_cast<uint32_t>(m_Scratch >> 32U);
*ptr = utility::to_big_endian32(ptr_value);
m_Scratch <<= 32ULL;
Expand Down Expand Up @@ -326,10 +297,11 @@ namespace bitstream

if (m_ScratchBits % 32U == 0U && num_words > 0U)
{
BS_ASSERT(m_Policy.extend(num_words * 32U));

// If the written buffer is word-aligned, just memcpy it
std::memcpy(m_Buffer + m_WordIndex, word_buffer, num_words * 4U);
std::memcpy(m_Policy.get_buffer() + m_WordIndex, word_buffer, num_words * 4U);

m_NumBitsWritten += num_words * 32U;
m_WordIndex += num_words;
}
else
Expand Down Expand Up @@ -367,7 +339,7 @@ namespace bitstream
*/
[[nodiscard]] bool serialize_into(bit_writer& writer) const noexcept
{
uint8_t* buffer = reinterpret_cast<uint8_t*>(m_Buffer);
uint8_t* buffer = reinterpret_cast<uint8_t*>(m_Policy.get_buffer());
uint32_t num_bits = get_num_bits_serialized();
uint32_t remainder_bits = num_bits % 8U;

Expand Down Expand Up @@ -413,9 +385,11 @@ namespace bitstream
}

private:
uint32_t* m_Buffer;
uint32_t m_NumBitsWritten;
uint32_t m_TotalBits;
fixed_policy m_Policy;

//uint32_t* m_Buffer;
//uint32_t m_NumBitsWritten;
//uint32_t m_TotalBits;

uint64_t m_Scratch;
uint32_t m_ScratchBits;
Expand Down
75 changes: 75 additions & 0 deletions include/bitstream/stream/stream_traits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include "byte_buffer.h"

#include <cstddef>
#include <cstdint>
#include <type_traits>

namespace bitstream
{
struct fixed_policy
{
using void_ptr = std::conditional_t<true, void*, const void*>;
using buffer_ptr = std::conditional_t<true, uint32_t*, const uint32_t*>;

/**
* @brief Construct a stream pointing to the given byte array with @p num_bytes size
* @param bytes The byte array to serialize to/from. Must be 4-byte aligned and the size must be a multiple of 4
* @param num_bytes The number of bytes in the array
*/
fixed_policy(void_ptr buffer, uint32_t num_bits) noexcept :
m_Buffer(static_cast<buffer_ptr>(buffer)),
m_NumBitsSerialized(0),
m_TotalBits(num_bits) {}

/**
* @brief Construct a stream pointing to the given @p buffer
* @param buffer The buffer to serialize to/from
* @param num_bits The maximum number of bits that we can read
*/
template<size_t Size>
fixed_policy(byte_buffer<Size>& buffer, uint32_t num_bits) noexcept :
m_Buffer(reinterpret_cast<buffer_ptr>(buffer.Bytes)),
m_NumBitsSerialized(0),
m_TotalBits(num_bits) {}

/**
* @brief Construct a stream pointing to the given @p buffer
* @param buffer The buffer to serialize to/from
*/
template<size_t Size>
fixed_policy(byte_buffer<Size>& buffer) noexcept :
m_Buffer(reinterpret_cast<buffer_ptr>(buffer.Bytes)),
m_NumBitsSerialized(0),
m_TotalBits(Size * 8) {}

buffer_ptr get_buffer() const noexcept { return m_Buffer; }

// TODO: Transition to size_t
uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsSerialized; }

bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_NumBitsSerialized + num_bits <= m_TotalBits; }

bool extend(uint32_t num_bits)
{
bool status = can_serialize_bits(num_bits);
m_NumBitsSerialized += num_bits;
return status;
}

buffer_ptr m_Buffer;
uint32_t m_NumBitsSerialized;
uint32_t m_TotalBits;
};

template<typename T>
struct growing_policy
{
bool can_serialize_bits(uint32_t bits_written, uint32_t num_bits) const noexcept { return true; }

bool extend(uint32_t bits_written, uint32_t num_bits) { return ; }

T Buffer;
};
}