Skip to content

File udp.hpp

File List > connections > udp.hpp

Go to the documentation of this file

// Copyright 2024 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#ifndef _GNU_SOURCE
// See `man strerror_r`
#define _GNU_SOURCE
#endif

#include <nebula_common/util/expected.hpp>

#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <unistd.h>

#include <algorithm>
#include <array>
#include <atomic>
#include <cassert>
#include <cerrno>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <exception>
#include <functional>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <thread>
#include <utility>
#include <variant>
#include <vector>

namespace nebula::drivers::connections
{

class SocketError : public std::exception
{
  static constexpr size_t gnu_max_strerror_length = 1024;

public:
  explicit SocketError(int err_no)
  {
    std::array<char, gnu_max_strerror_length> msg_buf;
    std::string_view msg = strerror_r(err_no, msg_buf.data(), msg_buf.size());
    what_ = std::string{msg};
  }

  explicit SocketError(const std::string_view & msg) : what_(msg) {}

  const char * what() const noexcept override { return what_.c_str(); }

private:
  std::string what_;
};

class UsageError : public std::runtime_error
{
public:
  explicit UsageError(const std::string & msg) : std::runtime_error(msg) {}
};

class UdpSocket
{
  struct Endpoint
  {
    in_addr ip;
    uint16_t port;
  };

  class SockFd
  {
    static const int uninitialized = -1;
    int sock_fd_;

  public:
    SockFd() : sock_fd_{uninitialized} {}
    explicit SockFd(int sock_fd) : sock_fd_{sock_fd} {}
    SockFd(SockFd && other) noexcept : sock_fd_{other.sock_fd_} { other.sock_fd_ = uninitialized; }

    SockFd(const SockFd &) = delete;
    SockFd & operator=(const SockFd &) = delete;
    SockFd & operator=(SockFd && other)
    {
      std::swap(sock_fd_, other.sock_fd_);
      return *this;
    };

    ~SockFd()
    {
      if (sock_fd_ == uninitialized) return;
      close(sock_fd_);
    }

    [[nodiscard]] int get() const { return sock_fd_; }

    template <typename T>
    [[nodiscard]] util::expected<std::monostate, SocketError> setsockopt(
      int level, int optname, const T & optval)
    {
      int result = ::setsockopt(sock_fd_, level, optname, &optval, sizeof(T));
      if (result == -1) return SocketError(errno);
      return std::monostate{};
    }
  };

  struct SocketConfig
  {
    int32_t polling_interval_ms{10};

    size_t buffer_size{1500};
    Endpoint host;
    std::optional<in_addr> multicast_ip;
    std::optional<Endpoint> sender;
  };

  struct MsgBuffers
  {
    msghdr msg{};
    iovec iov{};
    std::array<std::byte, 1024> control;
    sockaddr_in sender_addr;
  };

  class DropMonitor
  {
    uint32_t last_drop_counter_{0};

  public:
    uint32_t get_drops_since_last_receive(uint32_t current_drop_counter)
    {
      uint32_t last = last_drop_counter_;
      last_drop_counter_ = current_drop_counter;

      bool counter_did_wrap = current_drop_counter < last;
      if (counter_did_wrap) {
        return (UINT32_MAX - last) + current_drop_counter;
      }

      return current_drop_counter - last;
    }
  };

  UdpSocket(SockFd sock_fd, SocketConfig config)
  : sock_fd_(std::move(sock_fd)), poll_fd_{sock_fd_.get(), POLLIN, 0}, config_{std::move(config)}
  {
  }

public:
  class Builder
  {
  public:
    Builder(const std::string & host_ip, uint16_t host_port)
    {
      in_addr host_in_addr = parse_ip_or_throw(host_ip);
      if (host_in_addr.s_addr == INADDR_BROADCAST)
        throw UsageError("Do not bind to broadcast IP. Bind to 0.0.0.0 or a specific IP instead.");

      config_.host = {host_in_addr, host_port};

      int sock_fd = socket(AF_INET, SOCK_DGRAM, 0);
      if (sock_fd == -1) throw SocketError(errno);
      sock_fd_ = SockFd{sock_fd};

      sock_fd_.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1).value_or_throw();

      // Enable kernel-space receive time measurement
      sock_fd_.setsockopt(SOL_SOCKET, SO_TIMESTAMP, 1).value_or_throw();

      // Enable reporting on packets dropped due to full UDP receive buffer
      sock_fd_.setsockopt(SOL_SOCKET, SO_RXQ_OVFL, 1).value_or_throw();
    }

    Builder && limit_to_sender(const std::string & sender_ip, uint16_t sender_port)
    {
      config_.sender.emplace(Endpoint{parse_ip_or_throw(sender_ip), sender_port});
      return std::move(*this);
    }

    Builder && set_mtu(size_t bytes)
    {
      config_.buffer_size = bytes;
      return std::move(*this);
    }

    Builder && set_socket_buffer_size(size_t bytes)
    {
      if (bytes > static_cast<size_t>(INT32_MAX))
        throw UsageError("The maximum value supported (0x7FFFFFF) has been exceeded");

      auto buf_size = static_cast<int>(bytes);
      sock_fd_.setsockopt(SOL_SOCKET, SO_RCVBUF, buf_size).value_or_throw();
      return std::move(*this);
    }

    Builder && join_multicast_group(const std::string & group_ip)
    {
      if (config_.multicast_ip)
        throw UsageError("Only one multicast group can be joined by this socket");
      ip_mreq mreq{parse_ip_or_throw(group_ip), config_.host.ip};

      sock_fd_.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq).value_or_throw();
      config_.multicast_ip.emplace(mreq.imr_multiaddr);
      return std::move(*this);
    }

