🤵♂️ 個人主頁:@艾派森的個人主頁

✍🏻作者簡介:Python學習者
🐋 希望大家多多支持,我們一起進步!😄
如果文章對你有幫助的話,
歡迎評論 💬點贊👍🏻 收藏 📂加關注+


目錄

1.項目背景

2.數據集介紹

3.技術工具

4.實驗過程

4.1導入數據

4.2數據預處理

4.3模型編譯

4.4模型訓練

4.5模型評估

4.6結果可視化

源代碼


1.項目背景

        近年來,隨着醫學影像數據的快速增長和人工智能技術的迅猛發展,深度學習在輔助醫療診斷領域展現出巨大的應用潛力。腦腫瘤作為一種常見的顱內病變,其早期發現和精準分類對臨牀治療至關重要。傳統的診斷方法主要依賴放射科醫生人工分析核磁共振圖像,這不僅耗時耗力,而且容易受到主觀經驗的影響。雖然基於卷積神經網絡的模型在醫學影像分析中取得了顯著進展,但其局部感知特性在處理需要全局上下文信息的複雜病灶時仍存在一定侷限。在此背景下,Google Research於2020年提出的Vision Transformer模型通過自注意力機制實現了對圖像全局特徵的捕捉,為醫學圖像分析提供了新的技術路徑。本項目旨在探索ViT模型在腦腫瘤MRI圖像檢測任務中的適用性,通過構建端到端的智能診斷系統,實現對腦膠質瘤、腦膜瘤等常見腫瘤的自動識別與分類,為提升醫療診斷效率與準確性提供新的技術支撐。

2.數據集介紹

腦腫瘤是全球面臨的嚴峻健康挑戰,早期發現對有效治療至關重要。MRI(磁共振成像)是診斷腦腫瘤最常用的技術之一,因為它能夠提供腦部軟組織的詳細圖像。

該數據集包含分為兩類的 MRI 圖像:

  1. (是)對於腦腫瘤圖像
  2. (否)對於沒有腦腫瘤的圖像

該數據集專為二元分類任務而設計,可用於訓練和評估機器學習或深度學習模型,特別是卷積神經網絡(CNN)。

圖像收集自公開的醫學影像庫和開放的研究出版物,可用於教育和研究。數據集已整理到單獨的文件夾中,方便初學者和研究人員直接用於模型訓練。

該數據集的主要目的是為希望將計算機視覺技術應用於醫學成像領域的學生、研究人員和從業人員提供易於入門且隨時可用的資源。項目範圍涵蓋從簡單的 CNN 分類到更高級的應用,例如遷移學習和模型可解釋性(例如 Grad-CAM 可視化)。

3.技術工具

Python版本:3.9

代碼編輯器:jupyter notebook

4.實驗過程

4.1導入數據

首先導入本次實驗用到的第三方庫並加載數據集

import os
import cv2
import warnings
import timm
import torch
import copy
import types
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px

from tqdm import tqdm
from PIL import Image
from collections import defaultdict
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

sns.set_style("darkgrid")
warnings.filterwarnings('ignore')
DATA_DIR = './brain_tumor_dataset'
classes = [folder for folder in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, folder))]

查看數據信息

data = []
for cls in classes:
    class_path = os.path.join(DATA_DIR, cls)
    for filename in os.listdir(class_path):
        filepath = os.path.join(class_path, filename)
        data.append({'filepath': filepath, 'label': cls})

df = pd.DataFrame(data)
df = df.sample(frac=1).reset_index(drop=True)

display(df.head())
print("\n\nDataFrame Information:")
df.info()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_Image

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_Image_02

查看腫瘤和非腫瘤圖片

plt.figure(figsize=(15, 8))
n_samples = 5
yes_samples = df[df['label'] == 'yes'].sample(n_samples)

for i, row in enumerate(yes_samples.iterrows()):
    filepath = row[1]['filepath']
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.subplot(2, n_samples, i + 1)
    plt.imshow(img)
    plt.title(f"Tumor Detected\n(Class: yes)")
    plt.axis('off')


no_samples = df[df['label'] == 'no'].sample(n_samples)
for i, row in enumerate(no_samples.iterrows()):
    filepath = row[1]['filepath']
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.subplot(2, n_samples, n_samples + i + 1)
    plt.imshow(img)
    plt.title(f"No Tumor\n(Class: no)")
    plt.axis('off')
    
