Skip to content

File hysteresis_state_machine.hpp

File List > common > diagnostics > hysteresis_state_machine.hpp

Go to the documentation of this file

// Copyright 2025 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Copied from https://github.com/tier4/ros2_v4l2_camera/pull/37

#ifndef HYSTERESIS_STATE_MACHINE_HPP_
#define HYSTERESIS_STATE_MACHINE_HPP_

#include <tracetools/utils.hpp>

#include <diagnostic_msgs/msg/diagnostic_status.hpp>

#include <string>
#include <variant>

namespace custom_diagnostic_tasks
{
using DiagnosticStatus_t = unsigned char;

// Helper struct to express state machine nodes
struct StateBase
{
  explicit StateBase(const DiagnosticStatus_t lv) : level(lv), num_observations(1) {}

  DiagnosticStatus_t level;
  size_t num_observations;
};

struct Stale : public StateBase
{
  Stale() : StateBase(diagnostic_msgs::msg::DiagnosticStatus::STALE) {}
};

struct Ok : public StateBase
{
  Ok() : StateBase(diagnostic_msgs::msg::DiagnosticStatus::OK) {}
};

struct Warn : public StateBase
{
  Warn() : StateBase(diagnostic_msgs::msg::DiagnosticStatus::WARN) {}
};

struct Error : public StateBase
{
  Error() : StateBase(diagnostic_msgs::msg::DiagnosticStatus::ERROR) {}
};

using StateHolder = std::variant<Stale, Ok, Warn, Error>;

static inline StateHolder generate_state(const DiagnosticStatus_t & state)
{
  switch (state) {
    case diagnostic_msgs::msg::DiagnosticStatus::STALE:
      return Stale{};
    case diagnostic_msgs::msg::DiagnosticStatus::OK:
      return Ok{};
    case diagnostic_msgs::msg::DiagnosticStatus::WARN:
      return Warn{};
    case diagnostic_msgs::msg::DiagnosticStatus::ERROR:
      return Error{};
    default:
      throw std::runtime_error("Undefined status");
  }
}

static std::string get_level_string(DiagnosticStatus_t level)
{
  switch (level) {
    case diagnostic_msgs::msg::DiagnosticStatus::OK:
      return "OK";
    case diagnostic_msgs::msg::DiagnosticStatus::WARN:
      return "WARN";
    case diagnostic_msgs::msg::DiagnosticStatus::ERROR:
      return "ERROR";
    case diagnostic_msgs::msg::DiagnosticStatus::STALE:
      return "STALE";
    default:
      return "UNDEFINED";
  }
}

static DiagnosticStatus_t get_level(const StateHolder & state)
{
  return std::visit([](const auto & s) { return s.level; }, state);
}

static size_t get_num_observations(const StateHolder & state)
{
  return std::visit([](const auto & s) { return s.num_observations; }, state);
}

class HysteresisStateMachine
{
public:
  explicit HysteresisStateMachine(
    const size_t num_frame_transition = 1, const bool immediate_error_report = false,
    const bool immediate_relax_state = true)
  : num_frame_transition_(num_frame_transition),
    immediate_error_report_(immediate_error_report),
    immediate_relax_state_(immediate_relax_state),
    current_state_(Stale{})
  {
    if (num_frame_transition < 1) {
      num_frame_transition_ = 1;
    }
  }

  void update_state(const DiagnosticStatus_t & observation)
  {
    // If the classify result is same as previous one and the observation is
    // different from the current one, increment the number of observation
    // Otherwise, update candidate
    auto candidate_level = get_level(candidate_state_);
    auto current_level = get_level(current_state_);
    if (candidate_level == observation && candidate_level != current_level) {
      std::visit([](auto & s) { s.num_observations += 1; }, candidate_state_);
    } else {
      candidate_state_ = generate_state(observation);
    }

    // Update the current state if
    // - immediate error report is required and the observed state is error
    // - Or the same state is observed multiple times
    // - Or the observed state has lower level than the current one (i.e., the state is improved)
    bool is_immediate_error =
      (immediate_error_report_ && std::holds_alternative<Error>(candidate_state_));
    bool observed_over_threshold =
      (get_num_observations(candidate_state_) >= num_frame_transition_);
    bool is_immediate_relax =
      (immediate_relax_state_ && get_level(candidate_state_) < current_level);

    DiagnosticStatus_t updated_level = current_level;
    if (is_immediate_error || observed_over_threshold || is_immediate_relax) {
      updated_level = get_level(candidate_state_);
    }

    current_state_ = generate_state(updated_level);
  }

  DiagnosticStatus_t get_candidate_level() { return get_level(candidate_state_); }

  size_t get_candidate_num_observation() { return get_num_observations(candidate_state_); }

  size_t get_num_frame_transition() { return num_frame_transition_; }

  DiagnosticStatus_t get_current_state_level() { return get_level(current_state_); }

  void set_current_state_level(const DiagnosticStatus_t & state)
  {
    current_state_ = generate_state(state);
  }

  bool get_immediate_error_report_param() { return immediate_error_report_; }

  bool get_immediate_relax_state_param() { return immediate_relax_state_; }

protected:
  size_t num_frame_transition_;
  bool immediate_error_report_;
  bool immediate_relax_state_;
  StateHolder candidate_state_;
  StateHolder current_state_;
};  // class HysteresisStateMachine

}  // namespace custom_diagnostic_tasks

#endif  // HYSTERESIS_STATE_MACHINE_HPP_