Adding Models¶
This guide walks you through adding a new model to Autoware-ML. You'll implement a model class, create a DataModule, and wire everything together with a config.
The BaseModel Interface¶
New models should inherit from BaseModel. The minimal contract is still the
same two abstract methods:
from autoware_ml.models.base import BaseModel
class MyModel(BaseModel):
def forward(self, **kwargs: Any) -> torch.Tensor | Sequence[torch.Tensor]:
...
def compute_metrics(
self, batch_inputs_dict: Mapping[str, Any], outputs: Any
) -> dict[str, torch.Tensor]:
...
The base class handles training/validation/test/predict steps, optimizer
configuration, metric logging, prediction output conversion, runtime
preprocessing, and deployment export integration. The forward() method
can have any signature as long as the default batch-to-argument mapping
matches, or the model overrides the relevant hooks.
Extending BaseModel
Specialized models should still use BaseModel. When the default
signature-based path is not enough, prefer overriding hooks such as
set_data_preprocessing(), predict_outputs(), get_log_batch_size(),
or build_export_spec() instead of introducing a standalone
LightningModule. Output decoding (for example, voxel-to-point scatter
for segmentation) belongs inside the model, typically in forward(),
compute_metrics(), and predict_outputs() - not in a separate
framework pipeline.
Step 1: Implement the Model¶
Create a new file in autoware_ml/models/:
from collections.abc import Sequence
from typing import Any
import torch
import torch.nn as nn
from autoware_ml.models.base import BaseModel
class MyModel(BaseModel):
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
num_classes: int,
**kwargs: Any, # Pass optimizer, scheduler to BaseModel
):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.num_classes = num_classes
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
features = self.encoder(input_tensor)
logits = self.decoder(features)
return logits
def compute_metrics(
self,
batch_inputs_dict: Mapping[str, Any],
outputs: torch.Tensor | Sequence[torch.Tensor],
) -> dict[str, torch.Tensor]:
gt_labels = batch_inputs_dict["gt_labels"]
logits = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
loss = self.loss_fn(logits, gt_labels)
preds = torch.argmax(logits, dim=1)
accuracy = (preds == gt_labels).float().mean()
return {
"loss": loss,
"accuracy": accuracy,
}
Key Points¶
-
forward()signature matters - Parameter names must match keys in your batch dictionary. The base class automatically extracts matching keys using signature inspection. -
compute_metrics()receives the full batch and outputs - The first argument isbatch_inputs_dict(the full batch dictionary after preprocessing), and the second isoutputsfromforward(). Extract any needed targets (e.g.gt_labels) frombatch_inputs_dict. -
Return
'loss'- The metrics dict must include a'loss'key for backpropagation. -
Optimizer and scheduler - Passed as callables to
BaseModel.__init__(). Need to be marked as_partial_: truein YAML configs. -
Use hooks when needed - If your model needs custom batch unpacking, prediction formatting, or an explicit deployment wrapper, override the appropriate
BaseModelhook instead of bypassing the shared training and deployment flow.
Step 2: Create a DataModule¶
Create a DataModule that provides data for your model:
import os
import pickle
from typing import Any
from autoware_ml.datamodule.base import DataModule, Dataset
from autoware_ml.transforms.base import TransformsCompose
class MyDataset(Dataset):
def __init__(
self,
ann_file: str,
data_root: str,
dataset_transforms: TransformsCompose | None = None,
):
super().__init__(dataset_transforms=dataset_transforms)
self.data_root = data_root
# Load annotations
with open(ann_file, "rb") as f:
self.annotations = pickle.load(f)
def __len__(self) -> int:
return len(self.annotations)
def get_data_info(self, index: int) -> dict[str, Any]:
ann = self.annotations[index]
return {
"input_path": os.path.join(self.data_root, ann["input_path"]),
"label": ann["label"],
}
class MyDataModule(DataModule):
def __init__(
self,
data_root: str,
train_ann_file: str,
val_ann_file: str,
test_ann_file: str | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.data_root = data_root
self.train_ann_file = train_ann_file
self.val_ann_file = val_ann_file
self.test_ann_file = test_ann_file or val_ann_file
def _create_dataset(
self,
split: str,
transforms: TransformsCompose | None = None,
) -> Dataset:
ann_file = {
"train": self.train_ann_file,
"val": self.val_ann_file,
"test": self.test_ann_file,
"predict": self.test_ann_file,
}[split]
return MyDataset(
ann_file=ann_file,
data_root=self.data_root,
dataset_transforms=transforms,
)
Data Flow¶
get_data_info() -> transforms -> collate_fn() -> BaseModel.on_after_batch_transfer() -> forward() -> compute_metrics()/predict_outputs()
get_data_info(): Return raw sample metadata as dicttransforms: Load files and apply per-sample augmentations (in Dataset)collate_fn(): Batch samples, convert to tensorsBaseModel.on_after_batch_transfer(): model-owned preprocessingforward(): model inference/training forward passcompute_metrics()/predict_outputs(): model owns any output shaping (e.g., voxel-to-point scatter for segmentation) directly inside these methods
Step 3: Register Components¶
Add __init__.py exports:
from autoware_ml.models.my_task.my_model import MyModel
__all__ = ["MyModel"]
from autoware_ml.datamodule.my_dataset.my_task import MyDataModule, MyDataset
__all__ = ["MyDataModule", "MyDataset"]
Step 4: Create Config¶
Create a task config:
# @package _global_
defaults:
- /defaults/default_runtime
- _self_
datamodule:
_target_: autoware_ml.datamodule.my_dataset.MyDataModule
collation_map:
input_tensor: stack
gt_labels: stack
train_dataloader_cfg:
batch_size: 8
num_workers: 4
shuffle: true
val_dataloader_cfg:
batch_size: 8
num_workers: 4
model:
_target_: autoware_ml.models.my_task.MyModel
num_classes: 10
encoder:
_target_: autoware_ml.models.common.backbones.resnet.ResNet18
in_channels: 3
decoder:
_target_: torch.nn.Linear
in_features: 512
out_features: ${model.num_classes}
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.001
weight_decay: 0.01
scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: true
T_max: ${trainer.max_epochs}
trainer:
max_epochs: 50
data_preprocessing:
_target_: autoware_ml.preprocessing.base.DataPreprocessing
pipeline: []
Create a dataset-specific config:
# @package _global_
defaults:
- /tasks/my_task/my_model/base
- _self_
data_root: /workspace/data/my_dataset
datamodule:
data_root: ${data_root}
train_ann_file: ${data_root}/info/train.pkl
val_ann_file: ${data_root}/info/val.pkl
Note
Some parameters are inherited from the default runtime config. Take a look at configs/defaults/default_runtime.yaml for more details.
Runtime preprocessing lives at the top level of the composed config and is attached to the model by the entrypoints.
Step 5: Add Transforms (Optional)¶
If your task needs custom transforms:
from typing import Any
import numpy as np
from autoware_ml.transforms.base import BaseTransform
class MyAugmentation(BaseTransform):
def __init__(self, p: float = 0.5, intensity: float = 0.1):
# BaseTransform handles the application probability through `p`.
self.p = p
self.intensity = intensity
def transform(self, input_dict: dict[str, Any]) -> dict[str, Any]:
# Your augmentation logic
input_tensor = input_dict["input_tensor"]
augmented = input_tensor + np.random.randn(*input_tensor.shape) * self.intensity
return {"input_tensor": augmented}
Add to config:
datamodule:
train_transforms:
pipeline:
- _target_: autoware_ml.transforms.my_transforms.my_transform.MyAugmentation
p: 0.5
intensity: 0.1
Step 6: Add Runtime Data Preprocessing (Optional)¶
Runtime preprocessing runs on the target device after batch transfer and before the forward pass. It is configured at the top level and attached to the model by the entrypoint scripts.
If your task needs custom preprocessing:
from typing import Any
class MyPreprocessingLayer:
def __init__(self, input_key: str = "input_tensor", scale: float = 1.0):
self.input_key = input_key
self.scale = scale
def __call__(self, batch_inputs_dict: dict[str, Any]) -> dict[str, Any]:
processed = batch_inputs_dict[self.input_key] * self.scale
return {self.input_key: processed}
Add to config:
data_preprocessing:
_target_: autoware_ml.preprocessing.base.DataPreprocessing
pipeline:
- _target_: autoware_ml.preprocessing.my_preprocessing.my_preprocessing.MyPreprocessingLayer
input_key: input_tensor
scale: 1.0
Warning
Preprocessing layers must be callable objects that accept dict[str, Any] and return dict[str, Any].
Output-side shaping (logits -> probabilities, decoder scatter, voxel-to-point mapping, etc.) belongs
inside the model - in forward(), compute_metrics(), or predict_outputs().
Step 7: Train and Deploy¶
Config Naming Convention¶
Task configs should follow:
Use these rules when creating <variant>:
- include only future-distinguishing choices such as backbone, modality, voxel size, or range
- do not encode properties that are inherent to the model family
- normalize voxel sizes as
voxel020,voxel005,voxel030 - encode ranges as human-readable suffixes such as
50m,90m,102m,121m - keep dataset names explicit and stable, for example
nuscenesandt4dataset_j6gen2
Examples:
segmentation3d/ptv3/voxel005_102m_nuscenes
segmentation3d/ptv3/voxel012_122m_t4dataset_j6gen2
my_task/my_model/my_variant_my_dataset
# Train
autoware-ml train --config-name my_task/my_model/my_config
# Deploy
autoware-ml deploy \
--config-name my_task/my_model/my_config \
--weights mlruns/my_task/my_model/my_config/<run_id>/artifacts/checkpoints/last.ckpt
Common Patterns¶
Multiple Inputs¶
def forward(self, image: torch.Tensor, lidar: torch.Tensor) -> torch.Tensor:
img_features = self.image_encoder(image)
lidar_features = self.lidar_encoder(lidar)
fused = torch.cat([img_features, lidar_features], dim=1)
return self.head(fused)
Batch dict must have image and lidar keys.
Multiple Outputs¶
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
features = self.backbone(x)
boxes = self.box_head(features)
scores = self.score_head(features)
return boxes, scores
def compute_metrics(
self,
batch_inputs_dict: Mapping[str, Any],
outputs: tuple[torch.Tensor, torch.Tensor],
):
boxes, scores = outputs
gt_boxes = batch_inputs_dict["gt_boxes"]
gt_scores = batch_inputs_dict["gt_scores"]
box_loss = self.box_loss(boxes, gt_boxes)
score_loss = self.score_loss(scores, gt_scores)
return {"loss": box_loss + score_loss, "box_loss": box_loss, "score_loss": score_loss}