Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
python-version: ['3.11', '3.12', '3.13', '3.14']
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -170,3 +170,31 @@ jobs:
"description": "'$status'",
"context": "github-actions/Build"
}'

# This is a temporary workflow to test flax on Python 3.14 and
# skipping deps like tensorstore, tensorflow etc
tests-python314:
name: Run Tests on Python 3.14
needs: [pre-commit, commit-count]
runs-on: ubuntu-24.04-16core
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup uv
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
with:
version: "0.9.2"
python-version: "3.14"
activate-environment: true
enable-cache: true

- name: Install dependencies
run: |
rm -fr .venv
uv sync --extra testing --extra docs
# temporary: install jax nightly
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
- name: Test with pytest
run: |
export XLA_FLAGS='--xla_force_host_platform_device_count=4'
find tests/ -name "*.py" | grep -vE 'io_test|tensorboard' | xargs pytest -n auto

2 changes: 2 additions & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,5 @@
while_loop as while_loop,
)
# pylint: enable=g-multiple-import
# For BC
from flax.linen import kw_only_dataclasses as kw_only_dataclasses
221 changes: 6 additions & 215 deletions flax/linen/kw_only_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,230 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Support for keyword-only fields in dataclasses for Python versions <3.10.
"""This module is kept for backward compatibility.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praxis is using flax/linen/kw_only_dataclasses.py module here: https://github.com/google/praxis/blob/main/praxis/base_hyperparams.py#L36


This module provides wrappers for `dataclasses.dataclass` and
`dataclasses.field` that simulate support for keyword-only fields for Python
versions before 3.10 (which is the version where dataclasses added keyword-only
field support). If this module is imported in Python 3.10+, then
`kw_only_dataclasses.dataclass` and `kw_only_dataclasses.field` will simply be
aliases for `dataclasses.dataclass` and `dataclasses.field`.

For earlier Python versions, when constructing a dataclass, any fields that have
been marked as keyword-only (including inherited fields) will be moved to the
end of the constuctor's argument list. This makes it possible to have a base
class that defines a field with a default, and a subclass that defines a field
without a default. E.g.:

>>> from flax.linen import kw_only_dataclasses
>>> @kw_only_dataclasses.dataclass
... class Parent:
... name: str = kw_only_dataclasses.field(default='', kw_only=True)

>>> @kw_only_dataclasses.dataclass
... class Child(Parent):
... size: float # required.

>>> import inspect
>>> print(inspect.signature(Child.__init__))
(self, size: float, name: str = '') -> None


(If we used `dataclasses` rather than `kw_only_dataclasses` for the above
example, then it would have failed with TypeError "non-default argument
'size' follows default argument.")

WARNING: fields marked as keyword-only will not *actually* be turned into
keyword-only parameters in the constructor; they will only be moved to the
end of the parameter list (after all non-keyword-only parameters).
Previous code targeting Python versions <3.10 is removed and wired to
built-in dataclasses module.
"""

import dataclasses
import functools
import inspect
from types import MappingProxyType
from typing import Any, TypeVar

import typing_extensions as tpe

import flax

M = TypeVar('M', bound='flax.linen.Module')
FieldName = str
Annotation = Any
Default = Any


class _KwOnlyType:
"""Metadata tag used to tag keyword-only fields."""

def __repr__(self):
return 'KW_ONLY'


KW_ONLY = _KwOnlyType()


def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs):
"""Wrapper for dataclassess.field that adds support for kw_only fields.

Args:
metadata: A mapping or None, containing metadata for the field.
kw_only: If true, the field will be moved to the end of `__init__`'s
parameter list.
**kwargs: Keyword arguments forwarded to `dataclasses.field`

Returns:
A `dataclasses.Field` object.
"""
if kw_only is not dataclasses.MISSING and kw_only:
if (
kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING
and kwargs.get('default_factory', dataclasses.MISSING)
is dataclasses.MISSING
):
raise ValueError('Keyword-only fields with no default are not supported.')
if metadata is None:
metadata = {}
metadata[KW_ONLY] = True
return dataclasses.field(metadata=metadata, **kwargs)


