aboutsummaryrefslogtreecommitdiff
path: root/comp3331/server/src/shared/net.cc
diff options
context:
space:
mode:
Diffstat (limited to 'comp3331/server/src/shared/net.cc')
-rw-r--r--comp3331/server/src/shared/net.cc284
1 files changed, 284 insertions, 0 deletions
diff --git a/comp3331/server/src/shared/net.cc b/comp3331/server/src/shared/net.cc
new file mode 100644
index 0000000..b71170a
--- /dev/null
+++ b/comp3331/server/src/shared/net.cc
@@ -0,0 +1,284 @@
+#include "shared/net.hh"
+
+namespace shared {
+
+static std::string get_errno_str() noexcept { return strerror(errno); }
+
+addrinfo_t make_addrinfo(const char* const address, const char* const port,
+ const addrinfo&& hints) {
+ addrinfo* info = nullptr;
+ if (int status = getaddrinfo(address, port, &hints, &info)) {
+ throw std::runtime_error{gai_strerror(status)};
+ }
+ return addrinfo_t{info, [](const auto& p) { freeaddrinfo(p); }};
+}
+
+socket_t make_socket(const addrinfo_t& info) {
+ const int sock =
+ socket(info->ai_family, info->ai_socktype, info->ai_protocol);
+ if (sock == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+
+ // to avoid annoying binding timeouts
+ const int enable = 1;
+ setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
+ setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
+ return sock;
+}
+
+socket_t accept_socket(const socket_t& lsock) {
+ sockaddr_in in{};
+ socklen_t size = sizeof(in);
+
+ const shared::socket_t sock =
+ accept(lsock, reinterpret_cast<sockaddr*>(&in), &size);
+ if (sock == -1) {
+ throw std::runtime_error(get_errno_str());
+ }
+ return sock;
+}
+
+void bind_socket(const socket_t sock, const addrinfo_t& info) {
+ // bind to first we can
+ for (const addrinfo* i = info.get(); i != nullptr; i = i->ai_next) {
+ if (bind(sock, info->ai_addr, info->ai_addrlen) == -1) {
+ continue;
+ }
+ return;
+ }
+ throw std::runtime_error{get_errno_str()};
+}
+
+void connect_socket(const socket_t sock, const addrinfo_t& info) {
+ if (connect(sock, info->ai_addr, info->ai_addrlen) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void listen_socket(const socket_t sock) {
+ if (listen(sock, SOMAXCONN) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void close_socket(const socket_t sock) {
+ if (close(sock) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+packet contents_to_packet(const contents_t& contents,
+ const char* const command) {
+ packet ret{};
+ std::memset(&ret.header, 0, sizeof(ret.header)); // 0init padding bytes
+
+ std::uint32_t size = 0;
+ for (const auto& keyvalue : contents) {
+ const auto& key = keyvalue.first; // no structured bindings :(
+ const auto& value = keyvalue.second;
+
+ const auto contents_size = static_cast<std::uint32_t>(value.size());
+
+ decltype(ret.contents) addition;
+ // Copy the size, name and data as expected, then push it.
+ std::copy(reinterpret_cast<const char* const>(&contents_size),
+ reinterpret_cast<const char* const>(&contents_size) +
+ sizeof(contents_size),
+ std::back_inserter(addition));
+ std::copy(key.c_str(), key.c_str() + key.size() + 1,
+ std::back_inserter(addition));
+ std::copy(value.c_str(), value.c_str() + value.size(),
+ std::back_inserter(addition));
+
+ size += addition.size(); // update size before moving
+ std::copy(std::make_move_iterator(std::begin(addition)),
+ std::make_move_iterator(std::end(addition)),
+ std::back_inserter(ret.contents));
+ }
+
+ // fill in expected header values
+ ret.header.size = size + sizeof(packet::header);
+ std::memcpy(ret.header.command, command, 3);
+
+ return ret;
+}
+
+contents_t packet_to_contents(const packet& packet) {
+ contents_t ret{};
+
+ // Extract the data as described in the header file.
+ for (auto i = 0u; i < packet.contents.size();) {
+ const char* const data = packet.contents.data() + i;
+
+ const std::uint32_t contents_size = [&]() {
+ std::uint32_t contents_size;
+ std::memcpy(&contents_size,
+ reinterpret_cast<const std::uint32_t* const>(data),
+ sizeof(std::uint32_t));
+ return contents_size;
+ }();
+
+ const auto size_size = sizeof(std::uint32_t);
+ const std::string name = data + size_size;
+ const char* const contents = data + size_size + name.length() + 1;
+
+ ret.emplace(name, std::string{contents, contents + contents_size});
+ i += size_size + name.size() + 1 + contents_size;
+ }
+ return ret;
+}
+
+static std::vector<char> packet_to_data(const packet& packet) noexcept {
+ std::vector<char> data{};
+ // data.reserve(sizeof(struct packet) + packet.contents.size());
+ std::copy(reinterpret_cast<const char* const>(&packet.header),
+ reinterpret_cast<const char* const>(&packet.header) +
+ sizeof(packet::header),
+ std::back_inserter(data));
+ std::copy(std::begin(packet.contents), std::end(packet.contents),
+ std::back_inserter(data));
+ return data;
+}
+
+void send_packet(const packet& packet, const socket_t& sock,
+ const sockaddr_in& dest) {
+ const std::vector<char> data = packet_to_data(packet);
+
+ if (sendto(sock, data.data(), data.size(), 0, (sockaddr*)&dest,
+ sizeof(sockaddr_in)) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void send_packet(const packet& packet, const socket_t& sock) {
+ const std::vector<char> data = packet_to_data(packet);
+
+ const auto total_size = data.size();
+ for (unsigned sent = 0; sent < data.size();) {
+ const ssize_t res =
+ send(sock, data.data() + sent, total_size - sent, 0);
+ if (res != -1) {
+ sent += res;
+ continue;
+ }
+
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ continue;
+ }
+
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+static std::size_t get_backlog_size(const socket_t& socket) noexcept {
+ size_t size = 0;
+ if (ioctl(socket, FIONREAD, &size) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+ return size;
+}
+
+std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock) {
+ const auto packet_size = get_backlog_size(sock);
+ if (packet_size <= 0) {
+ return nullptr;
+ }
+
+ recv_packet_ret ret;
+ unsigned int origin_len = sizeof(ret.origin);
+
+ std::vector<char> buffer;
+ buffer.reserve(packet_size);
+ if (recvfrom(sock, buffer.data(), packet_size, 0, (sockaddr*)&ret.origin,
+ &origin_len) == -1) {
+ throw std::runtime_error(get_errno_str());
+ }
+
+ std::memcpy(&ret.packet.header, buffer.data(), sizeof(packet::header));
+ ret.packet.contents.reserve(ret.packet.header.size -
+ sizeof(packet::header));
+ std::copy(buffer.data() + sizeof(packet::header),
+ buffer.data() + ret.packet.header.size,
+ std::back_inserter(ret.packet.contents));
+
+ return std::make_shared<recv_packet_ret>(std::move(ret));
+}
+
+// true when finished reading as our stream packets may be very large
+static bool maybe_rrecv_packet(const socket_t& sock, packet& packet) {
+
+ auto backlog_size = get_backlog_size(sock);
+ if (backlog_size <= 0) {
+ return false;
+ }
+
+ auto& target_size = packet.header.size;
+ if (!target_size) { // header read required
+ if (backlog_size < sizeof(header)) { // no header, try again later
+ return false;
+ }
+
+ if (read(sock, &packet.header, sizeof(header)) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+ backlog_size -= sizeof(header);
+ }
+
+ const auto read_size =
+ std::min(backlog_size, static_cast<unsigned long>(target_size));
+ std::vector<char> buffer;
+ buffer.reserve(read_size);
+ if (read(sock, buffer.data(), read_size) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+
+ std::copy(buffer.data(), buffer.data() + read_size,
+ std::back_inserter(packet.contents));
+ if (packet.contents.size() < packet.header.size - sizeof(header)) {
+ return false; // more data to read, do again
+ }
+
+ return true;
+}
+
+static const auto ms_timeout = std::chrono::seconds(30);
+std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock,
+ const bool& timeout) {
+ // poll our non-blocking sockets
+ const auto start = std::chrono::steady_clock::now();
+ while (!shared::should_exit) {
+ const auto pkt = maybe_urecv_packet(rsock);
+ if (pkt != nullptr) {
+ return pkt;
+ }
+ if (timeout && std::chrono::steady_clock::now() > start + ms_timeout) {
+ throw std::runtime_error("urecv timeout elapsed");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ throw should_exit_exception();
+}
+
+std::shared_ptr<packet> rrecv_packet(const socket_t& rsock,
+ const bool& timeout) {
+ packet packet{};
+ auto last = std::chrono::steady_clock::now();
+ while (!shared::should_exit) {
+ const auto prev_size = packet.contents.size();
+ if (maybe_rrecv_packet(rsock, packet)) {
+ return std::make_shared<struct packet>(std::move(packet));
+ }
+ if (prev_size != packet.contents.size()) {
+ last = std::chrono::steady_clock::now();
+ }
+ if (timeout && std::chrono::steady_clock::now() > last + ms_timeout) {
+ throw std::runtime_error("rrecv timeout elapsed");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ throw should_exit_exception();
+}
+
+} // namespace shared