aboutsummaryrefslogtreecommitdiff
path: root/comp3331/server/src/shared
diff options
context:
space:
mode:
Diffstat (limited to 'comp3331/server/src/shared')
-rw-r--r--comp3331/server/src/shared/connection.cc54
-rw-r--r--comp3331/server/src/shared/connection.hh60
-rw-r--r--comp3331/server/src/shared/net.cc284
-rw-r--r--comp3331/server/src/shared/net.hh87
-rw-r--r--comp3331/server/src/shared/shared.cc25
-rw-r--r--comp3331/server/src/shared/shared.hh28
6 files changed, 538 insertions, 0 deletions
diff --git a/comp3331/server/src/shared/connection.cc b/comp3331/server/src/shared/connection.cc
new file mode 100644
index 0000000..1063823
--- /dev/null
+++ b/comp3331/server/src/shared/connection.cc
@@ -0,0 +1,54 @@
+#include "shared/connection.hh"
+
+namespace shared {
+
+void connection::send_packet(packet&& packet) noexcept {
+ packet.header.sequence = this->seq_num;
+ shared::send_packet(packet, this->sock, this->info);
+ ++this->seq_num;
+
+ std::lock_guard<std::mutex> guard{*this->lock};
+ this->sent.push_back(packet);
+}
+
+bool connection::should_discard_packet(const packet& packet) noexcept {
+ std::lock_guard<std::mutex> guard{*this->lock};
+
+ // ack case
+ if (packet.header.command[0] == '\0') {
+ this->sent.erase(std::remove_if(std::begin(this->sent),
+ std::end(this->sent),
+ [&](const auto& p) {
+ return p.header.sequence <=
+ packet.header.sequence;
+ }),
+ std::end(this->sent));
+ return true;
+ }
+
+ // Send an ack for the packet if it's not an ack itself
+ auto ack_pkt = shared::contents_to_packet({}, "\0\0\0");
+ ack_pkt.header.sequence = packet.header.sequence;
+ shared::send_packet(ack_pkt, this->sock, this->info);
+
+ if (packet.header.sequence != this->ack_num) {
+ return true;
+ }
+ ++this->ack_num;
+ return false;
+}
+
+// Reliable transport reads our packets on a different thread and resends if
+// necessary.
+void connection::do_reliable_transport() noexcept {
+ while (!*this->should_thread_exit) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+ std::lock_guard<std::mutex> guard{*this->lock};
+ for (const auto& packet : this->sent) {
+ shared::send_packet(packet, this->sock, this->info);
+ }
+ }
+}
+
+} // namespace shared
diff --git a/comp3331/server/src/shared/connection.hh b/comp3331/server/src/shared/connection.hh
new file mode 100644
index 0000000..d7237cb
--- /dev/null
+++ b/comp3331/server/src/shared/connection.hh
@@ -0,0 +1,60 @@
+#ifndef SHARED_CONNECTION_HH_
+#define SHARED_CONNECTION_HH_
+
+#include <atomic>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+#include "shared/net.hh"
+
+namespace shared {
+
+// The connection class abstracts sending and receiving data, including reliable
+// transmission over UDP.
+class connection {
+private:
+ shared::socket_t sock;
+ sockaddr_in info;
+
+private:
+ std::uint32_t seq_num = 0; // track packet sequence number
+ std::uint32_t ack_num = 0;
+
+ // for reliable transport, spawn a new thread which reads sent/received
+ std::unique_ptr<std::atomic<bool>> should_thread_exit;
+ std::unique_ptr<std::mutex> lock;
+ std::vector<packet> sent;
+ std::vector<packet> received;
+ std::shared_ptr<std::thread> reliable_transport_thread;
+ void do_reliable_transport() noexcept;
+
+public:
+ connection(const socket_t& sock, sockaddr_in&& info)
+ : sock(sock), info(std::move(info)),
+ should_thread_exit(std::make_unique<std::atomic<bool>>(false)),
+ lock(std::make_unique<std::mutex>()),
+ reliable_transport_thread(std::make_shared<std::thread>(
+ &connection::do_reliable_transport, this)) {}
+
+ connection(const connection&) = delete;
+ connection(connection&&) = default;
+ ~connection() noexcept {
+ *this->should_thread_exit = true;
+ this->reliable_transport_thread->join();
+ }
+
+public:
+ const sockaddr_in& get_info() const noexcept { return this->info; }
+ const socket_t& get_socket() const noexcept { return this->sock; }
+
+public:
+ // All unreliable packets should go through these functions so we may track
+ // if our packets have been sent or received, making them reliable.
+ void send_packet(packet&& packet) noexcept;
+ bool should_discard_packet(const packet& packet) noexcept;
+};
+
+} // namespace shared
+
+#endif
diff --git a/comp3331/server/src/shared/net.cc b/comp3331/server/src/shared/net.cc
new file mode 100644
index 0000000..b71170a
--- /dev/null
+++ b/comp3331/server/src/shared/net.cc
@@ -0,0 +1,284 @@
+#include "shared/net.hh"
+
+namespace shared {
+
+static std::string get_errno_str() noexcept { return strerror(errno); }
+
+addrinfo_t make_addrinfo(const char* const address, const char* const port,
+ const addrinfo&& hints) {
+ addrinfo* info = nullptr;
+ if (int status = getaddrinfo(address, port, &hints, &info)) {
+ throw std::runtime_error{gai_strerror(status)};
+ }
+ return addrinfo_t{info, [](const auto& p) { freeaddrinfo(p); }};
+}
+
+socket_t make_socket(const addrinfo_t& info) {
+ const int sock =
+ socket(info->ai_family, info->ai_socktype, info->ai_protocol);
+ if (sock == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+
+ // to avoid annoying binding timeouts
+ const int enable = 1;
+ setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
+ setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
+ return sock;
+}
+
+socket_t accept_socket(const socket_t& lsock) {
+ sockaddr_in in{};
+ socklen_t size = sizeof(in);
+
+ const shared::socket_t sock =
+ accept(lsock, reinterpret_cast<sockaddr*>(&in), &size);
+ if (sock == -1) {
+ throw std::runtime_error(get_errno_str());
+ }
+ return sock;
+}
+
+void bind_socket(const socket_t sock, const addrinfo_t& info) {
+ // bind to first we can
+ for (const addrinfo* i = info.get(); i != nullptr; i = i->ai_next) {
+ if (bind(sock, info->ai_addr, info->ai_addrlen) == -1) {
+ continue;
+ }
+ return;
+ }
+ throw std::runtime_error{get_errno_str()};
+}
+
+void connect_socket(const socket_t sock, const addrinfo_t& info) {
+ if (connect(sock, info->ai_addr, info->ai_addrlen) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void listen_socket(const socket_t sock) {
+ if (listen(sock, SOMAXCONN) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void close_socket(const socket_t sock) {
+ if (close(sock) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+packet contents_to_packet(const contents_t& contents,
+ const char* const command) {
+ packet ret{};
+ std::memset(&ret.header, 0, sizeof(ret.header)); // 0init padding bytes
+
+ std::uint32_t size = 0;
+ for (const auto& keyvalue : contents) {
+ const auto& key = keyvalue.first; // no structured bindings :(
+ const auto& value = keyvalue.second;
+
+ const auto contents_size = static_cast<std::uint32_t>(value.size());
+
+ decltype(ret.contents) addition;
+ // Copy the size, name and data as expected, then push it.
+ std::copy(reinterpret_cast<const char* const>(&contents_size),
+ reinterpret_cast<const char* const>(&contents_size) +
+ sizeof(contents_size),
+ std::back_inserter(addition));
+ std::copy(key.c_str(), key.c_str() + key.size() + 1,
+ std::back_inserter(addition));
+ std::copy(value.c_str(), value.c_str() + value.size(),
+ std::back_inserter(addition));
+
+ size += addition.size(); // update size before moving
+ std::copy(std::make_move_iterator(std::begin(addition)),
+ std::make_move_iterator(std::end(addition)),
+ std::back_inserter(ret.contents));
+ }
+
+ // fill in expected header values
+ ret.header.size = size + sizeof(packet::header);
+ std::memcpy(ret.header.command, command, 3);
+
+ return ret;
+}
+
+contents_t packet_to_contents(const packet& packet) {
+ contents_t ret{};
+
+ // Extract the data as described in the header file.
+ for (auto i = 0u; i < packet.contents.size();) {
+ const char* const data = packet.contents.data() + i;
+
+ const std::uint32_t contents_size = [&]() {
+ std::uint32_t contents_size;
+ std::memcpy(&contents_size,
+ reinterpret_cast<const std::uint32_t* const>(data),
+ sizeof(std::uint32_t));
+ return contents_size;
+ }();
+
+ const auto size_size = sizeof(std::uint32_t);
+ const std::string name = data + size_size;
+ const char* const contents = data + size_size + name.length() + 1;
+
+ ret.emplace(name, std::string{contents, contents + contents_size});
+ i += size_size + name.size() + 1 + contents_size;
+ }
+ return ret;
+}
+
+static std::vector<char> packet_to_data(const packet& packet) noexcept {
+ std::vector<char> data{};
+ // data.reserve(sizeof(struct packet) + packet.contents.size());
+ std::copy(reinterpret_cast<const char* const>(&packet.header),
+ reinterpret_cast<const char* const>(&packet.header) +
+ sizeof(packet::header),
+ std::back_inserter(data));
+ std::copy(std::begin(packet.contents), std::end(packet.contents),
+ std::back_inserter(data));
+ return data;
+}
+
+void send_packet(const packet& packet, const socket_t& sock,
+ const sockaddr_in& dest) {
+ const std::vector<char> data = packet_to_data(packet);
+
+ if (sendto(sock, data.data(), data.size(), 0, (sockaddr*)&dest,
+ sizeof(sockaddr_in)) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void send_packet(const packet& packet, const socket_t& sock) {
+ const std::vector<char> data = packet_to_data(packet);
+
+ const auto total_size = data.size();
+ for (unsigned sent = 0; sent < data.size();) {
+ const ssize_t res =
+ send(sock, data.data() + sent, total_size - sent, 0);
+ if (res != -1) {
+ sent += res;
+ continue;
+ }
+
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ continue;
+ }
+
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+static std::size_t get_backlog_size(const socket_t& socket) noexcept {
+ size_t size = 0;
+ if (ioctl(socket, FIONREAD, &size) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+ return size;
+}
+
+std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock) {
+ const auto packet_size = get_backlog_size(sock);
+ if (packet_size <= 0) {
+ return nullptr;
+ }
+
+ recv_packet_ret ret;
+ unsigned int origin_len = sizeof(ret.origin);
+
+ std::vector<char> buffer;
+ buffer.reserve(packet_size);
+ if (recvfrom(sock, buffer.data(), packet_size, 0, (sockaddr*)&ret.origin,
+ &origin_len) == -1) {
+ throw std::runtime_error(get_errno_str());
+ }
+
+ std::memcpy(&ret.packet.header, buffer.data(), sizeof(packet::header));
+ ret.packet.contents.reserve(ret.packet.header.size -
+ sizeof(packet::header));
+ std::copy(buffer.data() + sizeof(packet::header),
+ buffer.data() + ret.packet.header.size,
+ std::back_inserter(ret.packet.contents));
+
+ return std::make_shared<recv_packet_ret>(std::move(ret));
+}
+
+// true when finished reading as our stream packets may be very large
+static bool maybe_rrecv_packet(const socket_t& sock, packet& packet) {
+
+ auto backlog_size = get_backlog_size(sock);
+ if (backlog_size <= 0) {
+ return false;
+ }
+
+ auto& target_size = packet.header.size;
+ if (!target_size) { // header read required
+ if (backlog_size < sizeof(header)) { // no header, try again later
+ return false;
+ }
+
+ if (read(sock, &packet.header, sizeof(header)) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+ backlog_size -= sizeof(header);
+ }
+
+ const auto read_size =
+ std::min(backlog_size, static_cast<unsigned long>(target_size));
+ std::vector<char> buffer;
+ buffer.reserve(read_size);
+ if (read(sock, buffer.data(), read_size) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+
+ std::copy(buffer.data(), buffer.data() + read_size,
+ std::back_inserter(packet.contents));
+ if (packet.contents.size() < packet.header.size - sizeof(header)) {
+ return false; // more data to read, do again
+ }
+
+ return true;
+}
+
+static const auto ms_timeout = std::chrono::seconds(30);
+std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock,
+ const bool& timeout) {
+ // poll our non-blocking sockets
+ const auto start = std::chrono::steady_clock::now();
+ while (!shared::should_exit) {
+ const auto pkt = maybe_urecv_packet(rsock);
+ if (pkt != nullptr) {
+ return pkt;
+ }
+ if (timeout && std::chrono::steady_clock::now() > start + ms_timeout) {
+ throw std::runtime_error("urecv timeout elapsed");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ throw should_exit_exception();
+}
+
+std::shared_ptr<packet> rrecv_packet(const socket_t& rsock,
+ const bool& timeout) {
+ packet packet{};
+ auto last = std::chrono::steady_clock::now();
+ while (!shared::should_exit) {
+ const auto prev_size = packet.contents.size();
+ if (maybe_rrecv_packet(rsock, packet)) {
+ return std::make_shared<struct packet>(std::move(packet));
+ }
+ if (prev_size != packet.contents.size()) {
+ last = std::chrono::steady_clock::now();
+ }
+ if (timeout && std::chrono::steady_clock::now() > last + ms_timeout) {
+ throw std::runtime_error("rrecv timeout elapsed");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ throw should_exit_exception();
+}
+
+} // namespace shared
diff --git a/comp3331/server/src/shared/net.hh b/comp3331/server/src/shared/net.hh
new file mode 100644
index 0000000..b17ca1f
--- /dev/null
+++ b/comp3331/server/src/shared/net.hh
@@ -0,0 +1,87 @@
+#ifndef SHARED_NET_HH_
+#define SHARED_NET_HH_
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <chrono>
+#include <cstdint>
+#include <cstring>
+#include <errno.h>
+#include <fcntl.h>
+#include <memory>
+#include <netdb.h>
+#include <optional>
+#include <stdexcept>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <thread>
+#include <type_traits>
+#include <unistd.h>
+#include <unordered_map>
+#include <vector>
+
+#include "shared/shared.hh"
+
+// Functions in this namespace are common network-related wrappers with error
+// checking.
+namespace shared {
+
+using addrinfo_t = std::shared_ptr<addrinfo>; // for automatic freeaddrinfo call
+addrinfo_t make_addrinfo(const char* const address, const char* const port,
+ const addrinfo&& hints);
+
+using socket_t = int;
+int make_socket(const addrinfo_t& info);
+int accept_socket(const socket_t& sock);
+void bind_socket(const socket_t sock, const addrinfo_t& info);
+void connect_socket(const socket_t sock, const addrinfo_t& info);
+void listen_socket(const socket_t sock);
+void close_socket(const socket_t sock);
+
+struct header {
+ std::uint32_t size = 0; // size of packet, including header.
+ std::uint32_t sequence; // sequence number of packet
+ char command[3]; // command type, if \0\0\0 then it's an ack
+};
+static_assert(std::is_trivially_copyable<header>::value,
+ "header must be memcpy-able");
+struct packet {
+ header header;
+ std::vector<char> contents;
+};
+
+// Our packets contents consist of data entries like so:
+// std::uint32_t | char[] | DATA
+// ^ size of data ^ name ^ data, which repeats for size length
+// Any packet may contain multiple, or zero, entries. The name can be any
+// length. This way we don't have to define structs with hardcoded length
+// limits, and we can use this format when sending anything, even files.
+// Numeric values will be encoded as strings and converted when necessary.
+using contents_t = std::unordered_map<std::string, std::string>;
+void send_packet(const packet& packet, const socket_t& sock,
+ const sockaddr_in& dest);
+void send_packet(const packet& packet, const socket_t& sock);
+
+packet contents_to_packet(const contents_t& contents,
+ const char* const command);
+contents_t packet_to_contents(const packet& packet);
+
+// Recv's sockets, might return nullptr if no packet available.
+struct recv_packet_ret {
+ sockaddr_in origin;
+ struct packet packet;
+};
+
+// non-blocking
+std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock);
+// blocking, will throw if timeout is elapsed
+std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock,
+ const bool& timeout = true);
+std::shared_ptr<packet> rrecv_packet(const socket_t& rsock,
+ const bool& timeout = true);
+
+} // namespace shared
+
+#endif
diff --git a/comp3331/server/src/shared/shared.cc b/comp3331/server/src/shared/shared.cc
new file mode 100644
index 0000000..3a13b47
--- /dev/null
+++ b/comp3331/server/src/shared/shared.cc
@@ -0,0 +1,25 @@
+#include "shared/shared.hh"
+
+namespace shared {
+
+bool should_exit = false;
+
+static void set_signal(const decltype(SIGINT) signal,
+ void (*const callback)(const int)) {
+ struct sigaction sa {};
+ sa.sa_handler = callback;
+ if (sigaction(signal, &sa, nullptr) == -1) {
+ throw std::runtime_error("failed to set signal handler!");
+ }
+}
+
+void set_exit_handler() {
+ set_signal(SIGPIPE, SIG_IGN);
+ set_signal(SIGINT, [](const int) {
+ std::cout << " interrupt signal received\n";
+ should_exit = true;
+ });
+}
+
+} // namespace shared
+
diff --git a/comp3331/server/src/shared/shared.hh b/comp3331/server/src/shared/shared.hh
new file mode 100644
index 0000000..8e4fa36
--- /dev/null
+++ b/comp3331/server/src/shared/shared.hh
@@ -0,0 +1,28 @@
+#ifndef SHARED_SHARED_HH_
+#define SHARED_SHARED_HH_
+
+#include <functional>
+#include <iostream>
+#include <signal.h>
+#include <stdexcept>
+
+namespace shared {
+extern bool should_exit;
+
+void set_exit_handler();
+
+class should_exit_exception : public std::exception {};
+
+// This won't exist until c++24 lol
+class scoped_function {
+private:
+ using func_t = std::function<void()>;
+ func_t func;
+
+public:
+ scoped_function(const func_t& f) : func(f) {}
+ ~scoped_function() { this->func(); }
+};
+} // namespace shared
+
+#endif