@tpe.dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
def dataclass(cls=None, extra_fields=None, **kwargs):
"""Wrapper for dataclasses.dataclass that adds support for kw_only fields.

Args:
cls: The class to transform (or none to return a decorator).
extra_fields: A list of `(name, type, Field)` tuples describing extra fields
that should be added to the dataclass. This is necessary for linen's
use-case of this module, since the base class (linen.Module) is *not* a
dataclass. In particular, linen.Module class is used as the base for both
frozen and non-frozen dataclass subclasses; but the frozen status of a
dataclass must match the frozen status of any base dataclasses.
**kwargs: Additional arguments for `dataclasses.dataclass`.

Returns:
`cls`.
"""

def wrap(cls):
return _process_class(cls, extra_fields=extra_fields, **kwargs)

return wrap if cls is None else wrap(cls)


def _process_class(cls: type[M], extra_fields=None, **kwargs):
"""Transforms `cls` into a dataclass that supports kw_only fields."""
if '__annotations__' not in cls.__dict__:
cls.__annotations__ = {}

# The original __dataclass_fields__ dicts for all base classes. We will
# modify these in-place before turning `cls` into a dataclass, and then
# restore them to their original values.
base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()]

# The keyword only fields from `cls` or any of its base classes.
kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {}

# Scan for KW_ONLY marker.
kw_only_name = None
for name, annotation in cls.__annotations__.items():
if annotation is KW_ONLY:
if kw_only_name is not None:
raise TypeError('Multiple KW_ONLY markers')
kw_only_name = name
elif kw_only_name is not None:
if not hasattr(cls, name):
raise ValueError(
'Keyword-only fields with no default are not supported.'
)
default = getattr(cls, name)
if isinstance(default, dataclasses.Field):
default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True})
else:
default = field(default=default, kw_only=True)
setattr(cls, name, default)
if kw_only_name:
del cls.__annotations__[kw_only_name]

# Inject extra fields.
if extra_fields:
for name, annotation, default in extra_fields:
if not (isinstance(name, str) and isinstance(default, dataclasses.Field)):
raise ValueError(
'Expected extra_fields to a be a list of '
'(name, type, Field) tuples.'
)
setattr(cls, name, default)
cls.__annotations__[name] = annotation

# Extract kw_only fields from base classes' __dataclass_fields__.
for base in reversed(cls.__mro__[1:]):
if not dataclasses.is_dataclass(base):
continue
base_annotations = base.__dict__.get('__annotations__', {})
base_dataclass_fields[base] = dict(
getattr(base, '__dataclass_fields__', {})
)
for base_field in list(dataclasses.fields(base)):
field_name = base_field.name
if base_field.metadata.get(KW_ONLY) or field_name in kw_only_fields:
kw_only_fields[field_name] = (
base_annotations.get(field_name),
base_field,
)
del base.__dataclass_fields__[field_name]

# Remove any keyword-only fields from this class.
cls_annotations = cls.__dict__['__annotations__']
for name, annotation in list(cls_annotations.items()):
value = getattr(cls, name, None)
if (
isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY)
) or name in kw_only_fields:
del cls_annotations[name]
kw_only_fields[name] = (annotation, value)

# Add keyword-only fields at the end of __annotations__, in the order they
# were found in the base classes and in this class.
for name, (annotation, default) in kw_only_fields.items():
setattr(cls, name, default)
cls_annotations.pop(name, None)
cls_annotations[name] = annotation

create_init = '__init__' not in vars(cls) and kwargs.get('init', True)

# Apply the dataclass transform.
transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs)

# Restore the base classes' __dataclass_fields__.
for _cls, fields in base_dataclass_fields.items():
_cls.__dataclass_fields__ = fields

if create_init:
dataclass_init = transformed_cls.__init__
# use sum to count the number of init fields that are not keyword-only
expected_num_args = sum(
f.init and not f.metadata.get(KW_ONLY, False)
for f in dataclasses.fields(transformed_cls)
)

@functools.wraps(dataclass_init)
def init_wrapper(self, *args, **kwargs):
num_args = len(args)
if num_args > expected_num_args:
# we add + 1 to each to account for `self`, matching python's
# default error message
raise TypeError(
f'__init__() takes {expected_num_args + 1} positional '
f'arguments but {num_args + 1} were given'
)

