Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ jobs:
- python-version: '3.12'
test-type: mypy
jax-version: 'newest'
- python-version: '3.12'
test-type: pyrefly
jax-version: 'newest'
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup uv
Expand Down Expand Up @@ -128,6 +131,8 @@ jobs:
uv run --no-sync tests/run_all_tests.sh --only-pytype
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
uv run --no-sync tests/run_all_tests.sh --only-mypy
elif [[ "${{ matrix.test-type }}" == "pyrefly" ]]; then
uv run --no-sync tests/run_all_tests.sh --only-pyrefly
else
echo "Unknown test type: ${{ matrix.test-type }}"
exit 1
Expand Down
1 change: 1 addition & 0 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
inputs, kernel, bias = self.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)
assert inputs is not None and kernel is not None
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prevents pyrefly ignore comment on the code below:

out = dot_general(
      inputs,  # pyrefly: ignore [bad-argument-type]
      kernel,  # pyrefly: ignore [bad-argument-type]
      ((axis, contract_ind), (batch_dims, batch_ind)),
      precision=self.precision,
    )

and pyrefly is right on as self.promote_dtype output type hint is Array | None, so they deduce that inputs and kernel can be also None.


if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
Expand Down
2 changes: 1 addition & 1 deletion flax/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Whether to filter flax frames from traceback.
_flax_filter_tracebacks = config.flax_filter_frames
# Flax specific set of paths to exclude from tracebacks.
_flax_exclusions = set()
_flax_exclusions = set() # type: ignore[var-annotated]


# re-import JAX symbol for convenience.
Expand Down
2 changes: 1 addition & 1 deletion flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def read_chunk(i):
else:
checkpoint_contents = fp.read()

state_dict = serialization.msgpack_restore(checkpoint_contents)
state_dict = serialization.msgpack_restore(bytes(checkpoint_contents))
state_dict = _restore_mpas(
state_dict,
target,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ testing = [
"jraph>=0.0.6dev0",
"ml-collections",
"mypy",
"pyrefly",
"opencv-python",
# Set protobuf version to prevent error in
# examples/mnist/train_test.py::TrainTest::test_train_and_evaluate
Expand Down
26 changes: 26 additions & 0 deletions pyrefly.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Pyrefly configuration - migrated from mypy
# Only type-check flax/linen/linear.py for now; expand as issues are resolved.
project-includes = ["flax/linen/linear.py"]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the check runs over single file to show that there is no more errors like # pyrefly: ignore [bad-class-definition] (as in previous attempt: https://github.com/google/flax/pull/5437/changes#diff-cfcdd6089c08daca13a27f79e07b856585c5f9b6910570f9cc3e7be76fd746b1R103)


preset = "legacy"
ignore-missing-imports = [
"tensorflow.*",
"tensorboard.*",
"absl.*",
"jax.*",
"rich.*",
"jaxlib.cuda.*",
"jaxlib.cpu.*",
"msgpack",
"numpy.*",
"optax.*",
"orbax.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"yaml",
]

[errors]
missing-attribute = "ignore"
16 changes: 14 additions & 2 deletions tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PYTEST_OPTS=
RUN_DOCTEST=false
RUN_MYPY=false
RUN_PYTEST=false
RUN_PYREFLY=false
RUN_PYTYPE=false
GH_VENV=false

Expand All @@ -30,6 +31,9 @@ case $flag in
--only-mypy)
RUN_MYPY=true
;;
--only-pyrefly)
RUN_PYREFLY=true
;;
--use-venv)
GH_VENV=true
;;
Expand All @@ -40,12 +44,13 @@ case $flag in
esac
done

# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy is set, run all tests
if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY; then
# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy, --only-pyrefly is set, run all tests
if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY && ! $RUN_PYREFLY; then
RUN_DOCTEST=true
RUN_PYTEST=true
RUN_PYTYPE=true
RUN_MYPY=true
RUN_PYREFLY=true
fi

# Activate cached virtual env for github CI
Expand All @@ -58,6 +63,7 @@ echo "PYTEST_OPTS: $PYTEST_OPTS"
echo "RUN_DOCTEST: $RUN_DOCTEST"
echo "RUN_PYTEST: $RUN_PYTEST"
echo "RUN_MYPY: $RUN_MYPY"
echo "RUN_PYREFLY: $RUN_PYREFLY"
echo "RUN_PYTYPE: $RUN_PYTYPE"
echo "GH_VENV: $GH_VENV"
echo "WHICH PYTHON: $(which python)"
Expand Down Expand Up @@ -155,5 +161,11 @@ if $RUN_MYPY; then
mypy --config pyproject.toml flax/ --show-error-codes
fi

if $RUN_PYREFLY; then
echo "=== RUNNING PYREFLY ==="
# Type-check using pyrefly.toml (currently scoped to flax/linen/linear.py).
pyrefly check
fi

# Return error code 0 if no real failures happened.
echo "finished all tests."
Loading