plt.tight_layout()
plt.show()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_Image_03

4.2數據預處理

查看數據類別分佈

class_counts = df['label'].value_counts()

fig_pie = px.pie(
    names=class_counts.index,
    values=class_counts.values,
    title='Class Distribution of Brain Tumor MRI',
    color_discrete_sequence=px.colors.qualitative.Pastel
)
fig_pie.update_traces(textposition='inside', textinfo='percent+label')
fig_pie.show()

plt.figure(figsize=(8, 6))
sns.countplot(x='label', data=df, palette='viridis')
plt.title('Count of Images per Class')
plt.xlabel('Class Label')
plt.ylabel('Number of Images')
plt.show()

print(class_counts)

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_Image_04

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#深度學習_05

可以看出這些類別是不平衡的。

圖像尺寸分析(可視化高度和寬度的分佈)

heights = []
widths = []

for filepath in tqdm(df['filepath'], desc="Reading image dimensions"):
    try:
        img = cv2.imread(filepath)
        h, w, _ = img.shape
        heights.append(h)
        widths.append(w)
    except Exception as e:
        print(f"Could not read {filepath}: {e}")

df['height'] = heights
df['width'] = widths

plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
sns.histplot(df['height'], kde=True, color='skyblue').set_title('Distribution of Image Heights')
plt.subplot(1, 2, 2)
sns.histplot(df['width'], kde=True, color='salmon').set_title('Distribution of Image Widths')
plt.show()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#人工智能_06

IMG_SIZE = 224
avg_yes = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.float64)
avg_no = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.float64)

count_yes = len(df[df['label'] == 'yes'])
count_no = len(df[df['label'] == 'no'])

for filepath in tqdm(df[df['label'] == 'yes']['filepath'], desc="Averaging 'yes' images"):
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    avg_yes += img / count_yes

for filepath in tqdm(df[df['label'] == 'no']['filepath'], desc="Averaging 'no' images"):
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    avg_no += img / count_no

avg_yes_img = np.array(np.round(avg_yes), dtype=np.uint8)
avg_no_img = np.array(np.round(avg_no), dtype=np.uint8)

print("\nDisplaying the 'average' brain MRI for each class:")
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(avg_no_img)
plt.title('Average Non-Tumorous Brain')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(avg_yes_img)
plt.title('Average Tumorous Brain')
plt.axis('off')

plt.suptitle('Mean Image Representation by Class', fontsize=16)
plt.show()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#transformer_07

平均的腫瘤腦圖像似乎有一個稍亮、模糊、密集、更明確的中心區域,暗示了該數據集中腫瘤的共同位置或外觀。這表明空間信息是我們的模型學習的一個非常強烈的信號。另一方面,非腫瘤腦圖像更加透明和清晰。

IMG_SIZE = 224
BATCH_SIZE = 32
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=666)

for train_index, val_index in splitter.split(df, df['label']):
    train_df = df.iloc[train_index]
    val_df = df.iloc[val_index]

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print("Training data info:")
print(train_df['label'].value_counts(normalize=True))
print("\nValidation data info:")
print(val_df['label'].value_counts(normalize=True))

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#transformer_08

train_transforms = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=20, p=0.7),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=IMG_SIZE//8, max_width=IMG_SIZE//8, 
                    min_holes=1, fill_value=0, p=0.5),
    A.Normalize(mean=MEAN, std=STD),
    ToTensorV2()
])

val_transforms = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=MEAN, std=STD),
    ToTensorV2()
])

創建自定義PyTorch數據集

class BrainMRIDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        self.labels = self.df['label'].map({'yes': 1, 'no': 0}).values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        filepath = self.df.loc[idx, 'filepath']
        label = self.labels[idx]
        image = cv2.imread(filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transforms:
            transformed = self.transforms(image=image)
            image = transformed['image']
            
        return image, torch.tensor(label, dtype=torch.float32)
temp_train_dataset = BrainMRIDataset(train_df, transforms=train_transforms)
print("Visualizing some of the applied augmentations on a single training image:")

n_images = 8
original_image, _ = temp_train_dataset[0]
viz_transforms = A.Compose([t for t in temp_train_dataset.transforms if not isinstance(t, (A.Normalize, ToTensorV2))])

raw_img = cv2.imread(temp_train_dataset.df.loc[0, 'filepath'])
raw_img = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(16, 8))
plt.subplot(2, (n_images + 1) // 2, 1)
plt.imshow(raw_img)
plt.title("Original Image")
plt.axis('off')

for i in range(1, n_images):
    augmented = viz_transforms(image=raw_img)['image']
    plt.subplot(2, (n_images + 1) // 2, i + 1)
    plt.imshow(augmented)
    plt.title(f"Augmented {i}")
    plt.axis('off')
    
plt.tight_layout()
plt.show()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#深度學習_09

創建DataLoaders

train_dataset = BrainMRIDataset(train_df, transforms=train_transforms)
val_dataset = BrainMRIDataset(val_df, transforms=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"\nNumber of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

images, labels = next(iter(train_loader))
print(f"\nShape of one batch of images: {images.shape}")
print(f"Shape of one batch of labels: {labels.shape}")

4.3模型編譯

在建模方面,我的目標是使用ViT模型;特別是vit_base_patch16_224模型。

MODEL_NAME = 'vit_base_patch16_224'
PRETRAINED = True
device = "cuda" 

def create_model(num_classes=1):
    model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=num_classes)
    in_features = model.head.in_features
    model.head = nn.Linear(in_features, num_classes)
    
    return model


model = create_model()
model.to(device)
try:
    images, labels = next(iter(train_loader))
except NameError:
    print("DataLoaders not found. Creating dummy data for the smoke test.")
    images = torch.randn(BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)

images = images.to(device)
with torch.no_grad():
    output = model(images)

assert output.shape == (BATCH_SIZE, 1), f"Output shape is incorrect! Expected {(BATCH_SIZE, 1)}, but got {output.shape}"

4.4模型訓練

先自定義訓練和驗證函數

EPOCHS = 15
LR = 1e-4
WEIGHT_DECAY = 1e-6

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * EPOCHS, eta_min=1e-6)

def train_one_epoch(model, train_loader, loss_fn, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    progress_bar = tqdm(train_loader, desc="Training", total=len(train_loader))

    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        outputs = outputs.squeeze()
        
        loss = loss_fn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * images.size(0)
        
        preds = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (preds == labels).sum().item()
        total_samples += labels.size(0)

        progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions/total_samples)

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc


def validate_one_epoch(model, val_loader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    progress_bar = tqdm(val_loader, desc="Validation", total=len(val_loader))

    with torch.no_grad():
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images).squeeze()
            loss = loss_fn(outputs, labels)

            running_loss += loss.item() * images.size(0)
            
            preds = (torch.sigmoid(outputs) > 0.5).float()
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)
            
            progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions/total_samples)

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

開始訓練

history = defaultdict(list)
best_val_acc = 0.0
best_model_state = None

print("🚀 Starting training!")
for epoch in range(1, EPOCHS + 1):
    print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
    
    train_loss, train_acc = train_one_epoch(model, train_loader, loss_fn, optimizer, scheduler, device)
    val_loss, val_acc = validate_one_epoch(model, val_loader, loss_fn, device)

    print(f"Epoch {epoch} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Valid Loss: {val_loss:.4f} | Valid Acc: {val_acc:.4f}")
    print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = copy.deepcopy(model.state_dict())
        torch.save(best_model_state, 'best_model.pth')
        print(f"🎉 New best model saved with validation accuracy: {best_val_acc:.4f}")

print("\n🏁 Training finished!")
print(f"Best validation accuracy achieved: {best_val_acc:.4f}")

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_數據集_10

4.5模型評估

model = create_model()
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

# Plot Training & Validation Accuracy
ax1.plot(history['train_acc'], label='Train Accuracy', color='royalblue')
ax1.plot(history['val_acc'], label='Validation Accuracy', color='darkorange')
ax1.set_title('Model Accuracy', fontsize=16)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Accuracy', fontsize=12)
ax1.legend()
ax1.grid(True)

# Plot Training & Validation Loss
ax2.plot(history['train_loss'], label='Train Loss', color='royalblue')
ax2.plot(history['val_loss'], label='Validation Loss', color='darkorange')
ax2.set_title('Model Loss', fontsize=16)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.legend()
ax2.grid(True)

plt.suptitle('Training and Validation Learning Curves', fontsize=20)
plt.show()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_數據集_11

