daiki-kimura / TensorLNN

Scalable Training of Propositional Logical Neural Networks.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Scaling of training and inference of Propositional Logical Neural Networks with applications to Text Word Common Sense Games and Wordnet Sense Disambiguation.

The repository contains the code base for TensorLNN, a scalable implementation of the training and inference of Propositional Logical Neural Networks using Lukasevic logic.

Setting up the environment

conda create -n tensorlnn python=3.9 numpy ipython matplotlib tqdm scipy
conda activate tensorlnn
conda install pytorch=1.10.0 torchvision torchaudio -c pytorch

LOA benchmark using Tensor Logical Neural Networks.

The loa benchmark employs a supervised training of TensorLNN where the training samples and ground truths are generated and the underlying formula is to be determined by training through the TensorLNN.

The benchmark can be run by simply executing the following commands:

cd examples/loa
python atloc_lnn.py

The top level program is atloc_lnn.py. The tensorLNN model is defined using the statement tensorlnn.NeuralNet(num_inputs, gpu_device, nepochs, lr, optimizer). Here,

{
        "num_inputs" : "number/of/nodes/whose/AND/is/to/be/determined",
        "gpu_device": /set/to/true/or/false/depending/on/gpu/or/cpu/run,
        "nepochs" : "number/of/epochs/to/train",
        "lr" : "learning/rate",
        "optimizer" : "SGD/or/AdamW",
}

A simpler program to test the loa based on some random sample generation is implemented in basic_lnn.py. The program can be run by executing the following commands:

cd examples/loa
python basic_lnn.py <num_samples> <num_inputs>

This will generate num_samples random positive training samples, each being a vector of length num_inputs, and accordingly train the propositional LNN on those samples.

Scaling of Logical Neural Networks for Word Sense Disambiguation (WSD)

TensorLNN for WSD involves unsupervised training using initial bounds of the nodes, and senses defined by universes. This is done by executing the following commands.

cd examples/wsd
python wsd_main.py

The top level program is wsd_main.py. The model construction and training parameter inputs need to be specified in config.json. A typical input is of the form:

{
        "univ_dir" : "path/to/universe/data/folder",
        "group_size" : "number/of/universes/merged/in/megauniverse",
        "nepochs" : "number/of/epochs/to/train",
        "lr" : "learning/rate",
        "inf_steps" : "number/of/inference/steps/in/each/epoch",
        "fwd_method" : "baseline/or/checkpoint",
        "checkpoint_phases" : "number/of/inference/steps/between/checkpoints",
        "smooth_alpha" : "exponentiation/parameter/for/smooth/aggregation",
        "clamp_thr_lb" : "lower/theshold/for/claming/weights/and/bias",
        "clamp_thr_ub" : "upper/theshold/for/claming/weights/and/bias",
        "eps" : "some/low/positive/epsilon",
        "optimizer" : "AdamW",
        "gap_slope" : \gamma,
        "contra_slope" : \zeta,
        "vacuity_mult" : \nu,
        "logical_mult" : \lambda,
        "bound_mult" : \beta
}

Note: --- "univ_dir" should have "global" and "local" subdirectories and "universes.txt". "universes.txt" should contain list of the universe ids. The "global" subdirectory should contain npz file specifying the adjacency matrix of the global LNN. The "local" subdirectory should have for each of the universe ids one subdirectory of the same name and that shall contain for that universe (i) npz file for the adjacency matrix of AND net (ii) npz file for the adjacency matrix of NOT net (iii) bounds file.

---"checkpoint_phases" is used only when "fwd_method" is set as "checkpoint". It should be set to a number that divides "inf_steps".

---Given bounds () on each node and weights and bias , the total loss after inference is given by:

where Gap Loss (), Contradiction Loss (), Logical Loss () and Vacuity Loss () are respectively given as:

About

Scalable Training of Propositional Logical Neural Networks.

License:Apache License 2.0


Languages

Language:Python 100.0%