具体报错如下:

Traceback (most recent call last):
  File "train.py", line 198, in <module>
    main()
  File "train.py", line 153, in main
    train_model(
  File "/home/wang/CenterPoint-KITTI/tools/train_utils/train_utils.py", line 86, in train_model
    accumulated_iter = train_one_epoch(
  File "/home/wang/CenterPoint-KITTI/tools/train_utils/train_utils.py", line 38, in train_one_epoch
    loss, tb_dict, disp_dict = model_func(model, batch)
  File "/home/wang/CenterPoint-KITTI/pcdet/models/__init__.py", line 30, in model_func
    ret_dict, tb_dict, disp_dict = model(batch_dict)
  File "/home/wang/miniconda3/envs/cen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wang/CenterPoint-KITTI/pcdet/models/detectors/centerpoint.py", line 11, in forward
    batch_dict = cur_module(batch_dict)
  File "/home/wang/miniconda3/envs/cen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wang/CenterPoint-KITTI/pcdet/models/dense_heads/centerpoint_head_single.py", line 77, in forward
    targets_dict = self.assign_targets(
  File "/home/wang/CenterPoint-KITTI/pcdet/models/dense_heads/centerpoint_head_single.py", line 142, in assign_targets
    heatmaps = np.array(heatmaps).transpose(1, 0).tolist()
  File "/home/wang/miniconda3/envs/cen/lib/python3.8/site-packages/torch/_tensor.py", line 678, in __array__
    return self.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

解决方法:

cd /your_dir/CenterPoint-KITTI/pcdet/models/dense_heads

找到其中的centerpoint_head_single.py,翻到下面的部分:

将红色框中的内容替换成下面的代码:

heatmaps = list(map(list, zip(*heatmaps)))          # transpose list
        anno_boxes = list(map(list, zip(*anno_boxes)))
        inds = list(map(list, zip(*inds)))
        masks = list(map(list, zip(*masks)))

        heatmaps  = [torch.stack(hms)  for hms  in heatmaps]
        anno_boxes = [torch.stack(boxs) for boxs in anno_boxes]
        inds       = [torch.stack(idx)  for idx  in inds]
        masks      = [torch.stack(msk)  for msk  in masks]

替换以后如下:

Logo

欢迎来到FlagOS开发社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