aboutsummaryrefslogtreecommitdiff
path: root/comp3331
diff options
context:
space:
mode:
Diffstat (limited to 'comp3331')
-rw-r--r--comp3331/server/CMakeLists.txt44
-rw-r--r--comp3331/server/src/client/client.cc519
-rw-r--r--comp3331/server/src/client/client.hh24
-rw-r--r--comp3331/server/src/client/main.cc28
-rw-r--r--comp3331/server/src/client/main.hh11
-rw-r--r--comp3331/server/src/server/client.cc1
-rw-r--r--comp3331/server/src/server/client.hh24
-rw-r--r--comp3331/server/src/server/main.cc28
-rw-r--r--comp3331/server/src/server/main.hh9
-rw-r--r--comp3331/server/src/server/server.cc829
-rw-r--r--comp3331/server/src/server/server.hh28
-rw-r--r--comp3331/server/src/shared/connection.cc54
-rw-r--r--comp3331/server/src/shared/connection.hh60
-rw-r--r--comp3331/server/src/shared/net.cc284
-rw-r--r--comp3331/server/src/shared/net.hh87
-rw-r--r--comp3331/server/src/shared/shared.cc25
-rw-r--r--comp3331/server/src/shared/shared.hh28
17 files changed, 2083 insertions, 0 deletions
diff --git a/comp3331/server/CMakeLists.txt b/comp3331/server/CMakeLists.txt
new file mode 100644
index 0000000..60c174e
--- /dev/null
+++ b/comp3331/server/CMakeLists.txt
@@ -0,0 +1,44 @@
+cmake_minimum_required(VERSION 3.10)
+
+project(comp3331_assignment)
+
+set(CMAKE_C_COMPILER "clang")
+set(CMAKE_CXX_COMPILER "clang++")
+
+set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CXX_EXTENSIONS OFF)
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bin")
+
+include_directories(
+ "${PROJECT_SOURCE_DIR}/src"
+)
+
+file (GLOB_RECURSE SHARED_SOURCE_FILES CONFIGURE_DEPENDS
+ "${PROJECT_SOURCE_DIR}/src/shared/*.cc"
+)
+
+foreach(BINARY_NAME client;server)
+ file (GLOB_RECURSE BINARY_SOURCE_FILES CONFIGURE_DEPENDS
+ "${PROJECT_SOURCE_DIR}/src/${BINARY_NAME}/*.cc"
+ )
+ add_executable(${BINARY_NAME}
+ ${BINARY_SOURCE_FILES}
+ ${SHARED_SOURCE_FILES}
+ )
+
+ target_compile_options(${BINARY_NAME} PRIVATE
+ -Wall -Wextra -Wshadow -Wdouble-promotion -Wformat=2 -Wundef -fno-common
+ -Wconversion -Wpedantic -std=c++14 -g3 -O2
+ -fstack-protector-strong -fno-omit-frame-pointer -fsanitize=undefined
+ -Wno-exceptions
+ )
+ target_link_libraries(${BINARY_NAME} PRIVATE
+ #pthread
+ )
+ target_link_options(${BINARY_NAME} PRIVATE
+ -fstack-protector-strong -fsanitize=undefined
+ )
+endforeach()
+
diff --git a/comp3331/server/src/client/client.cc b/comp3331/server/src/client/client.cc
new file mode 100644
index 0000000..9ec2ab2
--- /dev/null
+++ b/comp3331/server/src/client/client.cc
@@ -0,0 +1,519 @@
+#include "client/client.hh"
+
+namespace client {
+
+// Args passed to handle functions, the input line split by spacebars. Similar
+// to argv, the first is the program name (or in this case command name).
+
+static std::string read_cin(const char* const msg) noexcept {
+ std::cout << msg << ": ";
+ std::string line;
+ std::getline(std::cin, line);
+ return line;
+}
+
+static std::pair<std::string, std::string>&
+get_address_port(const char* const address = nullptr,
+ const char* const port = nullptr) {
+ static std::pair<std::string, std::string> pair = [&]() {
+ return std::make_pair<std::string, std::string>(address, port);
+ }();
+ return pair;
+}
+
+static shared::contents_t urecv_contents(shared::connection& connection) {
+ const auto ret = [&]() {
+ for (;;) {
+ const auto ret = shared::urecv_packet(connection.get_socket());
+ if (connection.should_discard_packet(ret->packet)) {
+ continue;
+ }
+ return ret;
+ }
+ }();
+ const shared::contents_t contents = shared::packet_to_contents(ret->packet);
+
+ // error checking early out
+ const auto find_it = contents.find("success");
+ if (find_it != std::end(contents)) {
+ if (find_it->second == "error") {
+ throw std::runtime_error(contents.find("message")->second);
+ }
+ }
+
+ return contents;
+}
+
+static shared::socket_t get_new_rsock() {
+ const auto& addr_port = get_address_port();
+ const auto info = [&]() {
+ addrinfo a{};
+ a.ai_flags = AI_PASSIVE;
+ a.ai_family = AF_INET;
+ a.ai_socktype = SOCK_STREAM;
+ return shared::make_addrinfo(addr_port.first.c_str(),
+ addr_port.second.c_str(), std::move(a));
+ }();
+ const shared::socket_t socket = shared::make_socket(info);
+ shared::connect_socket(socket, info);
+ return socket;
+}
+
+using args_t = std::vector<std::string>;
+static void handle_create_thread(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 2) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ { // send request
+ const shared::contents_t contents{
+ {"thread_title", thread_title},
+ };
+ connection.send_packet(shared::contents_to_packet(contents, "CRT"));
+ }
+
+ const shared::contents_t contents = urecv_contents(connection);
+
+ const auto message_it = contents.find("message");
+ const auto success_it = contents.find("success");
+ if (message_it == std::end(contents) || success_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void handle_post_message(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() < 3) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ const std::string message = [&]() {
+ std::string message = args[2];
+ for (auto i = 3ul; i < args.size(); ++i) {
+ message += ' ' + args[i];
+ }
+ return message;
+ }();
+
+ connection.send_packet(shared::contents_to_packet(
+ {{"thread_title", thread_title}, {"message", message}}, "MSG"));
+
+ const shared::contents_t contents = urecv_contents(connection);
+ const auto message_it = contents.find("message");
+ if (message_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void handle_delete_message(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 3) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ const std::string& message_number = args[2];
+
+ try {
+ std::stoi(message_number);
+ } catch (...) {
+ throw std::invalid_argument{"failed to parse \"" + message_number +
+ "\" as an integer"};
+ }
+
+ connection.send_packet(shared::contents_to_packet(
+ {{"thread_title", thread_title}, {"message_number", message_number}},
+ "DLT"));
+
+ const shared::contents_t contents = urecv_contents(connection);
+ const auto message_it = contents.find("message");
+ if (message_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void handle_edit_message(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() < 4) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ const std::string& message_number = args[2];
+ const std::string message = [&]() {
+ std::string message = args[3];
+ for (auto i = 4ul; i < args.size(); ++i) {
+ message += ' ' + args[i];
+ }
+ return message;
+ }();
+
+ try {
+ std::stoi(message_number);
+ } catch (...) {
+ throw std::invalid_argument{"failed to parse \"" + message_number +
+ "\" as an integer"};
+ }
+
+ connection.send_packet(
+ shared::contents_to_packet({{"thread_title", thread_title},
+ {"message_number", message_number},
+ {"message", message}},
+ "EDT"));
+
+ const shared::contents_t contents = urecv_contents(connection);
+ const auto message_it = contents.find("message");
+ if (message_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void handle_list_threads(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 1) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ { connection.send_packet(shared::contents_to_packet({}, "LST")); }
+
+ const shared::contents_t contents = urecv_contents(connection);
+ const auto message_it = contents.find("message");
+ if (message_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void handle_read_thread(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 2) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ { // send request
+ const shared::contents_t contents{
+ {"thread_title", thread_title},
+ };
+ connection.send_packet(shared::contents_to_packet(contents, "RDT"));
+ }
+
+ const shared::contents_t contents = urecv_contents(connection);
+
+ const auto message_it = contents.find("message");
+ const auto success_it = contents.find("success");
+ if (message_it == std::end(contents) || success_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void handle_upload_file(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 3) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ const std::string& filename = args[2];
+ const std::string file_contents = [&]() -> std::string {
+ std::ifstream in{"./" + filename};
+ if (!in.is_open()) {
+ throw std::invalid_argument("file not found");
+ }
+ std::stringstream ss;
+ ss << in.rdbuf();
+ return ss.str();
+ }();
+
+ {
+ const shared::contents_t contents{{"thread_title", thread_title},
+ {"filename", filename}};
+ connection.send_packet(shared::contents_to_packet(contents, "UPD"));
+ }
+
+ {
+ const shared::contents_t contents = urecv_contents(connection);
+
+ const auto message_it = contents.find("message");
+ const auto success_it = contents.find("success");
+ if (message_it == std::end(contents) ||
+ success_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+ std::cout << message_it->second;
+ if (success_it->second != "true") {
+ return;
+ }
+ }
+
+ { // send file
+ const shared::socket_t rsock = get_new_rsock();
+ shared::scoped_function close_rsock{[rsock]() { close(rsock); }};
+ const shared::contents_t contents{{"file_contents", file_contents}};
+ const shared::packet packet =
+ shared::contents_to_packet(contents, "UPD");
+ shared::send_packet(packet, rsock);
+ }
+
+ {
+ const shared::contents_t contents = urecv_contents(connection);
+
+ const auto message_it = contents.find("message");
+ const auto success_it = contents.find("success");
+ if (message_it == std::end(contents) ||
+ success_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+ std::cout << message_it->second;
+ }
+}
+
+static void handle_download_file(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 3) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const std::string& thread_title = args[1];
+ const std::string& filename = args[2];
+ {
+ const shared::contents_t contents{{"thread_title", thread_title},
+ {"filename", filename}};
+ connection.send_packet(shared::contents_to_packet(contents, "DWN"));
+ }
+
+ {
+ const shared::contents_t contents = urecv_contents(connection);
+
+ const auto message_it = contents.find("message");
+ const auto success_it = contents.find("success");
+ if (message_it == std::end(contents) ||
+ success_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+ std::cout << message_it->second;
+ if (success_it->second != "true") {
+ return;
+ }
+ }
+
+ const auto file_contents = []() {
+ const shared::socket_t rsock = get_new_rsock();
+ shared::scoped_function close_rsock{[rsock]() { close(rsock); }};
+ const auto packet = shared::rrecv_packet(rsock);
+ return packet_to_contents(*packet);
+ }();
+
+ const auto data_it = file_contents.find("file_contents");
+ const auto message_it = file_contents.find("message");
+ if (data_it == std::end(file_contents) ||
+ message_it == std::end(file_contents)) {
+ return;
+ }
+
+ std::ofstream out{filename, std::ios_base::trunc};
+ if (!out.is_open()) {
+ throw std::runtime_error{"failed to write file " + filename};
+ }
+ out << data_it->second;
+ std::cout << message_it->second;
+}
+
+static void handle_remove_thread(const args_t& args,
+ shared::connection& connection) {
+ if (args.size() != 2) {
+ throw std::invalid_argument("invalid syntax");
+ }
+
+ const auto& thread_title = args[1];
+ {
+ const shared::contents_t contents{
+ {"thread_title", thread_title},
+ };
+ connection.send_packet(shared::contents_to_packet(contents, "RMV"));
+ }
+
+ const shared::contents_t contents = urecv_contents(connection);
+
+ const auto message_it = contents.find("message");
+ const auto success_it = contents.find("success");
+ if (message_it == std::end(contents) || success_it == std::end(contents)) {
+ std::cout << "\treceived bad response from the server, try again\n";
+ return;
+ }
+
+ std::cout << message_it->second;
+}
+
+static void authenticate(shared::connection& connection) {
+ const auto send_auth_var = [&](const char* const varname) {
+ for (;;) { // send username
+ if (shared::should_exit) {
+ throw shared::should_exit_exception{};
+ }
+
+ const std::string var = read_cin(varname);
+ if (var.length() == 0) {
+ continue;
+ }
+
+ { // request server for status on username
+ const shared::contents_t contents{{varname, var}};
+ connection.send_packet(
+ shared::contents_to_packet(contents, "ATH"));
+ }
+
+ const shared::contents_t contents = urecv_contents(connection);
+ const auto success_it = contents.find("success");
+ const auto message_it = contents.find("message");
+ if (message_it == std::end(contents) ||
+ success_it == std::end(contents)) {
+ std::cout
+ << "\treceived bad response from the server, try again\n";
+ continue;
+ }
+ std::cout << message_it->second;
+
+ if (success_it->second == "true") {
+ break;
+ }
+ }
+ };
+
+ send_auth_var("username");
+ send_auth_var("password");
+}
+
+static void interact(shared::connection& connection) {
+ static const std::unordered_map<std::string,
+ decltype(&handle_create_thread)>
+ commands{{"CRT", handle_create_thread}, {"MSG", handle_post_message},
+ {"DLT", handle_delete_message}, {"EDT", handle_edit_message},
+ {"LST", handle_list_threads}, {"RDT", handle_read_thread},
+ {"UPD", handle_upload_file}, {"DWN", handle_download_file},
+ {"RMV", handle_remove_thread}};
+ const auto print_commands = [&]() {
+ std::cout << "\tavailable commands: ";
+ for (const auto& pair : commands) {
+ const auto& key = pair.first;
+ std::cout << key << ' ';
+ }
+ std::cout << "(or \'XIT\' to quit)\n";
+ };
+
+ print_commands();
+ for (;;) {
+ if (shared::should_exit) {
+ throw shared::should_exit_exception{};
+ }
+ const std::string input = read_cin("command");
+ if (input.length() <= 0) {
+ continue;
+ }
+
+ const std::vector<std::string> args = [&]() { // string split args
+ std::vector<std::string> ret;
+
+ std::stringstream ss{std::move(input)};
+ std::string arg;
+ while (std::getline(ss, arg, ' ')) {
+ if (arg.size() <= 0) {
+ continue;
+ }
+ ret.emplace_back(std::move(arg));
+ }
+
+ return ret;
+ }();
+
+ if (args.size() <= 0) {
+ continue;
+ }
+
+ const std::string command = [&]() {
+ std::string command;
+ std::transform(std::begin(args[0]), std::end(args[0]),
+ std::back_inserter(command),
+ [](const char c) { return toupper(c); });
+ return command;
+ }();
+
+ if (command == "XIT") {
+ break;
+ }
+
+ const auto find_it = commands.find(command);
+ if (find_it == std::end(commands)) { // unknown command, print commands
+ print_commands();
+ continue;
+ }
+
+ try {
+ const auto& func = find_it->second;
+ func(args, connection);
+ } catch (const std::invalid_argument& e) {
+ std::cout << "\tbad arguments: " << e.what() << '\n';
+ } catch (const std::runtime_error& e) {
+ std::cout << "\truntime error: " << e.what() << '\n';
+ }
+ }
+}
+
+void do_client(const char* const address, const char* const port) {
+ const shared::socket_t usock = [&]() -> shared::socket_t {
+ get_address_port(address, port);
+ const auto info = [&]() { // no designated initialisers :(
+ addrinfo a{};
+ a.ai_flags = AI_PASSIVE;
+ a.ai_family = AF_INET;
+ a.ai_socktype = SOCK_DGRAM;
+ return shared::make_addrinfo("0.0.0.0", 0, std::move(a));
+ }();
+ const shared::socket_t socket = shared::make_socket(info);
+ shared::bind_socket(socket, info);
+ return socket;
+ }();
+
+ shared::connection connection = [&]() {
+ sockaddr_in dest{};
+ dest.sin_family = AF_INET;
+ dest.sin_port = htons(static_cast<std::uint16_t>(std::stoi(port)));
+ inet_aton(address, &dest.sin_addr);
+ return shared::connection{usock, std::move(dest)};
+ }();
+
+ try {
+ authenticate(connection);
+ interact(connection);
+ } catch (const shared::should_exit_exception&) {
+ // exit gracefully on should_exit flag
+ }
+
+ // upload exit before leaving
+ connection.send_packet(shared::contents_to_packet({}, "XIT"));
+}
+
+} // namespace client
diff --git a/comp3331/server/src/client/client.hh b/comp3331/server/src/client/client.hh
new file mode 100644
index 0000000..8230e46
--- /dev/null
+++ b/comp3331/server/src/client/client.hh
@@ -0,0 +1,24 @@
+#ifndef CLIENT_CLIENT_HH_
+#define CLIENT_CLIENT_HH_
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <chrono>
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "shared/connection.hh"
+#include "shared/net.hh"
+#include "shared/shared.hh"
+
+namespace client {
+void do_client(const char* const address, const char* const port);
+}
+
+#endif
diff --git a/comp3331/server/src/client/main.cc b/comp3331/server/src/client/main.cc
new file mode 100644
index 0000000..a452c3e
--- /dev/null
+++ b/comp3331/server/src/client/main.cc
@@ -0,0 +1,28 @@
+#include "client/main.hh"
+
+using namespace client;
+
+int main(const int argc, const char* const argv[]) {
+
+ if (argc != 2) {
+ std::cerr << "usage: ./client PORT<int>\n";
+ return EXIT_SUCCESS;
+ }
+
+ const char* const address = "localhost"; // not an argument fsr
+ const char* const port = argv[1];
+
+ try {
+ shared::set_exit_handler();
+ do_client(address, port);
+ } catch (const std::exception& e) {
+ std::cerr << "caught exception from client!\n\twhat(): " << e.what()
+ << '\n';
+ return EXIT_FAILURE;
+ } catch (...) {
+ std::cerr << "unhandled exception from client!\n";
+ return EXIT_FAILURE;
+ }
+
+ return EXIT_SUCCESS;
+}
diff --git a/comp3331/server/src/client/main.hh b/comp3331/server/src/client/main.hh
new file mode 100644
index 0000000..1e04e74
--- /dev/null
+++ b/comp3331/server/src/client/main.hh
@@ -0,0 +1,11 @@
+#ifndef CLIENT_MAIN_HH_
+#define CLIENT_MAIN_HH_
+
+#include <exception>
+#include <iostream>
+#include <string>
+
+#include "client/client.hh"
+#include "shared/shared.hh"
+
+#endif
diff --git a/comp3331/server/src/server/client.cc b/comp3331/server/src/server/client.cc
new file mode 100644
index 0000000..0197f90
--- /dev/null
+++ b/comp3331/server/src/server/client.cc
@@ -0,0 +1 @@
+#include "server/client.hh"
diff --git a/comp3331/server/src/server/client.hh b/comp3331/server/src/server/client.hh
new file mode 100644
index 0000000..8fa57f2
--- /dev/null
+++ b/comp3331/server/src/server/client.hh
@@ -0,0 +1,24 @@
+#ifndef SERVER_CLIENT_HH_
+#define SERVER_CLIENT_HH_
+
+#include <memory>
+#include <string>
+
+#include "shared/connection.hh"
+#include "shared/net.hh"
+
+namespace server {
+struct client {
+ shared::connection connection;
+ std::unique_ptr<std::string> username = nullptr; // non-null = authenticated
+
+public:
+ client(const shared::socket_t& sock, sockaddr_in&& peer)
+ : connection(sock, std::move(peer)) {}
+
+public:
+ bool is_authenticated() const noexcept { return username != nullptr; }
+};
+} // namespace server
+
+#endif
diff --git a/comp3331/server/src/server/main.cc b/comp3331/server/src/server/main.cc
new file mode 100644
index 0000000..764a732
--- /dev/null
+++ b/comp3331/server/src/server/main.cc
@@ -0,0 +1,28 @@
+#include "server/main.hh"
+
+using namespace server;
+
+int main(const int argc, const char* const argv[]) {
+
+ if (argc != 2) {
+ std::cerr << "usage: ./server PORT<int>\n";
+ return EXIT_SUCCESS;
+ }
+
+ const char* const address = "0.0.0.0";
+ const char* const port = argv[1];
+
+ try {
+ shared::set_exit_handler();
+ do_server(address, port);
+ } catch (const std::exception& e) {
+ std::cerr << "caught exception from server!\n\twhat(): " << e.what()
+ << '\n';
+ return EXIT_FAILURE;
+ } catch (...) {
+ std::cerr << "unhandled exception from server!\n";
+ return EXIT_FAILURE;
+ }
+
+ return EXIT_SUCCESS;
+}
diff --git a/comp3331/server/src/server/main.hh b/comp3331/server/src/server/main.hh
new file mode 100644
index 0000000..2714ef8
--- /dev/null
+++ b/comp3331/server/src/server/main.hh
@@ -0,0 +1,9 @@
+#ifndef SERVER_MAIN_HH_
+#define SERVER_MAIN_HH_
+
+#include <iostream>
+
+#include "server/server.hh"
+#include "shared/shared.hh"
+
+#endif
diff --git a/comp3331/server/src/server/server.cc b/comp3331/server/src/server/server.cc
new file mode 100644
index 0000000..a9f2ae9
--- /dev/null
+++ b/comp3331/server/src/server/server.cc
@@ -0,0 +1,829 @@
+#include "server/server.hh"
+
+namespace server {
+
+static shared::socket_t make_socket(const char* const address,
+ const char* const port,
+ const decltype(addrinfo::ai_socktype)& t) {
+ const auto info = [&]() { // no designated initialisers :(
+ addrinfo a{};
+ a.ai_flags = AI_PASSIVE;
+ a.ai_family = AF_INET;
+ a.ai_socktype = t;
+ return shared::make_addrinfo(address, port, std::move(a));
+ }();
+ const shared::socket_t socket = shared::make_socket(info);
+ shared::bind_socket(socket, info);
+ return socket;
+}
+
+// default args only used during initialiser
+static shared::socket_t& get_rsock(const char* const address = nullptr,
+ const char* const port = nullptr) {
+ static shared::socket_t rsock = [&]() {
+ const shared::socket_t rsock = make_socket(address, port, SOCK_STREAM);
+ shared::listen_socket(rsock);
+ return rsock;
+ }();
+ return rsock;
+}
+
+using clients_t = std::vector<std::unique_ptr<client>>;
+static clients_t& get_clients() noexcept {
+ static clients_t clients;
+ return clients;
+}
+
+using users_t = std::unordered_map<std::string, std::string>;
+static users_t& get_known_users() {
+ static users_t users = []() {
+ users_t users;
+
+ std::ifstream in{cred_path};
+ if (!in.is_open()) {
+ throw std::runtime_error{"failed to open credentials"};
+ }
+
+ std::stringstream ss;
+ ss << in.rdbuf();
+
+ for (std::string line; std::getline(ss, line);) {
+ if (line.empty()) {
+ continue;
+ }
+
+ const auto split = static_cast<long>(line.find_first_of(' '));
+ if (static_cast<unsigned long>(split) == std::string::npos) {
+ continue;
+ }
+
+ const std::string username{std::begin(line),
+ std::begin(line) + split};
+ std::string password{std::next(std::begin(line) + split),
+ std::end(line)};
+
+ password.erase(
+ std::remove_if(std::begin(password), std::end(password),
+ [](const char c) { return std::isspace(c); }),
+ std::end(password));
+
+ if (username.length() == 0 || password.length() == 0) {
+ continue; // bad line?
+ }
+ users.emplace(username, password);
+ }
+ return users;
+ }();
+ return users;
+}
+
+struct message {
+ std::string poster;
+ std::string contents; // contents is a filename if is_file is true
+ bool is_file = false;
+};
+struct thread {
+ std::string original_poster;
+ std::vector<message> messages;
+};
+using threads_t = std::unordered_map<std::string, struct thread>;
+static threads_t& get_known_threads() {
+ static threads_t threads = []() {
+ // We have to read all files in the current directory and attempt to
+ // parse them as threads. We will inevitably read cmake and makefiles,
+ // so early out if they are broken. Also, as we are writing this in
+ // C++14, we will have to resort to some C library calls.
+ threads_t threads;
+ std::shared_ptr<DIR> directory{opendir("./"),
+ [](auto& d) { closedir(d); }};
+ if (directory == nullptr) {
+ throw std::runtime_error("failed to open current directory");
+ }
+
+ for (dirent* entry = readdir(directory.get()); entry != nullptr;
+ entry = readdir(directory.get())) {
+ if (entry->d_type != DT_REG) {
+ continue;
+ }
+
+ std::ifstream in{entry->d_name};
+ if (!in.is_open()) {
+ continue; // probs bad permissions
+ }
+
+ std::stringstream ss;
+ ss << in.rdbuf();
+
+ // Read OP whose username is the first line in the file.
+ std::string line;
+ std::getline(ss, line, '\n');
+ if (line.empty() ||
+ std::any_of(std::begin(line), std::end(line),
+ [](const auto& c) { return std::isspace(c); })) {
+ continue; // bad username or bad file
+ }
+
+ thread thread{};
+ thread.original_poster = line;
+
+ const auto maybe_parse_messages = [&]() { // return false if bad
+ while (std::getline(ss, line, '\n')) {
+ if (line.empty()) {
+ continue;
+ }
+
+ // we differentiate between files and messages with this
+ const bool is_file = line[0] == 'F';
+
+ const auto split =
+ static_cast<long>(line.find_first_of(' ', 1));
+ if (static_cast<unsigned long>(split) ==
+ std::string::npos) {
+ return false; // filter out bad files
+ }
+
+ const std::string username{std::next(std::begin(line)),
+ std::begin(line) + split};
+ const std::string contents{
+ std::next(std::begin(line) + split), std::end(line)};
+
+ if (username.length() == 0 || contents.length() == 0) {
+ continue; // bad line?
+ }
+
+ thread.messages.push_back(
+ message{username, contents, is_file});
+ }
+ return true;
+ };
+
+ if (!maybe_parse_messages()) {
+ continue;
+ }
+
+ threads.emplace(entry->d_name, std::move(thread));
+ }
+
+ return threads;
+ }();
+ return threads;
+}
+
+// returns an iterator to the message if it exists, or std::end
+static auto get_message_it(decltype(thread::messages)& messages,
+ const int& message_number) noexcept {
+ return std::find_if(std::begin(messages), std::end(messages),
+ [&, i = 0](const auto& message) mutable {
+ if (message.is_file) {
+ return false;
+ }
+ if (i == message_number) {
+ return true;
+ }
+ ++i;
+ return false;
+ });
+}
+
+static void handle_auth_client(client& client,
+ const shared::contents_t& contents) {
+
+ auto& known_users = get_known_users();
+ const auto& clients = get_clients();
+ const auto response_contents = [&]() -> shared::contents_t {
+ if (client.username == nullptr) {
+ const auto username_it = contents.find("username");
+ if (username_it == std::end(contents)) {
+ return {{"message", "\tno username provided!\n"},
+ {"success", "false"}};
+ }
+ const auto& username = username_it->second;
+
+ if (std::any_of(std::begin(clients), std::end(clients),
+ [&](const auto& client) {
+ if (client->username == nullptr) {
+ return false;
+ }
+ return *client->username == username;
+ })) {
+ std::cout << "\tmultiple login attempt for: " << username
+ << '\n';
+ return {{"message", "\tuser already logged in\n"},
+ {"success", "false"}};
+ }
+
+ client.username =
+ std::make_unique<std::string>(username_it->second);
+ const auto find_it = known_users.find(username);
+ if (find_it == std::end(known_users)) {
+ return {
+ {"message", "\twelcome new user \"" + username + "\"\n"},
+ {"success", "true"}};
+ }
+ return {{"message", "\twelcome back user \"" + username + "\"\n"},
+ {"success", "true"}};
+ }
+
+ const auto password_it = contents.find("password");
+ if (password_it == std::end(contents)) {
+ return {{"message", "\tno password provided!\n"},
+ {"success", "false"}};
+ }
+ const auto& password = password_it->second;
+
+ const auto find_it = known_users.find(*client.username);
+ if (find_it == std::end(known_users)) {
+ known_users.emplace(*client.username, password);
+ return {{"message", "\taccount created successfully\n"},
+ {"success", "true"}};
+ }
+
+ const auto& prev_password = find_it->second;
+ if (prev_password != password) {
+ std::cout << "\tunsuccessful login for user: " << *client.username
+ << '\n';
+ return {{"message", "\tusername password mismatch!\n"},
+ {"success", "false"}};
+ }
+
+ std::cout << "\tsuccessful login for: " << *client.username << '\n';
+ return {{"message",
+ "\tsuccessful login for user \"" + *client.username + "\"\n"},
+ {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "ATH"));
+}
+
+static void handle_create_thread(client& client,
+ const shared::contents_t& contents) {
+ const auto thread_title_it = contents.find("thread_title");
+ if (thread_title_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+ const std::string& thread_title = thread_title_it->second;
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ auto& threads = get_known_threads();
+ if (threads.find(thread_title) != std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" already exists!\n"},
+ {"success", "false"}};
+ }
+
+ thread thread{};
+ thread.original_poster = *client.username;
+ threads.emplace(thread_title, std::move(thread));
+
+ return {{"message", "\tthread \"" + thread_title + "\" created\n"},
+ {"success", "true"}};
+ }();
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "CRT"));
+}
+
+static void handle_post_message(client& client,
+ const shared::contents_t& contents) {
+ const auto message_it = contents.find("message");
+ const auto thread_title_it = contents.find("thread_title");
+
+ if (message_it == std::end(contents) ||
+ thread_title_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+
+ const auto& message = message_it->second;
+ const auto& thread_title = thread_title_it->second;
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ auto& threads = get_known_threads();
+
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+
+ auto& thread = thread_it->second;
+
+ struct message new_message {};
+ new_message.poster = *client.username;
+ new_message.contents = message;
+ new_message.is_file = false;
+ thread.messages.emplace_back(std::move(new_message));
+ return {{"message", "\tmessage on thread " + thread_title +
+ " created successfully\n"},
+ {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "MSG"));
+}
+
+static void handle_delete_message(client& client,
+ const shared::contents_t& contents) {
+ const auto message_number_it = contents.find("message_number");
+ const auto thread_title_it = contents.find("thread_title");
+
+ if (message_number_it == std::end(contents) ||
+ thread_title_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+
+ const auto message_number = std::stoi(message_number_it->second);
+ const auto& thread_title = thread_title_it->second;
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ auto& threads = get_known_threads();
+
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+ auto& thread = thread_it->second;
+
+ auto& messages = thread.messages;
+ const auto message_it = get_message_it(messages, message_number);
+ if (message_it == std::end(messages)) {
+ return {
+ {"message", "\ta message with that number doesn't exist!\n"},
+ {"success", "false"}};
+ }
+
+ if (message_it->poster != *client.username) {
+ return {{"message", "\tyou can't delete someone elses post!\n"},
+ {"success", "false"}};
+ }
+
+ messages.erase(message_it);
+ return {{"message", "\tmessage removed successfully\n"},
+ {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "RDT"));
+}
+
+static void handle_edit_message(client& client,
+ const shared::contents_t& contents) {
+ const auto message_number_it = contents.find("message_number");
+ const auto thread_title_it = contents.find("thread_title");
+ const auto message_it = contents.find("message");
+
+ if (message_number_it == std::end(contents) ||
+ thread_title_it == std::end(contents) ||
+ message_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+
+ const auto& thread_title = thread_title_it->second;
+ const auto& message = message_it->second;
+ const auto& message_number = std::stoi(message_number_it->second);
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ auto& threads = get_known_threads();
+
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+ auto& thread = thread_it->second;
+
+ auto& messages = thread.messages;
+ const auto message_it = get_message_it(messages, message_number);
+ if (message_it == std::end(messages)) {
+ return {
+ {"message", "\ta message with that number doesn't exist!\n"},
+ {"success", "false"}};
+ }
+
+ if (message_it->poster != *client.username) {
+ return {{"message", "\tyou can't edit someone elses post!\n"},
+ {"success", "false"}};
+ }
+
+ message_it->contents = message;
+ return {{"message", "\tmessage edited successfully\n"},
+ {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "EDT"));
+}
+
+static void handle_list_threads(client& client, const shared::contents_t&) {
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ const auto& threads = get_known_threads();
+ if (threads.size() == 0) {
+ return {{"message", "\tthere are no threads to list!\n"},
+ {"success", "true"}};
+ }
+
+ std::string resp;
+ for (const auto& keyvalue : threads) {
+ const auto& threadname = keyvalue.first;
+ resp += '\t' + threadname + '\n';
+ }
+ return {{"message", resp}, {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "RDT"));
+}
+
+static void handle_read_thread(client& client,
+ const shared::contents_t& contents) {
+ const auto thread_title_it = contents.find("thread_title");
+ if (thread_title_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+ const std::string& thread_title = thread_title_it->second;
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ auto& threads = get_known_threads();
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+
+ const auto& messages = thread_it->second.messages;
+ if (messages.size() == 0) {
+ return {{"message", "\tthere are no messages to list!\n"},
+ {"success", "false"}};
+ }
+
+ int message_num = 0;
+ std::string resp{};
+ for (const auto& message : messages) {
+ resp += '\t';
+ if (message.is_file) {
+ resp += message.poster + " uploaded " + message.contents + '\n';
+ continue;
+ }
+
+ resp += std::to_string(message_num) + ' ' + message.poster + ": " +
+ message.contents + '\n';
+ ++message_num;
+ }
+
+ return {{"message", resp}, {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "RDT"));
+}
+
+static void handle_upload_file(client& client,
+ const shared::contents_t& contents) {
+ const auto thread_title_it = contents.find("thread_title");
+ const auto filename_it = contents.find("filename");
+ if (thread_title_it == std::end(contents) ||
+ filename_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+
+ const auto& thread_title = thread_title_it->second;
+ const auto& filename = filename_it->second;
+ const auto response_contents = [&]() -> shared::contents_t {
+ const auto& threads = get_known_threads();
+
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+ const auto& thread = thread_it->second;
+
+ const auto& messages = thread.messages;
+ const auto old_file_it = std::find_if(
+ std::begin(messages), std::end(messages), [&](const auto& message) {
+ if (!message.is_file) {
+ return false;
+ }
+ return message.contents == filename;
+ });
+ if (old_file_it != std::end(messages)) {
+ return {{"message", "\ta file with \"" + filename +
+ "\" already exists in the thread!\n"},
+ {"success", "false"}};
+ }
+
+ return {{"message", "\tbeginning file transfer\n"},
+ {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "UPD"));
+
+ if (response_contents.find("success")->second != "true") {
+ return;
+ }
+
+ const auto sock = shared::accept_socket(get_rsock());
+ const auto packet = shared::rrecv_packet(sock);
+ const shared::contents_t file_contents = packet_to_contents(*packet);
+ close(sock);
+
+ const auto file_contents_it = file_contents.find("file_contents");
+ if (file_contents_it == std::end(file_contents)) {
+ throw std::runtime_error("malformed request");
+ }
+
+ {
+ const auto& data = file_contents_it->second;
+ std::ofstream out{thread_title + "-" + filename, std::ios_base::trunc};
+ if (!out.is_open()) {
+ throw std::runtime_error{"failed to open file for writing"};
+ }
+ out << data;
+ }
+
+ // append to thread
+ {
+ message message{};
+ message.poster = *client.username;
+ message.is_file = true;
+ message.contents = filename;
+ get_known_threads()
+ .find(thread_title)
+ ->second.messages.emplace_back(std::move(message));
+ }
+
+ {
+ const shared::contents_t final_response{
+ {"message", "\tfile \"" + filename + "\" uploaded successfully!\n"},
+ {"success", "true"}};
+ client.connection.send_packet(
+ shared::contents_to_packet(final_response, "UPD"));
+ }
+}
+
+static void handle_download_file(client& client,
+ const shared::contents_t& contents) {
+ const auto thread_title_it = contents.find("thread_title");
+ const auto filename_it = contents.find("filename");
+ if (thread_title_it == std::end(contents) ||
+ filename_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+
+ const auto& thread_title = thread_title_it->second;
+ const auto& filename = filename_it->second;
+ const auto response_contents = [&]() -> shared::contents_t {
+ const auto& threads = get_known_threads();
+
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+ const auto& thread = thread_it->second;
+
+ const auto& messages = thread.messages;
+ const auto old_file_it = std::find_if(
+ std::begin(messages), std::end(messages), [&](const auto& message) {
+ if (!message.is_file) {
+ return false;
+ }
+ return message.contents == filename;
+ });
+ if (old_file_it == std::end(messages)) {
+ return {{"message", "\tno file named \"" + filename +
+ "\" exists in the thread!\n"},
+ {"success", "false"}};
+ }
+
+ if (!std::ifstream{"./" + thread_title + "-" + filename}.is_open()) {
+ throw std::runtime_error("file \"" + filename + "\" not found");
+ }
+
+ return {{"message", "\tbeginning file transfer\n"},
+ {"success", "true"}};
+ }();
+
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "DWN"));
+
+ if (response_contents.find("success")->second != "true") {
+ return;
+ }
+
+ const auto sock = shared::accept_socket(get_rsock());
+ const std::string file_contents = [&]() -> std::string {
+ std::ifstream in{"./" + thread_title + "-" + filename};
+ if (!in.is_open()) {
+ throw std::runtime_error("file \"" + filename + "\" not found");
+ }
+ std::stringstream ss;
+ ss << in.rdbuf();
+ return ss.str();
+ }();
+ const shared::contents_t final_contents = {
+ {"file_contents", file_contents},
+ {"message", "\tfile downloaded successfully\n"},
+ {"success", "true"}};
+ shared::send_packet(shared::contents_to_packet(final_contents, "DWN"),
+ sock);
+ close(sock);
+}
+
+static void handle_remove_thread(client& client,
+ const shared::contents_t& contents) {
+ const auto thread_title_it = contents.find("thread_title");
+ if (thread_title_it == std::end(contents)) {
+ throw std::runtime_error("malformed request");
+ }
+ const std::string& thread_title = thread_title_it->second;
+
+ const auto response_contents = [&]() -> shared::contents_t {
+ auto& threads = get_known_threads();
+
+ const auto thread_it = threads.find(thread_title);
+ if (thread_it == std::end(threads)) {
+ return {{"message", "\ta thread with title \"" + thread_title +
+ "\" doesn't exist!\n"},
+ {"success", "false"}};
+ }
+
+ auto& thread = thread_it->second;
+ if (thread.original_poster != *client.username) {
+ return {{"message", "\tyou can't delete a thread you don't own!\n"},
+ {"success", "false"}};
+ }
+
+ for (const auto& message : thread.messages) {
+ if (!message.is_file) {
+ continue;
+ }
+ const std::string name = thread_title + '-' + message.contents;
+ remove(name.c_str());
+ }
+
+ threads.erase(thread_it);
+ remove(thread_title.c_str()); // remove file
+ return {{"message",
+ "\tthread \"" + thread_title + "\" removed successfully\n"},
+ {"success", "true"}};
+ }();
+ client.connection.send_packet(
+ shared::contents_to_packet(response_contents, "RDT"));
+}
+
+static void handle_exit(client& client, const shared::contents_t&) {
+ auto& clients = get_clients();
+ const auto info_ptr = &client.connection.get_info();
+ const auto find_it = std::find_if(
+ std::begin(clients), std::end(clients), [&](const auto& c) {
+ return std::memcmp(info_ptr, &c->connection.get_info(),
+ sizeof(decltype(*info_ptr))) == 0;
+ });
+ if (find_it == std::end(clients)) {
+ return; // already removed or not associated
+ }
+ clients.erase(find_it);
+}
+
+static void handle_client_packet(client& client, const shared::packet& packet) {
+ const std::string command = [&packet]() {
+ std::string command;
+ std::copy(std::begin(packet.header.command),
+ std::end(packet.header.command), std::back_inserter(command));
+ return command;
+ }();
+
+ static const std::unordered_map<std::string,
+ decltype(&handle_create_thread)>
+ commands{{"CRT", handle_create_thread},
+ {"MSG", handle_post_message},
+ {"DLT", handle_delete_message},
+ {"EDT", handle_edit_message},
+ {"LST", handle_list_threads},
+ {"RDT", handle_read_thread},
+ {"UPD", handle_upload_file},
+ {"DWN", handle_download_file},
+ {"RMV", handle_remove_thread},
+ {"ATH", handle_auth_client},
+ {"XIT", handle_exit}};
+
+ const auto find_it = commands.find(command);
+ if (find_it == std::end(commands)) {
+ std::cout << "got unknown command from client: " << command << '\n';
+ return;
+ }
+
+ std::cout << "got " << command << " from client: "
+ << (client.username != nullptr ? *client.username
+ : "*not logged in*")
+ << '\n';
+
+ const auto& func = find_it->second;
+ const shared::contents_t contents = shared::packet_to_contents(packet);
+
+ try {
+ func(client, contents);
+ } catch (const std::runtime_error& e) {
+ const std::string message =
+ "internal server error: " + std::string{e.what()};
+ const shared::contents_t error_contents{{"message", message},
+ {"success", "error"}};
+ client.connection.send_packet(
+ shared::contents_to_packet(error_contents, command.c_str()));
+ std::cout << message << '\n';
+ }
+}
+
+static void save() {
+ { // write creds
+ std::ofstream out{cred_path, std::ios_base::trunc};
+ if (!out.is_open()) {
+ throw std::runtime_error{"failed to write to credentials"};
+ }
+ for (const auto& keypair : get_known_users()) {
+ out << keypair.first + ' ' + keypair.second << '\n';
+ }
+ }
+ { // write threads
+ for (const auto& keypair : get_known_threads()) {
+ const auto& thread_name = keypair.first;
+ const auto& thread = keypair.second;
+ std::ofstream out{"./" + thread_name, std::ios_base::trunc};
+
+ if (!out.is_open()) {
+ throw std::runtime_error{"failed to write thread " +
+ thread_name + '\n'};
+ }
+
+ out << thread.original_poster << '\n';
+ for (const auto& message : thread.messages) {
+ out << (message.is_file ? 'F' : 'M');
+ out << message.poster << ' ' << message.contents << '\n';
+ }
+ }
+ }
+ get_clients().clear();
+}
+
+void do_server(const char* const address, const char* const port) {
+ const shared::socket_t usock = make_socket(address, port, SOCK_DGRAM);
+
+ { // init static vars before server starts to catch errors early
+ get_known_users();
+ get_known_threads();
+ get_rsock(address, port);
+ }
+
+ try {
+ auto& clients = get_clients();
+ while (!shared::should_exit) {
+ auto packet = shared::maybe_urecv_packet(usock);
+ if (packet == nullptr) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ continue;
+ }
+
+ auto& origin = packet->origin;
+ // Compare the sockaddr_storage of the packet to existing clients.
+ const auto find_it = std::find_if(
+ std::begin(clients), std::end(clients),
+ [&](const auto& client) {
+ return std::memcmp(&client->connection.get_info(), &origin,
+ sizeof(decltype(origin))) == 0;
+ });
+
+ const auto& contents = packet->packet;
+
+ client& client = [&]() -> struct client& {
+ if (find_it == std::end(clients)) {
+ clients.push_back(std::make_unique<struct client>(
+ usock, std::move(origin)));
+ return *clients.back();
+ }
+ return *find_it->get();
+ }
+ ();
+
+ if (client.connection.should_discard_packet(contents)) {
+ continue;
+ }
+
+ handle_client_packet(client, contents);
+ }
+ } catch (const shared::should_exit_exception&) {
+ // gracefully exit on should_exit flag
+ }
+
+ save(); // save updated contents before exiting
+}
+
+} // namespace server
diff --git a/comp3331/server/src/server/server.hh b/comp3331/server/src/server/server.hh
new file mode 100644
index 0000000..e2554de
--- /dev/null
+++ b/comp3331/server/src/server/server.hh
@@ -0,0 +1,28 @@
+#ifndef SERVER_SERVER_HH_
+#define SERVER_SERVER_HH_
+
+#include <algorithm>
+#include <chrono>
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <iterator>
+#include <sstream>
+#include <thread>
+#include <unordered_map>
+#include <vector>
+#include <dirent.h>
+
+#include "server/client.hh"
+#include "shared/connection.hh"
+#include "shared/net.hh"
+#include "shared/shared.hh"
+
+namespace server {
+
+const char* const cred_path = "./credentials.txt";
+
+void do_server(const char* const address, const char* const port);
+} // namespace server
+
+#endif
diff --git a/comp3331/server/src/shared/connection.cc b/comp3331/server/src/shared/connection.cc
new file mode 100644
index 0000000..1063823
--- /dev/null
+++ b/comp3331/server/src/shared/connection.cc
@@ -0,0 +1,54 @@
+#include "shared/connection.hh"
+
+namespace shared {
+
+void connection::send_packet(packet&& packet) noexcept {
+ packet.header.sequence = this->seq_num;
+ shared::send_packet(packet, this->sock, this->info);
+ ++this->seq_num;
+
+ std::lock_guard<std::mutex> guard{*this->lock};
+ this->sent.push_back(packet);
+}
+
+bool connection::should_discard_packet(const packet& packet) noexcept {
+ std::lock_guard<std::mutex> guard{*this->lock};
+
+ // ack case
+ if (packet.header.command[0] == '\0') {
+ this->sent.erase(std::remove_if(std::begin(this->sent),
+ std::end(this->sent),
+ [&](const auto& p) {
+ return p.header.sequence <=
+ packet.header.sequence;
+ }),
+ std::end(this->sent));
+ return true;
+ }
+
+ // Send an ack for the packet if it's not an ack itself
+ auto ack_pkt = shared::contents_to_packet({}, "\0\0\0");
+ ack_pkt.header.sequence = packet.header.sequence;
+ shared::send_packet(ack_pkt, this->sock, this->info);
+
+ if (packet.header.sequence != this->ack_num) {
+ return true;
+ }
+ ++this->ack_num;
+ return false;
+}
+
+// Reliable transport reads our packets on a different thread and resends if
+// necessary.
+void connection::do_reliable_transport() noexcept {
+ while (!*this->should_thread_exit) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+ std::lock_guard<std::mutex> guard{*this->lock};
+ for (const auto& packet : this->sent) {
+ shared::send_packet(packet, this->sock, this->info);
+ }
+ }
+}
+
+} // namespace shared
diff --git a/comp3331/server/src/shared/connection.hh b/comp3331/server/src/shared/connection.hh
new file mode 100644
index 0000000..d7237cb
--- /dev/null
+++ b/comp3331/server/src/shared/connection.hh
@@ -0,0 +1,60 @@
+#ifndef SHARED_CONNECTION_HH_
+#define SHARED_CONNECTION_HH_
+
+#include <atomic>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+#include "shared/net.hh"
+
+namespace shared {
+
+// The connection class abstracts sending and receiving data, including reliable
+// transmission over UDP.
+class connection {
+private:
+ shared::socket_t sock;
+ sockaddr_in info;
+
+private:
+ std::uint32_t seq_num = 0; // track packet sequence number
+ std::uint32_t ack_num = 0;
+
+ // for reliable transport, spawn a new thread which reads sent/received
+ std::unique_ptr<std::atomic<bool>> should_thread_exit;
+ std::unique_ptr<std::mutex> lock;
+ std::vector<packet> sent;
+ std::vector<packet> received;
+ std::shared_ptr<std::thread> reliable_transport_thread;
+ void do_reliable_transport() noexcept;
+
+public:
+ connection(const socket_t& sock, sockaddr_in&& info)
+ : sock(sock), info(std::move(info)),
+ should_thread_exit(std::make_unique<std::atomic<bool>>(false)),
+ lock(std::make_unique<std::mutex>()),
+ reliable_transport_thread(std::make_shared<std::thread>(
+ &connection::do_reliable_transport, this)) {}
+
+ connection(const connection&) = delete;
+ connection(connection&&) = default;
+ ~connection() noexcept {
+ *this->should_thread_exit = true;
+ this->reliable_transport_thread->join();
+ }
+
+public:
+ const sockaddr_in& get_info() const noexcept { return this->info; }
+ const socket_t& get_socket() const noexcept { return this->sock; }
+
+public:
+ // All unreliable packets should go through these functions so we may track
+ // if our packets have been sent or received, making them reliable.
+ void send_packet(packet&& packet) noexcept;
+ bool should_discard_packet(const packet& packet) noexcept;
+};
+
+} // namespace shared
+
+#endif
diff --git a/comp3331/server/src/shared/net.cc b/comp3331/server/src/shared/net.cc
new file mode 100644
index 0000000..b71170a
--- /dev/null
+++ b/comp3331/server/src/shared/net.cc
@@ -0,0 +1,284 @@
+#include "shared/net.hh"
+
+namespace shared {
+
+static std::string get_errno_str() noexcept { return strerror(errno); }
+
+addrinfo_t make_addrinfo(const char* const address, const char* const port,
+ const addrinfo&& hints) {
+ addrinfo* info = nullptr;
+ if (int status = getaddrinfo(address, port, &hints, &info)) {
+ throw std::runtime_error{gai_strerror(status)};
+ }
+ return addrinfo_t{info, [](const auto& p) { freeaddrinfo(p); }};
+}
+
+socket_t make_socket(const addrinfo_t& info) {
+ const int sock =
+ socket(info->ai_family, info->ai_socktype, info->ai_protocol);
+ if (sock == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+
+ // to avoid annoying binding timeouts
+ const int enable = 1;
+ setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
+ setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
+ return sock;
+}
+
+socket_t accept_socket(const socket_t& lsock) {
+ sockaddr_in in{};
+ socklen_t size = sizeof(in);
+
+ const shared::socket_t sock =
+ accept(lsock, reinterpret_cast<sockaddr*>(&in), &size);
+ if (sock == -1) {
+ throw std::runtime_error(get_errno_str());
+ }
+ return sock;
+}
+
+void bind_socket(const socket_t sock, const addrinfo_t& info) {
+ // bind to first we can
+ for (const addrinfo* i = info.get(); i != nullptr; i = i->ai_next) {
+ if (bind(sock, info->ai_addr, info->ai_addrlen) == -1) {
+ continue;
+ }
+ return;
+ }
+ throw std::runtime_error{get_errno_str()};
+}
+
+void connect_socket(const socket_t sock, const addrinfo_t& info) {
+ if (connect(sock, info->ai_addr, info->ai_addrlen) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void listen_socket(const socket_t sock) {
+ if (listen(sock, SOMAXCONN) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void close_socket(const socket_t sock) {
+ if (close(sock) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+packet contents_to_packet(const contents_t& contents,
+ const char* const command) {
+ packet ret{};
+ std::memset(&ret.header, 0, sizeof(ret.header)); // 0init padding bytes
+
+ std::uint32_t size = 0;
+ for (const auto& keyvalue : contents) {
+ const auto& key = keyvalue.first; // no structured bindings :(
+ const auto& value = keyvalue.second;
+
+ const auto contents_size = static_cast<std::uint32_t>(value.size());
+
+ decltype(ret.contents) addition;
+ // Copy the size, name and data as expected, then push it.
+ std::copy(reinterpret_cast<const char* const>(&contents_size),
+ reinterpret_cast<const char* const>(&contents_size) +
+ sizeof(contents_size),
+ std::back_inserter(addition));
+ std::copy(key.c_str(), key.c_str() + key.size() + 1,
+ std::back_inserter(addition));
+ std::copy(value.c_str(), value.c_str() + value.size(),
+ std::back_inserter(addition));
+
+ size += addition.size(); // update size before moving
+ std::copy(std::make_move_iterator(std::begin(addition)),
+ std::make_move_iterator(std::end(addition)),
+ std::back_inserter(ret.contents));
+ }
+
+ // fill in expected header values
+ ret.header.size = size + sizeof(packet::header);
+ std::memcpy(ret.header.command, command, 3);
+
+ return ret;
+}
+
+contents_t packet_to_contents(const packet& packet) {
+ contents_t ret{};
+
+ // Extract the data as described in the header file.
+ for (auto i = 0u; i < packet.contents.size();) {
+ const char* const data = packet.contents.data() + i;
+
+ const std::uint32_t contents_size = [&]() {
+ std::uint32_t contents_size;
+ std::memcpy(&contents_size,
+ reinterpret_cast<const std::uint32_t* const>(data),
+ sizeof(std::uint32_t));
+ return contents_size;
+ }();
+
+ const auto size_size = sizeof(std::uint32_t);
+ const std::string name = data + size_size;
+ const char* const contents = data + size_size + name.length() + 1;
+
+ ret.emplace(name, std::string{contents, contents + contents_size});
+ i += size_size + name.size() + 1 + contents_size;
+ }
+ return ret;
+}
+
+static std::vector<char> packet_to_data(const packet& packet) noexcept {
+ std::vector<char> data{};
+ // data.reserve(sizeof(struct packet) + packet.contents.size());
+ std::copy(reinterpret_cast<const char* const>(&packet.header),
+ reinterpret_cast<const char* const>(&packet.header) +
+ sizeof(packet::header),
+ std::back_inserter(data));
+ std::copy(std::begin(packet.contents), std::end(packet.contents),
+ std::back_inserter(data));
+ return data;
+}
+
+void send_packet(const packet& packet, const socket_t& sock,
+ const sockaddr_in& dest) {
+ const std::vector<char> data = packet_to_data(packet);
+
+ if (sendto(sock, data.data(), data.size(), 0, (sockaddr*)&dest,
+ sizeof(sockaddr_in)) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+void send_packet(const packet& packet, const socket_t& sock) {
+ const std::vector<char> data = packet_to_data(packet);
+
+ const auto total_size = data.size();
+ for (unsigned sent = 0; sent < data.size();) {
+ const ssize_t res =
+ send(sock, data.data() + sent, total_size - sent, 0);
+ if (res != -1) {
+ sent += res;
+ continue;
+ }
+
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ continue;
+ }
+
+ throw std::runtime_error{get_errno_str()};
+ }
+}
+
+static std::size_t get_backlog_size(const socket_t& socket) noexcept {
+ size_t size = 0;
+ if (ioctl(socket, FIONREAD, &size) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+ return size;
+}
+
+std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock) {
+ const auto packet_size = get_backlog_size(sock);
+ if (packet_size <= 0) {
+ return nullptr;
+ }
+
+ recv_packet_ret ret;
+ unsigned int origin_len = sizeof(ret.origin);
+
+ std::vector<char> buffer;
+ buffer.reserve(packet_size);
+ if (recvfrom(sock, buffer.data(), packet_size, 0, (sockaddr*)&ret.origin,
+ &origin_len) == -1) {
+ throw std::runtime_error(get_errno_str());
+ }
+
+ std::memcpy(&ret.packet.header, buffer.data(), sizeof(packet::header));
+ ret.packet.contents.reserve(ret.packet.header.size -
+ sizeof(packet::header));
+ std::copy(buffer.data() + sizeof(packet::header),
+ buffer.data() + ret.packet.header.size,
+ std::back_inserter(ret.packet.contents));
+
+ return std::make_shared<recv_packet_ret>(std::move(ret));
+}
+
+// true when finished reading as our stream packets may be very large
+static bool maybe_rrecv_packet(const socket_t& sock, packet& packet) {
+
+ auto backlog_size = get_backlog_size(sock);
+ if (backlog_size <= 0) {
+ return false;
+ }
+
+ auto& target_size = packet.header.size;
+ if (!target_size) { // header read required
+ if (backlog_size < sizeof(header)) { // no header, try again later
+ return false;
+ }
+
+ if (read(sock, &packet.header, sizeof(header)) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+ backlog_size -= sizeof(header);
+ }
+
+ const auto read_size =
+ std::min(backlog_size, static_cast<unsigned long>(target_size));
+ std::vector<char> buffer;
+ buffer.reserve(read_size);
+ if (read(sock, buffer.data(), read_size) == -1) {
+ throw std::runtime_error{get_errno_str()};
+ }
+
+ std::copy(buffer.data(), buffer.data() + read_size,
+ std::back_inserter(packet.contents));
+ if (packet.contents.size() < packet.header.size - sizeof(header)) {
+ return false; // more data to read, do again
+ }
+
+ return true;
+}
+
+static const auto ms_timeout = std::chrono::seconds(30);
+std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock,
+ const bool& timeout) {
+ // poll our non-blocking sockets
+ const auto start = std::chrono::steady_clock::now();
+ while (!shared::should_exit) {
+ const auto pkt = maybe_urecv_packet(rsock);
+ if (pkt != nullptr) {
+ return pkt;
+ }
+ if (timeout && std::chrono::steady_clock::now() > start + ms_timeout) {
+ throw std::runtime_error("urecv timeout elapsed");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ throw should_exit_exception();
+}
+
+std::shared_ptr<packet> rrecv_packet(const socket_t& rsock,
+ const bool& timeout) {
+ packet packet{};
+ auto last = std::chrono::steady_clock::now();
+ while (!shared::should_exit) {
+ const auto prev_size = packet.contents.size();
+ if (maybe_rrecv_packet(rsock, packet)) {
+ return std::make_shared<struct packet>(std::move(packet));
+ }
+ if (prev_size != packet.contents.size()) {
+ last = std::chrono::steady_clock::now();
+ }
+ if (timeout && std::chrono::steady_clock::now() > last + ms_timeout) {
+ throw std::runtime_error("rrecv timeout elapsed");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ throw should_exit_exception();
+}
+
+} // namespace shared
diff --git a/comp3331/server/src/shared/net.hh b/comp3331/server/src/shared/net.hh
new file mode 100644
index 0000000..b17ca1f
--- /dev/null
+++ b/comp3331/server/src/shared/net.hh
@@ -0,0 +1,87 @@
+#ifndef SHARED_NET_HH_
+#define SHARED_NET_HH_
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <chrono>
+#include <cstdint>
+#include <cstring>
+#include <errno.h>
+#include <fcntl.h>
+#include <memory>
+#include <netdb.h>
+#include <optional>
+#include <stdexcept>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <thread>
+#include <type_traits>
+#include <unistd.h>
+#include <unordered_map>
+#include <vector>
+
+#include "shared/shared.hh"
+
+// Functions in this namespace are common network-related wrappers with error
+// checking.
+namespace shared {
+
+using addrinfo_t = std::shared_ptr<addrinfo>; // for automatic freeaddrinfo call
+addrinfo_t make_addrinfo(const char* const address, const char* const port,
+ const addrinfo&& hints);
+
+using socket_t = int;
+int make_socket(const addrinfo_t& info);
+int accept_socket(const socket_t& sock);
+void bind_socket(const socket_t sock, const addrinfo_t& info);
+void connect_socket(const socket_t sock, const addrinfo_t& info);
+void listen_socket(const socket_t sock);
+void close_socket(const socket_t sock);
+
+struct header {
+ std::uint32_t size = 0; // size of packet, including header.
+ std::uint32_t sequence; // sequence number of packet
+ char command[3]; // command type, if \0\0\0 then it's an ack
+};
+static_assert(std::is_trivially_copyable<header>::value,
+ "header must be memcpy-able");
+struct packet {
+ header header;
+ std::vector<char> contents;
+};
+
+// Our packets contents consist of data entries like so:
+// std::uint32_t | char[] | DATA
+// ^ size of data ^ name ^ data, which repeats for size length
+// Any packet may contain multiple, or zero, entries. The name can be any
+// length. This way we don't have to define structs with hardcoded length
+// limits, and we can use this format when sending anything, even files.
+// Numeric values will be encoded as strings and converted when necessary.
+using contents_t = std::unordered_map<std::string, std::string>;
+void send_packet(const packet& packet, const socket_t& sock,
+ const sockaddr_in& dest);
+void send_packet(const packet& packet, const socket_t& sock);
+
+packet contents_to_packet(const contents_t& contents,
+ const char* const command);
+contents_t packet_to_contents(const packet& packet);
+
+// Recv's sockets, might return nullptr if no packet available.
+struct recv_packet_ret {
+ sockaddr_in origin;
+ struct packet packet;
+};
+
+// non-blocking
+std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock);
+// blocking, will throw if timeout is elapsed
+std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock,
+ const bool& timeout = true);
+std::shared_ptr<packet> rrecv_packet(const socket_t& rsock,
+ const bool& timeout = true);
+
+} // namespace shared
+
+#endif
diff --git a/comp3331/server/src/shared/shared.cc b/comp3331/server/src/shared/shared.cc
new file mode 100644
index 0000000..3a13b47
--- /dev/null
+++ b/comp3331/server/src/shared/shared.cc
@@ -0,0 +1,25 @@
+#include "shared/shared.hh"
+
+namespace shared {
+
+bool should_exit = false;
+
+static void set_signal(const decltype(SIGINT) signal,
+ void (*const callback)(const int)) {
+ struct sigaction sa {};
+ sa.sa_handler = callback;
+ if (sigaction(signal, &sa, nullptr) == -1) {
+ throw std::runtime_error("failed to set signal handler!");
+ }
+}
+
+void set_exit_handler() {
+ set_signal(SIGPIPE, SIG_IGN);
+ set_signal(SIGINT, [](const int) {
+ std::cout << " interrupt signal received\n";
+ should_exit = true;
+ });
+}
+
+} // namespace shared
+
diff --git a/comp3331/server/src/shared/shared.hh b/comp3331/server/src/shared/shared.hh
new file mode 100644
index 0000000..8e4fa36
--- /dev/null
+++ b/comp3331/server/src/shared/shared.hh
@@ -0,0 +1,28 @@
+#ifndef SHARED_SHARED_HH_
+#define SHARED_SHARED_HH_
+
+#include <functional>
+#include <iostream>
+#include <signal.h>
+#include <stdexcept>
+
+namespace shared {
+extern bool should_exit;
+
+void set_exit_handler();
+
+class should_exit_exception : public std::exception {};
+
+// This won't exist until c++24 lol
+class scoped_function {
+private:
+ using func_t = std::function<void()>;
+ func_t func;
+
+public:
+ scoped_function(const func_t& f) : func(f) {}
+ ~scoped_function() { this->func(); }
+};
+} // namespace shared
+
+#endif