File udp.hpp
File List > connections > udp.hpp
Go to the documentation of this file
// Copyright 2024 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _GNU_SOURCE
// See `man strerror_r`
#define _GNU_SOURCE
#endif
#include <nebula_core_common/util/errno.hpp>
#include <nebula_core_common/util/expected.hpp>
#include <nebula_core_hw_interfaces/nebula_hw_interfaces_common/connections/socket_utils.hpp>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <unistd.h>
#include <algorithm>
#include <array>
#include <atomic>
#include <cassert>
#include <cerrno>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <optional>
#include <string>
#include <thread>
#include <utility>
#include <vector>
namespace nebula::drivers::connections
{
class UdpSocket
{
struct SocketConfig
{
int32_t polling_interval_ms{10};
size_t buffer_size{1500};
Endpoint host{};
std::optional<in_addr> multicast_ip;
std::optional<Endpoint> sender_filter;
std::optional<Endpoint> send_to;
};
struct MsgBuffers
{
explicit MsgBuffers(std::vector<uint8_t> & receive_buffer)
{
iov.iov_base = receive_buffer.data();
iov.iov_len = receive_buffer.size();
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = control.data();
msg.msg_controllen = control.size();
msg.msg_name = &sender_addr;
msg.msg_namelen = sizeof(sender_addr);
}
iovec iov{};
std::array<std::byte, 1024> control{};
sockaddr_in sender_addr{};
msghdr msg{};
};
class DropMonitor
{
uint32_t last_drop_counter_{0};
public:
uint32_t get_drops_since_last_receive(uint32_t current_drop_counter)
{
uint32_t last = last_drop_counter_;
last_drop_counter_ = current_drop_counter;
bool counter_did_wrap = current_drop_counter < last;
if (counter_did_wrap) {
return (UINT32_MAX - last) + current_drop_counter;
}
return current_drop_counter - last;
}
};
UdpSocket(SockFd sock_fd, SocketConfig config) : sock_fd_(std::move(sock_fd)), config_{config} {}
public:
~UdpSocket() { unsubscribe(); }
class Builder
{
public:
Builder(const std::string & host_ip, uint16_t host_port)
{
in_addr host_in_addr = parse_ip(host_ip).value_or_throw();
if (host_in_addr.s_addr == INADDR_BROADCAST)
throw UsageError("Do not bind to broadcast IP. Bind to 0.0.0.0 or a specific IP instead.");
config_.host = {host_in_addr, host_port};
int sock_fd = socket(AF_INET, SOCK_DGRAM, 0);
if (sock_fd == -1) throw SocketError(errno);
sock_fd_ = SockFd{sock_fd};
sock_fd_.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1).value_or_throw();
// Enable kernel-space receive time measurement
sock_fd_.setsockopt(SOL_SOCKET, SO_TIMESTAMP, 1).value_or_throw();
// Enable reporting on packets dropped due to full UDP receive buffer
sock_fd_.setsockopt(SOL_SOCKET, SO_RXQ_OVFL, 1).value_or_throw();
}
Builder && limit_to_sender(const std::string & sender_ip, uint16_t sender_port)
{
config_.sender_filter.emplace(Endpoint{parse_ip(sender_ip).value_or_throw(), sender_port});
return std::move(*this);
}
Builder && set_send_destination(const std::string & dest_ip, uint16_t dest_port)
{
config_.send_to.emplace(Endpoint{parse_ip(dest_ip).value_or_throw(), dest_port});
return std::move(*this);
}
Builder && set_mtu(size_t bytes)
{
config_.buffer_size = bytes;
return std::move(*this);
}
Builder && set_socket_buffer_size(size_t bytes)
{
if (bytes > static_cast<size_t>(INT32_MAX))
throw UsageError("The maximum value supported (0x7FFFFFF) has been exceeded");
auto buf_size = static_cast<int>(bytes);
sock_fd_.setsockopt(SOL_SOCKET, SO_RCVBUF, buf_size).value_or_throw();
return std::move(*this);
}
Builder && join_multicast_group(const std::string & group_ip)
{
if (config_.multicast_ip)
throw UsageError("Only one multicast group can be joined by this socket");
ip_mreq mreq{parse_ip(group_ip).value_or_throw(), config_.host.ip};
sock_fd_.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq).value_or_throw();
config_.multicast_ip.emplace(mreq.imr_multiaddr);
return std::move(*this);
}
Builder && set_polling_interval(int32_t interval_ms)
{
config_.polling_interval_ms = interval_ms;
return std::move(*this);
}
UdpSocket bind() &&
{
sockaddr_in addr = config_.host.to_sockaddr();
if (config_.multicast_ip) {
addr.sin_addr = *config_.multicast_ip;
}
int result = ::bind(sock_fd_.get(), (sockaddr *)&addr, sizeof(addr));
if (result == -1) throw SocketError(errno);
return UdpSocket{std::move(sock_fd_), config_};
}
private:
SockFd sock_fd_;
SocketConfig config_;
};
struct PerfCounters
{
uint64_t receive_duration_ns{0};
uint64_t n_woken_without_data{0};
uint64_t n_woken_by_wrong_sender{0};
};
struct RxMetadata
{
std::optional<uint64_t> timestamp_ns;
uint64_t n_packets_dropped_since_last_receive{0};
PerfCounters packet_perf_counters{};
bool truncated{};
};
using callback_t = std::function<void(std::vector<uint8_t> & data, const RxMetadata & metadata)>;
UdpSocket & subscribe(callback_t && callback)
{
unsubscribe();
callback_ = std::move(callback);
launch_receiver();
return *this;
}
bool is_subscribed() { return running_; }
UdpSocket & unsubscribe()
{
running_ = false;
if (receive_thread_.joinable()) {
receive_thread_.join();
}
return *this;
}
void send(const std::vector<uint8_t> & data)
{
if (!config_.send_to) throw UsageError("No destination set");
sockaddr_in dest_addr = config_.send_to->to_sockaddr();
ssize_t result{-1};
do {
result = sendto(
sock_fd_.get(), data.data(), data.size(), 0, (sockaddr *)&dest_addr, sizeof(dest_addr));
} while (result == -1 && errno == EINTR);
if (result == -1) throw SocketError(errno);
}
UdpSocket(const UdpSocket &) = delete;
UdpSocket(UdpSocket && other) noexcept
: sock_fd_((other.unsubscribe(), std::move(other.sock_fd_))), config_(other.config_)
{
if (other.callback_) subscribe(std::move(other.callback_));
};
UdpSocket & operator=(const UdpSocket &) = delete;
UdpSocket & operator=(UdpSocket &&) = delete;
private:
void launch_receiver()
{
assert(callback_);
running_ = true;
receive_thread_ = std::thread([this]() {
std::vector<uint8_t> buffer;
DropMonitor drop_monitor{};
PerfCounters current_packet_perf_counters{};
while (running_) {
auto data_available = is_socket_ready(sock_fd_.get(), config_.polling_interval_ms);
auto t_start = std::chrono::steady_clock::now();
if (!data_available.has_value()) throw SocketError(data_available.error());
if (!data_available.value()) {
current_packet_perf_counters.n_woken_without_data++;
current_packet_perf_counters.receive_duration_ns +=
(std::chrono::steady_clock::now() - t_start).count();
continue;
}
buffer.resize(config_.buffer_size);
MsgBuffers msg_header{buffer};
ssize_t recv_result{-1};
do {
// As per `man recvmsg`, zero-length datagrams are permitted and valid. Since the socket
// is blocking, a recv_result of 0 means we received a valid 0-length datagram.
recv_result = recvmsg(sock_fd_.get(), &msg_header.msg, MSG_TRUNC);
} while (recv_result == -1 && errno == EINTR);
if (recv_result < 0) throw SocketError(errno);
auto untruncated_packet_length = static_cast<size_t>(recv_result);
if (!is_accepted_sender(msg_header.sender_addr)) {
current_packet_perf_counters.n_woken_by_wrong_sender++;
current_packet_perf_counters.receive_duration_ns +=
(std::chrono::steady_clock::now() - t_start).count();
continue;
}
RxMetadata metadata;
get_receive_metadata(msg_header.msg, metadata, drop_monitor);
metadata.truncated = untruncated_packet_length > config_.buffer_size;
// Resize down to match received data so callback sees correct size
auto valids = std::min(config_.buffer_size, untruncated_packet_length);
buffer.resize(valids);
current_packet_perf_counters.receive_duration_ns +=
(std::chrono::steady_clock::now() - t_start).count();
metadata.packet_perf_counters = current_packet_perf_counters;
current_packet_perf_counters = {};
callback_(buffer, metadata);
}
});
}
static void get_receive_metadata(msghdr & msg, RxMetadata & metadata, DropMonitor & drop_monitor)
{
for (cmsghdr * cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET) continue;
switch (cmsg->cmsg_type) {
case SO_TIMESTAMP: {
const auto * tv = (const timeval *)CMSG_DATA(cmsg);
uint64_t timestamp_ns = tv->tv_sec * 1'000'000'000 + tv->tv_usec * 1000;
metadata.timestamp_ns.emplace(timestamp_ns);
break;
}
case SO_RXQ_OVFL: {
const auto * drops = (const uint32_t *)CMSG_DATA(cmsg);
metadata.n_packets_dropped_since_last_receive =
drop_monitor.get_drops_since_last_receive(*drops);
break;
}
default:
continue;
}
}
}
bool is_accepted_sender(const sockaddr_in & sender_addr)
{
if (!config_.sender_filter) return true;
return sender_addr.sin_addr.s_addr == config_.sender_filter->ip.s_addr;
}
SockFd sock_fd_;
SocketConfig config_;
std::atomic_bool running_{false};
std::thread receive_thread_;
callback_t callback_;
};
} // namespace nebula::drivers::connections