Skip to content

Framework Design

Autoware-ML is built on a modular architecture that separates concerns into key components including configuration, data handling, model definition, training, and deployment, among others. This design makes it easy to add new models and datasets while reusing common infrastructure.

Architecture Overview

flowchart TB
    subgraph Optimization [Hyperparameter Tuning]
        Optuna[Optuna]
    end

    subgraph Configuration [Configuration Layer]
        YAML[YAML Configs]
        Optuna --> Hydra[Hydra]
        YAML --> Hydra
    end

    subgraph TrainingPipeline [Training Pipeline]
        InfoFiles[Info Files]
        InfoFiles --> LightningDataModule[Lightning Data Module]
        LightningDataModule --> Transforms[Transforms]
        Transforms --> Collation[Collation]
        Collation --> BatchTransfer[Batch Transfer]
        BatchTransfer --> Preprocessing[Model Preprocessing]
        Preprocessing --> ForwardPass[Forward Pass]
        ForwardPass --> LossComputation[Loss Computation]
        LossComputation --> BackwardPass[Backward Pass]
    end

    subgraph ModelLayer [Model Definition]
        LightningModule[Lightning Module]
        LightningModule --> Blocks["Blocks"]
        LightningModule --> Optimizers["Optimizers"]
        LightningModule --> Schedulers["Schedulers"]
    end

    subgraph TrainingLoop [Training Orchestration]
        Trainer[Lightning Trainer]
        Trainer --> CustomCallbacks[Custom Callbacks]
        Trainer --> MLflow[MLflow Logger]
        Trainer --> Checkpoints[Checkpoints]
    end

    subgraph Deployment [Deployment Pipeline]
        ModelWeights[Model Weights]
        ModelWeights --> ONNXExport[ONNX Export]
        ONNXExport --> TensorRTEngine[TensorRT Engine]
    end

    Hydra --> LightningDataModule
    Hydra --> LightningModule
    Hydra --> Trainer
    Hydra --> ModelWeights

    style InfoFiles fill:#bbdefb,opacity:0.2,stroke:#1976d2
    style LightningDataModule fill:#bbdefb,opacity:0.2,stroke:#1976d2
    style Transforms fill:#bbdefb,opacity:0.2,stroke:#1976d2
    style Collation fill:#bbdefb,opacity:0.2,stroke:#1976d2
    style ModelWeights fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style ONNXExport fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style TensorRTEngine fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Blocks fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style BatchTransfer fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Preprocessing fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style ForwardPass fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style LossComputation fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style BackwardPass fill:#a5d6a7,opacity:0.2,stroke:#05bc23

Legend: CPU operations | GPU operations

Core Components

Configuration (Hydra)

Everything in Autoware-ML is configured through YAML files processed by Hydra. This enables:

  • Hierarchical configs - Inherit from base configs, override specific values
  • Runtime overrides - Change any parameter from the command line
  • Automatic instantiation - _target_ keys specify Python classes to instantiate via hydra.utils.instantiate()

See Configuration Guide for full details on Hydra syntax.

Data Module

The DataModule class (extending LightningDataModule) manages:

  • Dataset creation for each split (train/val/test/predict)
  • DataLoader configuration (batch size, workers, shuffling, pin_memory, etc.)
  • Transforms (CPU-side augmentations per split)
  • Collation (batching samples together via per-key collation_map strategies)
class DataModule(L.LightningDataModule, ABC):
    def __init__(
        self,
        collation_map: Mapping[str, CollationStrategy] | None = None,
        train_transforms: TransformsCompose | None = None,
        val_transforms: TransformsCompose | None = None,
        test_transforms: TransformsCompose | None = None,
        predict_transforms: TransformsCompose | None = None,
        train_dataloader_cfg: DataLoaderConfig | None = None,
        val_dataloader_cfg: DataLoaderConfig | None = None,
        test_dataloader_cfg: DataLoaderConfig | None = None,
        predict_dataloader_cfg: DataLoaderConfig | None = None,
    ):
        ...

    @abstractmethod
    def _create_dataset(
        self, split: str, transforms: TransformsCompose | None = None
    ) -> Dataset:
        ...

    def collate_fn(self, batch_inputs_dicts: Sequence[dict[str, Any]]) -> dict[str, Any]:
        ...

The Dataset base class handles transforms application:

class Dataset(TorchDataset, ABC):
    def __getitem__(self, index: int) -> dict[str, Any]:
        input_dict = self.get_data_info(index)
        context = PipelineContext(dataset=self, index=index)
        return self.apply_transforms(input_dict, self.dataset_transforms, context)

    @abstractmethod
    def get_data_info(self, index: int) -> dict[str, Any]:
        ...

Datasets are expected to return metadata records. File loading and sample materialization should happen in transforms.

Transforms

Transforms are composable data augmentations applied per-sample on CPU. They follow a dict-in/dict-out pattern where each transform receives a dictionary and returns updates to merge back.

class BaseTransform(ABC):
    def __call__(
        self,
        input_dict: dict[str, Any],
        context: PipelineContext | None = None,
    ) -> dict[str, Any]:
        self._context = context           # accessible via self.context property
        self._validate_required_keys(input_dict)
        self._handle_optional_keys(input_dict)
        if not self._should_apply():
            return self.on_skip(input_dict)
        return self.transform(input_dict)

    @abstractmethod
    def transform(self, input_dict: dict[str, Any]) -> dict[str, Any]:
        ...

