Skip to content

LCA CIFAR Tutorial Part3 Learning Weights

Pete Schultz edited this page Jul 5, 2023 · 3 revisions

Introduction

This page is a continuation of the tutorial for using PetaVision and the locally competitive algorithm (LCA) for sparse coding images. The previous page, using LCA to sparse code several images in parallel, is here. On this page, we will see how to make the weights learn as it processes images, in order to improve the sparse coding. In the following, make sure that you are in the LCA-CIFAR-Tutorial directory, inside the build directory, that you used in the single-image tutorial. Also make sure that the environment variables PV_SOURCEDIR and PYTHONPATH are defined as before:

$ PV_SOURCEDIR=/path/to/source/OpenPV # change to the correct path on your system
$ export PV_SOURCEDIR
$ export PYTHONPATH="$PYTHONPATH":$PV_SOURCEDIR/python

1. Displaying weights.

Recall that in Part 1 of the tutorial, we saw how to display the weights. There are 128 weight patches, of size 8x8x3, which we arranged in a tableau 12 patches wide by 11 patches tall (with 4 blank spaces).

$ python
>>> import pvtools
>>> import matplotlib.pyplot
>>> matplotlib.interactive(True)
>>> weights_data = pvtools.readpvpfile("output_full_parallelism/LeakyIntegratorToInputError.pvp")
>>> weights_values = weights_data['values']
>>> weights = weights_data['values'][0, 0]
>>> wgts_display = pvtools.arrangedictionary(weights)
>>> wgts_display_8bit = numpy.uint8(wgts_display * 127.5 + 127.500001)
>>> matplotlib.pyplot.clf()
>>> matplotlib.pyplot.imshow(wgts_display_8bit)

The weights are mostly zero, shown by medium gray, with a few random weights initialized to nonzero values. We have seen that it is possible to create a sparse coding of an image with these weights, but that the reconstruction error is noticeable. With a better choice of weights, we could achieve both a sparser and more accurate reconstruction. Rather than try to choose the best weights a priori, we can have the system learn better weights by performing sparse coding on an image (or a batch of images) and then nudging the weights in a direction that improves the sparse coding. By doing this over enough images to form a representative sample, we can learn weights that are adapted to the image space.

2. Learning weights.

In the previous run, we processed a batch of 16 images for 400 timesteps, and then stopped. To learn weights, we will process the first batch as usual, then update the weights, then process an additional 16 images, then update the weights again, and so on. We will need to increase the number of images so that multiple batches will be processed, and turn on weight updates so that the weights can learn. Create a new .lua file and edit it:

$ cp input/LCA_CIFAR_Full_Parallelism.lua input/LCA_CIFAR_Learning_Weights.lua

In the input/LCA_CIFAR_Learning_Weights.lua file, change the line

local numImages            = 16;  --Number of images to process.

to

local numImages            = 256; --Number of images to process.

The run will now run for 16 display periods of 400 timesteps each, for a total of 6400 display periods.

We have been writing the layer output every timestep. However, for this run we only need to look at the end result of each sparse coding. Accordingly, we only want to write the layer output at times 400, 800, etc. We also want to write the updated weights at these times, as well as write the initial weights. To make these changes, edit the lines:

local layerWriteStep       = 1.0;
local layerInitialWrite    = 0.0;
local connInitialWrite     = 0;
local connWriteStep        = infinity;

to

local layerWriteStep       = displayPeriod;
local layerInitialWrite    = displayPeriod;
local connInitialWrite     = 0;
local connWriteStep        = displayPeriod;

To turn on weight updates, change the line

local plasticityFlag       = false; --Determines if we are learning our dictionary or holding it constant

to

local plasticityFlag       = true;  --Determines if we are learning our dictionary or holding it constant

Finally, change the outputPath line, as usual, from

local outputPath           = "output_full_parallelism";

to

local outputPath           = "output_learning_weights";

We once again generate the .params file and run OpenPV.

$ lua input/LCA_CIFAR_NBatch.lua > input/LCA_CIFAR_NBatch.params
$ mpiexec -np 8 ../tests/BasicSystemTest/Release/BasicSystemTest -p input/LCA_CIFAR_Learning_Weights.params -batchwidth 8 -l LCA-CIFAR-run.log -t 2
$ python
>>> import pvtools
>>> import matplotlib.pyplot
>>> matplotlib.interactive(True)
>>> weights_data = pvtools.readpvpfile("output_learning_weights/LeakyIntegratorToInputError.pvp")
>>> weights_data['time']
array([   0.,  400.,  800., 1200., 1600., 2000., 2400., 2800., 3200.,
       3600., 4000., 4400., 4800., 5200., 5600., 6000., 6400.])

We wrote the weights at times 0, 400, 800, ..., 6400. There are 17 timesteps where the weights were written. This is reflected in the shape of the 'values' field:

>>> weights_data['values'].shape
(17, 1, 128, 8, 8, 3)

At each of the 17 timesteps, the weights are a data structure of size 128-by-8-by-8-by-3, which we refer to as a "frame." To see how the weights evolve, we rearrange the values into a 17-by-88-by-96-by-3 data structure, using the arrangedictionary function in the pvtools package, which we used in Part 1 of the tutorial:

