MilesCranmer / PySR

High-Performance Symbolic Regression in Python and Julia

Home Page:https://astroautomata.com/PySR

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Torch export key errors

yxie20 opened this issue · comments

Key errors:

Traceback (most recent call last):
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 124, in __init__
    arg_ = _memodict[arg]
KeyError: sqrt(Abs(x1) + 2)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 124, in __init__
    arg_ = _memodict[arg]
KeyError: 1/2

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test.py", line 151, in <module>
    batching=True,
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/sr.py", line 453, in pysr
    equations = get_hof(**kwargs)
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/sr.py", line 1002, in get_hof
    module = sympy2torch(eqn, sympy_symbols, selection=selection)
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 190, in sympy2torch
    expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 161, in __init__
    expr=expression, _memodict=_memodict, _func_lookup=_func_lookup
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 130, in __init__
    **kwargs,
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 130, in __init__
    **kwargs,
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/site-packages/pysr/export_torch.py", line 120, in __init__
    self._torch_func = _func_lookup[expr.func]
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/collections/__init__.py", line 916, in __getitem__
    return self.__missing__(key)            # support subclasses that define __missing__
  File "/home/yxie20/anaconda3/envs/rw/lib/python3.7/collections/__init__.py", line 908, in __missing__
    raise KeyError(key)
KeyError: <class 'sympy.core.numbers.Half'>

MWE:

import numpy as np
from pysr import pysr, best
import time

from pysr.sr import best_callable
import torch

# Dataset (alternative)
X = 2*np.random.randn(1152, 32)
y = 2*np.cos(X[:, 3:11]) + X[:, 0:8]**2 - 2 + X[:,2:10]*X[:,1:9]

equations = pysr(X, y, 
    binary_operators=["plus", "sub", "mult", "div", "pow"],
    unary_operators=["exp", "log_abs", "log10_abs", "log2_abs", 
        "cos", "sin", "tan", "sinh", "cosh", "tanh", 
        "atan", "asinh", "acosh_abs", "atanh_clip"],
    # verbosity=0,
    # procs=6,
    temp_equation_file=True,
    progress=False,
    julia_optimization=0,       # Faster startup time. Turn off optimizing compiler for Julia code
    output_torch_format=True,
    #
    # niterations=2,              # Iterations per population of the entire algorithm. Best equations are printed and migrated between populations.  populations * niterations = progress bar. This doesnt really matter
    # maxsize=20,
    # populations=2,              # Number of populations running. (must > 1)
    # npop=2000,                  # Number of individuals per population. More population, slower, more chance of hitting the correct.
    # ncyclesperiteration=200,    # Number of total mutations per 10 samples of population each iteration. Also like npop.
    # Quick debug
    niterations=5,              # Iterations per population of the entire algorithm. Best equations are printed and migrated between populations.  populations * niterations = progress bar. This doesnt really matter
    maxsize=10,
    populations=5,              # Number of populations running. (must > 1)
    npop=200,                  # Number of individuals per population. More population, slower, more chance of hitting the correct.
    ncyclesperiteration=20,    # Number of total mutations per 10 samples of population each iteration. Also like npop.
    annealing=True,            # With False, simple equations take longer but more complex equations are achievable
    batching=True,
)

I have no idea why sqrt(Abs(x1) + 2) is appearing as an output of your equation search... PySR should not find integers, only real numbers. But let me try to reproduce this...

By the way; some of the operators you are using are not actually implemented in the torch export, so it won't work anyways. You can add them as with, e.g., extra_torch_mappings={'atanh_clip': ...}. (I should give a better error for this). It also looks like this argument was also not transferring to the torch export; I'll add this now.

Also - that many operators and features will make the search very very slow. As a rule of thumb, the search will take O(factorial(M * N)) slower if you increase the number of operators by N and the number of features by M. There's also some redundant operators among the ones you passed - e.g., pow is redundant with exp. You can do feature pre-selection for all equations via the select_k_features argument (in which case you probably want to split this into multiple PySR runs!). Or use the methods described here to break down your problem: https://arxiv.org/abs/2006.11287.

I can't reproduce your example; it outputs fine for me. Presumably this is because it discovers a different equation. What was the output equation that it tried to convert to torch format?

Update: the thing I said about 'atanh_clip' is actually incorrect, sorry. It should have worked without that manual mapping. This is because this is equal to:

"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1)

and each atanh and Mod are converted to PyTorch separately.

However, it looks like mod was not implemented in the torch mapping. I just added it in 4d5aec3 (0.6.8). I think this should be fixed now, but let me know if it does not work!

Also, here's a way to test if an expression will work, without needing to run the full PySR pipeline:

from pysr import sympy2torch
import torch
from sympy import *
import numpy as np

x, y, z = symbols("x y z")
expression = x ** 2 + atanh(Mod(y + 1, 2) - 1) * 3.2 * z

module = sympy2torch(expression, [x, y, z])

print(module)
# >> _SingleSymPyModule(expression=x**2 + 3.2*z*atanh(Mod(y + 1, 2) - 1))

X = torch.rand(100, 3).float() * 10

torch_out = module(X)
true_out = X[:, 0] ** 2 + torch.atanh(torch.remainder(X[:, 1] + 1, 2) - 1) * 3.2 * X[:, 2]

Test it:

np.testing.assert_array_almost_equal(true_out.detach(), torch_out.detach(), decimal=4)

So we can see it gives the same answer for tanh_clip now.

Cheers,
Miles

Fixed, confirmed! Mod seems to be the issue indeed.