Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@

### 训练环境支持

Windows/Linux
Windows/Linux/Macos(AppleSilicon)

Macos仅支持cpu训练

## 1、深度学习必备环境配置(非仅本项目要求,而是所有深度学习项目要求,cpu训练除外)

### 开始本教程前请先前往[pytorch](https://pytorch.org/get-started/locally/) 官网查看自己系统与硬件支持的pytorch版本,注意30系列之前的N卡,如2080Ti等请选择cuda11以下的版本(例:CUDA 10.2),如果为30系N卡,仅支持CUDA 11版本,请选择CUDA 11以上版本(例:CUDA 11.3),然后根据选择的条件显示的pytorch安装命令完成pytorch安装,由于pytorch的版本更新速度导致很多pypi源仅缓存了cpu版本,CUDA版本需要自己在官网安装。

### 安装AppleSilicon

需要在conda下安装pytorch

`conda install pytorch torchvision torchaudio -c pytorch-nightly`

### 安装CUDA和CUDNN

根据自己显卡型号与系统选择
Expand Down Expand Up @@ -55,7 +60,7 @@ project_name 为项目名称,尽量不要以特殊符号命名
项目支持两种形式的数据

### A、从文件名导入

图片均在同一个文件夹中,且命名为类似,其中/root/images_set为图片所在目录,可以为任意目录地址

```
Expand All @@ -65,7 +70,7 @@ project_name 为项目名称,尽量不要以特殊符号命名
|---- 酱闷肘子_随机hash值.jpg

```

如下图所示

![image](https://cdn.wenanzhe.com/img/mkGu_000001d00f140741741ed9916240d8d5.jpg)
Expand All @@ -79,7 +84,7 @@ project_name 为项目名称,尽量不要以特殊符号命名
### B、从文件中导入

受限于可能样本组织形式或者特殊字符,本项目支持从txt文档中导入数据,数据集目录必须包含有`labels.txt`文件和`images`文件夹, 其中/root/images_set为图片所在目录,可以为任意目录地址

`labels.txt`文件中包含了所有在`/root/images_set/images`目录下基于`/root/images_set/images`的图片相对路径,`/root/images_set/images`下可以有目录。

#### 当然,在这种模式下,图片的文件名随意,可以有具体label也可以没有,因为咱们不从这里获取图片的label
Expand All @@ -100,7 +105,7 @@ project_name 为项目名称,尽量不要以特殊符号命名
随机hash值.jpg\tabcd
随机hash值.jpg\tsdae
酱闷肘子_随机hash值.jpg\t酱闷肘子
```
```
b.images下有目录的形式
```
/root/images_set/
Expand All @@ -115,7 +120,7 @@ project_name 为项目名称,尽量不要以特殊符号命名
aaaa/随机hash值.jpg\tsdae
酱闷肘子_随机hash值.jpg\t酱闷肘子

```
```

### 为了新手更好的理解本部分的内容,本项目也提供了两套基础数据集提供测试

Expand Down
3 changes: 3 additions & 0 deletions nets/__init__.py

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接改成用 if pytoch.cuda.is_available()torch.backends.mps.is_available() 判断呗?毕竟黑苹果支持 cuda 但不支持 mps 不是么?

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .backbone import *
import torch
import sys

torch.set_num_threads(1)

Expand Down Expand Up @@ -191,6 +192,8 @@ def save_model(self, path, net):
def get_device(gpu_id):
if gpu_id == -1:
device = torch.device('cpu'.format(str(gpu_id)))
elif sys.platform == 'darwin':
device = torch.device('mps')
else:
device = torch.device('cuda:{}'.format(str(gpu_id)))
return device
Expand Down