7567f45d60ea4b06af752824e2943eae

Demo on how to save and load models

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov

This notebook highlights some basic functionality with the dlsia package.

Using the dlsia framework, we initialize convolutional neural networks, train each on a small dataset using the cpu, and show how to save and load them.

[1]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from dlsia.core import helpers, custom_losses, train_scripts
from dlsia.core.networks import msdnet, tunet, tunet3plus, smsnet

import matplotlib.pyplot as plt

Create & load data

Generate random data

Let’s build some random data: 40 instances of single channel, 36-by-36 images.

[2]:
n_imgs = 40
n_channels = 1
n_xy = 36

random_data1 = torch.rand((n_imgs, n_channels, n_xy, n_xy))
k = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, stride=1, padding=5)
random_data2 = torch.rand((n_imgs, n_channels, n_xy, n_xy))
random_data_gt = k(random_data1)
random_data_obs = k(random_data1)  + random_data2*0.50
K = 3
random_data_obs = random_data_obs[:,:,K:32+K,K:32+K].detach()
random_data_gt = random_data_gt[:,:,K:32+K,K:32+K].detach()

train_x = random_data_obs[:20,...]
train_y = random_data_gt[:20,...]
test_x = random_data_obs[20:,...]
test_y = random_data_gt[20:,...]

Prep data

We cast data as tensors for dlsia pipeline ingestion by making liberal use of the PyTorch Dataloader. This allows us to easy handle and iterative load data into the networks and models.

[3]:
train_set = TensorDataset( train_x, train_y)
test_set = TensorDataset( train_x, train_y)

# Specify batch sizes
batch_size_train = 20
batch_size_test  = 20

# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
train_loader_params = {'batch_size': batch_size_train,'shuffle': True}
test_loader_params  = {'batch_size': batch_size_test, 'shuffle': False}

# Build Dataloaders
train_loader = DataLoader(train_set, **train_loader_params)
test_loader  = DataLoader(test_set, **test_loader_params)

Construct Networks

dlsia offers a variety of different convolutional neural network architectures.

MSDNet

Mixed-scale Dense networks that probe different length scales using dilated convolutions.

[4]:
msdnet_model = msdnet.MixedScaleDenseNetwork(in_channels = 1,
                                             out_channels = 1,
                                             num_layers=40)

TUNet

Tuneable U-Nets with a variety of user-customizable parameters.

[5]:
tunet_model = tunet.TUNet(image_shape=(32,32),
                          in_channels=1,
                          out_channels=1,
                          depth=3,
                          base_channels=10)

TUNet3+

A newer UNet modification connecting all encoder and decoder layers via carefully crafted upsampling/downsampling/convolution/concatenation bundles.

[6]:
tunet3plus_model = tunet3plus.TUNet3Plus(image_shape=(32,32),
                                         in_channels=1,
                                         out_channels=1,
                                         depth=3,
                                         base_channels=10)

SMSNets

Sparse Mixed-Scale Networks that lean, randomly & sparsely connected variants of MSDNets.

[7]:
smsnet_model = smsnet.random_SMS_network(in_channels=1,
                                         out_channels=1,
                                         hidden_out_channels=[1],
                                         layers=40,
                                         dilation_choices=[1,2,3,4],
                                         #layer_probabilities=layer_probabilities,
                                         network_type="Regression")
[8]:
# View number of learnable parameters in each network
print("MSDNet :     ", helpers.count_parameters(msdnet_model), "parameters")
print("TUNet :      ", helpers.count_parameters(tunet_model), "parameters")
print("TUNet3plus : ", helpers.count_parameters(tunet3plus_model), "parameters")
print("SMSNet :     ", helpers.count_parameters(smsnet_model), "parameters")
MSDNet :      9182 parameters
TUNet :       46131 parameters
TUNet3plus :  49421 parameters
SMSNet :      735 parameters

Training

Training parameters

Training hyperparameters are chosen

[9]:
epochs = 30
criterion = nn.MSELoss()
learning_rate = 1e-2

# Define optimizers, one per network
optimizer_msd        = optim.Adam(msdnet_model.parameters(), lr=learning_rate)
optimizer_tunet      = optim.Adam(tunet_model.parameters(), lr=learning_rate)
optimizer_tunet3plus = optim.Adam(tunet3plus_model.parameters(), lr=learning_rate)
optimizer_smsnet     = optim.Adam(smsnet_model.parameters(), lr=learning_rate)

device = "cpu"
#device = helpers.get_device()  # Uncomment to get detected GPU

print('Device we will compute on: ', device)   # cuda:0 for GPU. Else, CPU
Device we will compute on:  cpu

Training loops

[10]:
msdnet_model.to(device)
msdnet_model, results = train_scripts.train_regression(msdnet_model,
                                                       train_loader,
                                                       test_loader,
                                                       epochs,
                                                       criterion,
                                                       optimizer_msd,
                                                       device,
                                                       show=10)
msdnet_model = msdnet_model.cpu()

# clear out unnecessary variables from device (GPU) memory
#torch.cuda.empty_cache()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 3.0587e-02 | Validation Loss: 2.9307e-02
Training CC: 0.2502   Validation CC  : 0.2717
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 2.5751e-02 | Validation Loss: 2.5685e-02
Training CC: 0.3672   Validation CC  : 0.3708
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 2.4834e-02 | Validation Loss: 2.4734e-02
Training CC: 0.4068   Validation CC  : 0.4111
[11]:
tunet_model.to(device)
tunet_model, results = train_scripts.train_regression(tunet_model,
                                                      train_loader,
                                                      test_loader,
                                                      epochs,
                                                      criterion,
                                                      optimizer_tunet,
                                                      device,
                                                      show=10)
tunet_model = tunet_model.cpu()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 2.1023e-02 | Validation Loss: 2.0195e-02
Training CC: 0.5513   Validation CC  : 0.5673
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.6087e-02 | Validation Loss: 1.5932e-02
Training CC: 0.6799   Validation CC  : 0.6858
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.4304e-02 | Validation Loss: 1.4126e-02
Training CC: 0.7219   Validation CC  : 0.7252
[12]:
tunet3plus_model.to(device)
tunet3plus_model, results = train_scripts.train_regression(tunet3plus_model,
                                                           train_loader,
                                                           test_loader,
                                                           epochs,
                                                           criterion,
                                                           optimizer_tunet3plus,
                                                           device,
                                                           show=10)
tunet3plus_model = tunet3plus_model.cpu()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 2.9115e-02 | Validation Loss: 2.7833e-02
Training CC: 0.4941   Validation CC  : 0.5236
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.7936e-02 | Validation Loss: 1.7812e-02
Training CC: 0.6395   Validation CC  : 0.6462
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.5100e-02 | Validation Loss: 1.4969e-02
Training CC: 0.7016   Validation CC  : 0.7050
[13]:
smsnet_model.to(device)
smsnet_model, results = train_scripts.train_regression(smsnet_model,
                                                       train_loader,
                                                       test_loader,
                                                       epochs,
                                                       criterion,
                                                       optimizer_smsnet,
                                                       device,
                                                       show=10)
smsnet_model = smsnet_model.cpu()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 1.9555e-02 | Validation Loss: 1.8536e-02
Training CC: 0.5966   Validation CC  : 0.6203
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.5011e-02 | Validation Loss: 1.4870e-02
Training CC: 0.7039   Validation CC  : 0.7076
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.3773e-02 | Validation Loss: 1.3678e-02
Training CC: 0.7334   Validation CC  : 0.7355

Save Networks

Each network library contains submodule for saving the trained networks. Each instance saves in a .pt file the following:

  • model’s state_dict: the network parameters learned through optimization/minimization during training,

  • model’s topo_dict: the list of network hyperparameters needed to initialize the same architecture.

This follows standard PyTorch practice; instead of saving massive trained networks, the pickled weights may simply be loaded into a freshly created network.

[14]:
msdnet_model.save_network_parameters("this_msdnet.pt")
smsnet_model.save_network_parameters("this_smsnet.pt")
tunet_model.save_network_parameters("this_tunet.pt")

Load networks from file

Each network library loads in the .pt file containing architecture-governing hyperparameters and learned weights.

[15]:
copy_msdnet = msdnet.MSDNetwork_from_file("this_msdnet.pt")
copy_smsnet = smsnet.SMSNetwork_from_file("this_smsnet.pt")
copy_tunet = tunet.TUNetwork_from_file("this_tunet.pt")

Verify loaded networks

Network copies are loaded from file and checked against the originals.

[16]:
with torch.no_grad():
    r1 = msdnet_model(test_x)
    r2 = copy_msdnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8
[17]:
with torch.no_grad():
    r1 = smsnet_model(test_x)
    r2 = copy_smsnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8
[18]:
with torch.no_grad():
    r1 = tunet_model(test_x)
    r2 = copy_tunet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8