google / autobound

AutoBound automatically computes upper and lower bounds on functions.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Find upper and lower bounds for a simple MLP function

zeinebBC opened this issue · comments

How can I use AutoBound to compute upper and lower bounds on a MLP function?
I wanted to run this script:


import jax.numpy as jnp
import jax.nn
from jax import random
import autobound.jax as ab

def initialize_mlp_params(rng_key, input_dim, hidden_dim, output_dim):
k1, k2, k3, k4 = random.split(rng_key, 4)
weights_hidden1 = random.normal(k1, (input_dim, hidden_dim))
biases_hidden1 = jnp.zeros(hidden_dim)
weights_hidden2 = random.normal(k2, (hidden_dim, hidden_dim))
biases_hidden2 = jnp.zeros(hidden_dim)
weights_hidden3 = random.normal(k3, (hidden_dim, hidden_dim))
biases_hidden3 = jnp.zeros(hidden_dim)
weights_output = random.normal(k4, (hidden_dim, output_dim))
biases_output = jnp.zeros(output_dim)
return (weights_hidden1, biases_hidden1,
weights_hidden2, biases_hidden2,
weights_hidden3, biases_hidden3,
weights_output, biases_output)

def mlp(params, x):
(weights_hidden1, biases_hidden1,
weights_hidden2, biases_hidden2,
weights_hidden3, biases_hidden3,
weights_output, biases_output) = params
hidden_layer1 = jax.nn.softplus(jnp.dot(x, weights_hidden1) + biases_hidden1)
hidden_layer2 = jax.nn.softplus(jnp.dot(hidden_layer1, weights_hidden2) + biases_hidden2)
hidden_layer3 = jax.nn.softplus(jnp.dot(hidden_layer2, weights_hidden3) + biases_hidden3)
return jnp.dot(hidden_layer3, weights_output) + biases_output

input_dim = 2
hidden_dim = 10
output_dim = 1
rng_key = random.PRNGKey(0)
params = initialize_mlp_params(rng_key, input_dim, hidden_dim, output_dim)
x0 = jnp.array([0.5, 0.5])
trust_region = (jnp.array([0, 0]), jnp.array([1, 1]))
mlp_lambda = lambda x: mlp(params, x)
bounds = ab.taylor_bounds(mlp_lambda, max_degree=2)(x0, trust_region)
bounds.coefficients


but I got this error:

TypeError Traceback (most recent call last)
in <cell line: 47>()
45 mlp_lambda = lambda x: mlp(params, x)
46 # Use the mlp function directly in taylor_bounds
---> 47 bounds = ab.taylor_bounds(mlp_lambda, max_degree=2)(x0, trust_region)
48 bounds.coefficients

14 frames
/usr/local/lib/python3.10/dist-packages/autobound/jax/jax_bound.py in bound_fun(x0, x_trust_region)
140 if fun is None:
141 raise NotImplementedError(eqn.primitive)
--> 142 outvar_enclosures = fun(*invar_intermediates, **eqn.params)
143 if len(eqn.outvars) == 1:
144 outvar_enclosures = (outvar_enclosures,)

