Skip to content

Adding Models

This guide walks you through adding a new model to Autoware-ML. You'll implement a model class, create a DataModule, and wire everything together with a config.

The BaseModel Interface

New models should inherit from BaseModel. The minimal contract is still the same two abstract methods:

from autoware_ml.models.base import BaseModel

class MyModel(BaseModel):
    def forward(self, **kwargs: Any) -> torch.Tensor | Sequence[torch.Tensor]:
        ...

    def compute_metrics(
        self, batch_inputs_dict: Mapping[str, Any], outputs: Any
    ) -> dict[str, torch.Tensor]:
        ...

The base class handles training/validation/test/predict steps, optimizer configuration, metric logging, prediction output conversion, runtime preprocessing, and deployment export integration. The forward() method can have any signature as long as the default batch-to-argument mapping matches, or the model overrides the relevant hooks.

Extending BaseModel

Specialized models should still use BaseModel. When the default signature-based path is not enough, prefer overriding hooks such as set_data_preprocessing(), predict_outputs(), get_log_batch_size(), or build_export_spec() instead of introducing a standalone LightningModule. Output decoding (for example, voxel-to-point scatter for segmentation) belongs inside the model, typically in forward(), compute_metrics(), and predict_outputs() - not in a separate framework pipeline.

Step 1: Implement the Model

Create a new file in autoware_ml/models/:

autoware_ml/models/my_task/my_model.py
from collections.abc import Sequence
from typing import Any

import torch
import torch.nn as nn

from autoware_ml.models.base import BaseModel


