#include "shared/net/connection.hh" namespace shared { namespace net { connection::connection(const socket_t& rsock) { const std::string peer_address = get_socket_peer_address(rsock); const std::string peer_port = get_socket_peer_port(rsock); // Open up a connected usock based on the state of the connected rsock. const socket_t usock = [&]() -> socket_t { constexpr addrinfo hints = {.ai_flags = AI_PASSIVE, .ai_family = AF_INET, .ai_socktype = SOCK_DGRAM}; const std::string host_address = get_socket_host_address(rsock); const std::string host_port = get_socket_host_port(rsock); const auto host_info = get_addr_info(host_address, host_port, &hints); const socket_t usock = make_socket(host_info.get()); bind_socket(usock, host_info.get()); const auto peer_info = get_addr_info(peer_address, peer_port, &hints); connect_socket(usock, peer_info.get()); return usock; }(); nonblock_socket(usock); nonblock_socket(rsock); this->socks.emplace(sockets{.rsock = rsock, .usock = usock, .address = peer_address, .port = peer_port}); } connection::connection(connection&& other) noexcept { std::swap(this->socks, other.socks); std::swap(this->bad_reason, other.bad_reason); other.socks.reset(); } connection& connection::operator=(connection&& other) noexcept { std::swap(this->socks, other.socks); std::swap(this->bad_reason, other.bad_reason); other.socks.reset(); return *this; } connection::~connection() noexcept { this->close(); } void connection::rsend_packet(const rpacket_t& packet) noexcept { this->rpackets.push_back(packet); } void connection::usend_packet(const upacket_t& packet) noexcept { this->upackets.push_back(packet); } void connection::rsend_packet(rpacket&& packet) noexcept { this->rsend_packet(std::make_shared(std::move(packet))); } void connection::usend_packet(upacket&& packet) noexcept { this->usend_packet(std::make_shared(std::move(packet))); } std::optional connection::rrecv_packet() noexcept { if (!this->good()) { return std::nullopt; } const socket_t sock = this->socks->rsock; // Get the size of the backlog, early out of there's nothing there yet. const auto backlog_size = shared::net::get_backlog_size(sock); if (backlog_size < sizeof(packet_header_t)) { return std::nullopt; } const packet_header_t read_packet_size = [&]() { packet_header_t ret; recv(sock, &ret, sizeof(packet_header_t), MSG_PEEK); return ntohl(ret); }(); if (backlog_size < read_packet_size) { // data not there yet return std::nullopt; } if (read_packet_size < sizeof(packet_header_t)) { this->bad_reason.emplace("received packet too small"); return std::nullopt; } if (read_packet_size >= MAX_PACKET_SIZE) { this->bad_reason.emplace("received packet size exceeded limit"); return std::nullopt; } // Read the actual packet now, based on our claimed size. std::string data(read_packet_size, '\0'); if (const auto result = read(sock, std::data(data), read_packet_size); result == -1) { this->bad_reason = shared::net::get_errno_error(); return std::nullopt; } else if (result != read_packet_size) { return std::nullopt; // this shouldn't happen? } data = {std::begin(data) + sizeof(packet_header_t), std::end(data)}; // possible compression bomb :) if (const auto decompress = shared::maybe_decompress_string(data); decompress.has_value()) { data = *decompress; } else { return std::nullopt; } // Parse the packet, ignoring the header. proto::packet packet; if (!packet.ParseFromString(data)) { return std::nullopt; } return packet; } std::optional connection::urecv_packet() noexcept { const socket_t sock = this->socks->usock; restart: if (!this->good()) { return std::nullopt; } const auto packet_size = shared::net::get_backlog_size(sock); if (packet_size == 0) { return std::nullopt; } std::string data(packet_size, '\0'); if (const auto result = read(sock, std::data(data), packet_size); result == -1) { if (errno == ECONNREFUSED) { // icmp, ignore it goto restart; } this->bad_reason = shared::net::get_errno_error(); return std::nullopt; } else if (static_cast(result) != packet_size) { return std::nullopt; // shouldn't happen } if (const auto decompress = shared::maybe_decompress_string(data); decompress.has_value()) { data = *decompress; } else { goto restart; } proto::packet packet; if (!packet.ParseFromString(data)) { goto restart; } return packet; } std::optional connection::recv_packet() noexcept { if (const auto ret = urecv_packet(); ret.has_value()) { return ret; } return rrecv_packet(); } bool connection::maybe_send(const packet& packet, const socket_t& sock) noexcept { if (!this->good()) { return false; } const auto& data = packet.data; if (const auto result = write(sock, std::data(data), std::size(data)); result == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return false; } this->bad_reason.emplace(get_errno_error()); return false; } else if (static_cast(result) != std::size(data)) { return false; } return true; } void connection::poll() { const auto erase_send = [this](auto& packets, const socket_t& sock) { const auto it = std::find_if(std::begin(packets), std::end(packets), [&, this](const auto& packet) { return !this->maybe_send(*packet, sock); }); packets.erase(std::begin(packets), it); }; erase_send(this->upackets, this->socks->usock); erase_send(this->rpackets, this->socks->rsock); } void connection::close() { if (this->socks.has_value()) { shared::net::close_socket(socks->rsock); shared::net::close_socket(socks->usock); } this->socks.reset(); } } // namespace net } // namespace shared