0 前言
經常會有同學問我在學習了python和torch之後,該繼續學習啥,我都是説,不用再去接着學習理論,直接上CNN的代碼開始跑就可以,比如貓狗分類的代碼,或者花朵分類的模型。
但是網上的代碼怎麼説呢......能夠直接拿來用的很少,經常需要自己調試,這次我給大家都調試好了,大家直接下載就可以了,點擊run就能跑!這個代碼包含訓練預測,文末可以直接獲取。
代碼主要用來做花朵分類,類別有雛菊、蒲公英、玫瑰花、向日葵和鬱金香:
1 代碼概況
1.1 文件目錄
下載壓縮包後,我們來看一下文件目錄,主要包含相關代碼和數據集,其中有些文件是在代碼運行的過程中生成的。
flower_data本來並沒有被劃分為訓練集和驗證集,需要採用split_data.py進行劃分,在這裏我已經為大家處理好了,因此split_data.py大家不用管,感興趣的可以研究一下。
1.2 使用方法:
1、直接在相關的IDE中,比如Pycharm,點擊train.py,執行run。
2、如果在服務器中,就執行python train.py。
1.3 訓練結果
通過Loss曲線圖、訓練過程曲線我們可以觀察網絡的收斂過程是否正常。
當然啦,最重要的還是我們的train.py,搭建網絡以及訓練都在這裏面。
訓練完成後,我們可以用自己的圖片predict.png,直接運行predict.py來進行預測:
2. 代碼講解
代碼主要分為6部分:
- 導入必要的庫
- 數據預處理和加載
- 模型搭建
- 網絡訓練
- 曲線繪製
- 模型預測
2.1 導入必要的庫
在程序一開始的時候,我們需要導入必要的庫,比如torch、transforms, datasets, utils等等,被用來處理數據和搭建網絡模型。
import os
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet
2.2 數據預處理和加載
由於圖片的大小和數據質量都是不一樣的,因此在網絡訓練之前我們需要對數據進行預處理,預處理的操作包括:
- 隨機選擇圖像區域裁剪至224×224大小;
- 數據增強:以50%的概率翻轉圖像;
- 將圖像轉化為tensor,並且將像素值從0-255縮放到0-1;
- 對圖像進行標準化;
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224), # 隨機裁剪到224x224
transforms.RandomHorizontalFlip(), # 隨機水平翻轉(數據增強)
transforms.ToTensor(), # 轉換為Tensor(0-1範圍)
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) # 標準化到[-1,1]
]),
"val": transforms.Compose([
transforms.Resize((224, 224)), # 直接resize到224x224
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
}
2.3 模型搭建
模型搭建的代碼主要在model.py中,然後我們直接進行調用,並且傳入最終需要分類的類別為5個。
我們來看AlexNet的網絡模型結構,我們的輸入大小是[3,224,224]。
- 經過第一個卷積層,輸出通道為48,卷積核大小為11×11,步長4,填充2;
- 然後經過ReLu激活函數;
- 再經過第一個池化層,池化核大小3×3,步長為2;
- 第二個卷積層和第一個卷積層類似,但是第三個卷積層和第四個卷積層是沒有池化層的;
- 第五個卷積層又是卷積——>激活函數——>池化的結構;
- 最後通過3個全連接層映射打到類別為5,並且在倒數第二個全連接層之前加入了Dropout隨機失活處理。
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2),#步長默認為1,當步長為1時不用設置# output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),#隨機將一半的節點失活,默認為0.5
nn.Linear(128 * 6 * 6, 2048),#將特徵矩陣展平,128*6*6最後輸出的長*寬*高,2048為全連接層節點個數
nn.ReLU(inplace=True),#Relu激活函數
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),#全連接層2的輸入為全連接層1的輸出2048,全連接層2的節點個數2048
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),#全連接層3的輸入為全連接層2的輸出2048
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x#最後的輸出為圖片類別
2.4 網絡訓練
網絡訓練的邏輯是直接調用模型的train方法,該方法繼承自 nn.Module 的類,然後我們將數據輸入至模型中,進行前向傳播,模型在完成前向傳播後有了輸出,再與標籤對比,計算損失函數,再反向傳播。
我們通過optimizer來更新每一個參數。計算出損失值,並且保存,然後再通過net.eval()進行驗證。
for epoch in range(epochs):
net.train() # 使用net.train()方法,該方法中有dropout
running_loss = 0.0 # 使用running_loss方法統計訓練過程中的平均損失
train_bar = tqdm(train_loader)
for step, data in enumerate(train_bar): # 遍歷數據集
images, labels = data # 將數據分為圖像標籤
optimizer.zero_grad() # 清空之前的梯度信息
outputs = net(images.to(device)) # 通過正向傳播的到輸出
loss = loss_function(outputs, labels.to(device)) # 指定設備gpu或者cpu,通過Loss_function函數計算預測值與真實值之間的差距
loss.backward() # 將損失反向傳播到每一個節點
optimizer.step() # 通過optimizer更新每一個參數
running_loss += loss.item() # 累加損失
rate = (step + 1) / len(train_loader)
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss:{:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
avg_train_loss = running_loss / train_steps
train_losses.append(avg_train_loss)
net.eval() # 預測過程中使用net.eval()函數,該函數會關閉掉dropout
acc = 0.0 # accumulate accurate number / epoch
2.5 曲線繪製
我們再將剛才保存下來的loss值和正確率的值通過plt庫畫出來,並且保存為圖片。
2.6 模型預測
在代碼訓練完成後,我們會將模型保存,此時可以通過predict.py進行預測,我們隨便在網上下載一張鬱金香的圖片運行。
得到結果: