File udp.hpp
File List > connections > udp.hpp
Go to the documentation of this file
// Copyright 2024 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _GNU_SOURCE
// See `man strerror_r`
#define _GNU_SOURCE
#endif
#include <nebula_common/util/expected.hpp>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <unistd.h>
#include <algorithm>
#include <array>
#include <atomic>
#include <cassert>
#include <cerrno>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <exception>
#include <functional>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <thread>
#include <utility>
#include <variant>
#include <vector>
namespace nebula::drivers::connections
{
class SocketError : public std::exception
{
static constexpr size_t gnu_max_strerror_length = 1024;
public:
explicit SocketError(int err_no)
{
std::array<char, gnu_max_strerror_length> msg_buf;
std::string_view msg = strerror_r(err_no, msg_buf.data(), msg_buf.size());
what_ = std::string{msg};
}
explicit SocketError(const std::string_view & msg) : what_(msg) {}
const char * what() const noexcept override { return what_.c_str(); }
private:
std::string what_;
};
class UsageError : public std::runtime_error
{
public:
explicit UsageError(const std::string & msg) : std::runtime_error(msg) {}
};
class UdpSocket
{
struct Endpoint
{
in_addr ip;
uint16_t port;
};
class SockFd
{
static const int uninitialized = -1;
int sock_fd_;
public:
SockFd() : sock_fd_{uninitialized} {}
explicit SockFd(int sock_fd) : sock_fd_{sock_fd} {}
SockFd(SockFd && other) noexcept : sock_fd_{other.sock_fd_} { other.sock_fd_ = uninitialized; }
SockFd(const SockFd &) = delete;
SockFd & operator=(const SockFd &) = delete;
SockFd & operator=(SockFd && other)
{
std::swap(sock_fd_, other.sock_fd_);
return *this;
};
~SockFd()
{
if (sock_fd_ == uninitialized) return;
close(sock_fd_);
}
[[nodiscard]] int get() const { return sock_fd_; }
template <typename T>
[[nodiscard]] util::expected<std::monostate, SocketError> setsockopt(
int level, int optname, const T & optval)
{
int result = ::setsockopt(sock_fd_, level, optname, &optval, sizeof(T));
if (result == -1) return SocketError(errno);
return std::monostate{};
}
};
struct SocketConfig
{
int32_t polling_interval_ms{10};
size_t buffer_size{1500};
Endpoint host;
std::optional<in_addr> multicast_ip;
std::optional<Endpoint> sender;
};
struct MsgBuffers
{
msghdr msg{};
iovec iov{};
std::array<std::byte, 1024> control;
sockaddr_in sender_addr;
};
class DropMonitor
{
uint32_t last_drop_counter_{0};
public:
uint32_t get_drops_since_last_receive(uint32_t current_drop_counter)
{
uint32_t last = last_drop_counter_;
last_drop_counter_ = current_drop_counter;
bool counter_did_wrap = current_drop_counter < last;
if (counter_did_wrap) {
return (UINT32_MAX - last) + current_drop_counter;
}
return current_drop_counter - last;
}
};
UdpSocket(SockFd sock_fd, SocketConfig config)
: sock_fd_(std::move(sock_fd)), poll_fd_{sock_fd_.get(), POLLIN, 0}, config_{std::move(config)}
{
}
public:
class Builder
{
public:
Builder(const std::string & host_ip, uint16_t host_port)
{
in_addr host_in_addr = parse_ip_or_throw(host_ip);
if (host_in_addr.s_addr == INADDR_BROADCAST)
throw UsageError("Do not bind to broadcast IP. Bind to 0.0.0.0 or a specific IP instead.");
config_.host = {host_in_addr, host_port};
int sock_fd = socket(AF_INET, SOCK_DGRAM, 0);
if (sock_fd == -1) throw SocketError(errno);
sock_fd_ = SockFd{sock_fd};
sock_fd_.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1).value_or_throw();
// Enable kernel-space receive time measurement
sock_fd_.setsockopt(SOL_SOCKET, SO_TIMESTAMP, 1).value_or_throw();
// Enable reporting on packets dropped due to full UDP receive buffer
sock_fd_.setsockopt(SOL_SOCKET, SO_RXQ_OVFL, 1).value_or_throw();
}
Builder && limit_to_sender(const std::string & sender_ip, uint16_t sender_port)
{
config_.sender.emplace(Endpoint{parse_ip_or_throw(sender_ip), sender_port});
return std::move(*this);
}
Builder && set_mtu(size_t bytes)
{
config_.buffer_size = bytes;
return std::move(*this);
}
Builder && set_socket_buffer_size(size_t bytes)
{
if (bytes > static_cast<size_t>(INT32_MAX))
throw UsageError("The maximum value supported (0x7FFFFFF) has been exceeded");
auto buf_size = static_cast<int>(bytes);
sock_fd_.setsockopt(SOL_SOCKET, SO_RCVBUF, buf_size).value_or_throw();
return std::move(*this);
}
Builder && join_multicast_group(const std::string & group_ip)
{
if (config_.multicast_ip)
throw UsageError("Only one multicast group can be joined by this socket");
ip_mreq mreq{parse_ip_or_throw(group_ip), config_.host.ip};
sock_fd_.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq).value_or_throw();
config_.multicast_ip.emplace(mreq.imr_multiaddr);
return std::move(*this);
}
Builder && set_polling_interval(int32_t interval_ms)
{
config_.polling_interval_ms = interval_ms;
return std::move(*this);
}
UdpSocket bind() &&
{
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_port = htons(config_.host.port);
addr.sin_addr = config_.multicast_ip ? *config_.multicast_ip : config_.host.ip;
int result = ::bind(sock_fd_.get(), (struct sockaddr *)&addr, sizeof(addr));
if (result == -1) throw SocketError(errno);
return UdpSocket{std::move(sock_fd_), config_};
}
private:
SockFd sock_fd_;
SocketConfig config_;
};
struct RxMetadata
{
std::optional<uint64_t> timestamp_ns;
uint64_t drops_since_last_receive{0};
bool truncated;
};
using callback_t = std::function<void(const std::vector<uint8_t> &, const RxMetadata &)>;
UdpSocket & subscribe(callback_t && callback)
{
unsubscribe();
callback_ = std::move(callback);
launch_receiver();
return *this;
}
bool is_subscribed() { return running_; }
UdpSocket & unsubscribe()
{
running_ = false;
if (receive_thread_.joinable()) {
receive_thread_.join();
}
return *this;
}
UdpSocket(const UdpSocket &) = delete;
UdpSocket(UdpSocket && other)
: sock_fd_((other.unsubscribe(), std::move(other.sock_fd_))),
poll_fd_(std::move(other.poll_fd_)),
config_(std::move(other.config_)),
drop_monitor_(std::move(other.drop_monitor_))
{
if (other.callback_) subscribe(std::move(other.callback_));
};
UdpSocket & operator=(const UdpSocket &) = delete;
UdpSocket & operator=(UdpSocket &&) = delete;
~UdpSocket() { unsubscribe(); }
private:
void launch_receiver()
{
assert(callback_);
running_ = true;
receive_thread_ = std::thread([this]() {
std::vector<uint8_t> buffer;
while (running_) {
auto data_available = is_data_available();
if (!data_available.has_value()) throw SocketError(data_available.error());
if (!data_available.value()) continue;
buffer.resize(config_.buffer_size);
auto msg_header = make_msg_header(buffer);
// As per `man recvmsg`, zero-length datagrams are permitted and valid. Since the socket is
// blocking, a recv_result of 0 means we received a valid 0-length datagram.
ssize_t recv_result = recvmsg(sock_fd_.get(), &msg_header.msg, MSG_TRUNC);
if (recv_result < 0) throw SocketError(errno);
size_t untruncated_packet_length = recv_result;
if (!is_accepted_sender(msg_header.sender_addr)) continue;
RxMetadata metadata;
get_receive_metadata(msg_header.msg, metadata);
metadata.truncated = untruncated_packet_length > config_.buffer_size;
buffer.resize(std::min(config_.buffer_size, untruncated_packet_length));
callback_(buffer, metadata);
}
});
}
void get_receive_metadata(msghdr & msg, RxMetadata & inout_metadata)
{
for (cmsghdr * cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET) continue;
switch (cmsg->cmsg_type) {
case SO_TIMESTAMP: {
auto tv = (timeval const *)CMSG_DATA(cmsg);
uint64_t timestamp_ns = tv->tv_sec * 1'000'000'000 + tv->tv_usec * 1000;
inout_metadata.timestamp_ns.emplace(timestamp_ns);
break;
}
case SO_RXQ_OVFL: {
auto drops = (uint32_t const *)CMSG_DATA(cmsg);
inout_metadata.drops_since_last_receive =
drop_monitor_.get_drops_since_last_receive(*drops);
break;
}
default:
continue;
}
}
}
util::expected<bool, int> is_data_available()
{
int status = poll(&poll_fd_, 1, config_.polling_interval_ms);
if (status == -1) return errno;
return (poll_fd_.revents & POLLIN) && (status > 0);
}
bool is_accepted_sender(const sockaddr_in & sender_addr)
{
if (!config_.sender) return true;
return sender_addr.sin_addr.s_addr == config_.sender->ip.s_addr;
}
static MsgBuffers make_msg_header(std::vector<uint8_t> & receive_buffer)
{
msghdr msg{};
iovec iov{};
std::array<std::byte, 1024> control;
sockaddr_in sender_addr;
socklen_t sender_addr_len = sizeof(sender_addr);
iov.iov_base = receive_buffer.data();
iov.iov_len = receive_buffer.size();
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = control.data();
msg.msg_controllen = control.size();
msg.msg_name = &sender_addr;
msg.msg_namelen = sender_addr_len;
return MsgBuffers{msg, iov, control, sender_addr};
}
static in_addr parse_ip_or_throw(const std::string & ip)
{
in_addr parsed_addr;
bool valid = inet_aton(ip.c_str(), &parsed_addr);
if (!valid) throw UsageError("Invalid IP address given");
return parsed_addr;
}
SockFd sock_fd_;
pollfd poll_fd_;
SocketConfig config_;
std::atomic_bool running_{false};
std::thread receive_thread_;
callback_t callback_;
DropMonitor drop_monitor_;
};
} // namespace nebula::drivers::connections