Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions jax_cfd/base/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def advect_general(
raise NotImplementedError(
'Non-periodic boundary conditions are not implemented.')
target_offsets = grids.control_volume_offsets(c)
aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt)
aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt) # pyrefly: ignore[bad-argument-type]
for u, target_offset in zip(v, target_offsets))
aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt) # pyrefly: ignore[bad-argument-type]
for target_offset in target_offsets)
return _advect_aligned(aligned_c, aligned_v)

Expand Down Expand Up @@ -147,7 +147,7 @@ def _align_velocities(v: GridVariableVector) -> Tuple[GridVariableVector]:
tuple(interpolation.linear(v[i], offsets[i][j])
for j in range(grid.ndim))
for i in range(grid.ndim))
return aligned_v
return aligned_v # pyrefly: ignore[bad-return]


def _velocities_to_flux(
Expand Down Expand Up @@ -178,7 +178,7 @@ def _velocities_to_flux(
aligned_v[j][i].array),)
else:
flux[i] += (flux[j][i],)
return tuple(flux)
return tuple(flux) # pyrefly: ignore[bad-return]


def convect_linear(v: GridVariableVector) -> GridArrayVector:
Expand Down
10 changes: 5 additions & 5 deletions jax_cfd/base/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,10 @@ def laplacian_matrix_w_boundaries(
for i, side in enumerate(['lower', 'upper']): # lower and upper boundary
if bc.types[axis][i] == boundaries.BCType.NEUMANN:
_laplacian_boundary_neumann_cell_centered(
laplacians, grid, axis, side)
laplacians, grid, axis, side) # pyrefly: ignore[bad-argument-type]
elif bc.types[axis][i] == boundaries.BCType.DIRICHLET:
_laplacian_boundary_dirichlet_cell_centered(
laplacians, grid, axis, side)
laplacians, grid, axis, side) # pyrefly: ignore[bad-argument-type]
if np.isclose(offset[axis] % 1, 0.):
if bc.types[axis][0] == boundaries.BCType.DIRICHLET and bc.types[
axis][1] == boundaries.BCType.DIRICHLET:
Expand All @@ -287,7 +287,7 @@ def laplacian_matrix_w_boundaries(
elif boundaries.BCType.NEUMANN in bc.types[axis]:
raise NotImplementedError(
'edge-aligned Neumann boundaries are not implemented.')
return laplacians
return laplacians # pyrefly: ignore[bad-return]


def unstack(array, axis):
Expand Down Expand Up @@ -348,7 +348,7 @@ def interp1d( # pytype: disable=annotation-type-mismatch # jnp-type
x: Array,
y: Array,
axis: int = -1,
fill_value: Union[str, Array] = jnp.nan,
fill_value: Union[str, Array] = jnp.nan, # pyrefly: ignore[bad-function-definition]
assume_sorted: bool = True,
) -> Callable[[Array], jax.Array]:
"""Build an interpolation function to approximate `y = f(x)`.
Expand Down Expand Up @@ -464,4 +464,4 @@ def expand(array):
return jnp.reshape(
y_new, y_new.shape[:axis] + x_new_shape + y_new.shape[axis + 1:])

return interp_func
return interp_func # pyrefly: ignore[bad-return]
8 changes: 4 additions & 4 deletions jax_cfd/base/boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def make_padding(width):
expanded_data = jnp.pad(
data, bc_padding, mode='constant', constant_values=(0, 0))
padding_values = list(self.bc_values)
padding_values[axis] = [pad / 2 for pad in padding_values[axis]]
padding_values[axis] = [pad / 2 for pad in padding_values[axis]] # pyrefly: ignore[unsupported-operation]
data = 2 * jnp.pad(
data,
full_padding,
Expand Down Expand Up @@ -455,9 +455,9 @@ def values( # pytype: disable=signature-mismatch # overriding-parameter-count-
return (None, None)
bc_values = tuple(
jnp.full(grid.shape[:axis] +
grid.shape[axis + 1:], self.bc_values[axis][-i])
grid.shape[axis + 1:], self.bc_values[axis][-i]) # pyrefly: ignore[bad-argument-type]
for i in [0, 1])
return bc_values
return bc_values # pyrefly: ignore[bad-return]

def trim_boundary(self, u: grids.GridArray) -> grids.GridArray:
"""Returns GridArray without the grid points on the boundary.
Expand Down Expand Up @@ -805,7 +805,7 @@ def get_advection_flux_bc_from_velocity_and_scalar(
raise NotImplementedError(
'Flux boundary condition is not implemented for scalar' +
f' with {c.bc}')
if not np.isclose(c.bc.bc_values[axis][i], 0.0):
if not np.isclose(c.bc.bc_values[axis][i], 0.0): # pyrefly: ignore[no-matching-overload]
raise NotImplementedError(
'Flux boundary condition is not implemented for scalar' +
f' with {c.bc}')
Expand Down
12 changes: 6 additions & 6 deletions jax_cfd/base/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def _rhs_transform(
for i, _ in enumerate(['lower', 'upper']): # lower and upper boundary
if bc.types[axis][i] == boundaries.BCType.DIRICHLET:
bc_values = [0., 0.]
bc_values[i] = bc.bc_values[axis][i]
u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset,
bc_values)
bc_values[i] = bc.bc_values[axis][i] # pyrefly: ignore[unsupported-operation]
u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset, # pyrefly: ignore[bad-argument-type]
bc_values) # pyrefly: ignore[bad-argument-type]
elif bc.types[axis][i] == boundaries.BCType.NEUMANN:
if any(bc.bc_values[axis]):
raise NotImplementedError(
Expand All @@ -149,8 +149,8 @@ def solve_component(u: GridVariable) -> GridArray:

def linear_op(u_new: GridArray) -> GridArray:
"""Linear operator for (1 - ν Δt ∇²) u_{t+1}."""
u_new = grids.GridVariable(u_new, u.bc) # get boundary condition from u
return u_new.array - dt * nu * fd.laplacian(u_new)
u_new = grids.GridVariable(u_new, u.bc) # get boundary condition from u # pyrefly: ignore[bad-assignment]
return u_new.array - dt * nu * fd.laplacian(u_new) # pyrefly: ignore[bad-argument-type, missing-attribute]

def cg(b: GridArray, x0: GridArray) -> GridArray:
"""Iteratively solves Lx = b. with initial guess x0."""
Expand Down Expand Up @@ -197,7 +197,7 @@ def func(x):
u.grid, u.offset, u.bc)
op = fast_diagonalization.transform(
func,
laplacians,
laplacians, # pyrefly: ignore[bad-argument-type]
v[0].dtype,
hermitian=True,
circulant=circulant,
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def dynamic_time_step(v: GridVariableVector,
"""Pick a dynamic time-step for Navier-Stokes based on stable advection."""
v_max = jnp.sqrt(jnp.max(sum(u.data ** 2 for u in v)))
return stable_time_step( # pytype: disable=wrong-arg-types # jax-types
v_max, max_courant_number, viscosity, grid, implicit_diffusion)
v_max, max_courant_number, viscosity, grid, implicit_diffusion) # pyrefly: ignore[bad-argument-type]


def _wrap_term_as_vector(fun, *, name):
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/base/fast_diagonalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def pseudoinverse(
A function that computes the pseudo-inverse of the indicated operator.
"""
if cutoff is None:
cutoff = 10 * jnp.finfo(dtype).eps
cutoff = 10 * jnp.finfo(dtype).eps # pyrefly: ignore[bad-assignment]

def func(v):
with np.errstate(divide='ignore', invalid='ignore'):
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/base/finite_differences.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def stencil_sum(*arrays: GridArray) -> GridArray:
# Actually passed: (iterable: Generator[Union[jax.interpreters.xla.DeviceArray, numpy.ndarray], Any, None])
result = sum(array.data for array in arrays) # type: ignore
grid = grids.consistent_grid(*arrays)
return grids.GridArray(result, offset, grid)
return grids.GridArray(result, offset, grid) # pyrefly: ignore[bad-argument-type]


# incompatible with typing.overload
Expand Down Expand Up @@ -140,7 +140,7 @@ def divergence(v: Sequence[GridVariable]) -> GridArray:
raise ValueError('The length of `v` must be equal to `grid.ndim`.'
f'Expected length {grid.ndim}; got {len(v)}.')
differences = [backward_difference(u, axis) for axis, u in enumerate(v)]
return sum(differences)
return sum(differences) # pyrefly: ignore[bad-return]


def centered_divergence(v: Sequence[GridVariable]) -> GridArray:
Expand All @@ -150,7 +150,7 @@ def centered_divergence(v: Sequence[GridVariable]) -> GridArray:
raise ValueError('The length of `v` must be equal to `grid.ndim`.'
f'Expected length {grid.ndim}; got {len(v)}.')
differences = [central_difference(u, axis) for axis, u in enumerate(v)]
return sum(differences)
return sum(differences) # pyrefly: ignore[bad-return]


@typing.overload
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/base/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def taylor_green_forcing(
) -> ForcingFn:
"""Constant driving forced in the form of Taylor-Green vorcities."""
u, v = validation_problems.TaylorGreen(
shape=grid.shape[:2], kx=k, ky=k).velocity()
shape=grid.shape[:2], kx=k, ky=k).velocity() # pyrefly: ignore[bad-argument-type]
# Put force on same offset, grid as velocity components
if grid.ndim == 2:
u = grids.GridArray(u.data * scale, u.offset, grid)
Expand Down Expand Up @@ -72,8 +72,8 @@ def kolmogorov_forcing(
offsets = grid.cell_faces

if swap_xy:
x = grid.mesh(offsets[1])[0]
v = scale * grids.GridArray(jnp.sin(k * x), offsets[1], grid)
x = grid.mesh(offsets[1])[0] # pyrefly: ignore[bad-index]
v = scale * grids.GridArray(jnp.sin(k * x), offsets[1], grid) # pyrefly: ignore[bad-index]

if grid.ndim == 2:
u = grids.GridArray(jnp.zeros_like(v.data), (1, 1/2), grid)
Expand Down
9 changes: 5 additions & 4 deletions jax_cfd/base/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,12 @@ def __init__(
step = 1
if isinstance(step, numbers.Number):
step = (step,) * self.ndim
elif len(step) != self.ndim:
elif len(step) != self.ndim: # pyrefly: ignore[bad-argument-type]
# pyrefly: ignore[bad-argument-type]
raise ValueError('length of step does not match ndim: '
f'{len(step)} != {self.ndim}')
domain = tuple(
(0.0, float(step_ * size)) for step_, size in zip(step, shape))
(0.0, float(step_ * size)) for step_, size in zip(step, shape)) # pyrefly: ignore[bad-argument-type]

object.__setattr__(self, 'domain', domain)

Expand All @@ -569,7 +570,7 @@ def cell_faces(self) -> Tuple[Tuple[float, ...]]:
"""Returns the offsets at each of the 'forward' cell faces."""
d = self.ndim
offsets = (np.eye(d) + np.ones([d, d])) / 2.
return tuple(tuple(float(o) for o in offset) for offset in offsets)
return tuple(tuple(float(o) for o in offset) for offset in offsets) # pyrefly: ignore[bad-return]

def stagger(self, v: Tuple[Array, ...]) -> Tuple[GridArray, ...]:
"""Places the velocity components of `v` on the `Grid`'s cell faces."""
Expand Down Expand Up @@ -672,7 +673,7 @@ def eval_on_mesh(self,
"""
if offset is None:
offset = self.cell_center
return GridArray(fn(*self.mesh(offset)), offset, self)
return GridArray(fn(*self.mesh(offset)), offset, self) # pyrefly: ignore[bad-argument-type]


def domain_interior_masks(grid: Grid):
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/base/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ def lax_wendroff(
f'for Lax-Wendroff interpolation `c.offset` and `offset` must differ at'
f' most in one entry, but got: {c.offset} and {offset}.')
axis, = interpolation_axes
u = v[axis]
u = v[axis] # pyrefly: ignore[unsupported-operation]
offset_delta = u.offset[axis] - c.offset[axis]
floor = int(np.floor(offset_delta)) # used for positive velocity
ceil = int(np.ceil(offset_delta)) # used for negative velocity
grid = grids.consistent_grid(c, u)
courant_numbers = (dt / grid.step[axis]) * u.data
courant_numbers = (dt / grid.step[axis]) * u.data # pyrefly: ignore[unsupported-operation]
positive_u_case = (
c.shift(floor, axis).data + 0.5 * (1 - courant_numbers) *
(c.shift(ceil, axis).data - c.shift(floor, axis).data))
Expand Down Expand Up @@ -356,4 +356,4 @@ def point_interpolation(
(domain_upper - domain_lower))

return jax.scipy.ndimage.map_coordinates(
c.data, coordinates=index, order=order, mode=mode, cval=cval)
c.data, coordinates=index, order=order, mode=mode, cval=cval) # pyrefly: ignore[bad-argument-type]
2 changes: 1 addition & 1 deletion jax_cfd/base/pressure.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def solve_fast_diag(
rhs.grid, rhs.offset, pressure_bc)
rhs_transformed = _rhs_transform(rhs, pressure_bc)
pinv = fast_diagonalization.pseudoinverse(
laplacians,
laplacians, # pyrefly: ignore[bad-argument-type]
rhs_transformed.dtype,
hermitian=True,
circulant=circulant,
Expand Down
22 changes: 11 additions & 11 deletions jax_cfd/base/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,19 @@ def top_hat_downsample(
dx / dx_source
for dx, dx_source in zip(destination_grid.step, source_grid.step))
if filter_size is None:
filter_size = factor
filter_size = factor # pyrefly: ignore[bad-assignment]
if isinstance(filter_size, int):
filter_size = tuple(filter_size for _ in range(source_grid.ndim))
assert destination_grid.domain == source_grid.domain
assert all([round(f) == f for f in factor])
assert all([round(f) == f for f in filter_size]) # this can be relaxed
assert all([round(f) == f for f in filter_size]) # this can be relaxed # pyrefly: ignore[not-iterable]
acceptable_filter = lambda f: f % 2 == 0 or f == 1
assert all(map(acceptable_filter,
filter_size)) # only even filters are implemented
filter_size)) # only even filters are implemented # pyrefly: ignore[bad-argument-type]
assert all(list(map(acceptable_filter,
factor))) # only even factors are implemented
# filter has to be at least as large as the factor.
assert all(filt >= f for f, filt in zip(factor, filter_size))
assert all(filt >= f for f, filt in zip(factor, filter_size)) # pyrefly: ignore[bad-argument-type]
result = []
for c in variables:
if c.grid != source_grid:
Expand All @@ -132,31 +132,31 @@ def top_hat_downsample(
bc = c.bc
offset = c.offset
center_offset = tuple(
0.5 if f > 1 else o for o, f in zip(offset, filter_size))
0.5 if f > 1 else o for o, f in zip(offset, filter_size)) # pyrefly: ignore[bad-argument-type]
c_centered = interpolation.linear(c, center_offset).array
center_offset = np.array(center_offset)
grid_shape = np.array(source_grid.shape)
for axis in range(c.grid.ndim):
c_centered = bc.pad(
c_centered,
round(filter_size[axis]) // 2,
round(filter_size[axis]) // 2, # pyrefly: ignore[unsupported-operation]
axis=axis,
mode=boundaries.Padding.MIRROR)
c_centered = bc.pad(
c_centered,
-(round(filter_size[axis]) // 2),
-(round(filter_size[axis]) // 2), # pyrefly: ignore[unsupported-operation]
axis=axis,
mode=boundaries.Padding.MIRROR)
convolution_filter = jnp.ones(round(
filter_size[axis])) / filter_size[axis]
filter_size[axis])) / filter_size[axis] # pyrefly: ignore[unsupported-operation]
convolve_1d = lambda arr, convolution_filter=convolution_filter: jnp.convolve( # pylint: disable=g-long-lambda
arr, convolution_filter, 'valid')
axes = list(range(source_grid.ndim))
axes.remove(axis)
for ax in axes:
convolve_1d = jax.vmap(convolve_1d, in_axes=ax, out_axes=ax)
c_centered = convolve_1d(c_centered.data)
if filter_size[axis] > 1:
if filter_size[axis] > 1: # pyrefly: ignore[unsupported-operation]
if np.isclose(offset[axis], 0):
start = 0
end = c_centered.shape[axis] - 1
Expand Down Expand Up @@ -217,7 +217,7 @@ def downsample(u: GridArray, direction: int, factor: int) -> GridArray:
return GridArray(array, offset=u.offset, grid=destination_grid)
else:
downsample = downsample_staggered_velocity_component
result.append(downsample(u, j, round(factor)))
result.append(downsample(u, j, round(factor))) # pyrefly: ignore[bad-argument-type]
return tuple(result)


Expand All @@ -228,7 +228,7 @@ def downsample_spectral(_: grids.Grid, destination_grid: grids.Grid,
kx, ky = destination_grid.rfft_axes()
(num_x,), (num_y,) = kx.shape, ky.shape

input_num_x, _ = signal_hat.shape
input_num_x, _ = signal_hat.shape # pyrefly: ignore[bad-assignment]

downed = jnp.concatenate(
[signal_hat[:num_x // 2, :num_y], signal_hat[-num_x // 2:, :num_y]])
Expand Down
4 changes: 2 additions & 2 deletions jax_cfd/base/time_stepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def step_fn(u0):
k[0] = F(u0)

for i in range(1, num_steps):
u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j])
u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j]) # pyrefly: ignore[unsupported-operation]
u[i] = P(u_star)
k[i] = F(u[i])

u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j])
u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j]) # pyrefly: ignore[unsupported-operation]
u_final = P(u_star)

return u_final
Expand Down
4 changes: 2 additions & 2 deletions jax_cfd/base/validation_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def velocity(
u = grids.GridVariable(
array=grids.GridArray(
data=scale * jnp.cos(self._kx * ux) * jnp.sin(self._ky * uy),
offset=offsets[0],
offset=offsets[0], # pyrefly: ignore[bad-argument-type]
grid=self.grid),
bc=boundaries.periodic_boundary_conditions(self.grid.ndim))

vx, vy = self.grid.mesh(offsets[1])
v = grids.GridVariable(
array=grids.GridArray(
data=-scale * jnp.sin(self._kx * vx) * jnp.cos(self._ky * vy),
offset=offsets[1],
offset=offsets[1], # pyrefly: ignore[bad-argument-type]
grid=self.grid),
bc=boundaries.periodic_boundary_conditions(self.grid.ndim))

Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/collocated/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _velocities_to_flux(v: GridVariableVector) -> Tuple[GridVariableVector]:
flux[i] += (bc.impose_bc(v[i].array * v[j].array),)
else:
flux[i] += (flux[j][i],)
return tuple(flux)
return tuple(flux) # pyrefly: ignore[bad-return]


def convect_linear(v: GridVariableVector) -> GridArrayVector:
Expand Down
2 changes: 1 addition & 1 deletion jax_cfd/ml/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def modular_navier_stokes_model(
Returns:
A function that performs `steps` steps of the Navier-Stokes time dynamics.
"""
active_forcing_fn = physics_specs.forcing_module(grid)
active_forcing_fn = physics_specs.forcing_module(grid) # pyrefly: ignore[not-callable]

def navier_stokes_step_fn(state):
"""Advances Navier-Stokes state forward in time."""
Expand Down
Loading
Loading