/usr/local/lib/python3.10/dist-packages/autobound/jax/jax_bound.py in g(intermediate)
415 f = arithmetic.get_elementwise_fun(get_enclosure)
416 def g(intermediate):
--> 417 return f(intermediate.enclosure, intermediate.trust_region)
418 return g
419

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in fun(arg_enclosure, arg_trust_region)
292 self.max_degree,
293 self.np_like)
--> 294 return self.compose_enclosures(elementwise_enclosure, arg_enclosure)
295 return fun
296

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in compose_enclosures(self, elementwise_enclosure, arg_enclosure)
241 term = (coefficient,)
242 else:
--> 243 poly = self.power(arg_diff_enclosure, p)
244 term = tuple(
245 # The special-casing when i < p ensures that the TaylorEnclosure

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in power(self, a, p)
330 np_like=self.np_like)
331 multiplicative_identity = self.np_like.ones_like(self.trust_region[0])
--> 332 result = polynomials.integer_power( # pytype: disable=wrong-arg-types
333 a,
334 p,

/usr/local/lib/python3.10/dist-packages/autobound/polynomials.py in integer_power(a, exponent, add, additive_identity, multiplicative_identity, term_product_coefficient, term_power_coefficient, scalar_product)
222 return c
223 output_degree = (len(a) - 1) * exponent
--> 224 return tuple(get_coeff(i) for i in range(1 + output_degree))
225
226

/usr/local/lib/python3.10/dist-packages/autobound/polynomials.py in (.0)
222 return c
223 output_degree = (len(a) - 1) * exponent
--> 224 return tuple(get_coeff(i) for i in range(1 + output_degree))
225
226

/usr/local/lib/python3.10/dist-packages/autobound/polynomials.py in get_coeff(i)
210 running_product_power = 0
211 for j, p_j in enumerate(p):
--> 212 running_product = term_product_coefficient(
213 running_product,
214 term_power_coefficient(a[j], j, p_j),

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in _elementwise_term_product_coefficient(c0, c1, i, j, x_ndim, np_like)
443 return _pairwise_batched_multiply(u, v, ix_ndim, jx_ndim, np_like)
444 set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like)
--> 445 return set_arithmetic.arbitrary_bilinear(c0, c1, product, assume_product=True)
446
447

/usr/local/lib/python3.10/dist-packages/autobound/interval_arithmetic.py in arbitrary_bilinear(self, a, b, bilinear, assume_product)
74 b_is_interval = isinstance(b, tuple)
75 if not a_is_interval and not b_is_interval:
---> 76 return bilinear(a, b)
77
78 if assume_product:

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in product(u, v)
441 """Returns d such that <c0, zi> * <c1, zj> == <d, z**(i+j)>."""
442 def product(u, v):
--> 443 return _pairwise_batched_multiply(u, v, ix_ndim, jx_ndim, np_like)
444 set_arithmetic = interval_arithmetic.IntervalArithmetic(np_like)
445 return set_arithmetic.arbitrary_bilinear(c0, c1, product, assume_product=True)

/usr/local/lib/python3.10/dist-packages/autobound/enclosure_arithmetic.py in _pairwise_batched_multiply(u, v, p, q, np_like)
472 u = np_like.asarray(u)
473 v = np_like.asarray(v)
--> 474 return expand_multiple_dims(u, q) * expand_multiple_dims(v, p, v.ndim-q)
475
476

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in deferring_binary_op(self, other)
256 args = (other, self) if swap else (self, other)
257 if isinstance(other, _accepted_binop_types):
--> 258 return binary_op(*args)
259 if isinstance(other, rejected_binop_types):
260 raise TypeError(f"unsupported operand type(s) for {opchar}: "
[... skipping hidden 12 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ufuncs.py in fn(x1, x2)
95 def fn(x1, x2, /):
96 x1, x2 = promote_args(numpy_fn.name, x1, x2)
---> 97 return lax_fn(x1, x2) if x1.dtype != np.bool
else bool_lax_fn(x1, x2)
98 fn.qualname = f"jax.numpy.{numpy_fn.name}"
99 fn = jit(fn, inline=True)
[... skipping hidden 7 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py in broadcasting_shape_rule(name, *avals)
1577 result_shape.append(non_1s[0])
1578 else:
-> 1579 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1580 f'{", ".join(map(str, map(tuple, shapes)))}.')
1581

TypeError: mul got incompatible shapes for broadcasting**: (2, 1), (10, 2).**

Thanks for reporting this! I've found the problem and will fix it soon.

This example now works, when using version 0.1.3 of the package.

Thank you for resolving the issue! I just have another inquiry. Am I restricted to utilizing only jax.nn.sigmoid, jax.nn.softplus, and jax.nn.swish as activation functions for the Multi-Layer Perceptron (MLP)? Are the Elu and sin activation functions not yet implemented?

Yes, only sigmoid, softplus and swish are implemented currently.