>>> import numpy
>>> wgts_display = numpy.empty((17,88,96,3))
>>> for t in range(17):
...   weights1 = weights_data['values'][t,0,:,:,:,:]
...   wgts_display[t,:,:,:] = pvtools.arrangedictionary(weights1)
...
>>> wgts_display_8bit = numpy.uint8(wgts_display * 127.5 + 127.500001)

We can now compare the initial weights and the weights after learning.

>>> matplotlib.pyplot.figure()
>>> matplotlib.pyplot.imshow(wgts_display_8bit[0,:,:,:])
>>> matplotlib.pyplot.title('Initial Weights')

>>> matplotlib.pyplot.figure()
>>> matplotlib.pyplot.imshow(wgts_display_8bit[16,:,:,:])
>>> matplotlib.pyplot.title('Weights at t=6400')

Note that the speckled initial data has been replaced by more coherent image patches. Some of the patches are more or less solid patches of a bright color, and others show transitions from one shade to another across an edge. Although the weights are still not particularly sharp, it is easier to see how one might assemble these weight patches into a recognizable image.

You can look at other stages of the weights by taking a different slice of the wgts_display_8bit variable in the code above. We can also create an animated .gif file showing the evolution, using the imageio package.

>>> import imageio
>>> gifFrames = []
>>> for frame in range(17):
...   gifFrames.append(wgts_display_8bit[frame,:,:,:])
...
>>> imageio.mimsave('weights.gif', gifFrames)

Recall that in Part 1, we sparse coded a single image using unlearned weights, and compared the reconstruction to the original image. The reconstruction was recognizably similar, but there were noticeable differences. We can rerun the single-image run but with the new weights, to see what the reconstruction looks like.

$ cp input/LCA_CIFAR.lua input/LCA_CIFAR_Use_Learned_Weights.lua

In the input/LCA_CIFAR_Use_Learned_Weights.lua file, change the line

local dictionaryFile       = nil;   --nil for initial weights, otherwise, specifies the weights file to load.

to

local dictionaryFile       = "output_learning_weights/Checkpoints/Checkpoint6400/LeakyIntegratorToInputError_W.pvp";

and the line

local outputPath           = "output";

to

local outputPath           = "output_use_learned_weights";

Then generate the params file and run the executable. Note that we have returned to a .lua file where nbatch is still one, so we cannot use the -batchwidth option. We could still use -rows and -columns, although for a run this short there is no advantage to doing so.

$ lua input/LCA_CIFAR_Use_Learned_Weights.lua > input/LCA_CIFAR_Use_Learned_Weights.params
$ ../tests/BasicSystemTest/Release/BasicSystemTest -p input/LCA_CIFAR_Use_Learned_Weights.params -l LCA-CIFAR-run.log -t 2

Let's compare the total energy graph between the original run and the run with the learned weights.

$ python
>>> import pvtools
>>> import matplotlib.pyplot
>>> matplotlib.interactive(True)

>>> energy_initweights = pvtools.readenergyprobe(
...   probe_name='TotalEnergyProbe',
...   directory='output',
...   batch_element=0)
>>> matplotlib.pyplot.figure(1)
>>> matplotlib.pyplot.plot(energy_initweights['time'], energy_initweights['values'])
>>> matplotlib.pyplot.title('Total Energy, Initial Weights')

>>> energy_learnedweights = pvtools.readenergyprobe(
...   probe_name='TotalEnergyProbe',
...   directory='output_using_learned_weights',
...   batch_element=0)
>>> matplotlib.pyplot.figure(2)
>>> matplotlib.pyplot.plot(energy_learnedweights['time'], energy_learnedweights['values'])
>>> matplotlib.pyplot.title('Total Energy, Learned Weights')

With the learned weights, the cost function decreases much more rapidly at the beginning, and converges to a lower value (approx. 235 vs. approx. 778).

We can also compare the reconstructed images.

>>> recon_from_initwgts = pvtools.readpvpfile('output/InputRecon.pvp')
>>> matplotlib.pyplot.figure(1)
>>> matplotlib.pyplot.clf()
>>> recon_from_initwgts = recon_from_initwgts['values'][400]
>>> recon_min = numpy.min(recon_from_initwgts)
>>> recon_max = numpy.max(recon_from_initwgts)
>>> recon_from_initwgts_normalized = (recon_from_initwgts - recon_min) / (recon_max - recon_min)
>>> recon_from_initwgts_8bit = numpy.uint8(recon_from_initwgts_normalized * 255)
>>> matplotlib.pyplot.imshow(recon_from_initwgts_8bit)

>>> recon_from_learned = pvtools.readpvpfile('output_using_learned_weights/InputRecon.pvp')
>>> matplotlib.pyplot.figure(2)
>>> matplotlib.pyplot.clf()
>>> recon_from_learned = recon_from_learned['values'][400]
>>> recon_min = numpy.min(recon_from_learned)
>>> recon_max = numpy.max(recon_from_learned)
>>> recon_from_learned_normalized = (recon_from_learned - recon_min) / (recon_max - recon_min)
>>> recon_from_learned_8bit = numpy.uint8(recon_from_learned_normalized * 255)
>>> matplotlib.pyplot.imshow(recon_from_learned_8bit)

The reconstruction from the learned weights is much less noisy.

Conclusion

In this page, we saw how to use weight plasticity to adapt the weights of a connection to the input dataset. In the next part of the tutorial (currently under development), we will look at the params file in more detail than we have so far.