Jie Cheng1
Xiaodong Mei1
Ming Liu1,2
HKUST1 HKUST(GZ)1,2
- A neat yet effective MAE-based pre-training scheme for motion forecasting.
- A pretty simple forecasting model (basically pure transformer encoders) with relative good performance.
- This repo also provides a exemplary multi-agent motion forecasting baseline on Argoverse 2.0 dataset.
- Getting Started
- Setup Environment
- Preprocess
- Training
- Evaluation
- Results and checkpoints
- Acknowledgements
- Citation
1. Clone this repository:
git clone https://github.com/jchengai/forecast-mae.git
cd forecast-mae
2. Setup conda environment:
conda create -n forecast_mae python=3.8
conda activate forecast_mae
sh ./scripts/setup.sh
3. Setup Argoverse 2 Motion Forecasting Dataset, the expected data structure should be:
data_root
├── train
│ ├── 0000b0f9-99f9-4a1f-a231-5be9e4c523f7
│ ├── 0000b6ab-e100-4f6b-aee8-b520b57c0530
│ ├── ...
├── val
│ ├── 00010486-9a07-48ae-b493-cf4545855937
│ ├── 00062a32-8d6d-4449-9948-6fedac67bfcd
│ ├── ...
├── test
│ ├── 0000b329-f890-4c2b-93f2-7e2413d4ca5b
│ ├── 0008c251-e9b0-4708-b762-b15cb6effc27
│ ├── ...
(recommend) By default, we use ray and 16 cpu cores for preprocessing. It will take about 30 minutes to finish.
python3 preprocess.py --data_root=/path/to/data_root -p
python3 preprocess.py --data_root=/path/to/data_root -m -p
or you can disable parallel preprocessing by removing -p
.
- For single-card training, remove
gpus=4
in the following commands.batch_size
refers to the batch size of each GPU. - If you use WandB, you can enable wandb logging by adding option
wandb=online
.
phase 1 - pre-training:
python3 train.py data_root=/path/to/data_root model=model_mae gpus=4 batch_size=32
phase 2 - fine-tuning:
(Note that quotes in 'pretrained_weights="/path/to/pretrain_ckpt"'
are necessary)
python3 train.py data_root=/path/to/data_root model=model_forecast gpus=4 batch_size=32 monitor=val_minFDE 'pretrained_weights="/path/to/pretrain_ckpt"'
python3 train.py data_root=/path/to/data_root model=model_forecast gpus=4 batch_size=32 monitor=val_minFDE
We also provide a simple multi-agent motion forecasting baseline using Forecast-MAE's backbone model.
python train.py data_root=/path/to/data_root model=model_forecast_mutliagent gpus=4 batch_size=32 monitor=val_AvgMinFDE
Evaluate on the validation set
python3 eval.py model=model_forecast data_root=/path/to/data_root batch_size=64 'checkpoint="/path/to/checkpoint"'
Generate submission file for the AV2 multi-agent motion forecasting benchmark
python3 eval.py model=model_forecast data_root=/path/to/data_root batch_size=64 'checkpoint="/path/to/checkpoint"' test=true
Evaluate on the validation set
python3 eval.py model=model_forecast_multiagent data_root=/path/to/data_root batch_size=64 'checkpoint="/path/to/checkpoint"'
Generate submission file for the AV2 multi-agent motion forecasting benchmark
python3 eval.py model=model_forecast_multiagent data_root=/path/to/data_root batch_size=64 'checkpoint="/path/to/checkpoint"' test=true
MAE-pretrained_weights: download.
A visualization notebook of the mae reconstruction result can be found here.
For this repository, the expected performance on Argoverse 2 validation set is:
Models | minADE1 | minFDE1 | minADE6 | minFDE6 | MR6 |
---|---|---|---|---|---|
Forecast-MAE (scratch) | 1.802 | 4.529 | 0.7214 | 1.430 | 0.187 |
Forecast-MAE (fine-tune) | 1.744 | 4.376 | 0.7117 | 1.408 | 0.178 |
Models | AvgMinADE6 | AvgMinFDE6 | ActorMR6 |
---|---|---|---|
Multiagent-Baseline | 0.717 | 1.64 | 0.194 |
You can download the checkpoints with the corresponding link.
This repo benefits from MAE, Point-BERT, Point-MAE, NATTEN and HiVT. Thanks for their great works.
If you found this repository useful, please consider citing our work:
@article{cheng2023forecast,
title={{Forecast-MAE}: Self-supervised Pre-training for Motion Forecasting with Masked Autoencoders},
author={Cheng, Jie and Mei, Xiaodong and Liu, Ming},
journal={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}