class MyModel(BaseModel):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        num_classes: int,
        **kwargs: Any,  # Pass optimizer, scheduler to BaseModel
    ):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.num_classes = num_classes
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        features = self.encoder(input_tensor)
        logits = self.decoder(features)
        return logits

    def compute_metrics(
        self,
        batch_inputs_dict: Mapping[str, Any],
        outputs: torch.Tensor | Sequence[torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        gt_labels = batch_inputs_dict["gt_labels"]
        logits = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
        loss = self.loss_fn(logits, gt_labels)

        preds = torch.argmax(logits, dim=1)
        accuracy = (preds == gt_labels).float().mean()

        return {
            "loss": loss,
            "accuracy": accuracy,
        }

Key Points

  1. forward() signature matters - Parameter names must match keys in your batch dictionary. The base class automatically extracts matching keys using signature inspection.

  2. compute_metrics() receives the full batch and outputs - The first argument is batch_inputs_dict (the full batch dictionary after preprocessing), and the second is outputs from forward(). Extract any needed targets (e.g. gt_labels) from batch_inputs_dict.

  3. Return 'loss' - The metrics dict must include a 'loss' key for backpropagation.

  4. Optimizer and scheduler - Passed as callables to BaseModel.__init__(). Need to be marked as _partial_: true in YAML configs.

  5. Use hooks when needed - If your model needs custom batch unpacking, prediction formatting, or an explicit deployment wrapper, override the appropriate BaseModel hook instead of bypassing the shared training and deployment flow.

Step 2: Create a DataModule

Create a DataModule that provides data for your model:

autoware_ml/datamodule/my_dataset/my_task.py
import os
import pickle
from typing import Any

from autoware_ml.datamodule.base import DataModule, Dataset
from autoware_ml.transforms.base import TransformsCompose


class MyDataset(Dataset):
    def __init__(
        self,
        ann_file: str,
        data_root: str,
        dataset_transforms: TransformsCompose | None = None,
    ):
        super().__init__(dataset_transforms=dataset_transforms)
        self.data_root = data_root

        # Load annotations
        with open(ann_file, "rb") as f:
            self.annotations = pickle.load(f)

    def __len__(self) -> int:
        return len(self.annotations)

    def get_data_info(self, index: int) -> dict[str, Any]:
        ann = self.annotations[index]

        return {
            "input_path": os.path.join(self.data_root, ann["input_path"]),
            "label": ann["label"],
        }


class MyDataModule(DataModule):
    def __init__(
        self,
        data_root: str,
        train_ann_file: str,
        val_ann_file: str,
        test_ann_file: str | None = None,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.data_root = data_root
        self.train_ann_file = train_ann_file
        self.val_ann_file = val_ann_file
        self.test_ann_file = test_ann_file or val_ann_file

    def _create_dataset(
        self,
        split: str,
        transforms: TransformsCompose | None = None,
    ) -> Dataset:
        ann_file = {
            "train": self.train_ann_file,
            "val": self.val_ann_file,
            "test": self.test_ann_file,
            "predict": self.test_ann_file,
        }[split]

        return MyDataset(
            ann_file=ann_file,
            data_root=self.data_root,
            dataset_transforms=transforms,
        )

Data Flow

get_data_info() -> transforms -> collate_fn() -> BaseModel.on_after_batch_transfer() -> forward() -> compute_metrics()/predict_outputs()
  1. get_data_info(): Return raw sample metadata as dict
  2. transforms: Load files and apply per-sample augmentations (in Dataset)
  3. collate_fn(): Batch samples, convert to tensors
  4. BaseModel.on_after_batch_transfer(): model-owned preprocessing
  5. forward(): model inference/training forward pass
  6. compute_metrics() / predict_outputs(): model owns any output shaping (e.g., voxel-to-point scatter for segmentation) directly inside these methods

Step 3: Register Components

Add __init__.py exports:

autoware_ml/models/my_task/__init__.py
from autoware_ml.models.my_task.my_model import MyModel

__all__ = ["MyModel"]
autoware_ml/datamodule/my_dataset/__init__.py
from autoware_ml.datamodule.my_dataset.my_task import MyDataModule, MyDataset

__all__ = ["MyDataModule", "MyDataset"]

Step 4: Create Config

Create a task config:

configs/tasks/my_task/my_model/base.yaml
# @package _global_
defaults:
  - /defaults/default_runtime
  - _self_

datamodule:
  _target_: autoware_ml.datamodule.my_dataset.MyDataModule
  collation_map:
    input_tensor: stack
    gt_labels: stack

  train_dataloader_cfg:
    batch_size: 8
    num_workers: 4
    shuffle: true

  val_dataloader_cfg:
    batch_size: 8
    num_workers: 4

model:
  _target_: autoware_ml.models.my_task.MyModel
  num_classes: 10

  encoder:
    _target_: autoware_ml.models.common.backbones.resnet.ResNet18
    in_channels: 3

  decoder:
    _target_: torch.nn.Linear
    in_features: 512
    out_features: ${model.num_classes}

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 0.001
    weight_decay: 0.01

  scheduler:
    _target_: torch.optim.lr_scheduler.CosineAnnealingLR
    _partial_: true
    T_max: ${trainer.max_epochs}

trainer:
  max_epochs: 50

data_preprocessing:
  _target_: autoware_ml.preprocessing.base.DataPreprocessing
  pipeline: []

Create a dataset-specific config:

configs/tasks/my_task/my_model/my_config.yaml
# @package _global_
defaults:
  - /tasks/my_task/my_model/base
  - _self_

data_root: /workspace/data/my_dataset

datamodule:
  data_root: ${data_root}
  train_ann_file: ${data_root}/info/train.pkl
  val_ann_file: ${data_root}/info/val.pkl

Note

Some parameters are inherited from the default runtime config. Take a look at configs/defaults/default_runtime.yaml for more details.

Runtime preprocessing lives at the top level of the composed config and is attached to the model by the entrypoints.

Step 5: Add Transforms (Optional)

If your task needs custom transforms:

autoware_ml/transforms/my_transforms/my_transform.py
from typing import Any
import numpy as np

from autoware_ml.transforms.base import BaseTransform


class MyAugmentation(BaseTransform):
    def __init__(self, p: float = 0.5, intensity: float = 0.1):
        # BaseTransform handles the application probability through `p`.
        self.p = p
        self.intensity = intensity

    def transform(self, input_dict: dict[str, Any]) -> dict[str, Any]:
        # Your augmentation logic
        input_tensor = input_dict["input_tensor"]
        augmented = input_tensor + np.random.randn(*input_tensor.shape) * self.intensity

        return {"input_tensor": augmented}

Add to config:

datamodule:
  train_transforms:
    pipeline:
      - _target_: autoware_ml.transforms.my_transforms.my_transform.MyAugmentation
        p: 0.5
        intensity: 0.1

Step 6: Add Runtime Data Preprocessing (Optional)

Runtime preprocessing runs on the target device after batch transfer and before the forward pass. It is configured at the top level and attached to the model by the entrypoint scripts.

If your task needs custom preprocessing:

autoware_ml/preprocessing/my_preprocessing/my_preprocessing.py
from typing import Any


class MyPreprocessingLayer:
    def __init__(self, input_key: str = "input_tensor", scale: float = 1.0):
        self.input_key = input_key
        self.scale = scale

    def __call__(self, batch_inputs_dict: dict[str, Any]) -> dict[str, Any]:
        processed = batch_inputs_dict[self.input_key] * self.scale
        return {self.input_key: processed}

Add to config:

data_preprocessing:
  _target_: autoware_ml.preprocessing.base.DataPreprocessing
  pipeline:
    - _target_: autoware_ml.preprocessing.my_preprocessing.my_preprocessing.MyPreprocessingLayer
      input_key: input_tensor
      scale: 1.0

Warning

Preprocessing layers must be callable objects that accept dict[str, Any] and return dict[str, Any].

Output-side shaping (logits -> probabilities, decoder scatter, voxel-to-point mapping, etc.) belongs inside the model - in forward(), compute_metrics(), or predict_outputs().

Step 7: Train and Deploy

Config Naming Convention

Task configs should follow:

<task>/<model>/<variant>_<dataset>

Use these rules when creating <variant>:

  • include only future-distinguishing choices such as backbone, modality, voxel size, or range
  • do not encode properties that are inherent to the model family
  • normalize voxel sizes as voxel020, voxel005, voxel030
  • encode ranges as human-readable suffixes such as 50m, 90m, 102m, 121m
  • keep dataset names explicit and stable, for example nuscenes and t4dataset_j6gen2

Examples:

segmentation3d/ptv3/voxel005_102m_nuscenes
segmentation3d/ptv3/voxel012_122m_t4dataset_j6gen2
my_task/my_model/my_variant_my_dataset
# Train
autoware-ml train --config-name my_task/my_model/my_config

# Deploy
autoware-ml deploy \
    --config-name my_task/my_model/my_config \
    --weights mlruns/my_task/my_model/my_config/<run_id>/artifacts/checkpoints/last.ckpt

Common Patterns

Multiple Inputs

def forward(self, image: torch.Tensor, lidar: torch.Tensor) -> torch.Tensor:
    img_features = self.image_encoder(image)
    lidar_features = self.lidar_encoder(lidar)
    fused = torch.cat([img_features, lidar_features], dim=1)
    return self.head(fused)

Batch dict must have image and lidar keys.

Multiple Outputs

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    features = self.backbone(x)
    boxes = self.box_head(features)
    scores = self.score_head(features)
    return boxes, scores

def compute_metrics(
    self,
    batch_inputs_dict: Mapping[str, Any],
    outputs: tuple[torch.Tensor, torch.Tensor],
):
    boxes, scores = outputs
    gt_boxes = batch_inputs_dict["gt_boxes"]
    gt_scores = batch_inputs_dict["gt_scores"]
    box_loss = self.box_loss(boxes, gt_boxes)
    score_loss = self.score_loss(scores, gt_scores)
    return {"loss": box_loss + score_loss, "box_loss": box_loss, "score_loss": score_loss}