File downsample_mask.hpp
File List > include > nebula_decoders > nebula_decoders_common > point_filters > downsample_mask.hpp
Go to the documentation of this file
// Copyright 2024 TIER IV, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "nebula_decoders/nebula_decoders_common/angles.hpp"
#include <nebula_common/loggers/logger.hpp>
#include <nebula_common/nebula_common.hpp>
#include <nebula_common/point_types.hpp>
#include <nebula_common/util/string_conversions.hpp>
#include <png++/error.hpp>
#include <png++/gray_pixel.hpp>
#include <png++/image.hpp>
#include <Eigen/src/Core/Matrix.h>
#include <sys/types.h>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <filesystem>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
namespace nebula::drivers::point_filters
{
namespace impl
{
inline void dither(
const png::image<png::gray_pixel> & in, png::image<png::gray_pixel> & out,
uint8_t quantization_levels)
{
if (in.get_width() != out.get_width() || in.get_height() != out.get_height()) {
std::stringstream ss;
ss << "Expected downsample mask of size "
<< "(" << out.get_width() << ", " << out.get_height() << ")";
ss << ", got "
<< "(" << in.get_width() << ", " << in.get_height() << ")";
throw std::runtime_error(ss.str());
}
uint32_t denominator = quantization_levels;
auto should_keep = [denominator](uint32_t numerator, uint32_t pos) {
for (uint32_t i = 0; i < numerator; ++i) {
auto dithered_pos =
static_cast<size_t>(std::round(denominator / static_cast<double>(numerator) * i));
if (dithered_pos == pos) return true;
}
return false;
};
for (size_t y = 0; y < out.get_height(); ++y) {
for (size_t x = 0; x < out.get_width(); ++x) {
const auto & pixel = in.get_pixel(x, y);
uint32_t numerator = static_cast<uint32_t>(pixel) * denominator / 255;
size_t pos = (x + y) % denominator;
bool keep = should_keep(numerator, pos);
out.set_pixel(x, y, keep * 255);
}
}
}
} // namespace impl
class DownsampleMaskFilter
{
static const uint8_t g_quantization_levels = 10;
public:
DownsampleMaskFilter(
const std::string & filename, AngleRange<int32_t, MilliDegrees> azimuth_range_mdeg,
uint32_t azimuth_peak_resolution_mdeg, size_t n_channels,
const std::shared_ptr<loggers::Logger> & logger, bool export_dithered_mask = false)
: azimuth_range_{
deg2rad(azimuth_range_mdeg.start / 1000.), deg2rad(azimuth_range_mdeg.end / 1000.)}
{
if (azimuth_peak_resolution_mdeg == 0) {
throw std::invalid_argument("azimuth_peak_resolution_mdeg must be positive");
}
if (azimuth_range_.extent() <= 0) {
throw std::invalid_argument("azimuth range extent must be positive");
}
if (n_channels == 0) {
throw std::invalid_argument("n_channels must be positive");
}
png::image<png::gray_pixel> factors(filename);
size_t mask_cols = azimuth_range_mdeg.extent() / azimuth_peak_resolution_mdeg;
size_t mask_rows = n_channels;
png::image<png::gray_pixel> dithered(mask_cols, mask_rows);
impl::dither(factors, dithered, g_quantization_levels);
mask_ = Eigen::MatrixX<uint8_t>(mask_rows, mask_cols);
for (size_t y = 0; y < dithered.get_height(); ++y) {
for (size_t x = 0; x < dithered.get_width(); ++x) {
mask_.coeffRef(static_cast<int32_t>(y), static_cast<int32_t>(x)) = dithered.get_pixel(x, y);
}
}
if (export_dithered_mask) {
std::filesystem::path out_path{filename};
out_path = out_path.replace_filename(
out_path.stem().string() + "_dithered" + out_path.extension().string());
try {
dithered.write(out_path);
logger->info("Wrote dithered mask to " + out_path.native());
} catch (const png::std_error & e) {
logger->warn("Could not write " + out_path.native() + ": " + e.what());
}
}
}
bool excluded(const NebulaPoint & point)
{
double azi_normalized = (point.azimuth - azimuth_range_.start) / azimuth_range_.extent();
auto x = static_cast<ssize_t>(std::round(azi_normalized * static_cast<double>(mask_.cols())));
auto y = point.channel;
bool x_out_of_bounds = x < 0 || x >= mask_.cols();
bool y_out_of_bounds = y >= mask_.rows();
return x_out_of_bounds || y_out_of_bounds || !mask_.coeff(y, x);
}
private:
AngleRange<double, Radians> azimuth_range_;
Eigen::MatrixX<uint8_t> mask_;
};
} // namespace nebula::drivers::point_filters