dataclass_init(self, *args, **kwargs)

init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore
transformed_cls.__init__ = init_wrapper # type: ignore[method-assign]

# Return the transformed dataclass
return transformed_cls
KW_ONLY = dataclasses.KW_ONLY
field = dataclasses.field
dataclass = dataclasses.dataclass
50 changes: 21 additions & 29 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
union_filters,
)
from flax.ids import FlaxId, uuid
from flax.linen import kw_only_dataclasses
from flax.typing import (
RNGSequences,
PRNGKey,
Expand Down Expand Up @@ -1061,7 +1060,7 @@ def _customized_dataclass_transform(cls, kw_only: bool):
3. Generate a hash function (if not provided by cls).
"""
# Check reserved attributes have expected type annotations.
annotations = dict(cls.__dict__.get('__annotations__', {}))
annotations = inspect.get_annotations(cls)
if annotations.get('parent', _ParentType) != _ParentType:
raise errors.ReservedModuleAttributeError(annotations)
if annotations.get('name', str) not in ('str', str, Optional[str]):
Expand All @@ -1081,42 +1080,35 @@ def _customized_dataclass_transform(cls, kw_only: bool):
(
'parent',
_ParentType,
kw_only_dataclasses.field(
dataclasses.field(
repr=False, default=_unspecified_parent, kw_only=True
),
),
(
'name',
Optional[str],
kw_only_dataclasses.field(default=None, kw_only=True),
dataclasses.field(default=None, kw_only=True),
),
]

if kw_only:
if tuple(sys.version_info)[:3] >= (3, 10, 0):
for (
name,
annotation, # pytype: disable=invalid-annotation
default,
) in extra_fields:
setattr(cls, name, default)
cls.__annotations__[name] = annotation
dataclasses.dataclass( # type: ignore[call-overload]
unsafe_hash='__hash__' not in cls.__dict__,
repr=False,
kw_only=True,
)(cls)
else:
raise TypeError('`kw_only` is not available before Py 3.10.')
else:
# Now apply dataclass transform (which operates in-place).
# Do generate a hash function only if not provided by the class.
kw_only_dataclasses.dataclass(
cls,
unsafe_hash='__hash__' not in cls.__dict__,
repr=False,
extra_fields=extra_fields,
) # pytype: disable=wrong-keyword-args
for (
name,
annotation, # pytype: disable=invalid-annotation
default,
) in extra_fields:
setattr(cls, name, default)
cls.__annotations__[name] = annotation

# TODO: a workaround for the issue:
# https://github.com/google/flax/pull/5087#issuecomment-3536610568
if (sys.version_info.major, sys.version_info.minor) in [(3, 12), (3, 13)]:
setattr(cls, '__annotations__', cls.__annotations__)

dataclasses.dataclass( # type: ignore[call-overload]
unsafe_hash='__hash__' not in cls.__dict__,
repr=False,
kw_only=kw_only,
)(cls)

cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign]

Expand Down
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ testing = [
"clu",
"clu<=0.0.9; python_version<'3.10'",
"einops",
"gymnasium[atari]",
"gymnasium[atari]; python_version<'3.14'",
"jaxlib",
"jaxtyping",
"jraph>=0.0.6dev0",
Expand All @@ -61,11 +61,11 @@ testing = [
"tensorflow_text>=2.11.0; platform_system!='Darwin' and python_version < '3.13'",
"tensorflow_datasets",
"tensorflow>=2.12.0; python_version<'3.13'", # to fix Numpy np.bool8 deprecation error
"tensorflow>=2.20.0; python_version>='3.13'",
"tensorflow>=2.20.0; python_version>='3.13' and python_version<'3.14'",
"torch",
"treescope>=0.1.1; python_version>='3.10'",
"cloudpickle>=3.0.0",
"ale-py>=0.10.2",
"ale-py>=0.10.2; python_version<'3.14'",
]
docs = [
"sphinx==6.2.1",
Expand Down Expand Up @@ -237,3 +237,11 @@ quote-style = "single"
[tool.uv]
# Ignore uv.lock and always upgrade the package to the latest
upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"]

[tool.uv.sources]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from #5086 and can be removed if that PR wont land

torch = { index = "pytorch" }

[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
Loading
Loading