ivan-vasilev / neuralnetworks

java deep learning algorithms and deep neural networks with gpu acceleration

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DBN with softmaxlayer on top

NN-Research opened this issue · comments

Hi everyone,

i am trying to build a, DBN with two layers. the first one should be trained with CD, and the last one with BP. The lastone should act like a softmax layer, so it can be used for classification.

this is the code i wrote so far

public class MyTest {

public static void test() {
    Environment.getInstance().setUseDataSharedMemory(false);
    Environment.getInstance().setUseWeightsSharedMemory(false);

    //setup net
    DBN dbn = NNFactory.dbn(new int[] {4,2,2}, false);
    dbn.setLayerCalculator(NNFactory.lcWeightedSum(dbn, null));

    //get train an test dataset
    MyInputProvider trainSet = new MyInputProvider("0.train.data");
    MyInputProvider testSet = new MyInputProvider("0.test.data");

    //weights init
    NNRandomInitializer random = new NNRandomInitializer(new MersenneTwisterRandomInitializer());

    //setup trainer
    RBM firstRBM = dbn.getFirstNeuralNetwork();
    RBM secondRBM = dbn.getLastNeuralNetwork();
    secondRBM.setLayerCalculator(NNFactory.lcSoftRelu(secondRBM,null));
    AparapiCDTrainer firstTrainer = TrainerFactory.cdSigmoidBinaryTrainer(firstRBM, null, null, null, random, 0.5f, 0f, 0f, 0f, 1, 1, 5, true);
    BackPropagationTrainer secondTrainer = TrainerFactory.backPropagation(secondRBM, null, null, null, null, 0.5f, 0f, 0f, 0f, 1, 1,1, 5);
    //with random null pointer exeption
    //BackPropagationTrainer secondTrainer = TrainerFactory.backPropagation(secondRBM, null, null, null, random, 0.5f, 0f, 0f, 0f, 1, 1, 5, true);

    Map<NeuralNetwork, OneStepTrainer<?>> layerTrainers = new HashMap<>();
    layerTrainers.put(firstRBM, firstTrainer);
    layerTrainers.put(secondRBM, secondTrainer);

    DBNTrainer trainer = TrainerFactory.dbnTrainer(dbn,layerTrainers,trainSet,testSet,new MultipleNeuronsOutputError());

    //run training
    trainer.train();
    trainer.test();

    System.out.println(trainer.getOutputError().getTotalNetworkError());
}

}

is this the right way?