Releases: google/flax
Releases · google/flax
Version 0.10.3
What's Changed
- Fix fori_loop and while_loop on multiple modules by @IvyZX in #4390
- Upgrade Flax readme to NNX by @8bitmp3 in #4386
- [nnx] add performance guide notebook by @cgarciae in #4384
- Automated Code Change by @copybara-service in #4393
- [nnx] optimize NodeDef.attributes by @cgarciae in #4399
- Fixed the broken link in haiku_to_flax.rst file by @tilakrayal in #4402
- [nnx] optimize Variable by @cgarciae in #4400
- Update Flax NNX Randomness by @8bitmp3 in #4279
- Remove the repeated methods in flax.nnx.Module documentation by @rajasekharporeddy in #4416
- Fixed the broken link in linen_to_nnx.rst by @tilakrayal in #4415
- [nnx] add FlatState by @cgarciae in #4410
- Update
async_checkpointer.py
reference by @emmanuel-ferdman in #4385 - Fix multiple links to Orbax documentation by @Matt-Hurd in #4364
- Update links in
Why Flax NNX
documentation by @rajasekharporeddy in #4425 - [nnx] fix transforms guide by @cgarciae in #4421
- Add benchmark on state traversal, and a readme by @IvyZX in #4428
- Update Flax NNX performance guide by @8bitmp3 in #4401
- Create sharding via Partitioned.get_sharding() by @copybara-service in #4427
- Update Flax NNX vs JAX Transformations guide by @8bitmp3 in #4286
- Upgrade Flax NNX Gemma Sampling Inference doc by @8bitmp3 in #4325
- Update NNX
merge
docs in graph.py by @8bitmp3 in #4411 - Fix main and add nnx.fori_loop test by @cgarciae in #4472
- Upgrade Flax NNX Filters doc by @8bitmp3 in #4199
- Makes flax Modules more compatible with IPython auto-reload. by @copybara-service in #4420
- [nnx] RNN: add broadcast_rngs and state_axes APIs by @cgarciae in #4407
- Allow
nnx.bridge.variables.nnx_attrs_to_linen_vars
takennx.VariableState
as argument. by @copybara-service in #4473 - [nnx] add state summaries for print and display by @cgarciae in #4438
- Copybara import of the project: by @copybara-service in #4475
- [nnx] add state summaries for print and display by @copybara-service in #4477
- CI: add scheduled test against nightly JAX releases by @jakevdp in #4478
- CI: pin actions to specific commits by @jakevdp in #4479
- [nnx] fix MultiMetric typing by @cgarciae in #4485
- [nnx] fix ToNNX linen_attributes update by @cgarciae in #4486
- Remove usages of orbax_utils.save_args_from_target, as this function does nothing (it used to control a checkpointing behavior that has since been optimized away). by @copybara-service in #4482
- [nnx] improve Module docs by @cgarciae in #4499
- Update einsum layer for Gemma example by @copybara-service in #4498
- [nnx] fix fiddle by @cgarciae in #4500
- Don't create param in normalization layers instead of create None-value params. by @copybara-service in #4501
- Rename variable string mapping utils and move them to variableslib by @IvyZX in #4503
- fix LoRA initialization error in nnx layer by @copybara-service in #4502
- Remove all
Param(None)
lines from NNX by @IvyZX in #4504 - make gemma FFW LoRA friendly by @copybara-service in #4510
- Add
nnx.Module.perturb
by @IvyZX in #4515 - [nnx] add tabulate by @cgarciae in #4493
- batch_norm.rst: == should be = by @cool-RR in #4524
- v0.10.3 by @cgarciae in #4525
New Contributors
- @tilakrayal made their first contribution in #4402
- @emmanuel-ferdman made their first contribution in #4385
- @Matt-Hurd made their first contribution in #4364
Full Changelog: v0.10.2...v0.10.3
Version 0.10.2
What's Changed
- Add
nnx.fori_loop
by @IvyZX in #4353 - Linesearch (and lbfgs) support by @jlperla in #4351
- Upgrade Flax NNX Haiku Linen migration doc by @8bitmp3 in #4200
- Fix PRNG handling in
nn.jit
undernn.scan
. by @copybara-service in #4359 - support passing arguments directly to the struct.dataclass decorator by @copybara-service in #4275
- Avoid assert_array_equal for PRNG keys. by @copybara-service in #4363
- [nnx] support pure dicts by @cgarciae in #4352
- [nnx] add data parallel toy example by @cgarciae in #4354
- Add logical axis global context support for NNX by @IvyZX in #4350
- [nnx] fix ToLinen kwargs by @copybara-service in #4270
- [nnx] use HashableMapping instead of FrozenDict by @cgarciae in #4376
- [nnx] fix while_loop/fori_loop bug when sharing references by @cgarciae in #4379
- Add
flax.nnx.eval_shape
docstring by @8bitmp3 in #4374 - Setup the flaxlib in C++, using Meson and Nanobind. by @copybara-service in #4380
- Add
flax.nnx.remat
docstring by @8bitmp3 in #4373 - [nnx] add checkify by @cgarciae in #4381
- Lint flax.nnx.while_loop docstring by @8bitmp3 in #4371
- Lint flax.nnx.fori_loop docstring by @8bitmp3 in #4370
- [nnx] add some optimizations to graph.py by @cgarciae in #4377
- update version to 0.10.2 by @cgarciae in #4387
New Contributors
Full Changelog: v0.10.1...v0.10.2
Version 0.10.1
What's Changed
- Add Flax NNX GraphDef docstring by @8bitmp3 in #4302
- Flesh out the Haiku/Flax guide by @IvyZX in #4305
- [nnx] improve mnist tutorial by @cgarciae in #4316
- Update Flax Evolution from Linen to NNX guide by @8bitmp3 in #4289
- [nnx] try casting integers keys in State.replace_by_pure_dict by @cgarciae in #4317
- Fixed nnx examples bad links in the README.md by @vfdev-5 in #4282
- Fix philosophy link by @jorisSchaller in #4313
- [nnx] add gemma notebook by @cgarciae in #4075
- [nnx] improve init_cache docs by @cgarciae in #4291
- remove markdown from section titles by @cgarciae in #4322
- Avoid depending on JAX internals, which are about to change. by @copybara-service in #4326
- Remove outdated compatibility code. by @jakevdp in #4324
- fix ruff complaints by @levskaya in #4331
- Remove GeGLU activation function and golden tests. by @copybara-service in #4303
- Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False by @copybara-service in #4314
- Avoid assert_array_equal on PRNG keys. by @jakevdp in #4332
- Fix typos in Flax NNX Migrating from Haiku to Flax by @8bitmp3 in #4337
- Add API reference for flax.nnx.nn and improve landing page by @IvyZX in #4338
- [nnx] improve transforms guide by @cgarciae in #4333
- [nnx] cleanup gemma notebook by @cgarciae in #4334
- Remove non-lazy RNG compat mode and flag from flax. by @copybara-service in #4339
- [nnx] fix custom_vjp by @cgarciae in #4306
- Define model surgery in docs by @8bitmp3 in #4349
- [nnx] update State and variables docstrings by @cgarciae in #4346
- Add NNX transforms
nnx.while_loop
andnnx.switch
by @IvyZX in #4343 - update version to v0.10.1 by @cgarciae in #4345
New Contributors
Full Changelog: v0.10.0...v0.10.1
Version 0.10.0
What's Changed
- [nnx] clear nnx basics pip logs by @cgarciae in #4149
- Support linen <-> nnx metadata box converging in
nnx.bridge
by @IvyZX in #4145 - Add nnx bridge API reference to site by @IvyZX in #4158
- [nnx] use jax-style transforms API in nnx_basics by @cgarciae in #4155
- [nnx] improve nnx.scan in_axes/out_axes by @cgarciae in #4157
- Support direct quantization for FP8 matmul by @wenscarl in #3922
- Upgrade Flax NNX Model Surgery by @8bitmp3 in #4135
- [nnx] add more Variable proxy methods by @cgarciae in #4170
- [nnx] disallow Array leaves by @copybara-service in #4172
- Internal change by @copybara-service in #4176
- [nnx] improve landing page and nnx_basics messaging by @cgarciae in #4168
- Fixes a small bug in flax.linen.share_scope, where the scopes of children of the module being merged that were created before setup(),were not being updated to point to the new scope, and so they would end up staying under the original tree. by @copybara-service in #4150
- Move all NNX content up a level to be equal with Linen, to make python packaging more consistent. by @copybara-service in #4177
- Add a guide for
nnx.bridge
by @IvyZX in #4171 - [nnx] improve Optimizer metadata propagation by @cgarciae in #4180
- [nnx] enable sharding transformation on integer prefixes by @cgarciae in #4185
- Support linen.LogicallyPartitioned <-> nnx.Variable by @IvyZX in #4161
- Clean up axis hooks in
nnx.Variable
by @IvyZX in #4189 - Merge nnx.errors to flax.errors by @IvyZX in #4186
- [nnx] optimize jit by @cgarciae in #4191
- Split documentation for Linen and NNX by @cgarciae in #4192
- Partially revert #4192 which sets back a bunch of previous merged pushes. by @copybara-service in #4201
- Align bridge variable tree structures by @IvyZX in #4194
- [NNX site] Fix landing page and banner phrasing and add examples page by @IvyZX in #4202
- shorten banners by @cgarciae in #4206
- Add trimmed Linen to NNX guide by @IvyZX in #4209
- Minor documentation fixes for AxisMetadata. by @copybara-service in #4178
- fix tests for numpy 2.0 compatibility by @copybara-service in #4215
- Forward all arguments when using nnx.transforms.deprecated.scan as a decorator. by @copybara-service in #4208
- [nnx] add transforms guide by @cgarciae in #4197
- [nnx] fix transforms guide by @cgarciae in #4223
- Flax NNX GSPMD guide by @IvyZX in #4220
- Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change. by @copybara-service in #4225
- [nnx] add Randomness guide by @cgarciae in #4216
- Add pure dict conversion util functions to nnx.State. by @IvyZX in #4230
- [nnx] Simplify traversal by @cgarciae in #4205
- Fix false positive tracer leaks in flax library. by @copybara-service in #4232
- [nnx] add flaxlib by @copybara-service in #4235
- [nnx] improve docs by @cgarciae in #4236
- point nnx banner to flax-linen by @cgarciae in #4237
- update banners by @cgarciae in #4238
- Fix scale dtype and refactor q_dot_dq by @wenscarl in #4229
- update banners by @cgarciae in #4241
- Add redirects for Linen guide links in the NNX site scope. by @IvyZX in #4242
- Internal change by @copybara-service in #4243
- Copybara import of the project: by @copybara-service in #4245
- Update Flax NNX Scale Up SPMD guide by @8bitmp3 in #4239
- Upgrade Flax NNX basics doc by @8bitmp3 in #4173
- Improve landing page, glossary and misc by @IvyZX in #4244
- Nitting and adding links by @8bitmp3 in #4248
- enable doctest on notebooks by @cgarciae in #4250
- Update index.rst by @ariG23498 in #4251
- Add NNX checkpointing guide by @IvyZX in #4249
- Add checkpointing guide to website index. by @copybara-service in #4263
- Update to Flax NNX Transforms doc by @8bitmp3 in #4264
- Add why nnx by @cgarciae in #4240
- [nnx] add cloudpickle support by @cgarciae in #4253
- Fix typo:
impost
toimport
by @Vilin97 in #4256 - [nnx] revive TrainState toy example by @cgarciae in #4226
- [nnx] add custom_vjp to docs by @cgarciae in #4266
- remove flax-nnx urls by @cgarciae in #4267
- Add flatten to nnx.graph autosummary in graph.rst by @8bitmp3 in #4255
- [nnx] add FSDP toy example with custom optimizer by @cgarciae in #4183
- Update Flax NNX Landing Page by @8bitmp3 in #4274
- Update to Flax NNX Model Surgery by @8bitmp3 in #4276
- Update Why Flax NNX guide by @8bitmp3 in #4262
- Update to Flax NNX MNIST tutorial by @8bitmp3 in #4277
- [nnx] improve randomness guide by @cgarciae in #4281
- Remove notebook exceptions in
docs_nnx
doctest by @IvyZX in #4285 - [nnx] add PrefixMapping by @cgarciae in #4278
- [nnx] state filters by @cgarciae in #4288
- Fix devcontainer setup by @jorisSchaller in #4299
- Ugrade Flax NNX Checkpointing guide by @8bitmp3 in #4294
- Update Flax NNX Scale Up guide by @8bitmp3 in #4296
- Porting RNN from Linen to NNX by @zinccat in #4272
- Update Flax NNX Glossary by @8bitmp3 in #4284
- update version to 0.10.0 by @cgarciae in #4292
New Contributors
- @ariG23498 made their first contribution in #4251
- @Vilin97 made their first contribution in #4256
- @jorisSchaller made their first contribution in #4299
- @zinccat made their first contribution in #4272
Full Changelog: v0.9.0...v0.10.0
v0.9.0
What's Changed
- Add NNX surgery guide by @IvyZX in #4005
- Port gemma/transformer to NNX by @copybara-service in #4019
- upgrade python to 3.10 + use pyupgrade by @cgarciae in #4038
- [nnx] add Using Filters guide by @cgarciae in #4028
- v0.8.6 by @cgarciae in #4040
- allow imagenet training profiling to be disabled in config by @copybara-service in #4043
- [nnx] LoRAParam inherits from Param by @cgarciae in #3988
- [linen] allows multiple compact methods by @cgarciae in #3808
- Added support of NANOO fp8. by @wenchenvincent in #3993
- Add functool.wraps() annotation to flax.nn.jit. by @copybara-service in #4051
- Fix typo in
nnx_basics
doc by @rajasekharporeddy in #4047 - [nnx] fix Variable overloads and add shape/dtype properties by @cgarciae in #4049
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4039
- [nnx] stabilize unsafe_pytree by @cgarciae in #4030
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4055
- [NVIDIA] Rename fp8 custom dtype to
fp32_max_grad
by @kaixih in #3984 - [nnx] fix mnist_tutorial colab link by @cgarciae in #4063
- [nnx] fix Accuracy on eager mode by @cgarciae in #4065
- Update orbax_upgrade_guide.rst for async checkpointing usage examples by @kaushaladiti-2802 in #4036
- Re-enable some tests after Python 3.9 is dropped by @IvyZX in #4067
- Rename
nnx.compat
tonnx.bridge
by @IvyZX in #4066 - [nnx] improve mnist tutorial by @cgarciae in #4070
- Modify Flax checkpointing in preparation for cl/650338576. by @copybara-service in #4072
- Remove some outdated backward-compatibility code. by @copybara-service in #4068
- [NVIDIA] Add a user guide for fp8 by @kaixih in #4076
- [nnx] add extract APIs by @cgarciae in #4078
- [example]: remove lm1b useless parallism rules by @knightXun in #4077
- [nnx] improve filters guide by @cgarciae in #4059
- [nnx] add call by @cgarciae in #4004
- Ignore Orbax warning in deprecated
flax.training.checkpoints.py
to unbreak head doctest by @IvyZX in #4092 - fix mypy failures due tu numpy update by @cgarciae in #4098
- [linen] generalize transform caching by @copybara-service in #4057
- [linen] fold rngs on jit to improve caching by @copybara-service in #4064
- Add shape-based lazy init to
LinenToNNX
(prevLinenWrapper
) by @IvyZX in #4081 - [nnx] add reseed by @cgarciae in #4099
- [nnx] add split/merge_inputs by @cgarciae in #4084
- Perform shape checks for self.param AFTER unboxing by @danielwatson6 in #4079
- fix restore_checkpoint example in docstring by @copybara-service in #4101
- [numpy] Fix users of NumPy APIs that are removed in NumPy 2.0. by @copybara-service in #4104
- set profile_duration_ms = None as in periodic_actions there's default value for both num_profile_steps and profile_duration_ms, and the profile stopping condition is when both num_profile_steps and profile_duration_ms are satisfied, so setting profile_duration_ms=None so that the passed num_profile_steps value gets used by @copybara-service in #4096
- [linen] add share_scope by @cgarciae in #4102
- Allow metadata pass-through in flax.struct.field by @cool-RR in #4056
- avoid mixing
einsum_dot_general
andeinsum
argument by specifying them explicitly in the caller. by @copybara-service in #4115 - Add logging to track deprecated codepaths. by @copybara-service in #4121
- [pmap no rank reduce cleanup]: When flipping the by @copybara-service in #4125
- Add NNXToLinen wrapper to
nnx.bridge
by @IvyZX in #4126 - Switch NNX to use Treescope instead of Penzai. by @copybara-service in #4132
- Add GroupNorm to NNX normalization layers by @treigerm in #4095
- [nnx] fix initializing propagation by @cgarciae in #4134
- add JAX-style NNX Transforms FLIP by @cgarciae in #4108
- Fix
_ParentType
annotation by @dcharatan in #4120 - add uv.lock file by @copybara-service in #4139
- use uv package manager by @cgarciae in #4136
- More testing and misc fixes on wrappers by @IvyZX in #4137
- Fix link to orbax documentation by @cool-RR in #4123
- [nnx] experimental transforms by @cgarciae in #3963
- [nnx] improve docs by @cgarciae in #4141
- remove repeated license headers by @cgarciae in #4148
- update Flax to version 0.9.0 by @copybara-service in #4147
New Contributors
- @wenchenvincent made their first contribution in #3993
- @rajasekharporeddy made their first contribution in #4047
- @kaushaladiti-2802 made their first contribution in #4036
- @knightXun made their first contribution in #4077
- @danielwatson6 made their first contribution in #4079
- @cool-RR made their first contribution in #4056
- @treigerm made their first contribution in #4095
- @dcharatan made their first contribution in #4120
Full Changelog: v0.8.5...v0.9.0
v0.8.5
What's Changed
- v0.8.5 by @cgarciae in #3941
- [nnx] improve vmap axis size detection by @cgarciae in #3947
- Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
- [nnx] fix nnx_basics dependencies by @cgarciae in #3942
- Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
- updated rng guide by @chiamp in #3912
- upgraded haiku guide to include NNX by @chiamp in #3923
- parameterized NNX transforms tests by @chiamp in #3906
- Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
- fix HEAD by @chiamp in #3960
- Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
- Adding Welford metric. by @copybara-service in #3959
- Modify Welford metric to return mean value. by @copybara-service in #3970
- [nnx] make State generic by @cgarciae in #3964
- updated NNX nn docstrings by @chiamp in #3972
- make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
- updated
nnx.module
docstrings by @chiamp in #3966 - updated
nnx.Conv
andnnx.ConvTranspose
by @chiamp in #3974 - updated
nnx.graph
docstrings by @chiamp in #3958 -
- Adds
pmap
andPmap
.static_broadcasted_argnums
,donate_argnums
, andglobal_arg_shapes
are not yet supported. by @copybara-service in #3978
- Adds
- Fixes for batch norm docs by @jkarwowski in #3982
- fix deprecation warning by @chiamp in #3981
- updated NNX
rnglib
docstring by @chiamp in #3980 - updated
nnx.training
by @chiamp in #3975 - updated
nnx.variables
docstrings by @chiamp in #3986 - [nnx] vectorize vmap split counts by @cgarciae in #3989
- added
wrt
option tonnx.Optimizer
by @chiamp in #3983 - Added
nnx.graph.iter_children
by @chiamp in #3991 - [nnx] fix vmap by @copybara-service in #3995
- Fix head pytest breakage by @IvyZX in #4006
- Helper function for loading params from a linen module by @copybara-service in #4012
- Port gemma/layers to NNX by @copybara-service in #4013
- [nnx] fix grad by @cgarciae in #4007
- [nnx] add PathContains Filter by @cgarciae in #4011
- Support Python 3.9 by @copybara-service in #4018
- Port gemma/modules to NNX by @copybara-service in #4014
- Internal change to fix current head CI by @copybara-service in #4017
- Unpin the Orbax pip version. by @copybara-service in #4024
- Fix Gemma test to unbreak head by @IvyZX in #4025
- Fix pickling of exceptions by @sanderland in #4002
- Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
- CI: add test run against oldest supported jax version by @jakevdp in #3996
- Make
force_fp32_for_softmax
arg inMultiHeadDotProductAttention
useful. by @copybara-service in #4029
New Contributors
- @mcsmart76 made their first contribution in #3953
- @jkarwowski made their first contribution in #3982
- @sanderland made their first contribution in #4002
Full Changelog: v0.8.4...v0.8.5
v0.8.4
What's Changed
- fixed codecov by @chiamp in #3895
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3880
- Share nnx node registry between threads by @NeilGirdhar in #3901
- fixed
jnp.clip
deprecation by @chiamp in #3905 - Added three tab option to sphinx directive
codediff
and added testing for first tab by @chiamp in #3847 - Add support for
jax.sharding.PartitionSpec.UNCONSTRAINED
in logical specification by @copybara-service in #3902 - [nnx] fix mypy and pytype by @cgarciae in #3894
- [nnx] fix iter_nodes by @cgarciae in #3889
- [nnx] Sequential uses regular list by @cgarciae in #3909
- [nnx] add ConvTranspose by @cgarciae in #3908
- [nnx] add Module pytree_experimental static test by @cgarciae in #3864
- Added docstring for
Module.scope.path
by @chiamp in #3913 - [linen] test jit caching with state updates by @cgarciae in #3900
- v0.8.4 by @cgarciae in #3891
- [linen] enable separate initializers for out layer in MultiHeadDotProductAttention by @cgarciae in #3835
- [nnx] cleanup graph by @cgarciae in #3915
- [nnx] fix bugs by @cgarciae in #3925
- Replace deprecated
jax.tree_*
functions withjax.tree.*
by @copybara-service in #3926 - [nnx] Object refactor by @cgarciae in #3910
- [nnx] add iter_graph by @cgarciae in #3919
- [nnx] add compat by @cgarciae in #3921
- [nnx] transforms refactor by @cgarciae in #3927
- added equivalence test for
nnx.ConvTranspose
by @chiamp in #3934 - added equivalence test for
nnx.Sequential
by @chiamp in #3935 - [NNX] Add
LoRA
andLoRALinear
to NNX by @IvyZX in #3929 - [nnx] fix substate mutability by @cgarciae in #3932
- [nnx] improve update context by @cgarciae in #3933
- [nnx] move out of experimental by @cgarciae in #3936
Full Changelog: v0.8.3...v0.8.4
v0.8.3
What's Changed
- Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
- removed getattr/setattr unboxing magic from
nnx.Pytree
by @chiamp in #3743 - added Einsum layer to NNX by @chiamp in #3741
- Make
TrainState
'sstep
possibly jax.Array. This makesreplicate
valid for type checking. by @copybara-service in #3763 - v0.8.3 by @cgarciae in #3758
- [nnx] fix demo notebook by @cgarciae in #3744
- added nnx api reference by @chiamp in #3762
- updated rng docstring for init, apply and make_rng by @chiamp in #3765
- use note box in make_rng docstring by @cgarciae in #3767
- [nnx] improved graph update mechanism by @cgarciae in #3759
- use note box in docstrings by @chiamp in #3769
- Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
- Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
- Minor doc improvements by @canyon289 in #3588
- added MGU
reset_gate
test by @chiamp in #3773 - [nnx] Pytrees are Trees by @cgarciae in #3768
- Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
- fix tabulate on norm wrappers by @chiamp in #3772
- Add
kw_only
struct.dataclass test by @chiamp in #3651 - extended
PyTreeNode
to take dataclass kwargs by @chiamp in #3785 - [nnx] Arrays are state by @cgarciae in #3791
- [nnx] add GraphNode base class by @cgarciae in #3790
- [nnx] jit accepts many Modules by @cgarciae in #3783
- Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
- Expose
nnx.GraphNode
by @chiamp in #3796 - [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
- [nnx] TrainState uses struct by @cgarciae in #3788
- [nnx] split returns graphdef first by @cgarciae in #3794
- Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
- Add
nnx.training
by @chiamp in #3782 - [nnx] non-str State keys by @cgarciae in #3802
- [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
- [nnx] simplify readme by @cgarciae in #3805
- [nnx] Fix nnx basics by @cgarciae in #3812
- [nnx] grad accepts argnums by @cgarciae in #3798
- [nnx] improve toy examples by @cgarciae in #3813
- [nnx] expose Sequential by @cgarciae in #3814
- [nnx] Rng Variable tags by @cgarciae in #3807
- [nnx] remove copy in graph unflatten by @cgarciae in #3804
- fixed optax guide links and docstring typos by @chiamp in #3789
- added dropout broadcast test by @chiamp in #3776
- relaxed
grads
kwarg forOptimizer.update
by @chiamp in #3818 - added
tree_map
deprecation warning filter by @chiamp in #3828 - updated
tree_map
by @chiamp in #3823 - added NNX vs JAX transformations guide by @chiamp in #3819
- Updated NNX MNIST tutorial by @chiamp in #3810
- [nnx] add Dropout.rngs by @cgarciae in #3815
- removed autosummary from linen docs by @chiamp in #3792
- Fix cloudpickle sentinel cloning by @cgarciae in #3825
- [nnx] remove pytreelib by @cgarciae in #3816
- [nnx] fix nnx_basics by @cgarciae in #3839
- [linen] fix DenseGeneral init by @cgarciae in #3834
- [nnx] jit constrain object state by @cgarciae in #3817
- Copybara import of the project: by @copybara-service in #3857
- Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
- RNNCellBase refactor FLIP by @cgarciae in #3099
- [nnx] Some small documentation suggestions. by @gnecula in #3861
- updated nnx dropout by @chiamp in #3841
- Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
- Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
- added nnx api reference link by @chiamp in #3871
- option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
- allow custom dot_general for einsum. by @copybara-service in #3884
- [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
- updated
robots.txt
by @chiamp in #3886 - fixed autosummary links by @chiamp in #3887
- Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
- [nnx] v0.1 by @cgarciae in #3876
Full Changelog: v0.8.2...v0.8.3
v0.8.2
What's Changed
- Add +1 to version after 0.8.1 release by @IvyZX in #3684
- fixed rng guide outputs by @chiamp in #3685
- enforce mask kwarg in norm layers by @chiamp in #3663
- added kwargs to self.param and self.variable by @chiamp in #3675
- added nnx normalization tests by @chiamp in #3689
- added NNX init_cache docstring example by @chiamp in #3688
- added nnx attention equivalence test by @chiamp in #3687
- Fix bug that assumed frozen-dict keys were strings. by @copybara-service in #3692
- added nnx rmsnorm by @chiamp in #3691
- updated nnx compute_stats by @chiamp in #3693
- fixed intercept_methods docstring by @chiamp in #3694
- [nnx] Add Sphinx Docs by @cgarciae in #3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in #3703
- added default params rng to .apply by @chiamp in #3698
- [nnx] add partial_init by @cgarciae in #3674
- make make_rng default to 'params' by @chiamp in #3699
- Add SimpleCell. by @carlosgmartin in #3697
- fix Module.module_paths docstring by @cgarciae in #3709
- Guarantee the latest JAX version on CI by @cgarciae in #3705
- Replace deprecated API
jax.tree_map
by @copybara-service in #3715 - Use
jax.tree_util.tree_map
instead of deprecatedjax.tree_map
. by @copybara-service in #3714 - [nnx] simplify readme by @cgarciae in #3707
- [nnx] add demo.ipynb by @cgarciae in #3680
- Fix Tabulate's compute_flops by @cgarciae in #3721
- [nnx] simplify TraceState by @cgarciae in #3724
- Add broadcast of
strides
andkernel_dilation
tonn.ConvTranspose
by @IvyZX in #3731 - [nnx] Fix State.sub by @cgarciae in #3704
- [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in #3722
- [nnx] explicit Variables by @cgarciae in #3720
- Improves fingerprint definition for Modules in nn.jit. by @copybara-service in #3736
- Flax: avoid key reuse in tests by @copybara-service in #3740
- added Einsum layer by @chiamp in #3710
- nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in #3737
- [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in #3623
- removed nnx dataclass by @chiamp in #3742
- [nnx] cleanup graph_utils by @cgarciae in #3728
- Fix doctest and unbreak head by @IvyZX in #3753
- [nnx] add pytree support by @cgarciae in #3732
- fixed intercept_methods docstring by @chiamp in #3752
- Add ConvLSTMCell to docs. by @carlosgmartin in #3712
- [nnx] remove flagslib by @cgarciae in #3733
- Fix tests after applying JAX key-reuse checker. See: by @copybara-service in #3748
Full Changelog: v0.8.1...v0.8.2
Version 0.8.1
What's Changed
- bump version number to 0.8.1 by @chiamp in #3649
- Bump pillow from 10.0.1 to 10.2.0 in /examples/vae by @dependabot in #3641
- fixed docstring by @chiamp in #3643
- Add explicit control over frozen/slots setting in flax.struct.dataclass by @copybara-service in #3645
- make Sequential.call compact by @copybara-service in #3647
- add Module.module_paths by @cgarciae in #3654
- added rng_guide by @chiamp in #3497
- Replacing jax.tree_util.tree_map with mapping over leafs. by @copybara-service in #3658
- Copybara import of the project: by @copybara-service in #3659
- added InstanceNorm by @chiamp in #3652
- add Module.module_paths by @copybara-service in #3660
- added norm equivalence tests by @chiamp in #3662
- updated nowrap docstring by @chiamp in #3661
- Add module_paths method to docs by @cgarciae in #3657
- add default make_rng by @chiamp in #3669
- renamed channel_axes to feature_axes in InstanceNorm by @chiamp in #3667
- added flax.typing by @chiamp in #3624
- changed kwargs to actual key-word args by @chiamp in #3562
- updated docs and docstrings by @chiamp in #3670
- re-added linen_intro by @chiamp in #3672
- add compact_name_scope v3 by @cgarciae in #3646
- Release 0.8.1 by @IvyZX in #3682
Full Changelog: v0.8.0...v0.8.1