成功解决
RuntimeError: Attempting to deserialize object on CUDA device 1 but torch.cuda.device_count() is 1.

报错内容

程序在这一步报错
checkpoint = torch.load(‘model5_4.pt’)
在这里插入图片描述
以上问题描述是说未获取到当前环境下的 cuda,因为我的模型是在服务器上跑的,下载到本地后环境不同。

解决方法

若你当前在只有 CPU 环境下运行的话,需要加上map_location=torch.device(‘cpu’)。
若你当前在有 CUDA环境下运行的话,需要加上map_location=torch.device(‘cuda’)。
checkpoint = torch.load(‘model5_4.pt’)
即换成:

checkpoint = torch.load('model5_4.pt',map_location='cuda')

运行成功!不报错了!

Logo

欢迎来到FlagOS开发社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