Retrieval after training
robclouth opened this issue · comments
hopfield = Hopfield(
scaling=beta,
# do not project layer input
state_pattern_as_static=True,
stored_pattern_as_static=True,
pattern_projection_as_static=True,
# do not pre-process layer input
normalize_stored_pattern=False,
normalize_stored_pattern_affine=False,
normalize_state_pattern=False,
normalize_state_pattern_affine=False,
normalize_pattern_projection=False,
normalize_pattern_projection_affine=False,
# do not post-process layer output
disable_out_projection=True
)
# and then to retrieve
# Y = stored, x = state
hopfield((Y, x, Y))
I've trained the network on MNIST to test storage capacity.
How do I then do retrieval without training again?
I'm expecting a retrieve function like here:
hopfield-layers/examples/simpsons/models.py
Line 226 in e0d856a
Hi @robclouth,
the retrieve functionality you refer to is just a single forward pass of the trained Hopfield network. As the retrieved patterns of each head are combined and projected within
hopfield-layers/hflayers/__init__.py
Lines 12 to 15 in d6d88c8
hopfield-layers/hflayers/__init__.py
Lines 240 to 242 in d6d88c8
hopfield-layers/hflayers/__init__.py
Lines 257 to 259 in d6d88c8
To get the retrieved patterns, just reshape the association matrix as well as the projected pattern matrix and apply the former onto the latter. Python-like pseudocode would look like the following:
xi = hopfield.get_association_matrix(...)
xi = xi.view(*xi.shape[1:])
v = hopfield.get_projected_pattern_matrix(...)
v = v.view(*v.shape[1:])
retrieved_patterns = torch.bmm(xi, v)
If there are any remaining questions, feel free to reopen this issue.