centerpoint-kitti训练报错:TypeError: can‘t convert cuda:0 device type tensor to numpy. Use Tensor.cpu()
【代码】centerpoint-kitti训练报错:TypeError: can‘t convert cuda:0 device type tensor to numpy. Use Tensor.cpu()
·
具体报错如下:
Traceback (most recent call last):
File "train.py", line 198, in <module>
main()
File "train.py", line 153, in main
train_model(
File "/home/wang/CenterPoint-KITTI/tools/train_utils/train_utils.py", line 86, in train_model
accumulated_iter = train_one_epoch(
File "/home/wang/CenterPoint-KITTI/tools/train_utils/train_utils.py", line 38, in train_one_epoch
loss, tb_dict, disp_dict = model_func(model, batch)
File "/home/wang/CenterPoint-KITTI/pcdet/models/__init__.py", line 30, in model_func
ret_dict, tb_dict, disp_dict = model(batch_dict)
File "/home/wang/miniconda3/envs/cen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wang/CenterPoint-KITTI/pcdet/models/detectors/centerpoint.py", line 11, in forward
batch_dict = cur_module(batch_dict)
File "/home/wang/miniconda3/envs/cen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wang/CenterPoint-KITTI/pcdet/models/dense_heads/centerpoint_head_single.py", line 77, in forward
targets_dict = self.assign_targets(
File "/home/wang/CenterPoint-KITTI/pcdet/models/dense_heads/centerpoint_head_single.py", line 142, in assign_targets
heatmaps = np.array(heatmaps).transpose(1, 0).tolist()
File "/home/wang/miniconda3/envs/cen/lib/python3.8/site-packages/torch/_tensor.py", line 678, in __array__
return self.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
解决方法:
cd /your_dir/CenterPoint-KITTI/pcdet/models/dense_heads
找到其中的centerpoint_head_single.py,翻到下面的部分:

将红色框中的内容替换成下面的代码:
heatmaps = list(map(list, zip(*heatmaps))) # transpose list
anno_boxes = list(map(list, zip(*anno_boxes)))
inds = list(map(list, zip(*inds)))
masks = list(map(list, zip(*masks)))
heatmaps = [torch.stack(hms) for hms in heatmaps]
anno_boxes = [torch.stack(boxs) for boxs in anno_boxes]
inds = [torch.stack(idx) for idx in inds]
masks = [torch.stack(msk) for msk in masks]
替换以后如下:

欢迎来到FlagOS开发社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。
更多推荐
所有评论(0)