aboutsummaryrefslogtreecommitdiff
path: root/src/shared/net/connection.hh
diff options
context:
space:
mode:
authorNicolas James <Eele1Ephe7uZahRie@tutanota.com>2025-02-12 18:05:18 +1100
committerNicolas James <Eele1Ephe7uZahRie@tutanota.com>2025-02-12 18:05:18 +1100
commit1cc08c51eb4b0f95c30c0a98ad1fc5ad3459b2df (patch)
tree222dfcd07a1e40716127a347bbfd7119ce3d0984 /src/shared/net/connection.hh
initial commit
Diffstat (limited to 'src/shared/net/connection.hh')
-rw-r--r--src/shared/net/connection.hh212
1 files changed, 212 insertions, 0 deletions
diff --git a/src/shared/net/connection.hh b/src/shared/net/connection.hh
new file mode 100644
index 0000000..79146af
--- /dev/null
+++ b/src/shared/net/connection.hh
@@ -0,0 +1,212 @@
+#ifndef SHARED_NET_CONNECTION_HH_
+#define SHARED_NET_CONNECTION_HH_
+
+#include <algorithm>
+#include <optional>
+#include <string>
+
+#include "shared/net/net.hh"
+#include "shared/net/proto.hh"
+
+namespace shared {
+namespace net {
+
+class connection {
+private:
+ struct sockets {
+ int rsock;
+ int usock; // connected in constructor, no need to use sendto or revfrom
+ std::string address;
+ std::string port;
+
+ sockets(const int r, const int u, const std::string& a,
+ const std::string& p)
+ : rsock(r), usock(u), address(a), port(p) {}
+ };
+ std::optional<sockets> socks;
+ std::optional<std::string> bad_reason;
+
+private:
+ struct packet_header {
+ std::uint32_t size;
+ };
+
+ // Data has a non-serialised header, and variable length serialised
+ // protobuf content.
+ static std::string packet_to_data(const proto::packet& packet) {
+
+ std::string data;
+ packet.SerializeToString(&data);
+ shared::compress_string(data);
+
+ packet_header header{.size = htonl(static_cast<std::uint32_t>(
+ std::size(data) + sizeof(packet_header)))};
+
+ return std::string{reinterpret_cast<char*>(&header),
+ reinterpret_cast<char*>(&header) + sizeof(header)} +
+ std::move(data);
+ }
+
+ void send_sock_packet(const int sock, const proto::packet& packet) {
+ if (!this->good()) {
+ return;
+ }
+
+ const std::string data = packet_to_data(packet);
+
+ if (send(sock, std::data(data), std::size(data), 0) == -1) {
+ this->bad_reason = shared::net::get_errno_error();
+ }
+ }
+
+public:
+ connection(const int rsock) {
+ using namespace shared::net;
+
+ const std::string peer_address = get_socket_peer_address(rsock);
+ const std::string peer_port = get_socket_peer_port(rsock);
+
+ // Open up a connected usock based on the state of the connected rsock.
+ const int usock = [&peer_address, &peer_port, &rsock]() -> int {
+ constexpr addrinfo hints = {.ai_flags = AI_PASSIVE,
+ .ai_family = AF_INET,
+ .ai_socktype = SOCK_DGRAM};
+ const std::string host_address = get_socket_host_address(rsock);
+ const std::string host_port = get_socket_host_port(rsock);
+ const auto host_info =
+ get_addr_info(host_address, host_port, &hints);
+ const int usock = make_socket(host_info.get());
+ bind_socket(usock, host_info.get());
+
+ const auto peer_info =
+ get_addr_info(peer_address, peer_port, &hints);
+ connect_socket(usock, peer_info.get());
+
+ return usock;
+ }();
+
+ nonblock_socket(usock);
+ nonblock_socket(rsock);
+
+ this->socks.emplace(rsock, usock, peer_address, peer_port);
+ }
+
+ // We do not want to copy this object!
+ // We use std::nullopt to determine if this object has been moved and if we
+ // should close its sockets.
+ connection(const connection&) = delete;
+ connection(connection&& other) noexcept {
+ std::swap(this->socks, other.socks);
+ std::swap(this->bad_reason, other.bad_reason);
+ other.socks.reset();
+ }
+ connection& operator=(connection&& other) noexcept {
+ std::swap(this->socks, other.socks);
+ std::swap(this->bad_reason, other.bad_reason);
+ other.socks.reset();
+ return *this;
+ }
+ ~connection() noexcept {
+ if (this->socks.has_value()) {
+ shared::net::close_socket(socks->rsock);
+ shared::net::close_socket(socks->usock);
+ }
+ }
+
+ // Getters.
+ bool good() const noexcept { return !bad_reason.has_value(); }
+ std::string get_bad_reason() const noexcept {
+ return this->bad_reason.value();
+ }
+ std::string get_address() const noexcept { return this->socks->address; }
+
+public:
+ // Send does nothing if good() returns false!
+ // Returns whether or not we were able to send our packet.
+ void rsend_packet(const proto::packet& packet) noexcept {
+ return send_sock_packet(this->socks->rsock, packet);
+ }
+ void usend_packet(const proto::packet& packet) noexcept {
+ return send_sock_packet(this->socks->usock, packet);
+ }
+ std::optional<proto::packet> rrecv_packet() noexcept {
+ const int sock = this->socks->rsock;
+
+ // Get the size of the backlog, early out of there's nothing there yet.
+ const auto backlog_size = shared::net::get_backlog_size(sock);
+ if (backlog_size < sizeof(packet_header)) {
+ return std::nullopt;
+ }
+
+ // Read for the packet headers and get the claimed size. Early out if
+ // our stream isn't big enough for that yet.
+ packet_header header = {};
+ recv(sock, &header, sizeof(header), MSG_PEEK);
+ const std::uint32_t read_packet_size = ntohl(header.size);
+ if (backlog_size < read_packet_size) {
+ return std::nullopt;
+ }
+
+ // Read the actual packet now, based on our claimed size.
+ std::string data;
+ data.reserve(read_packet_size);
+ if (read(sock, std::data(data), read_packet_size) == -1) {
+ this->bad_reason = shared::net::get_errno_error();
+ return std::nullopt;
+ }
+
+ data = std::string{
+ reinterpret_cast<char*>(std::data(data)) + sizeof(packet_header),
+ reinterpret_cast<char*>(std::data(data)) + read_packet_size};
+ shared::decompress_string(data);
+
+ // Parse the packet, ignoring the header.
+ proto::packet packet;
+ if (!packet.ParseFromString(data)) {
+ return std::nullopt;
+ }
+
+ return packet;
+ }
+ std::optional<proto::packet> urecv_packet() noexcept {
+ const int sock = this->socks->usock;
+
+ restart:
+ const auto packet_size = shared::net::get_backlog_size(sock);
+
+ if (packet_size == 0) {
+ return std::nullopt;
+ }
+
+ std::string data;
+ data.reserve(packet_size);
+ if (recv(sock, std::data(data), packet_size, 0) == -1) {
+ this->bad_reason = shared::net::get_errno_error();
+ return std::nullopt;
+ }
+
+ data = std::string{
+ reinterpret_cast<char*>(std::data(data)) + sizeof(packet_header),
+ reinterpret_cast<char*>(std::data(data)) + packet_size};
+ shared::decompress_string(data);
+
+ proto::packet packet;
+ if (!packet.ParseFromString(data)) {
+ goto restart;
+ }
+
+ return packet;
+ }
+ // Gets packets from r/u streams, doesn't care which.
+ std::optional<proto::packet> recv_packet() noexcept {
+ if (const auto ret = urecv_packet(); ret.has_value()) {
+ return ret;
+ }
+ return rrecv_packet();
+ }
+};
+
+} // namespace net
+} // namespace shared
+
+#endif