diff options
Diffstat (limited to 'src/shared/net/connection.cc')
| -rw-r--r-- | src/shared/net/connection.cc | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/src/shared/net/connection.cc b/src/shared/net/connection.cc index 79a83f5..b368735 100644 --- a/src/shared/net/connection.cc +++ b/src/shared/net/connection.cc @@ -1 +1,220 @@ #include "shared/net/connection.hh" + +namespace shared { +namespace net { + +connection::connection(const socket_t& rsock) { + const std::string peer_address = get_socket_peer_address(rsock); + const std::string peer_port = get_socket_peer_port(rsock); + + // Open up a connected usock based on the state of the connected rsock. + const socket_t usock = [&]() -> socket_t { + constexpr addrinfo hints = {.ai_flags = AI_PASSIVE, + .ai_family = AF_INET, + .ai_socktype = SOCK_DGRAM}; + const std::string host_address = get_socket_host_address(rsock); + const std::string host_port = get_socket_host_port(rsock); + const auto host_info = get_addr_info(host_address, host_port, &hints); + const socket_t usock = make_socket(host_info.get()); + bind_socket(usock, host_info.get()); + + const auto peer_info = get_addr_info(peer_address, peer_port, &hints); + connect_socket(usock, peer_info.get()); + + return usock; + }(); + + nonblock_socket(usock); + nonblock_socket(rsock); + + this->socks.emplace(sockets{.rsock = rsock, + .usock = usock, + .address = peer_address, + .port = peer_port}); +} + +connection::connection(connection&& other) noexcept { + std::swap(this->socks, other.socks); + std::swap(this->bad_reason, other.bad_reason); + other.socks.reset(); +} + +connection& connection::operator=(connection&& other) noexcept { + std::swap(this->socks, other.socks); + std::swap(this->bad_reason, other.bad_reason); + other.socks.reset(); + return *this; +} + +connection::~connection() noexcept { this->close(); } + +void connection::rsend_packet(const rpacket_t& packet) noexcept { + this->rpackets.push_back(packet); +} + +void connection::usend_packet(const upacket_t& packet) noexcept { + this->upackets.push_back(packet); +} + +void connection::rsend_packet(rpacket&& packet) noexcept { + this->rsend_packet(std::make_shared<rpacket>(std::move(packet))); +} + +void connection::usend_packet(upacket&& packet) noexcept { + this->usend_packet(std::make_shared<upacket>(std::move(packet))); +} + +std::optional<proto::packet> connection::rrecv_packet() noexcept { + if (!this->good()) { + return std::nullopt; + } + + const socket_t sock = this->socks->rsock; + + // Get the size of the backlog, early out of there's nothing there yet. + const auto backlog_size = shared::net::get_backlog_size(sock); + if (backlog_size < sizeof(packet_header_t)) { + return std::nullopt; + } + + const packet_header_t read_packet_size = [&]() { + packet_header_t ret; + recv(sock, &ret, sizeof(packet_header_t), MSG_PEEK); + return ntohl(ret); + }(); + if (backlog_size < read_packet_size) { // data not there yet + return std::nullopt; + } + if (read_packet_size < sizeof(packet_header_t)) { + this->bad_reason.emplace("received packet too small"); + return std::nullopt; + } + if (read_packet_size >= MAX_PACKET_SIZE) { + this->bad_reason.emplace("received packet size exceeded limit"); + return std::nullopt; + } + + // Read the actual packet now, based on our claimed size. + std::string data(read_packet_size, '\0'); + if (const auto result = read(sock, std::data(data), read_packet_size); + result == -1) { + this->bad_reason = shared::net::get_errno_error(); + return std::nullopt; + } else if (result != read_packet_size) { + return std::nullopt; // this shouldn't happen? + } + data = {std::begin(data) + sizeof(packet_header_t), std::end(data)}; + + // possible compression bomb :) + if (const auto decompress = shared::maybe_decompress_string(data); + decompress.has_value()) { + data = *decompress; + } else { + return std::nullopt; + } + + // Parse the packet, ignoring the header. + proto::packet packet; + if (!packet.ParseFromString(data)) { + return std::nullopt; + } + + return packet; +} + +std::optional<proto::packet> connection::urecv_packet() noexcept { + const socket_t sock = this->socks->usock; + +restart: + if (!this->good()) { + return std::nullopt; + } + + const auto packet_size = shared::net::get_backlog_size(sock); + if (packet_size == 0) { + return std::nullopt; + } + + std::string data(packet_size, '\0'); + if (const auto result = read(sock, std::data(data), packet_size); + result == -1) { + + if (errno == ECONNREFUSED) { // icmp, ignore it + goto restart; + } + + this->bad_reason = shared::net::get_errno_error(); + return std::nullopt; + } else if (static_cast<decltype(packet_size)>(result) != packet_size) { + return std::nullopt; // shouldn't happen + } + + if (const auto decompress = shared::maybe_decompress_string(data); + decompress.has_value()) { + data = *decompress; + } else { + goto restart; + } + + proto::packet packet; + if (!packet.ParseFromString(data)) { + goto restart; + } + + return packet; +} + +std::optional<proto::packet> connection::recv_packet() noexcept { + if (const auto ret = urecv_packet(); ret.has_value()) { + return ret; + } + return rrecv_packet(); +} + +bool connection::maybe_send(const packet& packet, + const socket_t& sock) noexcept { + if (!this->good()) { + return false; + } + + const auto& data = packet.data; + if (const auto result = write(sock, std::data(data), std::size(data)); + result == -1) { + + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return false; + } + + this->bad_reason.emplace(get_errno_error()); + return false; + } else if (static_cast<std::size_t>(result) != std::size(data)) { + return false; + } + + return true; +} + +void connection::poll() { + const auto erase_send = [this](auto& packets, const socket_t& sock) { + const auto it = + std::find_if(std::begin(packets), std::end(packets), + [&, this](const auto& packet) { + return !this->maybe_send(*packet, sock); + }); + packets.erase(std::begin(packets), it); + }; + + erase_send(this->upackets, this->socks->usock); + erase_send(this->rpackets, this->socks->rsock); +} + +void connection::close() { + if (this->socks.has_value()) { + shared::net::close_socket(socks->rsock); + shared::net::close_socket(socks->usock); + } + this->socks.reset(); +} + +} // namespace net +} // namespace shared |
