#include "shared/net.hh" namespace shared { static std::string get_errno_str() noexcept { return strerror(errno); } addrinfo_t make_addrinfo(const char* const address, const char* const port, const addrinfo&& hints) { addrinfo* info = nullptr; if (int status = getaddrinfo(address, port, &hints, &info)) { throw std::runtime_error{gai_strerror(status)}; } return addrinfo_t{info, [](const auto& p) { freeaddrinfo(p); }}; } socket_t make_socket(const addrinfo_t& info) { const int sock = socket(info->ai_family, info->ai_socktype, info->ai_protocol); if (sock == -1) { throw std::runtime_error{get_errno_str()}; } // to avoid annoying binding timeouts const int enable = 1; setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); return sock; } socket_t accept_socket(const socket_t& lsock) { sockaddr_in in{}; socklen_t size = sizeof(in); const shared::socket_t sock = accept(lsock, reinterpret_cast(&in), &size); if (sock == -1) { throw std::runtime_error(get_errno_str()); } return sock; } void bind_socket(const socket_t sock, const addrinfo_t& info) { // bind to first we can for (const addrinfo* i = info.get(); i != nullptr; i = i->ai_next) { if (bind(sock, info->ai_addr, info->ai_addrlen) == -1) { continue; } return; } throw std::runtime_error{get_errno_str()}; } void connect_socket(const socket_t sock, const addrinfo_t& info) { if (connect(sock, info->ai_addr, info->ai_addrlen) == -1) { throw std::runtime_error{get_errno_str()}; } } void listen_socket(const socket_t sock) { if (listen(sock, SOMAXCONN) == -1) { throw std::runtime_error{get_errno_str()}; } } void close_socket(const socket_t sock) { if (close(sock) == -1) { throw std::runtime_error{get_errno_str()}; } } packet contents_to_packet(const contents_t& contents, const char* const command) { packet ret{}; std::memset(&ret.header, 0, sizeof(ret.header)); // 0init padding bytes std::uint32_t size = 0; for (const auto& keyvalue : contents) { const auto& key = keyvalue.first; // no structured bindings :( const auto& value = keyvalue.second; const auto contents_size = static_cast(value.size()); decltype(ret.contents) addition; // Copy the size, name and data as expected, then push it. std::copy(reinterpret_cast(&contents_size), reinterpret_cast(&contents_size) + sizeof(contents_size), std::back_inserter(addition)); std::copy(key.c_str(), key.c_str() + key.size() + 1, std::back_inserter(addition)); std::copy(value.c_str(), value.c_str() + value.size(), std::back_inserter(addition)); size += addition.size(); // update size before moving std::copy(std::make_move_iterator(std::begin(addition)), std::make_move_iterator(std::end(addition)), std::back_inserter(ret.contents)); } // fill in expected header values ret.header.size = size + sizeof(packet::header); std::memcpy(ret.header.command, command, 3); return ret; } contents_t packet_to_contents(const packet& packet) { contents_t ret{}; // Extract the data as described in the header file. for (auto i = 0u; i < packet.contents.size();) { const char* const data = packet.contents.data() + i; const std::uint32_t contents_size = [&]() { std::uint32_t contents_size; std::memcpy(&contents_size, reinterpret_cast(data), sizeof(std::uint32_t)); return contents_size; }(); const auto size_size = sizeof(std::uint32_t); const std::string name = data + size_size; const char* const contents = data + size_size + name.length() + 1; ret.emplace(name, std::string{contents, contents + contents_size}); i += size_size + name.size() + 1 + contents_size; } return ret; } static std::vector packet_to_data(const packet& packet) noexcept { std::vector data{}; // data.reserve(sizeof(struct packet) + packet.contents.size()); std::copy(reinterpret_cast(&packet.header), reinterpret_cast(&packet.header) + sizeof(packet::header), std::back_inserter(data)); std::copy(std::begin(packet.contents), std::end(packet.contents), std::back_inserter(data)); return data; } void send_packet(const packet& packet, const socket_t& sock, const sockaddr_in& dest) { const std::vector data = packet_to_data(packet); if (sendto(sock, data.data(), data.size(), 0, (sockaddr*)&dest, sizeof(sockaddr_in)) == -1) { throw std::runtime_error{get_errno_str()}; } } void send_packet(const packet& packet, const socket_t& sock) { const std::vector data = packet_to_data(packet); const auto total_size = data.size(); for (unsigned sent = 0; sent < data.size();) { const ssize_t res = send(sock, data.data() + sent, total_size - sent, 0); if (res != -1) { sent += res; continue; } if (errno == EAGAIN || errno == EWOULDBLOCK) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); continue; } throw std::runtime_error{get_errno_str()}; } } static std::size_t get_backlog_size(const socket_t& socket) noexcept { size_t size = 0; if (ioctl(socket, FIONREAD, &size) == -1) { throw std::runtime_error{get_errno_str()}; } return size; } std::shared_ptr maybe_urecv_packet(const socket_t& sock) { const auto packet_size = get_backlog_size(sock); if (packet_size <= 0) { return nullptr; } recv_packet_ret ret; unsigned int origin_len = sizeof(ret.origin); std::vector buffer; buffer.reserve(packet_size); if (recvfrom(sock, buffer.data(), packet_size, 0, (sockaddr*)&ret.origin, &origin_len) == -1) { throw std::runtime_error(get_errno_str()); } std::memcpy(&ret.packet.header, buffer.data(), sizeof(packet::header)); ret.packet.contents.reserve(ret.packet.header.size - sizeof(packet::header)); std::copy(buffer.data() + sizeof(packet::header), buffer.data() + ret.packet.header.size, std::back_inserter(ret.packet.contents)); return std::make_shared(std::move(ret)); } // true when finished reading as our stream packets may be very large static bool maybe_rrecv_packet(const socket_t& sock, packet& packet) { auto backlog_size = get_backlog_size(sock); if (backlog_size <= 0) { return false; } auto& target_size = packet.header.size; if (!target_size) { // header read required if (backlog_size < sizeof(header)) { // no header, try again later return false; } if (read(sock, &packet.header, sizeof(header)) == -1) { throw std::runtime_error{get_errno_str()}; } backlog_size -= sizeof(header); } const auto read_size = std::min(backlog_size, static_cast(target_size)); std::vector buffer; buffer.reserve(read_size); if (read(sock, buffer.data(), read_size) == -1) { throw std::runtime_error{get_errno_str()}; } std::copy(buffer.data(), buffer.data() + read_size, std::back_inserter(packet.contents)); if (packet.contents.size() < packet.header.size - sizeof(header)) { return false; // more data to read, do again } return true; } static const auto ms_timeout = std::chrono::seconds(30); std::shared_ptr urecv_packet(const socket_t& rsock, const bool& timeout) { // poll our non-blocking sockets const auto start = std::chrono::steady_clock::now(); while (!shared::should_exit) { const auto pkt = maybe_urecv_packet(rsock); if (pkt != nullptr) { return pkt; } if (timeout && std::chrono::steady_clock::now() > start + ms_timeout) { throw std::runtime_error("urecv timeout elapsed"); } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } throw should_exit_exception(); } std::shared_ptr rrecv_packet(const socket_t& rsock, const bool& timeout) { packet packet{}; auto last = std::chrono::steady_clock::now(); while (!shared::should_exit) { const auto prev_size = packet.contents.size(); if (maybe_rrecv_packet(rsock, packet)) { return std::make_shared(std::move(packet)); } if (prev_size != packet.contents.size()) { last = std::chrono::steady_clock::now(); } if (timeout && std::chrono::steady_clock::now() > last + ms_timeout) { throw std::runtime_error("rrecv timeout elapsed"); } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } throw should_exit_exception(); } } // namespace shared