diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 508f6771a..17de9f797 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -1400,7 +1400,7 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args): of the cache, and these clones will contain references to the same Variable objects which guarantees that state is propagated correctly back to the original graph nodes. Because of the previous, the final structure of all graph nodes must be the same - after each call to the cached function, otherswise an error will be raised. Temporary + after each call to the cached function, otherwise an error will be raised. Temporary mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before the function returns (e.g. via ``nnx.pop``).