def get_predictions(model, data_loader, device):
    model.eval() 
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Generating Predictions"):
            images = images.to(device)
            
            outputs = model(images).squeeze()
            
            # Get probabilities for ROC curve
            probs = torch.sigmoid(outputs)
            
            # Get final predictions (0 or 1)
            preds = (probs > 0.5).float()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    return np.array(all_labels), np.array(all_preds), np.array(all_probs)

y_true, y_pred, y_prob = get_predictions(model, val_loader, device)
target_names = ['No Tumor', 'Yes Tumor']
print("\n" + "="*50)
print("           Classification Report")
print("="*50 + "\n")
print(classification_report(y_true, y_pred, target_names=target_names))

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.title('Confusion Matrix', fontsize=16)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.show()

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#transformer_12

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#深度學習_13

fpr, tpr, thresholds = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='No-Skill Line')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=16)
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

print(f"\nArea Under the Curve (AUC): {roc_auc:.4f}")

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_數據集_14

總體而言,在數據集中樣本數量有限的情況下,模型表現出了出色的性能,顯示了vit的強大!我們的AUC得分為0.971,CM和分類報告是令人滿意的。

4.6結果可視化

PATCH_SIZE = 16

def capture_attention_forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)

    self.attention_map = attn
    attn = self.attn_drop(attn)
    
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

last_block = model.blocks[-1]
last_block.attn.forward = types.MethodType(capture_attention_forward, last_block.attn)

def visualize_attention(model, image_tensor, original_image, device, true_label, pred_label):
    model.eval()
    
    image_tensor = image_tensor.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image_tensor)
    
    attn_map = last_block.attn.attention_map
    attn_map = attn_map.mean(dim=1).squeeze(0)

    cls_token_attn = attn_map[0, 1:]

    grid_size = int(np.sqrt(cls_token_attn.shape[0]))
    attn_grid = cls_token_attn.cpu().numpy().reshape(grid_size, grid_size)

    resized_attn = cv2.resize(attn_grid, (original_image.shape[1], original_image.shape[0]))
    resized_attn = (resized_attn - resized_attn.min()) / (resized_attn.max() - resized_attn.min())

    heatmap = cv2.applyColorMap(np.uint8(255 * resized_attn), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    superimposed_img = cv2.addWeighted(original_image, 0.6, heatmap, 0.4, 0)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    class_names = {0: 'No Tumor', 1: 'Yes Tumor'}
    
    ax1.imshow(original_image)
    ax1.set_title("Original Image", fontsize=14)
    ax1.axis('off')

    ax2.imshow(heatmap)
    ax2.set_title("Attention Heatmap", fontsize=14)
    ax2.axis('off')
    
    ax3.imshow(superimposed_img)
    ax3.set_title("Superimposed View", fontsize=14)
    ax3.axis('off')
    
    plt.suptitle(f"True: {class_names[true_label]} | Predicted: {class_names[pred_label]}", fontsize=18)
    plt.show()

correct_yes_idx = np.where((y_true == 1) & (y_pred == 1))[0]
correct_no_idx = np.where((y_true == 0) & (y_pred == 0))[0]
incorrect_idx = np.where(y_true != y_pred)[0]

indices_to_show = []
if len(correct_yes_idx) > 0: indices_to_show.append(correct_yes_idx[0])
if len(correct_no_idx) > 0: indices_to_show.append(correct_no_idx[0])
if len(incorrect_idx) > 0: indices_to_show.append(incorrect_idx[0])

for idx in indices_to_show:
    img_tensor, label = val_dataset[idx]
    
    filepath = val_df.loc[idx, 'filepath']
    original_img = np.array(Image.open(filepath).convert('RGB')) # Read with PIL
    
    true_label = int(y_true[idx])
    pred_label = int(y_pred[idx])

    visualize_attention(model, img_tensor, original_img, device, true_label, pred_label)

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_#人工智能_15

深度學習實戰-基於Vision Transformer的腦腫瘤MRI圖像識別_Image_16

可以看出演示的這兩條數據模型均預測正確!

源代碼

import os
import cv2
import warnings
import timm
import torch
import copy
import types
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px

from tqdm import tqdm
from PIL import Image
from collections import defaultdict
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

sns.set_style("darkgrid")
warnings.filterwarnings('ignore')
DATA_DIR = './brain_tumor_dataset'
classes = [folder for folder in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, folder))]
data = []
for cls in classes:
    class_path = os.path.join(DATA_DIR, cls)
    for filename in os.listdir(class_path):
        filepath = os.path.join(class_path, filename)
        data.append({'filepath': filepath, 'label': cls})

