diff options
| author | Nicolas James <Eele1Ephe7uZahRie@tutanota.com> | 2025-02-13 18:00:17 +1100 |
|---|---|---|
| committer | Nicolas James <Eele1Ephe7uZahRie@tutanota.com> | 2025-02-13 18:00:17 +1100 |
| commit | 98cef5e9a772602d42acfcf233838c760424db9a (patch) | |
| tree | 5277fa1d7cc0a69a0f166fcbf10fd320f345f049 /comp3331/server | |
initial commit
Diffstat (limited to 'comp3331/server')
| -rw-r--r-- | comp3331/server/CMakeLists.txt | 44 | ||||
| -rw-r--r-- | comp3331/server/src/client/client.cc | 519 | ||||
| -rw-r--r-- | comp3331/server/src/client/client.hh | 24 | ||||
| -rw-r--r-- | comp3331/server/src/client/main.cc | 28 | ||||
| -rw-r--r-- | comp3331/server/src/client/main.hh | 11 | ||||
| -rw-r--r-- | comp3331/server/src/server/client.cc | 1 | ||||
| -rw-r--r-- | comp3331/server/src/server/client.hh | 24 | ||||
| -rw-r--r-- | comp3331/server/src/server/main.cc | 28 | ||||
| -rw-r--r-- | comp3331/server/src/server/main.hh | 9 | ||||
| -rw-r--r-- | comp3331/server/src/server/server.cc | 829 | ||||
| -rw-r--r-- | comp3331/server/src/server/server.hh | 28 | ||||
| -rw-r--r-- | comp3331/server/src/shared/connection.cc | 54 | ||||
| -rw-r--r-- | comp3331/server/src/shared/connection.hh | 60 | ||||
| -rw-r--r-- | comp3331/server/src/shared/net.cc | 284 | ||||
| -rw-r--r-- | comp3331/server/src/shared/net.hh | 87 | ||||
| -rw-r--r-- | comp3331/server/src/shared/shared.cc | 25 | ||||
| -rw-r--r-- | comp3331/server/src/shared/shared.hh | 28 |
17 files changed, 2083 insertions, 0 deletions
diff --git a/comp3331/server/CMakeLists.txt b/comp3331/server/CMakeLists.txt new file mode 100644 index 0000000..60c174e --- /dev/null +++ b/comp3331/server/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.10) + +project(comp3331_assignment) + +set(CMAKE_C_COMPILER "clang") +set(CMAKE_CXX_COMPILER "clang++") + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +#set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bin") + +include_directories( + "${PROJECT_SOURCE_DIR}/src" +) + +file (GLOB_RECURSE SHARED_SOURCE_FILES CONFIGURE_DEPENDS + "${PROJECT_SOURCE_DIR}/src/shared/*.cc" +) + +foreach(BINARY_NAME client;server) + file (GLOB_RECURSE BINARY_SOURCE_FILES CONFIGURE_DEPENDS + "${PROJECT_SOURCE_DIR}/src/${BINARY_NAME}/*.cc" + ) + add_executable(${BINARY_NAME} + ${BINARY_SOURCE_FILES} + ${SHARED_SOURCE_FILES} + ) + + target_compile_options(${BINARY_NAME} PRIVATE + -Wall -Wextra -Wshadow -Wdouble-promotion -Wformat=2 -Wundef -fno-common + -Wconversion -Wpedantic -std=c++14 -g3 -O2 + -fstack-protector-strong -fno-omit-frame-pointer -fsanitize=undefined + -Wno-exceptions + ) + target_link_libraries(${BINARY_NAME} PRIVATE + #pthread + ) + target_link_options(${BINARY_NAME} PRIVATE + -fstack-protector-strong -fsanitize=undefined + ) +endforeach() + diff --git a/comp3331/server/src/client/client.cc b/comp3331/server/src/client/client.cc new file mode 100644 index 0000000..9ec2ab2 --- /dev/null +++ b/comp3331/server/src/client/client.cc @@ -0,0 +1,519 @@ +#include "client/client.hh" + +namespace client { + +// Args passed to handle functions, the input line split by spacebars. Similar +// to argv, the first is the program name (or in this case command name). + +static std::string read_cin(const char* const msg) noexcept { + std::cout << msg << ": "; + std::string line; + std::getline(std::cin, line); + return line; +} + +static std::pair<std::string, std::string>& +get_address_port(const char* const address = nullptr, + const char* const port = nullptr) { + static std::pair<std::string, std::string> pair = [&]() { + return std::make_pair<std::string, std::string>(address, port); + }(); + return pair; +} + +static shared::contents_t urecv_contents(shared::connection& connection) { + const auto ret = [&]() { + for (;;) { + const auto ret = shared::urecv_packet(connection.get_socket()); + if (connection.should_discard_packet(ret->packet)) { + continue; + } + return ret; + } + }(); + const shared::contents_t contents = shared::packet_to_contents(ret->packet); + + // error checking early out + const auto find_it = contents.find("success"); + if (find_it != std::end(contents)) { + if (find_it->second == "error") { + throw std::runtime_error(contents.find("message")->second); + } + } + + return contents; +} + +static shared::socket_t get_new_rsock() { + const auto& addr_port = get_address_port(); + const auto info = [&]() { + addrinfo a{}; + a.ai_flags = AI_PASSIVE; + a.ai_family = AF_INET; + a.ai_socktype = SOCK_STREAM; + return shared::make_addrinfo(addr_port.first.c_str(), + addr_port.second.c_str(), std::move(a)); + }(); + const shared::socket_t socket = shared::make_socket(info); + shared::connect_socket(socket, info); + return socket; +} + +using args_t = std::vector<std::string>; +static void handle_create_thread(const args_t& args, + shared::connection& connection) { + if (args.size() != 2) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + { // send request + const shared::contents_t contents{ + {"thread_title", thread_title}, + }; + connection.send_packet(shared::contents_to_packet(contents, "CRT")); + } + + const shared::contents_t contents = urecv_contents(connection); + + const auto message_it = contents.find("message"); + const auto success_it = contents.find("success"); + if (message_it == std::end(contents) || success_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void handle_post_message(const args_t& args, + shared::connection& connection) { + if (args.size() < 3) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + const std::string message = [&]() { + std::string message = args[2]; + for (auto i = 3ul; i < args.size(); ++i) { + message += ' ' + args[i]; + } + return message; + }(); + + connection.send_packet(shared::contents_to_packet( + {{"thread_title", thread_title}, {"message", message}}, "MSG")); + + const shared::contents_t contents = urecv_contents(connection); + const auto message_it = contents.find("message"); + if (message_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void handle_delete_message(const args_t& args, + shared::connection& connection) { + if (args.size() != 3) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + const std::string& message_number = args[2]; + + try { + std::stoi(message_number); + } catch (...) { + throw std::invalid_argument{"failed to parse \"" + message_number + + "\" as an integer"}; + } + + connection.send_packet(shared::contents_to_packet( + {{"thread_title", thread_title}, {"message_number", message_number}}, + "DLT")); + + const shared::contents_t contents = urecv_contents(connection); + const auto message_it = contents.find("message"); + if (message_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void handle_edit_message(const args_t& args, + shared::connection& connection) { + if (args.size() < 4) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + const std::string& message_number = args[2]; + const std::string message = [&]() { + std::string message = args[3]; + for (auto i = 4ul; i < args.size(); ++i) { + message += ' ' + args[i]; + } + return message; + }(); + + try { + std::stoi(message_number); + } catch (...) { + throw std::invalid_argument{"failed to parse \"" + message_number + + "\" as an integer"}; + } + + connection.send_packet( + shared::contents_to_packet({{"thread_title", thread_title}, + {"message_number", message_number}, + {"message", message}}, + "EDT")); + + const shared::contents_t contents = urecv_contents(connection); + const auto message_it = contents.find("message"); + if (message_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void handle_list_threads(const args_t& args, + shared::connection& connection) { + if (args.size() != 1) { + throw std::invalid_argument("invalid syntax"); + } + + { connection.send_packet(shared::contents_to_packet({}, "LST")); } + + const shared::contents_t contents = urecv_contents(connection); + const auto message_it = contents.find("message"); + if (message_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void handle_read_thread(const args_t& args, + shared::connection& connection) { + if (args.size() != 2) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + { // send request + const shared::contents_t contents{ + {"thread_title", thread_title}, + }; + connection.send_packet(shared::contents_to_packet(contents, "RDT")); + } + + const shared::contents_t contents = urecv_contents(connection); + + const auto message_it = contents.find("message"); + const auto success_it = contents.find("success"); + if (message_it == std::end(contents) || success_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void handle_upload_file(const args_t& args, + shared::connection& connection) { + if (args.size() != 3) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + const std::string& filename = args[2]; + const std::string file_contents = [&]() -> std::string { + std::ifstream in{"./" + filename}; + if (!in.is_open()) { + throw std::invalid_argument("file not found"); + } + std::stringstream ss; + ss << in.rdbuf(); + return ss.str(); + }(); + + { + const shared::contents_t contents{{"thread_title", thread_title}, + {"filename", filename}}; + connection.send_packet(shared::contents_to_packet(contents, "UPD")); + } + + { + const shared::contents_t contents = urecv_contents(connection); + + const auto message_it = contents.find("message"); + const auto success_it = contents.find("success"); + if (message_it == std::end(contents) || + success_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + std::cout << message_it->second; + if (success_it->second != "true") { + return; + } + } + + { // send file + const shared::socket_t rsock = get_new_rsock(); + shared::scoped_function close_rsock{[rsock]() { close(rsock); }}; + const shared::contents_t contents{{"file_contents", file_contents}}; + const shared::packet packet = + shared::contents_to_packet(contents, "UPD"); + shared::send_packet(packet, rsock); + } + + { + const shared::contents_t contents = urecv_contents(connection); + + const auto message_it = contents.find("message"); + const auto success_it = contents.find("success"); + if (message_it == std::end(contents) || + success_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + std::cout << message_it->second; + } +} + +static void handle_download_file(const args_t& args, + shared::connection& connection) { + if (args.size() != 3) { + throw std::invalid_argument("invalid syntax"); + } + + const std::string& thread_title = args[1]; + const std::string& filename = args[2]; + { + const shared::contents_t contents{{"thread_title", thread_title}, + {"filename", filename}}; + connection.send_packet(shared::contents_to_packet(contents, "DWN")); + } + + { + const shared::contents_t contents = urecv_contents(connection); + + const auto message_it = contents.find("message"); + const auto success_it = contents.find("success"); + if (message_it == std::end(contents) || + success_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + std::cout << message_it->second; + if (success_it->second != "true") { + return; + } + } + + const auto file_contents = []() { + const shared::socket_t rsock = get_new_rsock(); + shared::scoped_function close_rsock{[rsock]() { close(rsock); }}; + const auto packet = shared::rrecv_packet(rsock); + return packet_to_contents(*packet); + }(); + + const auto data_it = file_contents.find("file_contents"); + const auto message_it = file_contents.find("message"); + if (data_it == std::end(file_contents) || + message_it == std::end(file_contents)) { + return; + } + + std::ofstream out{filename, std::ios_base::trunc}; + if (!out.is_open()) { + throw std::runtime_error{"failed to write file " + filename}; + } + out << data_it->second; + std::cout << message_it->second; +} + +static void handle_remove_thread(const args_t& args, + shared::connection& connection) { + if (args.size() != 2) { + throw std::invalid_argument("invalid syntax"); + } + + const auto& thread_title = args[1]; + { + const shared::contents_t contents{ + {"thread_title", thread_title}, + }; + connection.send_packet(shared::contents_to_packet(contents, "RMV")); + } + + const shared::contents_t contents = urecv_contents(connection); + + const auto message_it = contents.find("message"); + const auto success_it = contents.find("success"); + if (message_it == std::end(contents) || success_it == std::end(contents)) { + std::cout << "\treceived bad response from the server, try again\n"; + return; + } + + std::cout << message_it->second; +} + +static void authenticate(shared::connection& connection) { + const auto send_auth_var = [&](const char* const varname) { + for (;;) { // send username + if (shared::should_exit) { + throw shared::should_exit_exception{}; + } + + const std::string var = read_cin(varname); + if (var.length() == 0) { + continue; + } + + { // request server for status on username + const shared::contents_t contents{{varname, var}}; + connection.send_packet( + shared::contents_to_packet(contents, "ATH")); + } + + const shared::contents_t contents = urecv_contents(connection); + const auto success_it = contents.find("success"); + const auto message_it = contents.find("message"); + if (message_it == std::end(contents) || + success_it == std::end(contents)) { + std::cout + << "\treceived bad response from the server, try again\n"; + continue; + } + std::cout << message_it->second; + + if (success_it->second == "true") { + break; + } + } + }; + + send_auth_var("username"); + send_auth_var("password"); +} + +static void interact(shared::connection& connection) { + static const std::unordered_map<std::string, + decltype(&handle_create_thread)> + commands{{"CRT", handle_create_thread}, {"MSG", handle_post_message}, + {"DLT", handle_delete_message}, {"EDT", handle_edit_message}, + {"LST", handle_list_threads}, {"RDT", handle_read_thread}, + {"UPD", handle_upload_file}, {"DWN", handle_download_file}, + {"RMV", handle_remove_thread}}; + const auto print_commands = [&]() { + std::cout << "\tavailable commands: "; + for (const auto& pair : commands) { + const auto& key = pair.first; + std::cout << key << ' '; + } + std::cout << "(or \'XIT\' to quit)\n"; + }; + + print_commands(); + for (;;) { + if (shared::should_exit) { + throw shared::should_exit_exception{}; + } + const std::string input = read_cin("command"); + if (input.length() <= 0) { + continue; + } + + const std::vector<std::string> args = [&]() { // string split args + std::vector<std::string> ret; + + std::stringstream ss{std::move(input)}; + std::string arg; + while (std::getline(ss, arg, ' ')) { + if (arg.size() <= 0) { + continue; + } + ret.emplace_back(std::move(arg)); + } + + return ret; + }(); + + if (args.size() <= 0) { + continue; + } + + const std::string command = [&]() { + std::string command; + std::transform(std::begin(args[0]), std::end(args[0]), + std::back_inserter(command), + [](const char c) { return toupper(c); }); + return command; + }(); + + if (command == "XIT") { + break; + } + + const auto find_it = commands.find(command); + if (find_it == std::end(commands)) { // unknown command, print commands + print_commands(); + continue; + } + + try { + const auto& func = find_it->second; + func(args, connection); + } catch (const std::invalid_argument& e) { + std::cout << "\tbad arguments: " << e.what() << '\n'; + } catch (const std::runtime_error& e) { + std::cout << "\truntime error: " << e.what() << '\n'; + } + } +} + +void do_client(const char* const address, const char* const port) { + const shared::socket_t usock = [&]() -> shared::socket_t { + get_address_port(address, port); + const auto info = [&]() { // no designated initialisers :( + addrinfo a{}; + a.ai_flags = AI_PASSIVE; + a.ai_family = AF_INET; + a.ai_socktype = SOCK_DGRAM; + return shared::make_addrinfo("0.0.0.0", 0, std::move(a)); + }(); + const shared::socket_t socket = shared::make_socket(info); + shared::bind_socket(socket, info); + return socket; + }(); + + shared::connection connection = [&]() { + sockaddr_in dest{}; + dest.sin_family = AF_INET; + dest.sin_port = htons(static_cast<std::uint16_t>(std::stoi(port))); + inet_aton(address, &dest.sin_addr); + return shared::connection{usock, std::move(dest)}; + }(); + + try { + authenticate(connection); + interact(connection); + } catch (const shared::should_exit_exception&) { + // exit gracefully on should_exit flag + } + + // upload exit before leaving + connection.send_packet(shared::contents_to_packet({}, "XIT")); +} + +} // namespace client diff --git a/comp3331/server/src/client/client.hh b/comp3331/server/src/client/client.hh new file mode 100644 index 0000000..8230e46 --- /dev/null +++ b/comp3331/server/src/client/client.hh @@ -0,0 +1,24 @@ +#ifndef CLIENT_CLIENT_HH_ +#define CLIENT_CLIENT_HH_ + +#include <algorithm> +#include <arpa/inet.h> +#include <chrono> +#include <fstream> +#include <iostream> +#include <sstream> +#include <string> +#include <thread> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "shared/connection.hh" +#include "shared/net.hh" +#include "shared/shared.hh" + +namespace client { +void do_client(const char* const address, const char* const port); +} + +#endif diff --git a/comp3331/server/src/client/main.cc b/comp3331/server/src/client/main.cc new file mode 100644 index 0000000..a452c3e --- /dev/null +++ b/comp3331/server/src/client/main.cc @@ -0,0 +1,28 @@ +#include "client/main.hh" + +using namespace client; + +int main(const int argc, const char* const argv[]) { + + if (argc != 2) { + std::cerr << "usage: ./client PORT<int>\n"; + return EXIT_SUCCESS; + } + + const char* const address = "localhost"; // not an argument fsr + const char* const port = argv[1]; + + try { + shared::set_exit_handler(); + do_client(address, port); + } catch (const std::exception& e) { + std::cerr << "caught exception from client!\n\twhat(): " << e.what() + << '\n'; + return EXIT_FAILURE; + } catch (...) { + std::cerr << "unhandled exception from client!\n"; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/comp3331/server/src/client/main.hh b/comp3331/server/src/client/main.hh new file mode 100644 index 0000000..1e04e74 --- /dev/null +++ b/comp3331/server/src/client/main.hh @@ -0,0 +1,11 @@ +#ifndef CLIENT_MAIN_HH_ +#define CLIENT_MAIN_HH_ + +#include <exception> +#include <iostream> +#include <string> + +#include "client/client.hh" +#include "shared/shared.hh" + +#endif diff --git a/comp3331/server/src/server/client.cc b/comp3331/server/src/server/client.cc new file mode 100644 index 0000000..0197f90 --- /dev/null +++ b/comp3331/server/src/server/client.cc @@ -0,0 +1 @@ +#include "server/client.hh" diff --git a/comp3331/server/src/server/client.hh b/comp3331/server/src/server/client.hh new file mode 100644 index 0000000..8fa57f2 --- /dev/null +++ b/comp3331/server/src/server/client.hh @@ -0,0 +1,24 @@ +#ifndef SERVER_CLIENT_HH_ +#define SERVER_CLIENT_HH_ + +#include <memory> +#include <string> + +#include "shared/connection.hh" +#include "shared/net.hh" + +namespace server { +struct client { + shared::connection connection; + std::unique_ptr<std::string> username = nullptr; // non-null = authenticated + +public: + client(const shared::socket_t& sock, sockaddr_in&& peer) + : connection(sock, std::move(peer)) {} + +public: + bool is_authenticated() const noexcept { return username != nullptr; } +}; +} // namespace server + +#endif diff --git a/comp3331/server/src/server/main.cc b/comp3331/server/src/server/main.cc new file mode 100644 index 0000000..764a732 --- /dev/null +++ b/comp3331/server/src/server/main.cc @@ -0,0 +1,28 @@ +#include "server/main.hh" + +using namespace server; + +int main(const int argc, const char* const argv[]) { + + if (argc != 2) { + std::cerr << "usage: ./server PORT<int>\n"; + return EXIT_SUCCESS; + } + + const char* const address = "0.0.0.0"; + const char* const port = argv[1]; + + try { + shared::set_exit_handler(); + do_server(address, port); + } catch (const std::exception& e) { + std::cerr << "caught exception from server!\n\twhat(): " << e.what() + << '\n'; + return EXIT_FAILURE; + } catch (...) { + std::cerr << "unhandled exception from server!\n"; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/comp3331/server/src/server/main.hh b/comp3331/server/src/server/main.hh new file mode 100644 index 0000000..2714ef8 --- /dev/null +++ b/comp3331/server/src/server/main.hh @@ -0,0 +1,9 @@ +#ifndef SERVER_MAIN_HH_ +#define SERVER_MAIN_HH_ + +#include <iostream> + +#include "server/server.hh" +#include "shared/shared.hh" + +#endif diff --git a/comp3331/server/src/server/server.cc b/comp3331/server/src/server/server.cc new file mode 100644 index 0000000..a9f2ae9 --- /dev/null +++ b/comp3331/server/src/server/server.cc @@ -0,0 +1,829 @@ +#include "server/server.hh" + +namespace server { + +static shared::socket_t make_socket(const char* const address, + const char* const port, + const decltype(addrinfo::ai_socktype)& t) { + const auto info = [&]() { // no designated initialisers :( + addrinfo a{}; + a.ai_flags = AI_PASSIVE; + a.ai_family = AF_INET; + a.ai_socktype = t; + return shared::make_addrinfo(address, port, std::move(a)); + }(); + const shared::socket_t socket = shared::make_socket(info); + shared::bind_socket(socket, info); + return socket; +} + +// default args only used during initialiser +static shared::socket_t& get_rsock(const char* const address = nullptr, + const char* const port = nullptr) { + static shared::socket_t rsock = [&]() { + const shared::socket_t rsock = make_socket(address, port, SOCK_STREAM); + shared::listen_socket(rsock); + return rsock; + }(); + return rsock; +} + +using clients_t = std::vector<std::unique_ptr<client>>; +static clients_t& get_clients() noexcept { + static clients_t clients; + return clients; +} + +using users_t = std::unordered_map<std::string, std::string>; +static users_t& get_known_users() { + static users_t users = []() { + users_t users; + + std::ifstream in{cred_path}; + if (!in.is_open()) { + throw std::runtime_error{"failed to open credentials"}; + } + + std::stringstream ss; + ss << in.rdbuf(); + + for (std::string line; std::getline(ss, line);) { + if (line.empty()) { + continue; + } + + const auto split = static_cast<long>(line.find_first_of(' ')); + if (static_cast<unsigned long>(split) == std::string::npos) { + continue; + } + + const std::string username{std::begin(line), + std::begin(line) + split}; + std::string password{std::next(std::begin(line) + split), + std::end(line)}; + + password.erase( + std::remove_if(std::begin(password), std::end(password), + [](const char c) { return std::isspace(c); }), + std::end(password)); + + if (username.length() == 0 || password.length() == 0) { + continue; // bad line? + } + users.emplace(username, password); + } + return users; + }(); + return users; +} + +struct message { + std::string poster; + std::string contents; // contents is a filename if is_file is true + bool is_file = false; +}; +struct thread { + std::string original_poster; + std::vector<message> messages; +}; +using threads_t = std::unordered_map<std::string, struct thread>; +static threads_t& get_known_threads() { + static threads_t threads = []() { + // We have to read all files in the current directory and attempt to + // parse them as threads. We will inevitably read cmake and makefiles, + // so early out if they are broken. Also, as we are writing this in + // C++14, we will have to resort to some C library calls. + threads_t threads; + std::shared_ptr<DIR> directory{opendir("./"), + [](auto& d) { closedir(d); }}; + if (directory == nullptr) { + throw std::runtime_error("failed to open current directory"); + } + + for (dirent* entry = readdir(directory.get()); entry != nullptr; + entry = readdir(directory.get())) { + if (entry->d_type != DT_REG) { + continue; + } + + std::ifstream in{entry->d_name}; + if (!in.is_open()) { + continue; // probs bad permissions + } + + std::stringstream ss; + ss << in.rdbuf(); + + // Read OP whose username is the first line in the file. + std::string line; + std::getline(ss, line, '\n'); + if (line.empty() || + std::any_of(std::begin(line), std::end(line), + [](const auto& c) { return std::isspace(c); })) { + continue; // bad username or bad file + } + + thread thread{}; + thread.original_poster = line; + + const auto maybe_parse_messages = [&]() { // return false if bad + while (std::getline(ss, line, '\n')) { + if (line.empty()) { + continue; + } + + // we differentiate between files and messages with this + const bool is_file = line[0] == 'F'; + + const auto split = + static_cast<long>(line.find_first_of(' ', 1)); + if (static_cast<unsigned long>(split) == + std::string::npos) { + return false; // filter out bad files + } + + const std::string username{std::next(std::begin(line)), + std::begin(line) + split}; + const std::string contents{ + std::next(std::begin(line) + split), std::end(line)}; + + if (username.length() == 0 || contents.length() == 0) { + continue; // bad line? + } + + thread.messages.push_back( + message{username, contents, is_file}); + } + return true; + }; + + if (!maybe_parse_messages()) { + continue; + } + + threads.emplace(entry->d_name, std::move(thread)); + } + + return threads; + }(); + return threads; +} + +// returns an iterator to the message if it exists, or std::end +static auto get_message_it(decltype(thread::messages)& messages, + const int& message_number) noexcept { + return std::find_if(std::begin(messages), std::end(messages), + [&, i = 0](const auto& message) mutable { + if (message.is_file) { + return false; + } + if (i == message_number) { + return true; + } + ++i; + return false; + }); +} + +static void handle_auth_client(client& client, + const shared::contents_t& contents) { + + auto& known_users = get_known_users(); + const auto& clients = get_clients(); + const auto response_contents = [&]() -> shared::contents_t { + if (client.username == nullptr) { + const auto username_it = contents.find("username"); + if (username_it == std::end(contents)) { + return {{"message", "\tno username provided!\n"}, + {"success", "false"}}; + } + const auto& username = username_it->second; + + if (std::any_of(std::begin(clients), std::end(clients), + [&](const auto& client) { + if (client->username == nullptr) { + return false; + } + return *client->username == username; + })) { + std::cout << "\tmultiple login attempt for: " << username + << '\n'; + return {{"message", "\tuser already logged in\n"}, + {"success", "false"}}; + } + + client.username = + std::make_unique<std::string>(username_it->second); + const auto find_it = known_users.find(username); + if (find_it == std::end(known_users)) { + return { + {"message", "\twelcome new user \"" + username + "\"\n"}, + {"success", "true"}}; + } + return {{"message", "\twelcome back user \"" + username + "\"\n"}, + {"success", "true"}}; + } + + const auto password_it = contents.find("password"); + if (password_it == std::end(contents)) { + return {{"message", "\tno password provided!\n"}, + {"success", "false"}}; + } + const auto& password = password_it->second; + + const auto find_it = known_users.find(*client.username); + if (find_it == std::end(known_users)) { + known_users.emplace(*client.username, password); + return {{"message", "\taccount created successfully\n"}, + {"success", "true"}}; + } + + const auto& prev_password = find_it->second; + if (prev_password != password) { + std::cout << "\tunsuccessful login for user: " << *client.username + << '\n'; + return {{"message", "\tusername password mismatch!\n"}, + {"success", "false"}}; + } + + std::cout << "\tsuccessful login for: " << *client.username << '\n'; + return {{"message", + "\tsuccessful login for user \"" + *client.username + "\"\n"}, + {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "ATH")); +} + +static void handle_create_thread(client& client, + const shared::contents_t& contents) { + const auto thread_title_it = contents.find("thread_title"); + if (thread_title_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + const std::string& thread_title = thread_title_it->second; + + const auto response_contents = [&]() -> shared::contents_t { + auto& threads = get_known_threads(); + if (threads.find(thread_title) != std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" already exists!\n"}, + {"success", "false"}}; + } + + thread thread{}; + thread.original_poster = *client.username; + threads.emplace(thread_title, std::move(thread)); + + return {{"message", "\tthread \"" + thread_title + "\" created\n"}, + {"success", "true"}}; + }(); + client.connection.send_packet( + shared::contents_to_packet(response_contents, "CRT")); +} + +static void handle_post_message(client& client, + const shared::contents_t& contents) { + const auto message_it = contents.find("message"); + const auto thread_title_it = contents.find("thread_title"); + + if (message_it == std::end(contents) || + thread_title_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + + const auto& message = message_it->second; + const auto& thread_title = thread_title_it->second; + + const auto response_contents = [&]() -> shared::contents_t { + auto& threads = get_known_threads(); + + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + + auto& thread = thread_it->second; + + struct message new_message {}; + new_message.poster = *client.username; + new_message.contents = message; + new_message.is_file = false; + thread.messages.emplace_back(std::move(new_message)); + return {{"message", "\tmessage on thread " + thread_title + + " created successfully\n"}, + {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "MSG")); +} + +static void handle_delete_message(client& client, + const shared::contents_t& contents) { + const auto message_number_it = contents.find("message_number"); + const auto thread_title_it = contents.find("thread_title"); + + if (message_number_it == std::end(contents) || + thread_title_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + + const auto message_number = std::stoi(message_number_it->second); + const auto& thread_title = thread_title_it->second; + + const auto response_contents = [&]() -> shared::contents_t { + auto& threads = get_known_threads(); + + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + auto& thread = thread_it->second; + + auto& messages = thread.messages; + const auto message_it = get_message_it(messages, message_number); + if (message_it == std::end(messages)) { + return { + {"message", "\ta message with that number doesn't exist!\n"}, + {"success", "false"}}; + } + + if (message_it->poster != *client.username) { + return {{"message", "\tyou can't delete someone elses post!\n"}, + {"success", "false"}}; + } + + messages.erase(message_it); + return {{"message", "\tmessage removed successfully\n"}, + {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "RDT")); +} + +static void handle_edit_message(client& client, + const shared::contents_t& contents) { + const auto message_number_it = contents.find("message_number"); + const auto thread_title_it = contents.find("thread_title"); + const auto message_it = contents.find("message"); + + if (message_number_it == std::end(contents) || + thread_title_it == std::end(contents) || + message_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + + const auto& thread_title = thread_title_it->second; + const auto& message = message_it->second; + const auto& message_number = std::stoi(message_number_it->second); + + const auto response_contents = [&]() -> shared::contents_t { + auto& threads = get_known_threads(); + + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + auto& thread = thread_it->second; + + auto& messages = thread.messages; + const auto message_it = get_message_it(messages, message_number); + if (message_it == std::end(messages)) { + return { + {"message", "\ta message with that number doesn't exist!\n"}, + {"success", "false"}}; + } + + if (message_it->poster != *client.username) { + return {{"message", "\tyou can't edit someone elses post!\n"}, + {"success", "false"}}; + } + + message_it->contents = message; + return {{"message", "\tmessage edited successfully\n"}, + {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "EDT")); +} + +static void handle_list_threads(client& client, const shared::contents_t&) { + + const auto response_contents = [&]() -> shared::contents_t { + const auto& threads = get_known_threads(); + if (threads.size() == 0) { + return {{"message", "\tthere are no threads to list!\n"}, + {"success", "true"}}; + } + + std::string resp; + for (const auto& keyvalue : threads) { + const auto& threadname = keyvalue.first; + resp += '\t' + threadname + '\n'; + } + return {{"message", resp}, {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "RDT")); +} + +static void handle_read_thread(client& client, + const shared::contents_t& contents) { + const auto thread_title_it = contents.find("thread_title"); + if (thread_title_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + const std::string& thread_title = thread_title_it->second; + + const auto response_contents = [&]() -> shared::contents_t { + auto& threads = get_known_threads(); + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + + const auto& messages = thread_it->second.messages; + if (messages.size() == 0) { + return {{"message", "\tthere are no messages to list!\n"}, + {"success", "false"}}; + } + + int message_num = 0; + std::string resp{}; + for (const auto& message : messages) { + resp += '\t'; + if (message.is_file) { + resp += message.poster + " uploaded " + message.contents + '\n'; + continue; + } + + resp += std::to_string(message_num) + ' ' + message.poster + ": " + + message.contents + '\n'; + ++message_num; + } + + return {{"message", resp}, {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "RDT")); +} + +static void handle_upload_file(client& client, + const shared::contents_t& contents) { + const auto thread_title_it = contents.find("thread_title"); + const auto filename_it = contents.find("filename"); + if (thread_title_it == std::end(contents) || + filename_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + + const auto& thread_title = thread_title_it->second; + const auto& filename = filename_it->second; + const auto response_contents = [&]() -> shared::contents_t { + const auto& threads = get_known_threads(); + + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + const auto& thread = thread_it->second; + + const auto& messages = thread.messages; + const auto old_file_it = std::find_if( + std::begin(messages), std::end(messages), [&](const auto& message) { + if (!message.is_file) { + return false; + } + return message.contents == filename; + }); + if (old_file_it != std::end(messages)) { + return {{"message", "\ta file with \"" + filename + + "\" already exists in the thread!\n"}, + {"success", "false"}}; + } + + return {{"message", "\tbeginning file transfer\n"}, + {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "UPD")); + + if (response_contents.find("success")->second != "true") { + return; + } + + const auto sock = shared::accept_socket(get_rsock()); + const auto packet = shared::rrecv_packet(sock); + const shared::contents_t file_contents = packet_to_contents(*packet); + close(sock); + + const auto file_contents_it = file_contents.find("file_contents"); + if (file_contents_it == std::end(file_contents)) { + throw std::runtime_error("malformed request"); + } + + { + const auto& data = file_contents_it->second; + std::ofstream out{thread_title + "-" + filename, std::ios_base::trunc}; + if (!out.is_open()) { + throw std::runtime_error{"failed to open file for writing"}; + } + out << data; + } + + // append to thread + { + message message{}; + message.poster = *client.username; + message.is_file = true; + message.contents = filename; + get_known_threads() + .find(thread_title) + ->second.messages.emplace_back(std::move(message)); + } + + { + const shared::contents_t final_response{ + {"message", "\tfile \"" + filename + "\" uploaded successfully!\n"}, + {"success", "true"}}; + client.connection.send_packet( + shared::contents_to_packet(final_response, "UPD")); + } +} + +static void handle_download_file(client& client, + const shared::contents_t& contents) { + const auto thread_title_it = contents.find("thread_title"); + const auto filename_it = contents.find("filename"); + if (thread_title_it == std::end(contents) || + filename_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + + const auto& thread_title = thread_title_it->second; + const auto& filename = filename_it->second; + const auto response_contents = [&]() -> shared::contents_t { + const auto& threads = get_known_threads(); + + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + const auto& thread = thread_it->second; + + const auto& messages = thread.messages; + const auto old_file_it = std::find_if( + std::begin(messages), std::end(messages), [&](const auto& message) { + if (!message.is_file) { + return false; + } + return message.contents == filename; + }); + if (old_file_it == std::end(messages)) { + return {{"message", "\tno file named \"" + filename + + "\" exists in the thread!\n"}, + {"success", "false"}}; + } + + if (!std::ifstream{"./" + thread_title + "-" + filename}.is_open()) { + throw std::runtime_error("file \"" + filename + "\" not found"); + } + + return {{"message", "\tbeginning file transfer\n"}, + {"success", "true"}}; + }(); + + client.connection.send_packet( + shared::contents_to_packet(response_contents, "DWN")); + + if (response_contents.find("success")->second != "true") { + return; + } + + const auto sock = shared::accept_socket(get_rsock()); + const std::string file_contents = [&]() -> std::string { + std::ifstream in{"./" + thread_title + "-" + filename}; + if (!in.is_open()) { + throw std::runtime_error("file \"" + filename + "\" not found"); + } + std::stringstream ss; + ss << in.rdbuf(); + return ss.str(); + }(); + const shared::contents_t final_contents = { + {"file_contents", file_contents}, + {"message", "\tfile downloaded successfully\n"}, + {"success", "true"}}; + shared::send_packet(shared::contents_to_packet(final_contents, "DWN"), + sock); + close(sock); +} + +static void handle_remove_thread(client& client, + const shared::contents_t& contents) { + const auto thread_title_it = contents.find("thread_title"); + if (thread_title_it == std::end(contents)) { + throw std::runtime_error("malformed request"); + } + const std::string& thread_title = thread_title_it->second; + + const auto response_contents = [&]() -> shared::contents_t { + auto& threads = get_known_threads(); + + const auto thread_it = threads.find(thread_title); + if (thread_it == std::end(threads)) { + return {{"message", "\ta thread with title \"" + thread_title + + "\" doesn't exist!\n"}, + {"success", "false"}}; + } + + auto& thread = thread_it->second; + if (thread.original_poster != *client.username) { + return {{"message", "\tyou can't delete a thread you don't own!\n"}, + {"success", "false"}}; + } + + for (const auto& message : thread.messages) { + if (!message.is_file) { + continue; + } + const std::string name = thread_title + '-' + message.contents; + remove(name.c_str()); + } + + threads.erase(thread_it); + remove(thread_title.c_str()); // remove file + return {{"message", + "\tthread \"" + thread_title + "\" removed successfully\n"}, + {"success", "true"}}; + }(); + client.connection.send_packet( + shared::contents_to_packet(response_contents, "RDT")); +} + +static void handle_exit(client& client, const shared::contents_t&) { + auto& clients = get_clients(); + const auto info_ptr = &client.connection.get_info(); + const auto find_it = std::find_if( + std::begin(clients), std::end(clients), [&](const auto& c) { + return std::memcmp(info_ptr, &c->connection.get_info(), + sizeof(decltype(*info_ptr))) == 0; + }); + if (find_it == std::end(clients)) { + return; // already removed or not associated + } + clients.erase(find_it); +} + +static void handle_client_packet(client& client, const shared::packet& packet) { + const std::string command = [&packet]() { + std::string command; + std::copy(std::begin(packet.header.command), + std::end(packet.header.command), std::back_inserter(command)); + return command; + }(); + + static const std::unordered_map<std::string, + decltype(&handle_create_thread)> + commands{{"CRT", handle_create_thread}, + {"MSG", handle_post_message}, + {"DLT", handle_delete_message}, + {"EDT", handle_edit_message}, + {"LST", handle_list_threads}, + {"RDT", handle_read_thread}, + {"UPD", handle_upload_file}, + {"DWN", handle_download_file}, + {"RMV", handle_remove_thread}, + {"ATH", handle_auth_client}, + {"XIT", handle_exit}}; + + const auto find_it = commands.find(command); + if (find_it == std::end(commands)) { + std::cout << "got unknown command from client: " << command << '\n'; + return; + } + + std::cout << "got " << command << " from client: " + << (client.username != nullptr ? *client.username + : "*not logged in*") + << '\n'; + + const auto& func = find_it->second; + const shared::contents_t contents = shared::packet_to_contents(packet); + + try { + func(client, contents); + } catch (const std::runtime_error& e) { + const std::string message = + "internal server error: " + std::string{e.what()}; + const shared::contents_t error_contents{{"message", message}, + {"success", "error"}}; + client.connection.send_packet( + shared::contents_to_packet(error_contents, command.c_str())); + std::cout << message << '\n'; + } +} + +static void save() { + { // write creds + std::ofstream out{cred_path, std::ios_base::trunc}; + if (!out.is_open()) { + throw std::runtime_error{"failed to write to credentials"}; + } + for (const auto& keypair : get_known_users()) { + out << keypair.first + ' ' + keypair.second << '\n'; + } + } + { // write threads + for (const auto& keypair : get_known_threads()) { + const auto& thread_name = keypair.first; + const auto& thread = keypair.second; + std::ofstream out{"./" + thread_name, std::ios_base::trunc}; + + if (!out.is_open()) { + throw std::runtime_error{"failed to write thread " + + thread_name + '\n'}; + } + + out << thread.original_poster << '\n'; + for (const auto& message : thread.messages) { + out << (message.is_file ? 'F' : 'M'); + out << message.poster << ' ' << message.contents << '\n'; + } + } + } + get_clients().clear(); +} + +void do_server(const char* const address, const char* const port) { + const shared::socket_t usock = make_socket(address, port, SOCK_DGRAM); + + { // init static vars before server starts to catch errors early + get_known_users(); + get_known_threads(); + get_rsock(address, port); + } + + try { + auto& clients = get_clients(); + while (!shared::should_exit) { + auto packet = shared::maybe_urecv_packet(usock); + if (packet == nullptr) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + + auto& origin = packet->origin; + // Compare the sockaddr_storage of the packet to existing clients. + const auto find_it = std::find_if( + std::begin(clients), std::end(clients), + [&](const auto& client) { + return std::memcmp(&client->connection.get_info(), &origin, + sizeof(decltype(origin))) == 0; + }); + + const auto& contents = packet->packet; + + client& client = [&]() -> struct client& { + if (find_it == std::end(clients)) { + clients.push_back(std::make_unique<struct client>( + usock, std::move(origin))); + return *clients.back(); + } + return *find_it->get(); + } + (); + + if (client.connection.should_discard_packet(contents)) { + continue; + } + + handle_client_packet(client, contents); + } + } catch (const shared::should_exit_exception&) { + // gracefully exit on should_exit flag + } + + save(); // save updated contents before exiting +} + +} // namespace server diff --git a/comp3331/server/src/server/server.hh b/comp3331/server/src/server/server.hh new file mode 100644 index 0000000..e2554de --- /dev/null +++ b/comp3331/server/src/server/server.hh @@ -0,0 +1,28 @@ +#ifndef SERVER_SERVER_HH_ +#define SERVER_SERVER_HH_ + +#include <algorithm> +#include <chrono> +#include <cstring> +#include <fstream> +#include <iostream> +#include <iterator> +#include <sstream> +#include <thread> +#include <unordered_map> +#include <vector> +#include <dirent.h> + +#include "server/client.hh" +#include "shared/connection.hh" +#include "shared/net.hh" +#include "shared/shared.hh" + +namespace server { + +const char* const cred_path = "./credentials.txt"; + +void do_server(const char* const address, const char* const port); +} // namespace server + +#endif diff --git a/comp3331/server/src/shared/connection.cc b/comp3331/server/src/shared/connection.cc new file mode 100644 index 0000000..1063823 --- /dev/null +++ b/comp3331/server/src/shared/connection.cc @@ -0,0 +1,54 @@ +#include "shared/connection.hh" + +namespace shared { + +void connection::send_packet(packet&& packet) noexcept { + packet.header.sequence = this->seq_num; + shared::send_packet(packet, this->sock, this->info); + ++this->seq_num; + + std::lock_guard<std::mutex> guard{*this->lock}; + this->sent.push_back(packet); +} + +bool connection::should_discard_packet(const packet& packet) noexcept { + std::lock_guard<std::mutex> guard{*this->lock}; + + // ack case + if (packet.header.command[0] == '\0') { + this->sent.erase(std::remove_if(std::begin(this->sent), + std::end(this->sent), + [&](const auto& p) { + return p.header.sequence <= + packet.header.sequence; + }), + std::end(this->sent)); + return true; + } + + // Send an ack for the packet if it's not an ack itself + auto ack_pkt = shared::contents_to_packet({}, "\0\0\0"); + ack_pkt.header.sequence = packet.header.sequence; + shared::send_packet(ack_pkt, this->sock, this->info); + + if (packet.header.sequence != this->ack_num) { + return true; + } + ++this->ack_num; + return false; +} + +// Reliable transport reads our packets on a different thread and resends if +// necessary. +void connection::do_reliable_transport() noexcept { + while (!*this->should_thread_exit) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + std::lock_guard<std::mutex> guard{*this->lock}; + for (const auto& packet : this->sent) { + shared::send_packet(packet, this->sock, this->info); + } + } +} + +} // namespace shared diff --git a/comp3331/server/src/shared/connection.hh b/comp3331/server/src/shared/connection.hh new file mode 100644 index 0000000..d7237cb --- /dev/null +++ b/comp3331/server/src/shared/connection.hh @@ -0,0 +1,60 @@ +#ifndef SHARED_CONNECTION_HH_ +#define SHARED_CONNECTION_HH_ + +#include <atomic> +#include <memory> +#include <mutex> +#include <vector> + +#include "shared/net.hh" + +namespace shared { + +// The connection class abstracts sending and receiving data, including reliable +// transmission over UDP. +class connection { +private: + shared::socket_t sock; + sockaddr_in info; + +private: + std::uint32_t seq_num = 0; // track packet sequence number + std::uint32_t ack_num = 0; + + // for reliable transport, spawn a new thread which reads sent/received + std::unique_ptr<std::atomic<bool>> should_thread_exit; + std::unique_ptr<std::mutex> lock; + std::vector<packet> sent; + std::vector<packet> received; + std::shared_ptr<std::thread> reliable_transport_thread; + void do_reliable_transport() noexcept; + +public: + connection(const socket_t& sock, sockaddr_in&& info) + : sock(sock), info(std::move(info)), + should_thread_exit(std::make_unique<std::atomic<bool>>(false)), + lock(std::make_unique<std::mutex>()), + reliable_transport_thread(std::make_shared<std::thread>( + &connection::do_reliable_transport, this)) {} + + connection(const connection&) = delete; + connection(connection&&) = default; + ~connection() noexcept { + *this->should_thread_exit = true; + this->reliable_transport_thread->join(); + } + +public: + const sockaddr_in& get_info() const noexcept { return this->info; } + const socket_t& get_socket() const noexcept { return this->sock; } + +public: + // All unreliable packets should go through these functions so we may track + // if our packets have been sent or received, making them reliable. + void send_packet(packet&& packet) noexcept; + bool should_discard_packet(const packet& packet) noexcept; +}; + +} // namespace shared + +#endif diff --git a/comp3331/server/src/shared/net.cc b/comp3331/server/src/shared/net.cc new file mode 100644 index 0000000..b71170a --- /dev/null +++ b/comp3331/server/src/shared/net.cc @@ -0,0 +1,284 @@ +#include "shared/net.hh" + +namespace shared { + +static std::string get_errno_str() noexcept { return strerror(errno); } + +addrinfo_t make_addrinfo(const char* const address, const char* const port, + const addrinfo&& hints) { + addrinfo* info = nullptr; + if (int status = getaddrinfo(address, port, &hints, &info)) { + throw std::runtime_error{gai_strerror(status)}; + } + return addrinfo_t{info, [](const auto& p) { freeaddrinfo(p); }}; +} + +socket_t make_socket(const addrinfo_t& info) { + const int sock = + socket(info->ai_family, info->ai_socktype, info->ai_protocol); + if (sock == -1) { + throw std::runtime_error{get_errno_str()}; + } + + // to avoid annoying binding timeouts + const int enable = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + return sock; +} + +socket_t accept_socket(const socket_t& lsock) { + sockaddr_in in{}; + socklen_t size = sizeof(in); + + const shared::socket_t sock = + accept(lsock, reinterpret_cast<sockaddr*>(&in), &size); + if (sock == -1) { + throw std::runtime_error(get_errno_str()); + } + return sock; +} + +void bind_socket(const socket_t sock, const addrinfo_t& info) { + // bind to first we can + for (const addrinfo* i = info.get(); i != nullptr; i = i->ai_next) { + if (bind(sock, info->ai_addr, info->ai_addrlen) == -1) { + continue; + } + return; + } + throw std::runtime_error{get_errno_str()}; +} + +void connect_socket(const socket_t sock, const addrinfo_t& info) { + if (connect(sock, info->ai_addr, info->ai_addrlen) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +void listen_socket(const socket_t sock) { + if (listen(sock, SOMAXCONN) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +void close_socket(const socket_t sock) { + if (close(sock) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +packet contents_to_packet(const contents_t& contents, + const char* const command) { + packet ret{}; + std::memset(&ret.header, 0, sizeof(ret.header)); // 0init padding bytes + + std::uint32_t size = 0; + for (const auto& keyvalue : contents) { + const auto& key = keyvalue.first; // no structured bindings :( + const auto& value = keyvalue.second; + + const auto contents_size = static_cast<std::uint32_t>(value.size()); + + decltype(ret.contents) addition; + // Copy the size, name and data as expected, then push it. + std::copy(reinterpret_cast<const char* const>(&contents_size), + reinterpret_cast<const char* const>(&contents_size) + + sizeof(contents_size), + std::back_inserter(addition)); + std::copy(key.c_str(), key.c_str() + key.size() + 1, + std::back_inserter(addition)); + std::copy(value.c_str(), value.c_str() + value.size(), + std::back_inserter(addition)); + + size += addition.size(); // update size before moving + std::copy(std::make_move_iterator(std::begin(addition)), + std::make_move_iterator(std::end(addition)), + std::back_inserter(ret.contents)); + } + + // fill in expected header values + ret.header.size = size + sizeof(packet::header); + std::memcpy(ret.header.command, command, 3); + + return ret; +} + +contents_t packet_to_contents(const packet& packet) { + contents_t ret{}; + + // Extract the data as described in the header file. + for (auto i = 0u; i < packet.contents.size();) { + const char* const data = packet.contents.data() + i; + + const std::uint32_t contents_size = [&]() { + std::uint32_t contents_size; + std::memcpy(&contents_size, + reinterpret_cast<const std::uint32_t* const>(data), + sizeof(std::uint32_t)); + return contents_size; + }(); + + const auto size_size = sizeof(std::uint32_t); + const std::string name = data + size_size; + const char* const contents = data + size_size + name.length() + 1; + + ret.emplace(name, std::string{contents, contents + contents_size}); + i += size_size + name.size() + 1 + contents_size; + } + return ret; +} + +static std::vector<char> packet_to_data(const packet& packet) noexcept { + std::vector<char> data{}; + // data.reserve(sizeof(struct packet) + packet.contents.size()); + std::copy(reinterpret_cast<const char* const>(&packet.header), + reinterpret_cast<const char* const>(&packet.header) + + sizeof(packet::header), + std::back_inserter(data)); + std::copy(std::begin(packet.contents), std::end(packet.contents), + std::back_inserter(data)); + return data; +} + +void send_packet(const packet& packet, const socket_t& sock, + const sockaddr_in& dest) { + const std::vector<char> data = packet_to_data(packet); + + if (sendto(sock, data.data(), data.size(), 0, (sockaddr*)&dest, + sizeof(sockaddr_in)) == -1) { + throw std::runtime_error{get_errno_str()}; + } +} + +void send_packet(const packet& packet, const socket_t& sock) { + const std::vector<char> data = packet_to_data(packet); + + const auto total_size = data.size(); + for (unsigned sent = 0; sent < data.size();) { + const ssize_t res = + send(sock, data.data() + sent, total_size - sent, 0); + if (res != -1) { + sent += res; + continue; + } + + if (errno == EAGAIN || errno == EWOULDBLOCK) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + + throw std::runtime_error{get_errno_str()}; + } +} + +static std::size_t get_backlog_size(const socket_t& socket) noexcept { + size_t size = 0; + if (ioctl(socket, FIONREAD, &size) == -1) { + throw std::runtime_error{get_errno_str()}; + } + return size; +} + +std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock) { + const auto packet_size = get_backlog_size(sock); + if (packet_size <= 0) { + return nullptr; + } + + recv_packet_ret ret; + unsigned int origin_len = sizeof(ret.origin); + + std::vector<char> buffer; + buffer.reserve(packet_size); + if (recvfrom(sock, buffer.data(), packet_size, 0, (sockaddr*)&ret.origin, + &origin_len) == -1) { + throw std::runtime_error(get_errno_str()); + } + + std::memcpy(&ret.packet.header, buffer.data(), sizeof(packet::header)); + ret.packet.contents.reserve(ret.packet.header.size - + sizeof(packet::header)); + std::copy(buffer.data() + sizeof(packet::header), + buffer.data() + ret.packet.header.size, + std::back_inserter(ret.packet.contents)); + + return std::make_shared<recv_packet_ret>(std::move(ret)); +} + +// true when finished reading as our stream packets may be very large +static bool maybe_rrecv_packet(const socket_t& sock, packet& packet) { + + auto backlog_size = get_backlog_size(sock); + if (backlog_size <= 0) { + return false; + } + + auto& target_size = packet.header.size; + if (!target_size) { // header read required + if (backlog_size < sizeof(header)) { // no header, try again later + return false; + } + + if (read(sock, &packet.header, sizeof(header)) == -1) { + throw std::runtime_error{get_errno_str()}; + } + backlog_size -= sizeof(header); + } + + const auto read_size = + std::min(backlog_size, static_cast<unsigned long>(target_size)); + std::vector<char> buffer; + buffer.reserve(read_size); + if (read(sock, buffer.data(), read_size) == -1) { + throw std::runtime_error{get_errno_str()}; + } + + std::copy(buffer.data(), buffer.data() + read_size, + std::back_inserter(packet.contents)); + if (packet.contents.size() < packet.header.size - sizeof(header)) { + return false; // more data to read, do again + } + + return true; +} + +static const auto ms_timeout = std::chrono::seconds(30); +std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock, + const bool& timeout) { + // poll our non-blocking sockets + const auto start = std::chrono::steady_clock::now(); + while (!shared::should_exit) { + const auto pkt = maybe_urecv_packet(rsock); + if (pkt != nullptr) { + return pkt; + } + if (timeout && std::chrono::steady_clock::now() > start + ms_timeout) { + throw std::runtime_error("urecv timeout elapsed"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + throw should_exit_exception(); +} + +std::shared_ptr<packet> rrecv_packet(const socket_t& rsock, + const bool& timeout) { + packet packet{}; + auto last = std::chrono::steady_clock::now(); + while (!shared::should_exit) { + const auto prev_size = packet.contents.size(); + if (maybe_rrecv_packet(rsock, packet)) { + return std::make_shared<struct packet>(std::move(packet)); + } + if (prev_size != packet.contents.size()) { + last = std::chrono::steady_clock::now(); + } + if (timeout && std::chrono::steady_clock::now() > last + ms_timeout) { + throw std::runtime_error("rrecv timeout elapsed"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + throw should_exit_exception(); +} + +} // namespace shared diff --git a/comp3331/server/src/shared/net.hh b/comp3331/server/src/shared/net.hh new file mode 100644 index 0000000..b17ca1f --- /dev/null +++ b/comp3331/server/src/shared/net.hh @@ -0,0 +1,87 @@ +#ifndef SHARED_NET_HH_ +#define SHARED_NET_HH_ + +#include <algorithm> +#include <arpa/inet.h> +#include <chrono> +#include <cstdint> +#include <cstring> +#include <errno.h> +#include <fcntl.h> +#include <memory> +#include <netdb.h> +#include <optional> +#include <stdexcept> +#include <string.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <thread> +#include <type_traits> +#include <unistd.h> +#include <unordered_map> +#include <vector> + +#include "shared/shared.hh" + +// Functions in this namespace are common network-related wrappers with error +// checking. +namespace shared { + +using addrinfo_t = std::shared_ptr<addrinfo>; // for automatic freeaddrinfo call +addrinfo_t make_addrinfo(const char* const address, const char* const port, + const addrinfo&& hints); + +using socket_t = int; +int make_socket(const addrinfo_t& info); +int accept_socket(const socket_t& sock); +void bind_socket(const socket_t sock, const addrinfo_t& info); +void connect_socket(const socket_t sock, const addrinfo_t& info); +void listen_socket(const socket_t sock); +void close_socket(const socket_t sock); + +struct header { + std::uint32_t size = 0; // size of packet, including header. + std::uint32_t sequence; // sequence number of packet + char command[3]; // command type, if \0\0\0 then it's an ack +}; +static_assert(std::is_trivially_copyable<header>::value, + "header must be memcpy-able"); +struct packet { + header header; + std::vector<char> contents; +}; + +// Our packets contents consist of data entries like so: +// std::uint32_t | char[] | DATA +// ^ size of data ^ name ^ data, which repeats for size length +// Any packet may contain multiple, or zero, entries. The name can be any +// length. This way we don't have to define structs with hardcoded length +// limits, and we can use this format when sending anything, even files. +// Numeric values will be encoded as strings and converted when necessary. +using contents_t = std::unordered_map<std::string, std::string>; +void send_packet(const packet& packet, const socket_t& sock, + const sockaddr_in& dest); +void send_packet(const packet& packet, const socket_t& sock); + +packet contents_to_packet(const contents_t& contents, + const char* const command); +contents_t packet_to_contents(const packet& packet); + +// Recv's sockets, might return nullptr if no packet available. +struct recv_packet_ret { + sockaddr_in origin; + struct packet packet; +}; + +// non-blocking +std::shared_ptr<recv_packet_ret> maybe_urecv_packet(const socket_t& sock); +// blocking, will throw if timeout is elapsed +std::shared_ptr<recv_packet_ret> urecv_packet(const socket_t& rsock, + const bool& timeout = true); +std::shared_ptr<packet> rrecv_packet(const socket_t& rsock, + const bool& timeout = true); + +} // namespace shared + +#endif diff --git a/comp3331/server/src/shared/shared.cc b/comp3331/server/src/shared/shared.cc new file mode 100644 index 0000000..3a13b47 --- /dev/null +++ b/comp3331/server/src/shared/shared.cc @@ -0,0 +1,25 @@ +#include "shared/shared.hh" + +namespace shared { + +bool should_exit = false; + +static void set_signal(const decltype(SIGINT) signal, + void (*const callback)(const int)) { + struct sigaction sa {}; + sa.sa_handler = callback; + if (sigaction(signal, &sa, nullptr) == -1) { + throw std::runtime_error("failed to set signal handler!"); + } +} + +void set_exit_handler() { + set_signal(SIGPIPE, SIG_IGN); + set_signal(SIGINT, [](const int) { + std::cout << " interrupt signal received\n"; + should_exit = true; + }); +} + +} // namespace shared + diff --git a/comp3331/server/src/shared/shared.hh b/comp3331/server/src/shared/shared.hh new file mode 100644 index 0000000..8e4fa36 --- /dev/null +++ b/comp3331/server/src/shared/shared.hh @@ -0,0 +1,28 @@ +#ifndef SHARED_SHARED_HH_ +#define SHARED_SHARED_HH_ + +#include <functional> +#include <iostream> +#include <signal.h> +#include <stdexcept> + +namespace shared { +extern bool should_exit; + +void set_exit_handler(); + +class should_exit_exception : public std::exception {}; + +// This won't exist until c++24 lol +class scoped_function { +private: + using func_t = std::function<void()>; + func_t func; + +public: + scoped_function(const func_t& f) : func(f) {} + ~scoped_function() { this->func(); } +}; +} // namespace shared + +#endif |
