Support for PyTorch models
raam93 opened this issue · comments
Thanks for this library! I see you support LR and MLP models from sklearn currently, but can you give some pointers on how to make it work with PyTorch or TensorFlow models?
I checked generateSATExplanations.py
and found that you convert a model to a logic formula in modelConversion.py
, but do so only for four model classes from sklearn. How to replicate this for any model from say PyTorch? In particular, how to update mlp2formula(model, model_symbols)
method?
Indeed, the conversion from sklearn models to pysmt formulas was done
manually and using available pysmt API and a bit of trial and error + many
assertions to make sure behavior is replicated correctly. I imagine the
same should be done for pytorch models. A good starting point to see
available pysmt API is the factory.py file in their codebase.
An alternative (perhaps easier) route, would be to convert a pytorch model
into an equivalent sklearn model, and then use mlp2formula to go from
sklearn model to pysmt formula.
Hope this helps!