df = pd.DataFrame(data)
df = df.sample(frac=1).reset_index(drop=True)

print("First 5 rows of the DataFrame:")
display(df.head())
print("\n\nDataFrame Information:")
df.info()
print("\nClass Distribution:")
class_counts = df['label'].value_counts()

fig_pie = px.pie(
    names=class_counts.index,
    values=class_counts.values,
    title='Class Distribution of Brain Tumor MRI',
    color_discrete_sequence=px.colors.qualitative.Pastel
)
fig_pie.update_traces(textposition='inside', textinfo='percent+label')
fig_pie.show()

plt.figure(figsize=(8, 6))
sns.countplot(x='label', data=df, palette='viridis')
plt.title('Count of Images per Class')
plt.xlabel('Class Label')
plt.ylabel('Number of Images')
plt.show()

print(class_counts)
plt.figure(figsize=(15, 8))
n_samples = 5
yes_samples = df[df['label'] == 'yes'].sample(n_samples)

for i, row in enumerate(yes_samples.iterrows()):
    filepath = row[1]['filepath']
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.subplot(2, n_samples, i + 1)
    plt.imshow(img)
    plt.title(f"Tumor Detected\n(Class: yes)")
    plt.axis('off')


no_samples = df[df['label'] == 'no'].sample(n_samples)
for i, row in enumerate(no_samples.iterrows()):
    filepath = row[1]['filepath']
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.subplot(2, n_samples, n_samples + i + 1)
    plt.imshow(img)
    plt.title(f"No Tumor\n(Class: no)")
    plt.axis('off')
    
plt.tight_layout()
plt.show()
heights = []
widths = []

for filepath in tqdm(df['filepath'], desc="Reading image dimensions"):
    try:
        img = cv2.imread(filepath)
        h, w, _ = img.shape
        heights.append(h)
        widths.append(w)
    except Exception as e:
        print(f"Could not read {filepath}: {e}")

df['height'] = heights
df['width'] = widths

print("\nStatistics for image dimensions:")
display(df[['height', 'width']].describe())

plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
sns.histplot(df['height'], kde=True, color='skyblue').set_title('Distribution of Image Heights')
plt.subplot(1, 2, 2)
sns.histplot(df['width'], kde=True, color='salmon').set_title('Distribution of Image Widths')
plt.show()
IMG_SIZE = 224
avg_yes = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.float64)
avg_no = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.float64)

count_yes = len(df[df['label'] == 'yes'])
count_no = len(df[df['label'] == 'no'])

for filepath in tqdm(df[df['label'] == 'yes']['filepath'], desc="Averaging 'yes' images"):
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    avg_yes += img / count_yes

for filepath in tqdm(df[df['label'] == 'no']['filepath'], desc="Averaging 'no' images"):
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    avg_no += img / count_no

avg_yes_img = np.array(np.round(avg_yes), dtype=np.uint8)
avg_no_img = np.array(np.round(avg_no), dtype=np.uint8)

print("\nDisplaying the 'average' brain MRI for each class:")
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(avg_no_img)
plt.title('Average Non-Tumorous Brain')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(avg_yes_img)
plt.title('Average Tumorous Brain')
plt.axis('off')

plt.suptitle('Mean Image Representation by Class', fontsize=16)
plt.show()
IMG_SIZE = 224
BATCH_SIZE = 32
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=666)

for train_index, val_index in splitter.split(df, df['label']):
    train_df = df.iloc[train_index]
    val_df = df.iloc[val_index]

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print("Training data info:")
print(train_df['label'].value_counts(normalize=True))
print("\nValidation data info:")
print(val_df['label'].value_counts(normalize=True))
train_transforms = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=20, p=0.7),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=IMG_SIZE//8, max_width=IMG_SIZE//8, 
                    min_holes=1, fill_value=0, p=0.5),
    A.Normalize(mean=MEAN, std=STD),
    ToTensorV2()
])

