Skip to content

File tcp.hpp

File List > connections > tcp.hpp

Go to the documentation of this file

// Copyright 2026 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#ifndef _GNU_SOURCE
// See `man strerror_r`
#define _GNU_SOURCE
#endif

#include <nebula_core_common/util/errno.hpp>
#include <nebula_core_common/util/expected.hpp>
#include <nebula_core_hw_interfaces/nebula_hw_interfaces_common/connections/socket_utils.hpp>

#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <unistd.h>

#include <atomic>
#include <cassert>
#include <cerrno>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <string>
#include <string_view>
#include <thread>
#include <utility>
#include <vector>

namespace nebula::drivers::connections
{

class TcpSocket
{
  struct SocketConfig
  {
    int32_t polling_interval_ms{10};
    int32_t connect_timeout_ms{3000};
    size_t buffer_size{4096};
    Endpoint target{};
  };

  TcpSocket(SockFd sock_fd, SocketConfig config) : sock_fd_(std::move(sock_fd)), config_{config} {}

public:
  class Builder
  {
  public:
    Builder(const std::string & target_ip, uint16_t target_port)
    {
      in_addr target_in_addr = parse_ip(target_ip).value_or_throw();
      config_.target = {target_in_addr, target_port};

      init_socket();
    }

    explicit Builder(const Endpoint & target)
    {
      config_.target = target;
      init_socket();
    }

    Builder && set_socket_buffer_size(size_t bytes)
    {
      if (bytes > static_cast<size_t>(INT32_MAX))
        throw UsageError("The maximum value supported (0x7FFFFFF) has been exceeded");

      auto buf_size = static_cast<int>(bytes);
      sock_fd_.setsockopt(SOL_SOCKET, SO_RCVBUF, buf_size).value_or_throw();
      return std::move(*this);
    }

    Builder && set_polling_interval(int32_t interval_ms)
    {
      config_.polling_interval_ms = interval_ms;
      return std::move(*this);
    }

    Builder && set_connect_timeout(int32_t timeout_ms)
    {
      config_.connect_timeout_ms = timeout_ms;
      return std::move(*this);
    }

    Builder && set_buffer_size(size_t bytes)
    {
      config_.buffer_size = bytes;
      return std::move(*this);
    }

    TcpSocket connect() &&
    {
      sockaddr_in addr = config_.target.to_sockaddr();

      // Set socket to non-blocking
      int flags = fcntl(sock_fd_.get(), F_GETFL, 0);
      if (flags == -1) throw SocketError(errno);
      if (fcntl(sock_fd_.get(), F_SETFL, flags | O_NONBLOCK) == -1) throw SocketError(errno);

      int result{-1};
      do {
        result = ::connect(sock_fd_.get(), (sockaddr *)&addr, sizeof(addr));
      } while (result == -1 && errno == EINTR);

      if (result == -1) {
        if (errno != EINPROGRESS) {
          throw SocketError(errno);
        }

        // Connection is in progress, poll for completion
        pollfd pfd{sock_fd_.get(), POLLOUT, 0};
        int poll_result = poll(&pfd, 1, config_.connect_timeout_ms);

        if (poll_result == -1) {
          throw SocketError(errno);
        }
        if (poll_result == 0) {
          throw SocketError("Connection timeout");
        }

        // Check socket error status
        int so_error = 0;
        socklen_t len = sizeof(so_error);
        if (getsockopt(sock_fd_.get(), SOL_SOCKET, SO_ERROR, &so_error, &len) == -1) {
          throw SocketError(errno);
        }
        if (so_error != 0) {
          throw SocketError(so_error);
        }
      }

      // Restore blocking mode
      if (fcntl(sock_fd_.get(), F_SETFL, flags) == -1) throw SocketError(errno);

      return TcpSocket{std::move(sock_fd_), config_};
    }

  private:
    void init_socket()
    {
      int sock_fd = socket(AF_INET, SOCK_STREAM, 0);
      if (sock_fd == -1) throw SocketError(errno);
      sock_fd_ = SockFd{sock_fd};

      // Enable TCP_NODELAY to reduce latency
      int one = 1;
      sock_fd_.setsockopt(IPPROTO_TCP, TCP_NODELAY, one).value_or_throw();
    }

    SockFd sock_fd_;
    SocketConfig config_;
  };

  void send(const std::vector<uint8_t> & data)
  {
    size_t total_sent = 0;
    while (total_sent < data.size()) {
      ssize_t result{-1};
      do {
        result =
          ::send(sock_fd_.get(), data.data() + total_sent, data.size() - total_sent, MSG_NOSIGNAL);
      } while (result == -1 && errno == EINTR);

      if (result == -1) throw SocketError(errno);
      total_sent += static_cast<size_t>(result);
    }
  }

  std::vector<uint8_t> receive(
    size_t n, std::chrono::milliseconds timeout = std::chrono::milliseconds(0))
  {
    return receive_impl(n, timeout);
  }

  std::vector<uint8_t> receive(std::chrono::milliseconds timeout = std::chrono::milliseconds(0))
  {
    return receive_impl(config_.buffer_size, timeout);
  }

  TcpSocket(const TcpSocket &) = delete;
  TcpSocket(TcpSocket && other) noexcept
  : sock_fd_(std::move(other.sock_fd_)), config_(other.config_)
  {
  }

  TcpSocket & operator=(const TcpSocket &) = delete;
  TcpSocket & operator=(TcpSocket &&) = delete;

private:
  std::vector<uint8_t> receive_impl(size_t n, std::chrono::milliseconds timeout)
  {
    if (n == 0) throw UsageError("Receive size must be greater than zero");

    if (timeout.count() > 0) {
      auto ready = is_socket_ready(sock_fd_.get(), static_cast<int>(timeout.count()));
      if (!ready.has_value()) throw SocketError(ready.error());
      if (!ready.value()) return {};
    }

    std::vector<uint8_t> buffer(n);
    ssize_t result{-1};
    do {
      result = ::recv(sock_fd_.get(), buffer.data(), n, 0);
    } while (result == -1 && errno == EINTR);

    if (result < 0) throw SocketError(errno);
    if (result == 0) throw SocketError("Connection closed");
    buffer.resize(result);
    return buffer;
  }

  SockFd sock_fd_;
  SocketConfig config_;
};

}  // namespace nebula::drivers::connections