aboutsummaryrefslogtreecommitdiff
path: root/src/shared/net/connection.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/shared/net/connection.cc')
-rw-r--r--src/shared/net/connection.cc219
1 files changed, 219 insertions, 0 deletions
diff --git a/src/shared/net/connection.cc b/src/shared/net/connection.cc
index 79a83f5..b368735 100644
--- a/src/shared/net/connection.cc
+++ b/src/shared/net/connection.cc
@@ -1 +1,220 @@
#include "shared/net/connection.hh"
+
+namespace shared {
+namespace net {
+
+connection::connection(const socket_t& rsock) {
+ const std::string peer_address = get_socket_peer_address(rsock);
+ const std::string peer_port = get_socket_peer_port(rsock);
+
+ // Open up a connected usock based on the state of the connected rsock.
+ const socket_t usock = [&]() -> socket_t {
+ constexpr addrinfo hints = {.ai_flags = AI_PASSIVE,
+ .ai_family = AF_INET,
+ .ai_socktype = SOCK_DGRAM};
+ const std::string host_address = get_socket_host_address(rsock);
+ const std::string host_port = get_socket_host_port(rsock);
+ const auto host_info = get_addr_info(host_address, host_port, &hints);
+ const socket_t usock = make_socket(host_info.get());
+ bind_socket(usock, host_info.get());
+
+ const auto peer_info = get_addr_info(peer_address, peer_port, &hints);
+ connect_socket(usock, peer_info.get());
+
+ return usock;
+ }();
+
+ nonblock_socket(usock);
+ nonblock_socket(rsock);
+
+ this->socks.emplace(sockets{.rsock = rsock,
+ .usock = usock,
+ .address = peer_address,
+ .port = peer_port});
+}
+
+connection::connection(connection&& other) noexcept {
+ std::swap(this->socks, other.socks);
+ std::swap(this->bad_reason, other.bad_reason);
+ other.socks.reset();
+}
+
+connection& connection::operator=(connection&& other) noexcept {
+ std::swap(this->socks, other.socks);
+ std::swap(this->bad_reason, other.bad_reason);
+ other.socks.reset();
+ return *this;
+}
+
+connection::~connection() noexcept { this->close(); }
+
+void connection::rsend_packet(const rpacket_t& packet) noexcept {
+ this->rpackets.push_back(packet);
+}
+
+void connection::usend_packet(const upacket_t& packet) noexcept {
+ this->upackets.push_back(packet);
+}
+
+void connection::rsend_packet(rpacket&& packet) noexcept {
+ this->rsend_packet(std::make_shared<rpacket>(std::move(packet)));
+}
+
+void connection::usend_packet(upacket&& packet) noexcept {
+ this->usend_packet(std::make_shared<upacket>(std::move(packet)));
+}
+
+std::optional<proto::packet> connection::rrecv_packet() noexcept {
+ if (!this->good()) {
+ return std::nullopt;
+ }
+
+ const socket_t sock = this->socks->rsock;
+
+ // Get the size of the backlog, early out of there's nothing there yet.
+ const auto backlog_size = shared::net::get_backlog_size(sock);
+ if (backlog_size < sizeof(packet_header_t)) {
+ return std::nullopt;
+ }
+
+ const packet_header_t read_packet_size = [&]() {
+ packet_header_t ret;
+ recv(sock, &ret, sizeof(packet_header_t), MSG_PEEK);
+ return ntohl(ret);
+ }();
+ if (backlog_size < read_packet_size) { // data not there yet
+ return std::nullopt;
+ }
+ if (read_packet_size < sizeof(packet_header_t)) {
+ this->bad_reason.emplace("received packet too small");
+ return std::nullopt;
+ }
+ if (read_packet_size >= MAX_PACKET_SIZE) {
+ this->bad_reason.emplace("received packet size exceeded limit");
+ return std::nullopt;
+ }
+
+ // Read the actual packet now, based on our claimed size.
+ std::string data(read_packet_size, '\0');
+ if (const auto result = read(sock, std::data(data), read_packet_size);
+ result == -1) {
+ this->bad_reason = shared::net::get_errno_error();
+ return std::nullopt;
+ } else if (result != read_packet_size) {
+ return std::nullopt; // this shouldn't happen?
+ }
+ data = {std::begin(data) + sizeof(packet_header_t), std::end(data)};
+
+ // possible compression bomb :)
+ if (const auto decompress = shared::maybe_decompress_string(data);
+ decompress.has_value()) {
+ data = *decompress;
+ } else {
+ return std::nullopt;
+ }
+
+ // Parse the packet, ignoring the header.
+ proto::packet packet;
+ if (!packet.ParseFromString(data)) {
+ return std::nullopt;
+ }
+
+ return packet;
+}
+
+std::optional<proto::packet> connection::urecv_packet() noexcept {
+ const socket_t sock = this->socks->usock;
+
+restart:
+ if (!this->good()) {
+ return std::nullopt;
+ }
+
+ const auto packet_size = shared::net::get_backlog_size(sock);
+ if (packet_size == 0) {
+ return std::nullopt;
+ }
+
+ std::string data(packet_size, '\0');
+ if (const auto result = read(sock, std::data(data), packet_size);
+ result == -1) {
+
+ if (errno == ECONNREFUSED) { // icmp, ignore it
+ goto restart;
+ }
+
+ this->bad_reason = shared::net::get_errno_error();
+ return std::nullopt;
+ } else if (static_cast<decltype(packet_size)>(result) != packet_size) {
+ return std::nullopt; // shouldn't happen
+ }
+
+ if (const auto decompress = shared::maybe_decompress_string(data);
+ decompress.has_value()) {
+ data = *decompress;
+ } else {
+ goto restart;
+ }
+
+ proto::packet packet;
+ if (!packet.ParseFromString(data)) {
+ goto restart;
+ }
+
+ return packet;
+}
+
+std::optional<proto::packet> connection::recv_packet() noexcept {
+ if (const auto ret = urecv_packet(); ret.has_value()) {
+ return ret;
+ }
+ return rrecv_packet();
+}
+
+bool connection::maybe_send(const packet& packet,
+ const socket_t& sock) noexcept {
+ if (!this->good()) {
+ return false;
+ }
+
+ const auto& data = packet.data;
+ if (const auto result = write(sock, std::data(data), std::size(data));
+ result == -1) {
+
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ return false;
+ }
+
+ this->bad_reason.emplace(get_errno_error());
+ return false;
+ } else if (static_cast<std::size_t>(result) != std::size(data)) {
+ return false;
+ }
+
+ return true;
+}
+
+void connection::poll() {
+ const auto erase_send = [this](auto& packets, const socket_t& sock) {
+ const auto it =
+ std::find_if(std::begin(packets), std::end(packets),
+ [&, this](const auto& packet) {
+ return !this->maybe_send(*packet, sock);
+ });
+ packets.erase(std::begin(packets), it);
+ };
+
+ erase_send(this->upackets, this->socks->usock);
+ erase_send(this->rpackets, this->socks->rsock);
+}
+
+void connection::close() {
+ if (this->socks.has_value()) {
+ shared::net::close_socket(socks->rsock);
+ shared::net::close_socket(socks->usock);
+ }
+ this->socks.reset();
+}
+
+} // namespace net
+} // namespace shared