val_transforms = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=MEAN, std=STD),
    ToTensorV2()
])
class BrainMRIDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        self.labels = self.df['label'].map({'yes': 1, 'no': 0}).values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        filepath = self.df.loc[idx, 'filepath']
        label = self.labels[idx]
        image = cv2.imread(filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transforms:
            transformed = self.transforms(image=image)
            image = transformed['image']
            
        return image, torch.tensor(label, dtype=torch.float32)
temp_train_dataset = BrainMRIDataset(train_df, transforms=train_transforms)
print("Visualizing some of the applied augmentations on a single training image:")

n_images = 8
original_image, _ = temp_train_dataset[0]
viz_transforms = A.Compose([t for t in temp_train_dataset.transforms if not isinstance(t, (A.Normalize, ToTensorV2))])

raw_img = cv2.imread(temp_train_dataset.df.loc[0, 'filepath'])
raw_img = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(16, 8))
plt.subplot(2, (n_images + 1) // 2, 1)
plt.imshow(raw_img)
plt.title("Original Image")
plt.axis('off')

for i in range(1, n_images):
    augmented = viz_transforms(image=raw_img)['image']
    plt.subplot(2, (n_images + 1) // 2, i + 1)
    plt.imshow(augmented)
    plt.title(f"Augmented {i}")
    plt.axis('off')
    
plt.tight_layout()
plt.show()
train_dataset = BrainMRIDataset(train_df, transforms=train_transforms)
val_dataset = BrainMRIDataset(val_df, transforms=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"\nNumber of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

images, labels = next(iter(train_loader))
print(f"\nShape of one batch of images: {images.shape}")
print(f"Shape of one batch of labels: {labels.shape}")   
MODEL_NAME = 'vit_base_patch16_224'
PRETRAINED = True
device = "cuda" 

def create_model(num_classes=1):
    model = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=num_classes)
    in_features = model.head.in_features
    model.head = nn.Linear(in_features, num_classes)
    
    return model


model = create_model()
model.to(device)

print(f"\nModel '{MODEL_NAME}' loaded and adapted for binary classification.")
try:
    images, labels = next(iter(train_loader))
except NameError:
    print("DataLoaders not found. Creating dummy data for the smoke test.")
    images = torch.randn(BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)

images = images.to(device)
with torch.no_grad():
    output = model(images)

assert output.shape == (BATCH_SIZE, 1), f"Output shape is incorrect! Expected {(BATCH_SIZE, 1)}, but got {output.shape}"
EPOCHS = 15
LR = 1e-4
WEIGHT_DECAY = 1e-6

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * EPOCHS, eta_min=1e-6)
def train_one_epoch(model, train_loader, loss_fn, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    progress_bar = tqdm(train_loader, desc="Training", total=len(train_loader))

    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        outputs = outputs.squeeze()
        
        loss = loss_fn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * images.size(0)
        
        preds = (torch.sigmoid(outputs) > 0.5).float()
        correct_predictions += (preds == labels).sum().item()
        total_samples += labels.size(0)

        progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions/total_samples)

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc


def validate_one_epoch(model, val_loader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    progress_bar = tqdm(val_loader, desc="Validation", total=len(val_loader))

    with torch.no_grad():
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images).squeeze()
            loss = loss_fn(outputs, labels)

            running_loss += loss.item() * images.size(0)
            
            preds = (torch.sigmoid(outputs) > 0.5).float()
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)
            
            progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions/total_samples)

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc
history = defaultdict(list)
best_val_acc = 0.0
best_model_state = None

print("🚀 Starting training!")
for epoch in range(1, EPOCHS + 1):
    print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
    
    train_loss, train_acc = train_one_epoch(model, train_loader, loss_fn, optimizer, scheduler, device)
    val_loss, val_acc = validate_one_epoch(model, val_loader, loss_fn, device)

    print(f"Epoch {epoch} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Valid Loss: {val_loss:.4f} | Valid Acc: {val_acc:.4f}")
    print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = copy.deepcopy(model.state_dict())
        torch.save(best_model_state, 'best_model.pth')
        print(f"🎉 New best model saved with validation accuracy: {best_val_acc:.4f}")

print("\n🏁 Training finished!")
print(f"Best validation accuracy achieved: {best_val_acc:.4f}")
model = create_model()
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)

print("Best model loaded successfully for evaluation.")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

