Skip to content

Adding Models

This guide walks you through adding a new model to Autoware-ML. You'll implement a model class, create a DataModule, and wire everything together with a config.

The BaseModel Interface

New models should inherit from BaseModel. The minimal contract is still the same two abstract methods:

from autoware_ml.models.base import BaseModel

class MyModel(BaseModel):
    def forward(self, **kwargs: Any) -> torch.Tensor | Sequence[torch.Tensor]:
        ...

    def compute_metrics(
        self, outputs: torch.Tensor | Sequence[torch.Tensor], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        ...

The base class handles training/validation/test steps, optimizer configuration, metric logging, prediction output conversion, and deployment export integration. The forward() method can have any signature as long as the default batch-to-argument mapping matches, or the model overrides the relevant hooks.

Extending BaseModel

Specialized models should still use BaseModel. When the default signature-based path is not enough, prefer overriding hooks such as run_model(), prepare_metric_inputs(), get_log_batch_size(), predict_outputs(), or build_export_spec() instead of introducing a standalone LightningModule.

Step 1: Implement the Model

Create a new file in autoware_ml/models/:

autoware_ml/models/my_task/my_model.py
from collections.abc import Sequence
from typing import Any

import torch
import torch.nn as nn

from autoware_ml.models.base import BaseModel


class MyModel(BaseModel):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        num_classes: int,
        **kwargs: Any,  # Pass optimizer, scheduler to BaseModel
    ):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.num_classes = num_classes
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        features = self.encoder(input_tensor)
        logits = self.decoder(features)
        return logits

    def compute_metrics(
        self,
        outputs: torch.Tensor | Sequence[torch.Tensor],
        gt_labels: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        logits = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
        loss = self.loss_fn(logits, gt_labels)

        # Optional: compute accuracy
        preds = torch.argmax(logits, dim=1)
        accuracy = (preds == gt_labels).float().mean()

        return {
            "loss": loss,
            "accuracy": accuracy,
        }

Key Points

  1. forward() signature matters - Parameter names must match keys in your batch dictionary. The base class automatically extracts matching keys using signature inspection.

  2. compute_metrics() receives outputs - The first argument is always outputs from forward() (as a torch.Tensor | Sequence[torch.Tensor]). Additional parameters are matched from the batch.

  3. Return 'loss' - The metrics dict must include a 'loss' key for backpropagation.

  4. Optimizer and scheduler - Passed as callables to BaseModel.__init__(). Need to be marked as _partial_: true in YAML configs.

  5. Use hooks when needed - If your model needs custom batch unpacking, prediction formatting, or an explicit deployment wrapper, override the appropriate BaseModel hook instead of bypassing the shared training and deployment flow.

Step 2: Create a DataModule

Create a DataModule that provides data for your model:

autoware_ml/datamodule/my_dataset/my_task.py
import os
import pickle
from typing import Any

from autoware_ml.datamodule.base import DataModule, Dataset
from autoware_ml.transforms.base import TransformsCompose


class MyDataset(Dataset):
    def __init__(
        self,
        ann_file: str,
        data_root: str,
        dataset_transforms: TransformsCompose | None = None,
    ):
        super().__init__(dataset_transforms=dataset_transforms)
        self.data_root = data_root

        # Load annotations
        with open(ann_file, "rb") as f:
            self.annotations = pickle.load(f)

    def __len__(self) -> int:
        return len(self.annotations)

    def get_data_info(self, index: int) -> dict[str, Any]:
        ann = self.annotations[index]

        return {
            "input_path": os.path.join(self.data_root, ann["input_path"]),
            "label": ann["label"],
        }


class MyDataModule(DataModule):
    def __init__(
        self,
        data_root: str,
        train_ann_file: str,
        val_ann_file: str,
        test_ann_file: str | None = None,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.data_root = data_root
        self.train_ann_file = train_ann_file
        self.val_ann_file = val_ann_file
        self.test_ann_file = test_ann_file or val_ann_file

    def _create_dataset(
        self,
        split: str,
        transforms: TransformsCompose | None = None,
    ) -> Dataset:
        ann_file = {
            "train": self.train_ann_file,
            "val": self.val_ann_file,
            "test": self.test_ann_file,
            "predict": self.test_ann_file,
        }[split]

        return MyDataset(
            ann_file=ann_file,
            data_root=self.data_root,
            dataset_transforms=transforms,
        )

Data Flow

get_data_info() → transforms → collate_fn() → on_after_batch_transfer() → model
  1. get_data_info(): Return raw sample metadata as dict
  2. transforms: Load files and apply per-sample augmentations (in Dataset)
  3. collate_fn(): Batch samples, convert to tensors
  4. on_after_batch_transfer(): GPU preprocessing (optional)
  5. Model receives the batch dict

Step 3: Register Components

Add __init__.py exports:

autoware_ml/models/my_task/__init__.py
from autoware_ml.models.my_task.my_model import MyModel

__all__ = ["MyModel"]
autoware_ml/datamodule/my_dataset/__init__.py
from autoware_ml.datamodule.my_dataset.my_task import MyDataModule, MyDataset

__all__ = ["MyDataModule", "MyDataset"]

Step 4: Create Config

Create a task config:

configs/tasks/my_task/my_model/base.yaml
# @package _global_
defaults:
  - /defaults/default_runtime
  - _self_

datamodule:
  _target_: autoware_ml.datamodule.my_dataset.MyDataModule
  stack_keys: [input_tensor, gt_labels]  # Keys to stack into tensors

  train_dataloader_cfg:
    batch_size: 8
    num_workers: 4
    shuffle: true

  val_dataloader_cfg:
    batch_size: 8
    num_workers: 4

  # GPU preprocessing (optional)
  data_preprocessing:
    _target_: autoware_ml.preprocessing.base.DataPreprocessing
    pipeline: []

model:
  _target_: autoware_ml.models.my_task.MyModel
  num_classes: 10

  encoder:
    _target_: autoware_ml.models.common.backbones.resnet.ResNet18
    in_channels: 3

  decoder:
    _target_: torch.nn.Linear
    in_features: 512
    out_features: ${model.num_classes}

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 0.001
    weight_decay: 0.01

  scheduler:
    _target_: torch.optim.lr_scheduler.CosineAnnealingLR
    _partial_: true
    T_max: ${trainer.max_epochs}

trainer:
  max_epochs: 50

Create a dataset-specific config:

configs/tasks/my_task/my_model/my_config.yaml
# @package _global_
defaults:
  - /tasks/my_task/my_model/base
  - _self_

data_root: /workspace/data/my_dataset

datamodule:
  data_root: ${data_root}
  train_ann_file: ${data_root}/info/train.pkl
  val_ann_file: ${data_root}/info/val.pkl

Note

Some parameters are inherited from the default runtime config. Take a look at configs/defaults/default_runtime.yaml for more details.

Step 5: Add Transforms (Optional)

If your task needs custom transforms:

autoware_ml/transforms/my_transforms/my_transform.py
from typing import Any
import numpy as np

from autoware_ml.transforms.base import BaseTransform


class MyAugmentation(BaseTransform):
    def __init__(self, p: float = 0.5, intensity: float = 0.1):
        self.p = p
        self.intensity = intensity

    def transform(self, input_dict: dict[str, Any]) -> dict[str, Any]:
        if np.random.random() > self.p:
            return {}  # No changes

        # Your augmentation logic
        input_tensor = input_dict["input_tensor"]
        augmented = input_tensor + np.random.randn(*input_tensor.shape) * self.intensity

        return {"input_tensor": augmented}

Add to config:

datamodule:
  train_transforms:
    pipeline:
      - _target_: autoware_ml.transforms.my_transforms.my_transform.MyAugmentation
        p: 0.5
        intensity: 0.1

Step 6: Add Preprocessing (Optional)

Preprocessing runs on GPU after batch transfer, enabling hardware-accelerated operations. Unlike transforms (CPU-side, per-sample), preprocessing operates on entire batches already on the target device.

If your task needs custom preprocessing:

autoware_ml/preprocessing/my_preprocessing/my_preprocessing.py
from typing import Any

import torch
import torch.nn as nn


class MyPreprocessingLayer(nn.Module):
    def __init__(self, input_key: str = "input_tensor", scale: float = 1.0):
        super().__init__()
        self.input_key = input_key
        self.scale = scale

    def forward(self, batch_inputs_dict: dict[str, Any]) -> dict[str, Any]:
        processed = batch_inputs_dict[self.input_key] * self.scale
        return {self.input_key: processed}

Add to config:

datamodule:
  data_preprocessing:
    _target_: autoware_ml.preprocessing.base.DataPreprocessing
    pipeline:
      - _target_: autoware_ml.preprocessing.my_preprocessing.my_preprocessing.MyPreprocessingLayer
        input_key: input_tensor
        scale: 1.0

Warning

Preprocessing layers must be nn.Module subclasses that accept dict[str, Any] and return dict[str, Any].

Step 7: Train and Deploy

Config Naming Convention

Task configs should follow:

<task>/<model>/<variant>_<dataset>

Use these rules when creating <variant>:

  • include only future-distinguishing choices such as backbone, modality, voxel size, or range
  • do not encode properties that are inherent to the model family
  • normalize voxel sizes as voxel020, voxel005, voxel030
  • encode ranges as human-readable suffixes such as 50m, 90m, 102m, 121m
  • keep dataset names explicit and stable, for example nuscenes and t4dataset_j6gen2

Examples:

calibration_status/calibration_status_classifier/resnet18_nuscenes
calibration_status/calibration_status_classifier/resnet18_t4dataset_j6gen2
my_task/my_model/my_variant_my_dataset
# Train
autoware-ml train --config-name my_task/my_model/my_config

# Deploy
autoware-ml deploy \
    --config-name my_task/my_model/my_config \
    +checkpoint=mlruns/my_task/my_model/my_config/<run_id>/artifacts/checkpoints/last.ckpt

Common Patterns

Multiple Inputs

def forward(self, image: torch.Tensor, lidar: torch.Tensor) -> torch.Tensor:
    img_features = self.image_encoder(image)
    lidar_features = self.lidar_encoder(lidar)
    fused = torch.cat([img_features, lidar_features], dim=1)
    return self.head(fused)

Batch dict must have image and lidar keys.

Multiple Outputs

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    features = self.backbone(x)
    boxes = self.box_head(features)
    scores = self.score_head(features)
    return boxes, scores

def compute_metrics(
    self,
    outputs: tuple[torch.Tensor, torch.Tensor],
    gt_boxes: torch.Tensor,
    gt_scores: torch.Tensor,
):
    boxes, scores = outputs
    box_loss = self.box_loss(boxes, gt_boxes)
    score_loss = self.score_loss(scores, gt_scores)
    return {"loss": box_loss + score_loss, "box_loss": box_loss, "score_loss": score_loss}