class TransformsCompose:
    def __init__(self, pipeline: Sequence[BaseTransform] | None = None):
        self.pipeline = pipeline or []

    def __call__(
        self,
        input_dict: dict[str, Any],
        context: PipelineContext | None = None,
    ) -> dict[str, Any]:
        for transform in self.pipeline:
            input_dict |= transform(input_dict, context=context)
        return input_dict

Transforms are configured per split (train/val/test/predict) in the DataModule and applied during Dataset.__getitem__().

Public transform targets should reference the concrete implementation module, for example autoware_ml.transforms.point_cloud.loading.LoadPointsFromFile or autoware_ml.transforms.point_cloud.scene.RandomFlip3D. Avoid package-level re-export layers in __init__.py; imports and Hydra _target_ paths should point at the implementation module directly.

Runtime Data Preprocessing

Runtime preprocessing is a model-owned pipeline attached through BaseModel.set_data_preprocessing(...). It runs on the target device after Lightning moves the batch over, and before the model's forward().

class DataPreprocessing:
    def __init__(self, pipeline: Sequence[Any] = ()):
        self.pipeline = list(pipeline)

    def __call__(self, batch_inputs_dict: dict[str, Any]) -> dict[str, Any]:
        for layer in self.pipeline:
            batch_inputs_dict |= layer(batch_inputs_dict)
        return batch_inputs_dict

BaseModel.on_after_batch_transfer() applies the pipeline. Output-side shaping (e.g., logits -> probabilities, voxel-to-point scatter) lives inside the model, not in a framework pipeline: each model handles it in its own forward(), compute_metrics(), and predict_outputs(). Keeping this logic in the model class avoids invisible load-bearing dependencies between config composition and metric correctness.

Model

All supported models inherit from BaseModel (extending LightningModule), which provides a standard interface and a set of override hooks for task-specific behavior:

class BaseModel(L.LightningModule, ABC):
    def __init__(
        self,
        optimizer: Callable[..., Optimizer] | None = None,
        scheduler: Callable[[Optimizer], LRScheduler] | None = None,
    ):
        super().__init__()
        self.forward_signature = inspect.signature(self.forward)
        ...

    @abstractmethod
    def forward(self, **kwargs: Any) -> torch.Tensor | Sequence[torch.Tensor]:
        ...

    @abstractmethod
    def compute_metrics(
        self, batch_inputs_dict: Mapping[str, Any], outputs: Any
    ) -> dict[str, torch.Tensor]:
        ...

    def set_data_preprocessing(self, data_preprocessing: DataPreprocessing) -> None:
        ...

    def predict_outputs(self, batch_inputs_dict: Mapping[str, Any], outputs: Any) -> Any:
        ...

    def get_log_batch_size(self, batch_inputs_dict: Mapping[str, Any]) -> int | None:
        ...

    def build_export_spec(self, batch_inputs_dict: Mapping[str, Any]) -> ExportSpec:
        ...

    def configure_optimizers(self) -> Optimizer | dict[str, Any]:
        ...

The base class handles:

  • Unified step logic - All models share the same training, validation, test, and predict execution path
  • Automatic signature inspection - Only passes relevant kwargs to forward() based on the method signature captured at initialization
  • Runtime data preprocessing - Applies the model-owned preprocessing pipeline after batch transfer
  • Metric logging - Logs metrics to Lightning's logger with proper prefixes
  • Predict step - Runs forward and formats predictions via predict_outputs()
  • Export contract - Supports a generic forward-signature-based export path and model-owned explicit export wrappers

Models can have any internal architecture. The default path filters batch inputs to match the forward() signature using inspect.signature(), while specialized models can override hooks such as predict_outputs(), get_log_batch_size(), set_data_preprocessing(), or build_export_spec() without leaving the shared framework contract.

Note

When a model relies on the default signature-based path, forward() argument names must match keys in the batch dictionary after runtime preprocessing has run. Models with more specialized batching or export requirements should override the relevant hooks instead of bypassing BaseModel.

Deployment Pipeline

The deployment pipeline exports trained models to production-ready formats:

flowchart LR
    subgraph ONNXExport [ONNX Export]
        Checkpoint[Checkpoint] --> Load[Load Weights]
        Load --> Model[Model Eval Mode]
        Model --> Trace[Trace with Sample]
        Trace --> ONNX[ONNX File]
    end

    subgraph TensorRTBuild [TensorRT Build]
        ONNX --> Parse[Parse ONNX]
        Parse --> Optimize[Build Engine]
        Optimize --> EngineFile[Engine File]
    end

    style Checkpoint fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Load fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Model fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Trace fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style ONNX fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Parse fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style Optimize fill:#a5d6a7,opacity:0.2,stroke:#05bc23
    style EngineFile fill:#a5d6a7,opacity:0.2,stroke:#05bc23

The deployment process:

  1. Load checkpoint - Instantiates model from config and loads weights from checkpoint
  2. Get input sample - Uses the predict dataloader to obtain a preprocessed sample for deployment
  3. Resolve export spec - Builds the effective export module and example inputs through the model's build_export_spec() contract
  4. Export to ONNX - Traces the resolved export module, supporting dynamic shapes for variable input sizes
  5. Build TensorRT engine - Optimizes the ONNX model for inference on NVIDIA GPUs with configurable optimization profile

Configuration is done through the deploy section in task configs.

Extending the Framework

Extension Point How
New model Subclass BaseModel, implement forward() and compute_metrics(), override hooks as needed
New dataset Subclass DataModule and Dataset
New transform Subclass BaseTransform, implement transform()
New task Create config in configs/tasks/

See Adding Models for a detailed guide.