재형이의 성장통 일지
  • 인공 신경망 코드로 구현해서 다중 분류해보기 (2)
    2024년 03월 17일 12시 02분 55초에 업로드 된 글입니다.
    작성자: 재형이
    반응형
     

     

    • 오늘 저녁 약속이 있는데 잘 하고 올게요~

     

     

     

     

     

     


     

     

     

     

     

     

     

     

     

    • 오늘은 지난번에 했던 데이터 세트와 다른 데이터로 복습할겸 인공신경망을 구성해서 다시 학습을 해볼 것이다
    • 추가로 학습이 완료된 모델을 저장하는 방법도 실습해볼 것이다
      • 모델을 저장하는 방법에는 두가지가 있다
        1. 모델의 파라미터만 저장하는 방법
          • 이 방법은 모델의 계층 구조를 알고 있어야만 나중에 다시 해당 모델을 사용할 수 있음
        2. 모델 전체를 저장하는 방법
    • 사용할 데이터 세트 : Fashion MNIST
     

    torchvision.datasets — Torchvision master documentation

    torchvision.datasets All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.

    pytorch.org

    복습

    1. TensorDataset과 DataLoader
      • 입력 데이터를 쉽게 처리하고, 배치 단위로 잘러서 학습할 수 있게 도와주는 모듈
      • Dataset : 학습시 사용하는 feature와 target의 pair로 이루어짐
      • DataLoader: 학습 시 각 인스턴스에 쉽게 접근할 수 있도록 순회 가능한 객체(iterable)를 생성
      • DataLoader가 하는 역할
        • shuffling
        • batch ...
    2. Device 설정
      • 일반적으로 인공신경망의 학습은 (가능하다면) GPU를 사용하는 것이 바람직함
      • GPU를 사용하여 학습을 진행하도록 명시적으로 작성 필요
      • 연산 유형에 따라 GPU에서 수행이 불가능한 경우도 존재하는데, 그럴 경우도 마찬가지로 명시적으로 어떤 프로세서에서 연산을 수행해야하는지 코드로 작성해야함
    3. 신경망 생성
      • torch.nn 패키지는 신경망 생성 및 학습 시 설정해야하는 다양한 기능을 제공
      • 신경망을 nn.Module을 상속받아 정의
        • __ init __(): 신경망에서 사용할 layer를 초기화하는 부분
        • forward(): feed foward 연산 수행 시, 각 layer의 입출력이 어떻게 연결되는지를 지정
    4. Model compile
      • 학습 시 필요한 정보들(loss function, optimizer)을 선언
      • 일반적으로 loss와 optimizer는 아래와 같이 변수로 선언하고, 변수를 train/test 시 참고할 수 있도록 매개변수로 지정해줌
    5. Train
      • 신경망의 학습과정을 별도의 함수로 구성하는 것이 일반적
        • feed forward → loss → error back propagation → print(진행상황) 또는 로깅 → (반복)
    6. Test
      • 학습과정과 비슷하나 error back propagate하는 부분이 없음
        • feed forward → loss → print(진행상황) 또는 로깅 → (반복)
    7. Iteration
      • 신경망 학습은 여러 epochs을 반복해서 수행하면서 모델을 구성하는 최적의 파라미터를 찾음
      • 지정한 epochs 수만큼 학습과정과 평가과정을 반복하면서, 모델의 성능(loss, accuracy 등)을 체크함

    Fashion MNIST Classifier

    • Fashion MNIST 데이터셋을 사용하여 옷의 품목을 구분하는 분류기를 신경망을 사용하여 구현

    [Step1] Load libraries & Datasets

    • torch.nn : 신경망을 생성하기 위한 기본 재료들을 제공(Modules, Sequential, Layer, Activations, Loss, Dropout...)
    • torchvision.datasets : torchvision.transforms를 사용해 변형이 가능한 형태, feature와 label을 반환
    • torchvision.transforms
      • ToTensor() : ndarray를 FloatTensor로 변환하고 이미지 픽셀 크기를 [0., 1.]범위로 조정(scale)
    import numpy as np
    import matplotlib.pyplot as plt
    
    import torch
    from torch.utils.data import DataLoader
    from torch import nn
    
    from torchvision import datasets
    from torchvision.transforms import ToTensor
    
    # FashionMNIST 데이터 불러오기
    training_data = datasets.FashionMNIST(
        root = 'data',
        train = True,
        download = True,
        transform = ToTensor(),
    )
    
    test_data = datasets.FashionMNIST(
        root = 'data',
        train = False,
        download = True,
        transform = ToTensor(),
    )

    [Step2] Create DataLoader

    train_dataloader = DataLoader(training_data, batch_size = 64, shuffle = True)
    test_dataloader = DataLoader(test_data, batch_size = 64, shuffle = False)
    
    # Device 설정
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'device = {device}')
    • device = cuda

    EDA(Exploratory Data Analysis, 탐색적 데이터 분석)

    print(training_data, '\n------------------\n', test_data)
    
    training_data[0]
    
    train_features, train_labels = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    
    img, label = training_data[0]
    plt.imshow(img.squeeze(), cmap='gray')
    print(f'label={label}')
    
    labels_map = {
        0: "T-Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    
    figure = plt.figure(figsize = (20, 8))
    cols, rows = 5, 2
    
    for i in range(1, cols * rows +1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels_map[label])
        print(labels_map[label])
        plt.axis('off')
        plt.imshow(img.squeeze(), cmap='gray')
    plt.show()
    • Dataset FashionMNIST
          Number of datapoints: 60000
          Root location: data
          Split: Train
          StandardTransform
      Transform: ToTensor() 
      ------------------
       Dataset FashionMNIST
          Number of datapoints: 10000
          Root location: data
          Split: Test
          StandardTransform
      Transform: ToTensor()
    • (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,
                 0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0039, 0.0039, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,
                 0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,
                 0.0157, 0.0000, 0.0000, 0.0118],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,
                 0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0471, 0.0392, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,
                 0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,
                 0.3020, 0.5098, 0.2824, 0.0588],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,
                 0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,
                 0.5529, 0.3451, 0.6745, 0.2588],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,
                 0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,
                 0.4824, 0.7686, 0.8980, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,
                 0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,
                 0.8745, 0.9608, 0.6784, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,
                 0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,
                 0.8627, 0.9529, 0.7922, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,
                 0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,
                 0.8863, 0.7725, 0.8196, 0.2039],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,
                 0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,
                 0.9608, 0.4667, 0.6549, 0.2196],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,
                 0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,
                 0.8510, 0.8196, 0.3608, 0.0000],
                [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,
                 0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,
                 0.8549, 1.0000, 0.3020, 0.0000],
                [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,
                 0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,
                 0.8784, 0.9569, 0.6235, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,
                 0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,
                 0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,
                 0.9137, 0.9333, 0.8431, 0.0000],
                [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,
                 0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,
                 0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,
                 0.8627, 0.9098, 0.9647, 0.0000],
                [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,
                 0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,
                 0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,
                 0.8706, 0.8941, 0.8824, 0.0000],
                [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,
                 0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,
                 0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,
                 0.8745, 0.8784, 0.8980, 0.1137],
                [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,
                 0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,
                 0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,
                 0.8627, 0.8667, 0.9020, 0.2627],
                [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,
                 0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,
                 0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,
                 0.7098, 0.8039, 0.8078, 0.4510],
                [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,
                 0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,
                 0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,
                 0.6549, 0.6941, 0.8235, 0.3608],
                [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,
                 0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,
                 0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,
                 0.7529, 0.8471, 0.6667, 0.0000],
                [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,
                 0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,
                 0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,
                 0.3882, 0.2275, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,
                 0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                 0.0000, 0.0000, 0.0000, 0.0000]]]),
       9)
    • Feature batch shape: torch.Size([64, 1, 28, 28])
      Labels batch shape: torch.Size([64])

    [Step3] Set Network Structure

    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.classifier = nn.Sequential(
                nn.Linear(28*28, 128),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(128, 10),
            )
    
        def forward(self, x):
            x = self.flatten(x)
            output = self.classifier(x)
            return output

    [Step4] Create Model instance

    model = NeuralNetwork().to(device)
    print(model)
    • NeuralNetwork(
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (classifier): Sequential(
          (0): Linear(in_features=784, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.2, inplace=False)
          (3): Linear(in_features=128, out_features=10, bias=True)
        )
      )

    Model 테스트

    X = torch.rand(1, 28, 28, device = device)
    output = model(X)
    print(f'모델 출력 결과: {output}\n')
    pred_probab = nn.Softmax(dim=1)(output)
    print(f'Softmax 결과: {pred_probab}\n')
    y_pred = pred_probab.argmax()
    print(y_pred)
    • 모델 출력 결과: tensor([[ 0.3374, -0.2491,  0.2523, -0.0858, -0.0171, -0.0398, -0.0953, -0.1535,
               -0.0450,  0.0389]], device='cuda:0', grad_fn=<AddmmBackward0>)

      Softmax 결과: tensor([[0.1388, 0.0772, 0.1275, 0.0909, 0.0974, 0.0952, 0.0901, 0.0850, 0.0947,
               0.1030]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

      tensor(0, device='cuda:0')

    [Step5] Model compile

    # Loss
    loss = nn.CrossEntropyLoss()
    # Optimizer
    learning_rate = 1e-3  #0.001
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

    [Step6] Set train loop

    def train_loop(train_loader, model, loss_fn, optimizer):
        size = len(train_loader.dataset)
    
        for batch, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)
            pred = model(X)
    
            # 손실 계산
            loss = loss_fn(pred, y)
    
            # 역전파
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f'loss: {loss:>7f}  [{current:>5d}]/{size:5d}')

    [Step7] Set test loop

    def test_loop(test_loader, model, loss_fn):
        size = len(test_loader.dataset)
        num_batches = len(test_loader)
        test_loss, correct = 0, 0
    
        with torch.no_grad():
            for X, y in test_loader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:8f}\n")

    [Step8] Run model

    epochs = 10
    
    for i in range(epochs) :
        print(f"Epoch {i+1} \n------------------------")
        train_loop(train_dataloader, model, loss, optimizer)
        test_loop(test_dataloader, model, loss)
    print("Done!")
    • Epoch 1 
      ------------------------
      loss: 2.324765  [    0]/60000
      loss: 0.691231  [ 6400]/60000
      loss: 0.471458  [12800]/60000
      loss: 0.613083  [19200]/60000
      loss: 0.444306  [25600]/60000
      loss: 0.549569  [32000]/60000
      loss: 0.454276  [38400]/60000
      loss: 0.423303  [44800]/60000
      loss: 0.461990  [51200]/60000
      loss: 0.667872  [57600]/60000
      Test Error: 
       Accuracy: 81.8%, Avg loss: 0.497612

      Epoch 2 
      ------------------------
      loss: 0.371661  [    0]/60000
      loss: 0.409686  [ 6400]/60000
      loss: 0.383233  [12800]/60000
      loss: 0.420215  [19200]/60000
      loss: 0.388781  [25600]/60000
      loss: 0.447195  [32000]/60000
      loss: 0.476783  [38400]/60000
      loss: 0.300698  [44800]/60000
      loss: 0.307176  [51200]/60000
      loss: 0.498841  [57600]/60000
      Test Error: 
       Accuracy: 84.8%, Avg loss: 0.427579

      • • •

      Epoch 9 
      ------------------------
      loss: 0.353967  [    0]/60000
      loss: 0.383568  [ 6400]/60000
      loss: 0.510100  [12800]/60000
      loss: 0.307564  [19200]/60000
      loss: 0.241812  [25600]/60000
      loss: 0.327063  [32000]/60000
      loss: 0.312592  [38400]/60000
      loss: 0.401675  [44800]/60000
      loss: 0.229188  [51200]/60000
      loss: 0.273675  [57600]/60000
      Test Error: 
       Accuracy: 87.4%, Avg loss: 0.362324

      Epoch 10 
      ------------------------
      loss: 0.293197  [    0]/60000
      loss: 0.239528  [ 6400]/60000
      loss: 0.118810  [12800]/60000
      loss: 0.316047  [19200]/60000
      loss: 0.220012  [25600]/60000
      loss: 0.263426  [32000]/60000
      loss: 0.263754  [38400]/60000
      loss: 0.349957  [44800]/60000
      loss: 0.210180  [51200]/60000
      loss: 0.347154  [57600]/60000
      Test Error: 
       Accuracy: 86.8%, Avg loss: 0.367648

      Done!

    [Step9] Save & load model

    parameter만 저장하고 불러오기

    torch.save(model.state_dict(), 'model_weights.pth')
    
    model2 = NeuralNetwork().to(device)
    print(model2)
    
    model2.load_state_dict(torch.load('model_weights.pth'))
    
    model2.eval()
    test_loop(test_dataloader, model2, loss)
    • NeuralNetwork(
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (classifier): Sequential(
          (0): Linear(in_features=784, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.2, inplace=False)
          (3): Linear(in_features=128, out_features=10, bias=True)
        )
      )
    • <All keys matched successfully>
    • Test Error: 
       Accuracy: 88.0%, Avg loss: 0.338681

    Model 전체를 저장하고 불러오기

    torch.save(model, 'model.pth')
    
    model3 = torch.load('model.pth')
    
    model3.eval()
    test_loop(test_dataloader, model3, loss)
    • Test Error: 
       Accuracy: 88.0%, Avg loss: 0.338681

     

     

     

     

     

     

     


     

     

     

     

     

     

     

     

    반응형
    댓글