ml-jku / hopfield-layers

Hopfield Networks is All You Need

Home Page:https://ml-jku.github.io/hopfield-layers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Retrieval after training

robclouth opened this issue · comments

hopfield = Hopfield(
            scaling=beta,

            # do not project layer input
            state_pattern_as_static=True,
            stored_pattern_as_static=True,
            pattern_projection_as_static=True,

            # do not pre-process layer input
            normalize_stored_pattern=False,
            normalize_stored_pattern_affine=False,
            normalize_state_pattern=False,
            normalize_state_pattern_affine=False,
            normalize_pattern_projection=False,
            normalize_pattern_projection_affine=False,

            # do not post-process layer output
            disable_out_projection=True
        )

# and then to retrieve
# Y = stored, x = state
hopfield((Y, x, Y))

I've trained the network on MNIST to test storage capacity.
How do I then do retrieval without training again?
I'm expecting a retrieve function like here:

def retrieve(self, partial_pattern, max_iter=np.inf, thresh=0.5):

Hi @robclouth,

the retrieve functionality you refer to is just a single forward pass of the trained Hopfield network. As the retrieved patterns of each head are combined and projected within

class Hopfield(Module):
"""
Module with underlying Hopfield association.
"""
one needs to manually fetch the Hopfield association matrix as well as the projected patterns. This can be achieved by
def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
stored_pattern_padding_mask: Optional[Tensor] = None,
association_mask: Optional[Tensor] = None) -> Tensor:
and
def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
stored_pattern_padding_mask: Optional[Tensor] = None,
association_mask: Optional[Tensor] = None) -> Tensor:

To get the retrieved patterns, just reshape the association matrix as well as the projected pattern matrix and apply the former onto the latter. Python-like pseudocode would look like the following:

xi = hopfield.get_association_matrix(...)
xi = xi.view(*xi.shape[1:])

v = hopfield.get_projected_pattern_matrix(...)
v = v.view(*v.shape[1:])

retrieved_patterns = torch.bmm(xi, v)

If there are any remaining questions, feel free to reopen this issue.