ml-jku / hopfield-layers

Hopfield Networks is All You Need

Home Page:https://ml-jku.github.io/hopfield-layers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Creating an getEnergy() Function

galenoshea opened this issue · comments

We are trying to calculate the energy from the hopfield network and are wondering if there is simple way to do this.

Hi galenoshea!

So far there is no getEnergy() function planned. But if you look at our paper https://arxiv.org/pdf/2008.02217.pdf equation (2) should be fairly straightforward to implement.

Related to this, what's the best way to get the state pattern from the hopfield class? I'm having some trouble extracting it

A little late but I think I managed to implement a get_energy function:

def get_energy(R, Y, beta):
    lse = -(1.0/beta)*torch.logsumexp(beta*(torch.bmm(R, Y.transpose(1,2))), dim=2) # -lse(beta, Y^T*R)
    lnN = (1.0/beta)*torch.log(torch.tensor(Y.shape[1], dtype=float)) # beta^-1*ln(N)
    RTR = 0.5*torch.bmm(R, R.transpose(1,2)) # R^T*R
    M = 0.5*((torch.max(torch.linalg.norm(Y, dim=2), dim=1))[0]**2.0) # 0.5*M^2  *very large value*
    energy = lse + lnN + RTR + M
    return energy

I published a short notebook with a similar pattern retrieval task for Simpson faces and it shows expected behavior where the energy is minimized for the retrieved pattern.

Closing issue. Seems to be resolved by @a-kore.