Torch export key errors
yxie20 opened this issue · comments
Key errors:
Traceback (most recent call last):
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 124, in __init__
arg_ = _memodict[arg]
KeyError: sqrt(Abs(x1) + 2)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 124, in __init__
arg_ = _memodict[arg]
KeyError: 1/2
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test.py", line 151, in <module>
batching=True,
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/sr.py", line 453, in pysr
equations = get_hof(**kwargs)
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/sr.py", line 1002, in get_hof
module = sympy2torch(eqn, sympy_symbols, selection=selection)
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 190, in sympy2torch
expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 161, in __init__
expr=expression, _memodict=_memodict, _func_lookup=_func_lookup
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 130, in __init__
**kwargs,
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 130, in __init__
**kwargs,
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 120, in __init__
self._torch_func = _func_lookup[expr.func]
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/collections/__init__.py", line 916, in __getitem__
return self.__missing__(key) # support subclasses that define __missing__
File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/collections/__init__.py", line 908, in __missing__
raise KeyError(key)
KeyError: <class 'sympy.core.numbers.Half'>
MWE:
import numpy as np
from pysr import pysr, best
import time
from pysr.sr import best_callable
import torch
# Dataset (alternative)
X = 2*np.random.randn(1152, 32)
y = 2*np.cos(X[:, 3:11]) + X[:, 0:8]**2 - 2 + X[:,2:10]*X[:,1:9]
equations = pysr(X, y,
binary_operators=["plus", "sub", "mult", "div", "pow"],
unary_operators=["exp", "log_abs", "log10_abs", "log2_abs",
"cos", "sin", "tan", "sinh", "cosh", "tanh",
"atan", "asinh", "acosh_abs", "atanh_clip"],
# verbosity=0,
# procs=6,
temp_equation_file=True,
progress=False,
julia_optimization=0, # Faster startup time. Turn off optimizing compiler for Julia code
output_torch_format=True,
#
# niterations=2, # Iterations per population of the entire algorithm. Best equations are printed and migrated between populations. populations * niterations = progress bar. This doesnt really matter
# maxsize=20,
# populations=2, # Number of populations running. (must > 1)
# npop=2000, # Number of individuals per population. More population, slower, more chance of hitting the correct.
# ncyclesperiteration=200, # Number of total mutations per 10 samples of population each iteration. Also like npop.
# Quick debug
niterations=5, # Iterations per population of the entire algorithm. Best equations are printed and migrated between populations. populations * niterations = progress bar. This doesnt really matter
maxsize=10,
populations=5, # Number of populations running. (must > 1)
npop=200, # Number of individuals per population. More population, slower, more chance of hitting the correct.
ncyclesperiteration=20, # Number of total mutations per 10 samples of population each iteration. Also like npop.
annealing=True, # With False, simple equations take longer but more complex equations are achievable
batching=True,
)
I have no idea why sqrt(Abs(x1) + 2)
is appearing as an output of your equation search... PySR should not find integers, only real numbers. But let me try to reproduce this...
By the way; some of the operators you are using are not actually implemented in the torch export, so it won't work anyways. You can add them as with, e.g., extra_torch_mappings={'atanh_clip': ...}
. (I should give a better error for this). It also looks like this argument was also not transferring to the torch export; I'll add this now.
Also - that many operators and features will make the search very very slow. As a rule of thumb, the search will take O(factorial(M * N)) slower if you increase the number of operators by N and the number of features by M. There's also some redundant operators among the ones you passed - e.g., pow
is redundant with exp
. You can do feature pre-selection for all equations via the select_k_features
argument (in which case you probably want to split this into multiple PySR runs!). Or use the methods described here to break down your problem: https://arxiv.org/abs/2006.11287.
I can't reproduce your example; it outputs fine for me. Presumably this is because it discovers a different equation. What was the output equation that it tried to convert to torch format?
Update: the thing I said about 'atanh_clip'
is actually incorrect, sorry. It should have worked without that manual mapping. This is because this is equal to:
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1)
and each atanh
and Mod
are converted to PyTorch separately.
However, it looks like mod
was not implemented in the torch mapping. I just added it in 4d5aec3 (0.6.8). I think this should be fixed now, but let me know if it does not work!
Also, here's a way to test if an expression will work, without needing to run the full PySR pipeline:
from pysr import sympy2torch
import torch
from sympy import *
import numpy as np
x, y, z = symbols("x y z")
expression = x ** 2 + atanh(Mod(y + 1, 2) - 1) * 3.2 * z
module = sympy2torch(expression, [x, y, z])
print(module)
# >> _SingleSymPyModule(expression=x**2 + 3.2*z*atanh(Mod(y + 1, 2) - 1))
X = torch.rand(100, 3).float() * 10
torch_out = module(X)
true_out = X[:, 0] ** 2 + torch.atanh(torch.remainder(X[:, 1] + 1, 2) - 1) * 3.2 * X[:, 2]
Test it:
np.testing.assert_array_almost_equal(true_out.detach(), torch_out.detach(), decimal=4)
So we can see it gives the same answer for tanh_clip now.
Cheers,
Miles
Fixed, confirmed! Mod
seems to be the issue indeed.