V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
V2EX 提问指南
Richard14
V2EX  ›  问与答

使用预训练模型的 Alexnet 进行图片分类,准确率与网络数据不符,可能是什么原因导致的?

  •  
  •   Richard14 · 2021-11-21 01:48:54 +08:00 · 613 次点击
    这是一个创建于 881 天前的主题,其中的信息可能已经有所发展或是发生改变。

    预训练的意思是用 torchvision 里写好的 alexnet (修改最后一层),不是指导入训练好的,尝试用 quickstart 里的代码训练 cifar10 ,但是网上普遍查到的实验数据,准确率大概在 80%,78%左右,我迭代到收敛也只能得到 70%的准确率,这个差异产生的原因是啥呢?

    完整代码:

    from utils import *
    from pipeit import *
    import os,sys,time,pickle,random
    import matplotlib.pyplot as plt
    import numpy as np 
    import torch
    from torch import nn
    from torchvision import datasets, models
    from torch.utils.data import Dataset, DataLoader, TensorDataset
    from torchvision.transforms import ToTensor, Lambda, Resize, Compose, InterpolationMode
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))
    torch.backends.cudnn.benchmark=True
    
    # Download training data from open datasets.
    training_data = datasets.CIFAR10(
        root=".\\data\\cifar10",
        train=True,
        download=True,
        transform=Compose([
            Resize((64, 64), InterpolationMode.BICUBIC),
            ToTensor()
        ])
    )
    
    # Download test data from open datasets.
    test_data = datasets.CIFAR10(
        root=".\\data\\cifar10",
        train=False,
        download=True,
        transform=Compose([
            Resize((64, 64), InterpolationMode.BICUBIC),
            ToTensor()
        ])
    )
    
    def imshow(training_data):
        labels_map = {
            0: "plane",
            1: "car",
            2: "bird",
            3: "cat",
            4: "deer",
            5: "dog",
            6: "frog",
            7: "horse",
            8: "ship",
            9: "truck",
        }
        cols, rows = 3, 3
        figure = plt.figure(figsize=(8,8))
        for i in range(1, cols * rows + 1):
            sample_idx = torch.randint(len(training_data), size=(1,)).item()
            img, label = training_data[sample_idx]
            img = img.swapaxes(0,1)
            img = img.swapaxes(1,2)
            figure.add_subplot(rows, cols, i)
            plt.title(labels_map[label])
            plt.axis("off")
            plt.imshow(img)
        plt.show()
    
    # imshow(training_data)
    
    def train_loop(dataloader, net, loss_fn, optimizer):
        size = len(dataloader)
        train_loss = 0
        for batch_idx, (X, tag) in enumerate(dataloader):
            X, tag = X.to(device), tag.to(device)
            pred = net(X)
            loss = loss_fn(pred, tag)
            train_loss += loss.item()
    
            # Back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        train_loss /= size 
        return train_loss
    
    def test_loop(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        test_loss, correct = 0, 0
    
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
        test_loss /= num_batches
        correct /= size
        return test_loss, correct
    
    net = models.alexnet().to(device)
    net.classifier[6] = nn.Linear(4096, 10).to(device)
    
    learning_rate = 0.01
    batch_size = 128
    weight_decay = 0
    
    train_dataloader = DataLoader(training_data, batch_size = batch_size)
    test_dataloader = DataLoader(test_data, batch_size = batch_size)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr = learning_rate)
    
    epochs = 50
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        st_time = time.time()
        train_loss = train_loop(train_dataloader, net, loss_fn, optimizer)
        test_loss, correct = test_loop(test_dataloader, net, loss_fn)
        print(f"Train loss: {train_loss:>8f}, Test loss: {test_loss:>8f}, Accuracy: {(100*correct):>0.1f}%, Epoch time: {time.time() - st_time:.2f}s\n")
    print("Done!")
    torch.save(net.state_dict(), 'alexnet-pre1.model')
    

    最后收敛时的数据在这样:

    Epoch 52
    -------------------------------
    Train loss: 0.399347, Test loss: 0.970927, Accuracy: 70.3%, Epoch time: 17.20s
    
    1 条回复    2021-11-21 23:53:55 +08:00
    KangolHsu
        1
    KangolHsu  
       2021-11-21 23:53:55 +08:00 via iPhone
    输入的图片 64*64 ?是不是有点小啊
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   我们的愿景   ·   实用小工具   ·   2798 人在线   最高记录 6543   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 25ms · UTC 02:17 · PVG 10:17 · LAX 19:17 · JFK 22:17
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.