# Plot Training & Validation Accuracy
ax1.plot(history['train_acc'], label='Train Accuracy', color='royalblue')
ax1.plot(history['val_acc'], label='Validation Accuracy', color='darkorange')
ax1.set_title('Model Accuracy', fontsize=16)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Accuracy', fontsize=12)
ax1.legend()
ax1.grid(True)

# Plot Training & Validation Loss
ax2.plot(history['train_loss'], label='Train Loss', color='royalblue')
ax2.plot(history['val_loss'], label='Validation Loss', color='darkorange')
ax2.set_title('Model Loss', fontsize=16)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.legend()
ax2.grid(True)

plt.suptitle('Training and Validation Learning Curves', fontsize=20)
plt.show()
def get_predictions(model, data_loader, device):
    model.eval() 
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Generating Predictions"):
            images = images.to(device)
            
            outputs = model(images).squeeze()
            
            # Get probabilities for ROC curve
            probs = torch.sigmoid(outputs)
            
            # Get final predictions (0 or 1)
            preds = (probs > 0.5).float()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    return np.array(all_labels), np.array(all_preds), np.array(all_probs)

y_true, y_pred, y_prob = get_predictions(model, val_loader, device)
target_names = ['No Tumor', 'Yes Tumor']
print("\n" + "="*50)
print("           Classification Report")
print("="*50 + "\n")
print(classification_report(y_true, y_pred, target_names=target_names))

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.title('Confusion Matrix', fontsize=16)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.show()
fpr, tpr, thresholds = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='No-Skill Line')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=16)
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

print(f"\nArea Under the Curve (AUC): {roc_auc:.4f}")
PATCH_SIZE = 16

def capture_attention_forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)

    self.attention_map = attn
    attn = self.attn_drop(attn)
    
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

last_block = model.blocks[-1]
last_block.attn.forward = types.MethodType(capture_attention_forward, last_block.attn)

print("Model has been patched to capture attention maps.")
def visualize_attention(model, image_tensor, original_image, device, true_label, pred_label):
    """
    Generates and displays the attention map for a given image.
    """
    model.eval()
    
    image_tensor = image_tensor.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image_tensor)
    
    attn_map = last_block.attn.attention_map
    attn_map = attn_map.mean(dim=1).squeeze(0)

    cls_token_attn = attn_map[0, 1:]

    grid_size = int(np.sqrt(cls_token_attn.shape[0]))
    attn_grid = cls_token_attn.cpu().numpy().reshape(grid_size, grid_size)

    resized_attn = cv2.resize(attn_grid, (original_image.shape[1], original_image.shape[0]))
    resized_attn = (resized_attn - resized_attn.min()) / (resized_attn.max() - resized_attn.min())

    heatmap = cv2.applyColorMap(np.uint8(255 * resized_attn), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    superimposed_img = cv2.addWeighted(original_image, 0.6, heatmap, 0.4, 0)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    class_names = {0: 'No Tumor', 1: 'Yes Tumor'}
    
    ax1.imshow(original_image)
    ax1.set_title("Original Image", fontsize=14)
    ax1.axis('off')

    ax2.imshow(heatmap)
    ax2.set_title("Attention Heatmap", fontsize=14)
    ax2.axis('off')
    
    ax3.imshow(superimposed_img)
    ax3.set_title("Superimposed View", fontsize=14)
    ax3.axis('off')
    
    plt.suptitle(f"True: {class_names[true_label]} | Predicted: {class_names[pred_label]}", fontsize=18)
    plt.show()
correct_yes_idx = np.where((y_true == 1) & (y_pred == 1))[0]
correct_no_idx = np.where((y_true == 0) & (y_pred == 0))[0]
incorrect_idx = np.where(y_true != y_pred)[0]

indices_to_show = []
if len(correct_yes_idx) > 0: indices_to_show.append(correct_yes_idx[0])
if len(correct_no_idx) > 0: indices_to_show.append(correct_no_idx[0])
if len(incorrect_idx) > 0: indices_to_show.append(incorrect_idx[0])

for idx in indices_to_show:
    img_tensor, label = val_dataset[idx]
    
    filepath = val_df.loc[idx, 'filepath']
    original_img = np.array(Image.open(filepath).convert('RGB')) # Read with PIL
    
    true_label = int(y_true[idx])
    pred_label = int(y_pred[idx])

    visualize_attention(model, img_tensor, original_img, device, true_label, pred_label)