diff options
| author | Nicolas James <Eele1Ephe7uZahRie@tutanota.com> | 2025-02-13 18:00:17 +1100 |
|---|---|---|
| committer | Nicolas James <Eele1Ephe7uZahRie@tutanota.com> | 2025-02-13 18:00:17 +1100 |
| commit | 98cef5e9a772602d42acfcf233838c760424db9a (patch) | |
| tree | 5277fa1d7cc0a69a0f166fcbf10fd320f345f049 /comp3331/server/src/shared/net.cc | |
initial commit
Diffstat (limited to 'comp3331/server/src/shared/net.cc')
| -rw-r--r-- | comp3331/server/src/shared/net.cc | 284 |
1 files changed, 284 insertions, 0 deletions
diff --git a/comp3331/server/src/shared/net.cc b/comp3331/server/src/shared/net.cc new file mode 100644 index 0000000..b71170a --- /dev/null +++ b/comp3331/server/src/shared/net.cc @@ -0,0 +1,284 @@ +#include "shared/net.hh" + +namespace shared { + +static std::string get_errno_str() noexcept { return strerror(errno); } + +addrinfo_t make_addrinfo(const char* const address, const char* const port, + const addrinfo&& hints) { + addrinfo* info = nullptr; + if (int status = getaddrinfo(address, port, &hints, &info)) { + throw std::runtime_error{gai_strerror(status)}; + } + return addrinfo_t{info, [](const auto& p) { freeaddrinfo(p); }}; +} + +socket_t make_socket(const addrinfo_t& info) { + const int sock = + socket(info->ai_family, info->ai_socktype, info->ai_protocol); + if (sock == -1) { + throw std::runtime_error{get_errno_str()}; + } + + // to avoid annoying binding timeouts + const int enable = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + return sock; +} + +socket_t accept_socket(const socket_t& lsock) { + sockaddr_in in{}; + socklen_t size = sizeof(in); + + const shared::socket_t sock = + accept(lsock, reinterpret_cast<sockaddr*>(&in), &size); + if (sock == -1) { + throw std::runtime_error(get_errno_str()); + } + return sock; +} + +void bind_socket(const socket_t sock, const addrinfo_t& info) { + // bind to first we can + for (const addrinfo* i = info.get(); i != nullptr; i = i->ai_next) { + if (bind(sock, info->ai_addr, info->ai_addrlen) == -1) { + continue; + } + return; + } + throw std::runtime_error{get_errno_str()}; +} + +void connect_socket(const socket_t sock, const addrinfo_t& info) { + if (connect(sock, info->ai_addr, info->ai_addrlen) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +void listen_socket(const socket_t sock) { + if (listen(sock, SOMAXCONN) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +void close_socket(const socket_t sock) { + if (close(sock) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +packet contents_to_packet(const contents_t& contents, + const char* const command) { + packet ret{}; + std::memset(&ret.header, 0, sizeof(ret.header)); // 0init padding bytes + + std::uint32_t size = 0; + for (const auto& keyvalue : contents) { + const auto& key = keyvalue.first; // no structured bindings :( + const auto& value = keyvalue.second; + + const auto contents_size = static_cast<std::uint32_t>(value.size()); + + decltype(ret.contents) addition; + // Copy the size, name and data as expected, then push it. + std::copy(reinterpret_cast<const char* const>(&contents_size), + reinterpret_cast<const char* const>(&contents_size) + + sizeof(contents_size), + std::back_inserter(addition)); + std::copy(key.c_str(), key.c_str() + key.size() + 1, + std::back_inserter(addition)); + std::copy(value.c_str(), value.c_str() + value.size(), + std::back_inserter(addition)); + + size += addition.size(); // update size before moving + std::copy(std::make_move_iterator(std::begin(addition)), + std::make_move_iterator(std::end(addition)), + std::back_inserter(ret.contents)); + } + + // fill in expected header values + ret.header.size = size + sizeof(packet::header); + std::memcpy(ret.header.command, command, 3); + + return ret; +} + +contents_t packet_to_contents(const packet& packet) { + contents_t ret{}; + + // Extract the data as described in the header file. + for (auto i = 0u; i < packet.contents.size();) { + const char* const data = packet.contents.data() + i; + + const std::uint32_t contents_size = [&]() { + std::uint32_t contents_size; + std::memcpy(&contents_size, + reinterpret_cast<const std::uint32_t* const>(data), + sizeof(std::uint32_t)); + return contents_size; + }(); + + const auto size_size = sizeof(std::uint32_t); + const std::string name = data + size_size; + const char* const contents = data + size_size + name.length() + 1; + + ret.emplace(name, std::string{contents, contents + contents_size}); + i += size_size + name.size() + 1 + contents_size; + } + return ret; +} + +static std::vector<char> packet_to_data(const packet& packet) noexcept { + std::vector<char> data{}; + // data.reserve(sizeof(struct packet) + packet.contents.size()); + std::copy(reinterpret_cast<const char* const>(&packet.header), + reinterpret_cast<const char* const>(&packet.header) + + sizeof(packet::header), + std::back_inserter(data)); + std::copy(std::begin(packet.contents), std::end(packet.contents), + std::back_inserter(data)); + return data; +} + +void send_packet(const packet& packet, const socket_t& sock, + const sockaddr_in& dest) { + const std::vector<char> data = packet_to_data(packet); + + if (sendto(sock, data.data(), data.size(), 0, (sockaddr*)&dest, + sizeof(sockaddr_in)) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +void send_packet(const packet& packet, const socket_t& sock) { + const std::vector<char> data = packet_to_data(packet); + + const auto total_size = data.size(); + for (unsigned sent = 0; sent < data.size();) { + const ssize_t res = + send(sock, data.data() + sent, total_size - sent, 0); + if (res != -1) { + sent += res; + continue; + } + + if (errno == EAGAIN || errno == EWOULDBLOCK) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + + throw std::runtime_error{get_errno_str()}; + } +} + +static std::size_t get_backlog_size(const socket_t& socket) noexcept { + size_t size = 0; + if (ioctl(socket, FIONREAD, &size) == -1) { + throw std::runtime_error{get_errno_str()}; + } + return size; +} + +std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock) { + const auto packet_size = get_backlog_size(sock); + if (packet_size <= 0) { + return nullptr; + } + + recv_packet_ret ret; + unsigned int origin_len = sizeof(ret.origin); + + std::vector<char> buffer; + buffer.reserve(packet_size); + if (recvfrom(sock, buffer.data(), packet_size, 0, (sockaddr*)&ret.origin, + &origin_len) == -1) { + throw std::runtime_error(get_errno_str()); + } + + std::memcpy(&ret.packet.header, buffer.data(), sizeof(packet::header)); + ret.packet.contents.reserve(ret.packet.header.size - + sizeof(packet::header)); + std::copy(buffer.data() + sizeof(packet::header), + buffer.data() + ret.packet.header.size, + std::back_inserter(ret.packet.contents)); + + return std::make_shared<recv_packet_ret>(std::move(ret)); +} + +// true when finished reading as our stream packets may be very large +static bool maybe_rrecv_packet(const socket_t& sock, packet& packet) { + + auto backlog_size = get_backlog_size(sock); + if (backlog_size <= 0) { + return false; + } + + auto& target_size = packet.header.size; + if (!target_size) { // header read required + if (backlog_size < sizeof(header)) { // no header, try again later + return false; + } + + if (read(sock, &packet.header, sizeof(header)) == -1) { + throw std::runtime_error{get_errno_str()}; + } + backlog_size -= sizeof(header); + } + + const auto read_size = + std::min(backlog_size, static_cast<unsigned long>(target_size)); + std::vector<char> buffer; + buffer.reserve(read_size); + if (read(sock, buffer.data(), read_size) == -1) { + throw std::runtime_error{get_errno_str()}; + } + + std::copy(buffer.data(), buffer.data() + read_size, + std::back_inserter(packet.contents)); + if (packet.contents.size() < packet.header.size - sizeof(header)) { + return false; // more data to read, do again + } + + return true; +} + +static const auto ms_timeout = std::chrono::seconds(30); +std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock, + const bool& timeout) { + // poll our non-blocking sockets + const auto start = std::chrono::steady_clock::now(); + while (!shared::should_exit) { + const auto pkt = maybe_urecv_packet(rsock); + if (pkt != nullptr) { + return pkt; + } + if (timeout && std::chrono::steady_clock::now() > start + ms_timeout) { + throw std::runtime_error("urecv timeout elapsed"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + throw should_exit_exception(); +} + +std::shared_ptr<packet> rrecv_packet(const socket_t& rsock, + const bool& timeout) { + packet packet{}; + auto last = std::chrono::steady_clock::now(); + while (!shared::should_exit) { + const auto prev_size = packet.contents.size(); + if (maybe_rrecv_packet(rsock, packet)) { + return std::make_shared<struct packet>(std::move(packet)); + } + if (prev_size != packet.contents.size()) { + last = std::chrono::steady_clock::now(); + } + if (timeout && std::chrono::steady_clock::now() > last + ms_timeout) { + throw std::runtime_error("rrecv timeout elapsed"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + throw should_exit_exception(); +} + +} // namespace shared |
