RBM Toolbox
RBM toolbox is a MATLAB toolbox for training RBM's. It builds on the DeepLearnToolbox by Rasmus Berg.
RBM toolbox support among others:
- add support for training RBM's with class labels including, see [1,2]
- generative training objective
- discriminative training objective
- hybrid training objective
- semi-supervised learning
- CD - k (contrastive divergence k)
- PCD (persistent contrastive divergence)
- Various rbm sampling functions (pictures / movies)
- Classiciation support
- Regularization: L1, L2, maxL2norm, sparsity, early-stopping
- Support for custom error functions
Settings
training objectives
The RBM toolbox supports four different TBM training objectives. For a detailed description refer to [2].
rbmgenerative
:-log(p(x))
or-log(p(x,y)) if
classRBM` is 1rbmdiscriminative
-log(p(x I y))
[2]rbmhybrid
Models-(1-alpha)log(y I x) - alpha log(p (x) )
[2]rbmsemisuplearn
`TYPE + unsupervised. Where type is {generative, discriminative,hybrid} and unsupervised is generative training on unlabeled data [2]
The RBM training objective is set by supplying a function handle to one of the four training functions through opts.train_func
Regularization
RBM toolbox supports L1 and L2 regularization and regularization through a maximum L2 norm fo the incoming weights to each neuron [4]. Sparsity is implemented as described in [2]. Dropout of hidden units is implemented as described in [1].
When training a classification RBM ('opts.classRBM = 1') and a validation set is given through opts.x_val
and opts.y_val
, then early stopping can be used. The patience for early stopping can be specified with opts.patience
.
Training DBN's
DBN can be trained by given multiple hidden sizes to dbnsetup
e.g. sizes=[500 500]
for a two layer DBN with 500 hidden units in
each layer.
TODO add wake sleep algorithm?
Settings table
The table shows which fields in the opts struct that applies to the different training objectives.
Setting | @genrative | @discriminative | @rbmhybrid | @rbmsemisublearn |
---|---|---|---|---|
traintype | x | x | x | |
cdn | x | x | x | |
numepochs | x | x | x | x |
]classRBM | x | x | x | x |
err_func1 | x | x | x | x |
test_interval1 | x | x | x | x |
learningrate | x | x | x | x |
momentum | x | x | x | x |
L1 | x | x | x | x |
L2norm | x | x | x | x |
sparsity | x | x | x | x |
dropout_hidden | x | x | x | x |
early_stopping1 | x | x | x | x |
patience1 | x | x | x | x |
y_train2 | x | x | x | x |
x_val2 | x | x | x | x |
y_val2 | x | x | x | x |
x_semisup | x | |||
hybrid_alpha | x | |||
semisup_type | x | |||
semisup_beta | x |
-
Applies if classRBM is 1 and x_val and y_val are set
-
Applies if classRBM is 1
Examples
Example 1 - generative learning p(x)
Training RBM's in RBM_toolbox is controlled through three functions:
dbncreateopts
creates an opts struct. The opts struct control learningrate, number of epochs, reqularization, training type etc. The help fordbncreateopts
descripes all valid fields in the opts struct.dbnsetup
setups the DBN network, a single layer RBM is equal to a DBN.dbntrain
trains the DBN
The following example trains a generative RBM with 500 hidden units and visulizes the found weights. Note that the learning rate is controlled through the opts.learningrate
parameters. opts.learningrate
is a function which takes the current epoch and current epoch as arguments and returns the learning rate. Similary opts.momentum
is a function that controls the current momentum. When the opts.train_func
is set to @rbmgenerative
RBM outputs the reconstruction error after each epoch, the reconstruction error should not be interpreted as a measure of goodness of the model, see [3].
rng('default');rng(0);
load mnist_uint8;
train_x = double(train_x) / 255;
sizes = [500]; % hidden layer size
[opts, valid_fields] = dbncreateopts();
opts.numepochs = 50;
opts.traintype = 'CD';
opts.classRBM = 0;
opts.train_func = @rbmgenerative;
%% Set learningrate
eps = 0.05; % initial learning rate
f = 0.97; % learning rate decay
opts.learningrate = @(t,momentum) eps.*f.^t*(1-momentum);
% Set momentum
T = 25; % momentum ramp up
p_f = 0.9; % final momentum
p_i = 0.5; % initial momentum
opts.momentum = @(t) ifelse(t < T, p_i*(1-t/T)+(t/T)*p_f,p_f);
dbncheckopts(opts,valid_fields); %checks for validity of opts struct
dbn = dbnsetup(sizes, train_x, opts); % train function
dbn = dbntrain(dbn, train_x, opts);
figure;
figure;visualize(dbn.rbm{1}.W(1:144,:)');
set(gca,'visible','off');
In the example the learningrate (blue) starts at 0.05 and decays with each epoch. The momentum (green) ramps up over 25 epochs, as shown in the figure.
Finally the weights can be visualized:
Example 2 - Generative RBM with labels p(x,y)
A classification RBM can be trained by setting opts.classRBM
to 1 and and setting opts.y_train
to the training labels. The training labels must be one-of-K encoded.
When opts.classRBM
is 1 RBM toolbox will report the training error. The default error measure is accuracy but you may supply custom error measures through opts.error_func
. If opts.x_val
and opts.y_val
are given the validation error will also be reported.
In the example the validation error is calculated after each epoch, i.e opts.test_interval
is set 1. In the example we also enable early stopping, we use a early_stopping patience of 5, i.e if no progress have been made in 5 epochs stop training.
rng('default');rng(0);
load mnist_uint8;
train_x = double(train_x)/255;
test_x = double(test_x)/255;
train_y = double(train_y);
test_y = double(test_y);
sizes = [500]; % hidden layer size
[opts, valid_fields] = dbncreateopts();
opts.early_stopping = 1;
opts.patience = 5;
opts.numepochs = 50;
opts.traintype = 'CD';
opts.classRBM = 1;
opts.y_train = train_y;
opts.x_val = test_x;
opts.y_val = test_y;
opts.test_interval = 1;
opts.train_func = @rbmgenerative;
%% Set learningrate
eps = 0.05; % initial learning rate
f = 0.97; % learning rate decay
opts.learningrate = @(t,momentum) eps.*f.^t*(1-momentum);
% Set momentum
T = 25; % momentum ramp up
p_f = 0.9; % final momentum
p_i = 0.5; % initial momentum
opts.momentum = @(t) ifelse(t < T, p_i*(1-t/T)+(t/T)*p_f,p_f);
dbncheckopts(opts,valid_fields); %checks for validity of opts struct
dbn = dbnsetup(sizes, train_x, opts); % train function
dbn = dbntrain(dbn, train_x, opts);
%% Do predictions
pred_val = dbnpredict(dbn,test_x);
[~, labels_val] = max(test_y,[],2);
acc_val = mean(pred_val == labels_val);
%% plot weights
figure;visualize(dbn.rbm{1}.W(1:144,:)');
set(gca,'visible','off');
%% plot errors
plot([dbn.rbm{1}.val_error',dbn.rbm{1}.train_error'])
legend({'Validation error','Train error'})
[min_val,min_idx] = min(dbn.rbm{1}.val_error);
hold on; plot(min_idx,min_val,'xr'); hold off;
xlabel('Epoch'); ylabel('Error'); grid on;
For classification RBM's predictions can be calculated by dbnpredict
wich returns a label or with
dbnclassprobs
wich returns the predicted class probabilities.
The learned weights can be visualized with the visualize
function.
The training erorror and validation error can be visualized as well:
Note that in this example the validation error is lower than the training error, this is not typical. In the plot the red x indicate the lowest validation error.
Example 3 - PCD, layers and sampling
In example 3 we use PCD to train a classification DBN using the generative training objective. In the other examples opts.traintype
has ben CD
wich mean contrastive divergence [5]. In this example we will use PCD
, persistent contrastive divergence [6].
In CD the gibbs chains are initiated at the data points, PCD differs from this by having a number of persistent chains wich are used to initiate the gibbs sampling. Because the gibbs chains are persistent they can wander further away from the data then in CD, this typically means that PCD training requires a lower learningrate then CD training.
In the example we use a DBN with two layers, each with 500 hidden neurons. In RBM toolbox non-top layer RBM's are always trained with the generative training objective.
rng('default');rng(0);
load mnist_uint8;
train_x = double(train_x)/255;
test_x = double(test_x)/255;
train_y = double(train_y);
test_y = double(test_y);
sizes = [500 500]; % hidden layer size
[opts, valid_fields] = dbncreateopts();
opts.numepochs = 100;
opts.traintype = 'PCD';
opts.classRBM = 1;
opts.y_train = train_y;
opts.x_val = test_x;
opts.y_val = test_y;
opts.test_interval = 1;
opts.train_func = @rbmgenerative;
%% Set learningrate
eps = 0.001; % initial learning rate
f = 0.97; % learning rate decay
opts.learningrate = @(t,momentum) eps.*f.^t*(1-momentum);
% Set momentum
T = 25; % momentum ramp up
p_f = 0.9; % final momentum
p_i = 0.5; % initial momentum
opts.momentum = @(t) ifelse(t < T, p_i*(1-t/T)+(t/T)*p_f,p_f);
dbncheckopts(opts,valid_fields); %checks for validity of opts struct
dbn = dbnsetup(sizes, train_x, opts); % train function
dbn = dbntrain(dbn, train_x, opts);
class_vec = zeros(100,size(train_y,2));
for i = 1:size(train_y,2)
class_vec((i-1)*10+1:i*10,i) = 1;
end
digits = dbnsample(dbn,100,10000,class_vec);
Example 4 - Discriminative training
Look in folder mnist_cRBM_discriminative
Example 5 - Hybrid training
Look in folders
mnist_cRBM_PCD
mnist_cRBM_CD
mnist_cRBM_CD_nomomentum
Example 6 - Semi-supervised learning
References
[1] N. Srivastava and G. Hinton, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting,” J. Mach. …, 2014.
[2] H. Larochelle and Y. Bengio, “Classification using discriminative restricted Boltzmann machines,” … 25th Int. Conf. Mach. …, 2008.
[3] G. Hinton, “A practical guide to training restricted Boltzmann machines,” Momentum, 2010.
[4] G. E. Hinton, N. Srivastava, A. Krizhevsky, I. Sutskever, and R. R. Salakhutdinov, “Improving neural networks by preventing co-adaptation of feature detectors,” Jul. 2012.
[5] G. Hinton, “Training products of experts by minimizing contrastive divergence,” Neural Comput., 2002.
[6] T. Tieleman, “Training restricted Boltzmann machines using approximations to the likelihood gradient,” … 25th Int. Conf. Mach. …, 2008.
Copyright (c) 2014, Søren Kaae Sønderby (skaaesonderby@gmail.com) All rights reserved.