    Builder && set_polling_interval(int32_t interval_ms)
    {
      config_.polling_interval_ms = interval_ms;
      return std::move(*this);
    }

    UdpSocket bind() &&
    {
      sockaddr_in addr{};
      addr.sin_family = AF_INET;
      addr.sin_port = htons(config_.host.port);
      addr.sin_addr = config_.multicast_ip ? *config_.multicast_ip : config_.host.ip;

      int result = ::bind(sock_fd_.get(), (struct sockaddr *)&addr, sizeof(addr));
      if (result == -1) throw SocketError(errno);

      return UdpSocket{std::move(sock_fd_), config_};
    }

  private:
    SockFd sock_fd_;
    SocketConfig config_;
  };

  struct RxMetadata
  {
    std::optional<uint64_t> timestamp_ns;
    uint64_t drops_since_last_receive{0};
    bool truncated;
  };

  using callback_t = std::function<void(const std::vector<uint8_t> &, const RxMetadata &)>;

  UdpSocket & subscribe(callback_t && callback)
  {
    unsubscribe();
    callback_ = std::move(callback);
    launch_receiver();
    return *this;
  }

  bool is_subscribed() { return running_; }

  UdpSocket & unsubscribe()
  {
    running_ = false;
    if (receive_thread_.joinable()) {
      receive_thread_.join();
    }
    return *this;
  }

  UdpSocket(const UdpSocket &) = delete;
  UdpSocket(UdpSocket && other)
  : sock_fd_((other.unsubscribe(), std::move(other.sock_fd_))),
    poll_fd_(std::move(other.poll_fd_)),
    config_(std::move(other.config_)),
    drop_monitor_(std::move(other.drop_monitor_))
  {
    if (other.callback_) subscribe(std::move(other.callback_));
  };

  UdpSocket & operator=(const UdpSocket &) = delete;
  UdpSocket & operator=(UdpSocket &&) = delete;

  ~UdpSocket() { unsubscribe(); }

private:
  void launch_receiver()
  {
    assert(callback_);

    running_ = true;
    receive_thread_ = std::thread([this]() {
      std::vector<uint8_t> buffer;
      while (running_) {
        auto data_available = is_data_available();
        if (!data_available.has_value()) throw SocketError(data_available.error());
        if (!data_available.value()) continue;

        buffer.resize(config_.buffer_size);
        auto msg_header = make_msg_header(buffer);

        // As per `man recvmsg`, zero-length datagrams are permitted and valid. Since the socket is
        // blocking, a recv_result of 0 means we received a valid 0-length datagram.
        ssize_t recv_result = recvmsg(sock_fd_.get(), &msg_header.msg, MSG_TRUNC);
        if (recv_result < 0) throw SocketError(errno);
        size_t untruncated_packet_length = recv_result;

        if (!is_accepted_sender(msg_header.sender_addr)) continue;

        RxMetadata metadata;
        get_receive_metadata(msg_header.msg, metadata);
        metadata.truncated = untruncated_packet_length > config_.buffer_size;

        buffer.resize(std::min(config_.buffer_size, untruncated_packet_length));
        callback_(buffer, metadata);
      }
    });
  }

  void get_receive_metadata(msghdr & msg, RxMetadata & inout_metadata)
  {
    for (cmsghdr * cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
      if (cmsg->cmsg_level != SOL_SOCKET) continue;

      switch (cmsg->cmsg_type) {
        case SO_TIMESTAMP: {
          auto tv = (timeval const *)CMSG_DATA(cmsg);
          uint64_t timestamp_ns = tv->tv_sec * 1'000'000'000 + tv->tv_usec * 1000;
          inout_metadata.timestamp_ns.emplace(timestamp_ns);
          break;
        }
        case SO_RXQ_OVFL: {
          auto drops = (uint32_t const *)CMSG_DATA(cmsg);
          inout_metadata.drops_since_last_receive =
            drop_monitor_.get_drops_since_last_receive(*drops);
          break;
        }
        default:
          continue;
      }
    }
  }

  util::expected<bool, int> is_data_available()
  {
    int status = poll(&poll_fd_, 1, config_.polling_interval_ms);
    if (status == -1) return errno;
    return (poll_fd_.revents & POLLIN) && (status > 0);
  }

  bool is_accepted_sender(const sockaddr_in & sender_addr)
  {
    if (!config_.sender) return true;
    return sender_addr.sin_addr.s_addr == config_.sender->ip.s_addr;
  }

  static MsgBuffers make_msg_header(std::vector<uint8_t> & receive_buffer)
  {
    msghdr msg{};
    iovec iov{};
    std::array<std::byte, 1024> control;

    sockaddr_in sender_addr;
    socklen_t sender_addr_len = sizeof(sender_addr);

    iov.iov_base = receive_buffer.data();
    iov.iov_len = receive_buffer.size();

    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_control = control.data();
    msg.msg_controllen = control.size();
    msg.msg_name = &sender_addr;
    msg.msg_namelen = sender_addr_len;

    return MsgBuffers{msg, iov, control, sender_addr};
  }

  static in_addr parse_ip_or_throw(const std::string & ip)
  {
    in_addr parsed_addr;
    bool valid = inet_aton(ip.c_str(), &parsed_addr);
    if (!valid) throw UsageError("Invalid IP address given");
    return parsed_addr;
  }

  SockFd sock_fd_;
  pollfd poll_fd_;

  SocketConfig config_;

  std::atomic_bool running_{false};
  std::thread receive_thread_;
  callback_t callback_;

  DropMonitor drop_monitor_;
};

}  // namespace nebula::drivers::connections