Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions folx/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def jvp_fun(s):
return jax.jvp(f, primals, unravel(s))[1]

eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype)
if hasattr(jax.lax, 'pvary'):
eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma))
J = jax.vmap(jvp_fun, out_axes=-1)(eye)
return J

Expand Down
29 changes: 29 additions & 0 deletions test/test_shard_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from functools import partial

import jax
import jax.numpy as jnp
import pytest
from packaging.version import Version

from folx import forward_laplacian


@pytest.mark.skipif(
Version(jax.__version__) < Version('0.7.1'), reason='jax version too old'
)
def test_shard_map_bug_integer_pow():
# see https://github.com/microsoft/folx/issues/38

def f(w, x):
return jax.lax.integer_pow(x @ w, 1)

@jax.smap(out_axes=0, in_axes=(None, 0), axis_name='i')
@partial(jax.vmap, in_axes=(None, 0))
def test(w, x):
return forward_laplacian(partial(f, w))(x)

x = jnp.ones((1, 16))
w = jnp.ones((16, 16))

with jax.set_mesh(jax.sharding.Mesh(jax.devices()[:1], 'i')):
test(w, x)