diff --git a/jax_cfd/base/advection.py b/jax_cfd/base/advection.py index f4afd5b..60055c5 100644 --- a/jax_cfd/base/advection.py +++ b/jax_cfd/base/advection.py @@ -109,9 +109,9 @@ def advect_general( raise NotImplementedError( 'Non-periodic boundary conditions are not implemented.') target_offsets = grids.control_volume_offsets(c) - aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt) + aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt) # pyrefly: ignore[bad-argument-type] for u, target_offset in zip(v, target_offsets)) - aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt) + aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt) # pyrefly: ignore[bad-argument-type] for target_offset in target_offsets) return _advect_aligned(aligned_c, aligned_v) @@ -147,7 +147,7 @@ def _align_velocities(v: GridVariableVector) -> Tuple[GridVariableVector]: tuple(interpolation.linear(v[i], offsets[i][j]) for j in range(grid.ndim)) for i in range(grid.ndim)) - return aligned_v + return aligned_v # pyrefly: ignore[bad-return] def _velocities_to_flux( @@ -178,7 +178,7 @@ def _velocities_to_flux( aligned_v[j][i].array),) else: flux[i] += (flux[j][i],) - return tuple(flux) + return tuple(flux) # pyrefly: ignore[bad-return] def convect_linear(v: GridVariableVector) -> GridArrayVector: diff --git a/jax_cfd/base/array_utils.py b/jax_cfd/base/array_utils.py index d40d153..92227d9 100644 --- a/jax_cfd/base/array_utils.py +++ b/jax_cfd/base/array_utils.py @@ -272,10 +272,10 @@ def laplacian_matrix_w_boundaries( for i, side in enumerate(['lower', 'upper']): # lower and upper boundary if bc.types[axis][i] == boundaries.BCType.NEUMANN: _laplacian_boundary_neumann_cell_centered( - laplacians, grid, axis, side) + laplacians, grid, axis, side) # pyrefly: ignore[bad-argument-type] elif bc.types[axis][i] == boundaries.BCType.DIRICHLET: _laplacian_boundary_dirichlet_cell_centered( - laplacians, grid, axis, side) + laplacians, grid, axis, side) # pyrefly: ignore[bad-argument-type] if np.isclose(offset[axis] % 1, 0.): if bc.types[axis][0] == boundaries.BCType.DIRICHLET and bc.types[ axis][1] == boundaries.BCType.DIRICHLET: @@ -287,7 +287,7 @@ def laplacian_matrix_w_boundaries( elif boundaries.BCType.NEUMANN in bc.types[axis]: raise NotImplementedError( 'edge-aligned Neumann boundaries are not implemented.') - return laplacians + return laplacians # pyrefly: ignore[bad-return] def unstack(array, axis): @@ -348,7 +348,7 @@ def interp1d( # pytype: disable=annotation-type-mismatch # jnp-type x: Array, y: Array, axis: int = -1, - fill_value: Union[str, Array] = jnp.nan, + fill_value: Union[str, Array] = jnp.nan, # pyrefly: ignore[bad-function-definition] assume_sorted: bool = True, ) -> Callable[[Array], jax.Array]: """Build an interpolation function to approximate `y = f(x)`. @@ -464,4 +464,4 @@ def expand(array): return jnp.reshape( y_new, y_new.shape[:axis] + x_new_shape + y_new.shape[axis + 1:]) - return interp_func + return interp_func # pyrefly: ignore[bad-return] diff --git a/jax_cfd/base/boundaries.py b/jax_cfd/base/boundaries.py index de47b34..a6d8cb6 100644 --- a/jax_cfd/base/boundaries.py +++ b/jax_cfd/base/boundaries.py @@ -254,7 +254,7 @@ def make_padding(width): expanded_data = jnp.pad( data, bc_padding, mode='constant', constant_values=(0, 0)) padding_values = list(self.bc_values) - padding_values[axis] = [pad / 2 for pad in padding_values[axis]] + padding_values[axis] = [pad / 2 for pad in padding_values[axis]] # pyrefly: ignore[unsupported-operation] data = 2 * jnp.pad( data, full_padding, @@ -455,9 +455,9 @@ def values( # pytype: disable=signature-mismatch # overriding-parameter-count- return (None, None) bc_values = tuple( jnp.full(grid.shape[:axis] + - grid.shape[axis + 1:], self.bc_values[axis][-i]) + grid.shape[axis + 1:], self.bc_values[axis][-i]) # pyrefly: ignore[bad-argument-type] for i in [0, 1]) - return bc_values + return bc_values # pyrefly: ignore[bad-return] def trim_boundary(self, u: grids.GridArray) -> grids.GridArray: """Returns GridArray without the grid points on the boundary. @@ -805,7 +805,7 @@ def get_advection_flux_bc_from_velocity_and_scalar( raise NotImplementedError( 'Flux boundary condition is not implemented for scalar' + f' with {c.bc}') - if not np.isclose(c.bc.bc_values[axis][i], 0.0): + if not np.isclose(c.bc.bc_values[axis][i], 0.0): # pyrefly: ignore[no-matching-overload] raise NotImplementedError( 'Flux boundary condition is not implemented for scalar' + f' with {c.bc}') diff --git a/jax_cfd/base/diffusion.py b/jax_cfd/base/diffusion.py index eb28919..533db1a 100644 --- a/jax_cfd/base/diffusion.py +++ b/jax_cfd/base/diffusion.py @@ -124,9 +124,9 @@ def _rhs_transform( for i, _ in enumerate(['lower', 'upper']): # lower and upper boundary if bc.types[axis][i] == boundaries.BCType.DIRICHLET: bc_values = [0., 0.] - bc_values[i] = bc.bc_values[axis][i] - u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset, - bc_values) + bc_values[i] = bc.bc_values[axis][i] # pyrefly: ignore[unsupported-operation] + u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset, # pyrefly: ignore[bad-argument-type] + bc_values) # pyrefly: ignore[bad-argument-type] elif bc.types[axis][i] == boundaries.BCType.NEUMANN: if any(bc.bc_values[axis]): raise NotImplementedError( @@ -149,8 +149,8 @@ def solve_component(u: GridVariable) -> GridArray: def linear_op(u_new: GridArray) -> GridArray: """Linear operator for (1 - ν Δt ∇²) u_{t+1}.""" - u_new = grids.GridVariable(u_new, u.bc) # get boundary condition from u - return u_new.array - dt * nu * fd.laplacian(u_new) + u_new = grids.GridVariable(u_new, u.bc) # get boundary condition from u # pyrefly: ignore[bad-assignment] + return u_new.array - dt * nu * fd.laplacian(u_new) # pyrefly: ignore[bad-argument-type, missing-attribute] def cg(b: GridArray, x0: GridArray) -> GridArray: """Iteratively solves Lx = b. with initial guess x0.""" @@ -197,7 +197,7 @@ def func(x): u.grid, u.offset, u.bc) op = fast_diagonalization.transform( func, - laplacians, + laplacians, # pyrefly: ignore[bad-argument-type] v[0].dtype, hermitian=True, circulant=circulant, diff --git a/jax_cfd/base/equations.py b/jax_cfd/base/equations.py index af0954e..52e441b 100644 --- a/jax_cfd/base/equations.py +++ b/jax_cfd/base/equations.py @@ -67,7 +67,7 @@ def dynamic_time_step(v: GridVariableVector, """Pick a dynamic time-step for Navier-Stokes based on stable advection.""" v_max = jnp.sqrt(jnp.max(sum(u.data ** 2 for u in v))) return stable_time_step( # pytype: disable=wrong-arg-types # jax-types - v_max, max_courant_number, viscosity, grid, implicit_diffusion) + v_max, max_courant_number, viscosity, grid, implicit_diffusion) # pyrefly: ignore[bad-argument-type] def _wrap_term_as_vector(fun, *, name): diff --git a/jax_cfd/base/fast_diagonalization.py b/jax_cfd/base/fast_diagonalization.py index af00a31..c834486 100644 --- a/jax_cfd/base/fast_diagonalization.py +++ b/jax_cfd/base/fast_diagonalization.py @@ -255,7 +255,7 @@ def pseudoinverse( A function that computes the pseudo-inverse of the indicated operator. """ if cutoff is None: - cutoff = 10 * jnp.finfo(dtype).eps + cutoff = 10 * jnp.finfo(dtype).eps # pyrefly: ignore[bad-assignment] def func(v): with np.errstate(divide='ignore', invalid='ignore'): diff --git a/jax_cfd/base/finite_differences.py b/jax_cfd/base/finite_differences.py index e84077a..ac98974 100644 --- a/jax_cfd/base/finite_differences.py +++ b/jax_cfd/base/finite_differences.py @@ -51,7 +51,7 @@ def stencil_sum(*arrays: GridArray) -> GridArray: # Actually passed: (iterable: Generator[Union[jax.interpreters.xla.DeviceArray, numpy.ndarray], Any, None]) result = sum(array.data for array in arrays) # type: ignore grid = grids.consistent_grid(*arrays) - return grids.GridArray(result, offset, grid) + return grids.GridArray(result, offset, grid) # pyrefly: ignore[bad-argument-type] # incompatible with typing.overload @@ -140,7 +140,7 @@ def divergence(v: Sequence[GridVariable]) -> GridArray: raise ValueError('The length of `v` must be equal to `grid.ndim`.' f'Expected length {grid.ndim}; got {len(v)}.') differences = [backward_difference(u, axis) for axis, u in enumerate(v)] - return sum(differences) + return sum(differences) # pyrefly: ignore[bad-return] def centered_divergence(v: Sequence[GridVariable]) -> GridArray: @@ -150,7 +150,7 @@ def centered_divergence(v: Sequence[GridVariable]) -> GridArray: raise ValueError('The length of `v` must be equal to `grid.ndim`.' f'Expected length {grid.ndim}; got {len(v)}.') differences = [central_difference(u, axis) for axis, u in enumerate(v)] - return sum(differences) + return sum(differences) # pyrefly: ignore[bad-return] @typing.overload diff --git a/jax_cfd/base/forcings.py b/jax_cfd/base/forcings.py index 7990a83..00f73ff 100644 --- a/jax_cfd/base/forcings.py +++ b/jax_cfd/base/forcings.py @@ -37,7 +37,7 @@ def taylor_green_forcing( ) -> ForcingFn: """Constant driving forced in the form of Taylor-Green vorcities.""" u, v = validation_problems.TaylorGreen( - shape=grid.shape[:2], kx=k, ky=k).velocity() + shape=grid.shape[:2], kx=k, ky=k).velocity() # pyrefly: ignore[bad-argument-type] # Put force on same offset, grid as velocity components if grid.ndim == 2: u = grids.GridArray(u.data * scale, u.offset, grid) @@ -72,8 +72,8 @@ def kolmogorov_forcing( offsets = grid.cell_faces if swap_xy: - x = grid.mesh(offsets[1])[0] - v = scale * grids.GridArray(jnp.sin(k * x), offsets[1], grid) + x = grid.mesh(offsets[1])[0] # pyrefly: ignore[bad-index] + v = scale * grids.GridArray(jnp.sin(k * x), offsets[1], grid) # pyrefly: ignore[bad-index] if grid.ndim == 2: u = grids.GridArray(jnp.zeros_like(v.data), (1, 1/2), grid) diff --git a/jax_cfd/base/grids.py b/jax_cfd/base/grids.py index 3c87f38..2592b3f 100644 --- a/jax_cfd/base/grids.py +++ b/jax_cfd/base/grids.py @@ -542,11 +542,12 @@ def __init__( step = 1 if isinstance(step, numbers.Number): step = (step,) * self.ndim - elif len(step) != self.ndim: + elif len(step) != self.ndim: # pyrefly: ignore[bad-argument-type] + # pyrefly: ignore[bad-argument-type] raise ValueError('length of step does not match ndim: ' f'{len(step)} != {self.ndim}') domain = tuple( - (0.0, float(step_ * size)) for step_, size in zip(step, shape)) + (0.0, float(step_ * size)) for step_, size in zip(step, shape)) # pyrefly: ignore[bad-argument-type] object.__setattr__(self, 'domain', domain) @@ -569,7 +570,7 @@ def cell_faces(self) -> Tuple[Tuple[float, ...]]: """Returns the offsets at each of the 'forward' cell faces.""" d = self.ndim offsets = (np.eye(d) + np.ones([d, d])) / 2. - return tuple(tuple(float(o) for o in offset) for offset in offsets) + return tuple(tuple(float(o) for o in offset) for offset in offsets) # pyrefly: ignore[bad-return] def stagger(self, v: Tuple[Array, ...]) -> Tuple[GridArray, ...]: """Places the velocity components of `v` on the `Grid`'s cell faces.""" @@ -672,7 +673,7 @@ def eval_on_mesh(self, """ if offset is None: offset = self.cell_center - return GridArray(fn(*self.mesh(offset)), offset, self) + return GridArray(fn(*self.mesh(offset)), offset, self) # pyrefly: ignore[bad-argument-type] def domain_interior_masks(grid: Grid): diff --git a/jax_cfd/base/interpolation.py b/jax_cfd/base/interpolation.py index 8ef5788..80bb1e0 100644 --- a/jax_cfd/base/interpolation.py +++ b/jax_cfd/base/interpolation.py @@ -202,12 +202,12 @@ def lax_wendroff( f'for Lax-Wendroff interpolation `c.offset` and `offset` must differ at' f' most in one entry, but got: {c.offset} and {offset}.') axis, = interpolation_axes - u = v[axis] + u = v[axis] # pyrefly: ignore[unsupported-operation] offset_delta = u.offset[axis] - c.offset[axis] floor = int(np.floor(offset_delta)) # used for positive velocity ceil = int(np.ceil(offset_delta)) # used for negative velocity grid = grids.consistent_grid(c, u) - courant_numbers = (dt / grid.step[axis]) * u.data + courant_numbers = (dt / grid.step[axis]) * u.data # pyrefly: ignore[unsupported-operation] positive_u_case = ( c.shift(floor, axis).data + 0.5 * (1 - courant_numbers) * (c.shift(ceil, axis).data - c.shift(floor, axis).data)) @@ -356,4 +356,4 @@ def point_interpolation( (domain_upper - domain_lower)) return jax.scipy.ndimage.map_coordinates( - c.data, coordinates=index, order=order, mode=mode, cval=cval) + c.data, coordinates=index, order=order, mode=mode, cval=cval) # pyrefly: ignore[bad-argument-type] diff --git a/jax_cfd/base/pressure.py b/jax_cfd/base/pressure.py index dca8160..20d540a 100644 --- a/jax_cfd/base/pressure.py +++ b/jax_cfd/base/pressure.py @@ -149,7 +149,7 @@ def solve_fast_diag( rhs.grid, rhs.offset, pressure_bc) rhs_transformed = _rhs_transform(rhs, pressure_bc) pinv = fast_diagonalization.pseudoinverse( - laplacians, + laplacians, # pyrefly: ignore[bad-argument-type] rhs_transformed.dtype, hermitian=True, circulant=circulant, diff --git a/jax_cfd/base/resize.py b/jax_cfd/base/resize.py index ceb3c3c..85461a4 100644 --- a/jax_cfd/base/resize.py +++ b/jax_cfd/base/resize.py @@ -110,19 +110,19 @@ def top_hat_downsample( dx / dx_source for dx, dx_source in zip(destination_grid.step, source_grid.step)) if filter_size is None: - filter_size = factor + filter_size = factor # pyrefly: ignore[bad-assignment] if isinstance(filter_size, int): filter_size = tuple(filter_size for _ in range(source_grid.ndim)) assert destination_grid.domain == source_grid.domain assert all([round(f) == f for f in factor]) - assert all([round(f) == f for f in filter_size]) # this can be relaxed + assert all([round(f) == f for f in filter_size]) # this can be relaxed # pyrefly: ignore[not-iterable] acceptable_filter = lambda f: f % 2 == 0 or f == 1 assert all(map(acceptable_filter, - filter_size)) # only even filters are implemented + filter_size)) # only even filters are implemented # pyrefly: ignore[bad-argument-type] assert all(list(map(acceptable_filter, factor))) # only even factors are implemented # filter has to be at least as large as the factor. - assert all(filt >= f for f, filt in zip(factor, filter_size)) + assert all(filt >= f for f, filt in zip(factor, filter_size)) # pyrefly: ignore[bad-argument-type] result = [] for c in variables: if c.grid != source_grid: @@ -132,23 +132,23 @@ def top_hat_downsample( bc = c.bc offset = c.offset center_offset = tuple( - 0.5 if f > 1 else o for o, f in zip(offset, filter_size)) + 0.5 if f > 1 else o for o, f in zip(offset, filter_size)) # pyrefly: ignore[bad-argument-type] c_centered = interpolation.linear(c, center_offset).array center_offset = np.array(center_offset) grid_shape = np.array(source_grid.shape) for axis in range(c.grid.ndim): c_centered = bc.pad( c_centered, - round(filter_size[axis]) // 2, + round(filter_size[axis]) // 2, # pyrefly: ignore[unsupported-operation] axis=axis, mode=boundaries.Padding.MIRROR) c_centered = bc.pad( c_centered, - -(round(filter_size[axis]) // 2), + -(round(filter_size[axis]) // 2), # pyrefly: ignore[unsupported-operation] axis=axis, mode=boundaries.Padding.MIRROR) convolution_filter = jnp.ones(round( - filter_size[axis])) / filter_size[axis] + filter_size[axis])) / filter_size[axis] # pyrefly: ignore[unsupported-operation] convolve_1d = lambda arr, convolution_filter=convolution_filter: jnp.convolve( # pylint: disable=g-long-lambda arr, convolution_filter, 'valid') axes = list(range(source_grid.ndim)) @@ -156,7 +156,7 @@ def top_hat_downsample( for ax in axes: convolve_1d = jax.vmap(convolve_1d, in_axes=ax, out_axes=ax) c_centered = convolve_1d(c_centered.data) - if filter_size[axis] > 1: + if filter_size[axis] > 1: # pyrefly: ignore[unsupported-operation] if np.isclose(offset[axis], 0): start = 0 end = c_centered.shape[axis] - 1 @@ -217,7 +217,7 @@ def downsample(u: GridArray, direction: int, factor: int) -> GridArray: return GridArray(array, offset=u.offset, grid=destination_grid) else: downsample = downsample_staggered_velocity_component - result.append(downsample(u, j, round(factor))) + result.append(downsample(u, j, round(factor))) # pyrefly: ignore[bad-argument-type] return tuple(result) @@ -228,7 +228,7 @@ def downsample_spectral(_: grids.Grid, destination_grid: grids.Grid, kx, ky = destination_grid.rfft_axes() (num_x,), (num_y,) = kx.shape, ky.shape - input_num_x, _ = signal_hat.shape + input_num_x, _ = signal_hat.shape # pyrefly: ignore[bad-assignment] downed = jnp.concatenate( [signal_hat[:num_x // 2, :num_y], signal_hat[-num_x // 2:, :num_y]]) diff --git a/jax_cfd/base/time_stepping.py b/jax_cfd/base/time_stepping.py index b9db210..6ccc2ce 100644 --- a/jax_cfd/base/time_stepping.py +++ b/jax_cfd/base/time_stepping.py @@ -94,11 +94,11 @@ def step_fn(u0): k[0] = F(u0) for i in range(1, num_steps): - u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j]) + u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j]) # pyrefly: ignore[unsupported-operation] u[i] = P(u_star) k[i] = F(u[i]) - u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j]) + u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j]) # pyrefly: ignore[unsupported-operation] u_final = P(u_star) return u_final diff --git a/jax_cfd/base/validation_problems.py b/jax_cfd/base/validation_problems.py index bb09b2c..cd5b355 100644 --- a/jax_cfd/base/validation_problems.py +++ b/jax_cfd/base/validation_problems.py @@ -91,7 +91,7 @@ def velocity( u = grids.GridVariable( array=grids.GridArray( data=scale * jnp.cos(self._kx * ux) * jnp.sin(self._ky * uy), - offset=offsets[0], + offset=offsets[0], # pyrefly: ignore[bad-argument-type] grid=self.grid), bc=boundaries.periodic_boundary_conditions(self.grid.ndim)) @@ -99,7 +99,7 @@ def velocity( v = grids.GridVariable( array=grids.GridArray( data=-scale * jnp.sin(self._kx * vx) * jnp.cos(self._ky * vy), - offset=offsets[1], + offset=offsets[1], # pyrefly: ignore[bad-argument-type] grid=self.grid), bc=boundaries.periodic_boundary_conditions(self.grid.ndim)) diff --git a/jax_cfd/collocated/advection.py b/jax_cfd/collocated/advection.py index a235032..b2baac0 100644 --- a/jax_cfd/collocated/advection.py +++ b/jax_cfd/collocated/advection.py @@ -64,7 +64,7 @@ def _velocities_to_flux(v: GridVariableVector) -> Tuple[GridVariableVector]: flux[i] += (bc.impose_bc(v[i].array * v[j].array),) else: flux[i] += (flux[j][i],) - return tuple(flux) + return tuple(flux) # pyrefly: ignore[bad-return] def convect_linear(v: GridVariableVector) -> GridArrayVector: diff --git a/jax_cfd/ml/equations.py b/jax_cfd/ml/equations.py index 159197d..a9b9ea8 100644 --- a/jax_cfd/ml/equations.py +++ b/jax_cfd/ml/equations.py @@ -117,7 +117,7 @@ def modular_navier_stokes_model( Returns: A function that performs `steps` steps of the Navier-Stokes time dynamics. """ - active_forcing_fn = physics_specs.forcing_module(grid) + active_forcing_fn = physics_specs.forcing_module(grid) # pyrefly: ignore[not-callable] def navier_stokes_step_fn(state): """Advances Navier-Stokes state forward in time.""" diff --git a/jax_cfd/ml/interpolations.py b/jax_cfd/ml/interpolations.py index f984038..cf4a4de 100644 --- a/jax_cfd/ml/interpolations.py +++ b/jax_cfd/ml/interpolations.py @@ -73,7 +73,7 @@ def __init__( for tag in tags: key = (u.offset, target_offset, tag) derivatives[key] = layers.SpatialDerivativeFromLogits( - stencil_size_fn(*key), + stencil_size_fn(*key), # pyrefly: ignore[bad-argument-type] u.offset, target_offset, derivative_orders=derivative_orders, @@ -114,7 +114,7 @@ def __call__(self, if interpolator is None: raise KeyError(f'No interpolator for key {key}. ' f'Available keys: {list(self._interpolators.keys())}') - result = jnp.squeeze(interpolator(c.data), axis=-1) + result = jnp.squeeze(interpolator(c.data), axis=-1) # pyrefly: ignore[bad-argument-type] return grids.GridVariable( grids.GridArray(result, offset, c.grid), c.bc) diff --git a/jax_cfd/ml/layers.py b/jax_cfd/ml/layers.py index 9c662d7..31c60b8 100644 --- a/jax_cfd/ml/layers.py +++ b/jax_cfd/ml/layers.py @@ -63,7 +63,7 @@ def __init__( def __call__(self, inputs): return tiling.apply_convolution( - self._conv_module, inputs, self._tile_layout, self._padding) + self._conv_module, inputs, self._tile_layout, self._padding) # pyrefly: ignore[bad-argument-type] class PeriodicConv1D(PeriodicConvGeneral): @@ -292,7 +292,7 @@ def __call__(self, inputs): output_slice.append(slice(output_start, output_end)) output_slice.append(slice(None, None)) output = tiling.apply_convolution( - self._conv_module, inputs, self._tile_layout, self._padding) + self._conv_module, inputs, self._tile_layout, self._padding) # pyrefly: ignore[bad-argument-type] sliced_output = output[tuple(output_slice)] return jnp.roll(sliced_output, self._roll_shifts, list(range(ndim))) @@ -674,7 +674,7 @@ def fuse_spatial_derivative_layers( tile_layout, = {deriv.tile_layout for deriv in derivatives.values()} if constrain_with_conv: - ndim = len(tile_layout) + ndim = len(tile_layout) # pyrefly: ignore[bad-argument-type] kernel = jnp.expand_dims( joint_nullspace.astype(np.float32), axis=tuple(range(ndim))) all_coefficients = joint_bias + layers_util.periodic_convolution( diff --git a/jax_cfd/ml/layers_util.py b/jax_cfd/ml/layers_util.py index fe741f0..7a80c96 100644 --- a/jax_cfd/ml/layers_util.py +++ b/jax_cfd/ml/layers_util.py @@ -325,7 +325,7 @@ def periodic_convolution( padding='VALID', dimension_numbers=dimension_numbers, precision=precision) - return tiling.apply_convolution(conv, x, layout=tile_layout, padding=padding) + return tiling.apply_convolution(conv, x, layout=tile_layout, padding=padding) # pyrefly: ignore[bad-argument-type] # Caching the result of _patch_kernel() ensures that only one constant value is @@ -335,11 +335,11 @@ def periodic_convolution( @functools.lru_cache() def _patch_kernel( # pytype: disable=annotation-type-mismatch # numpy-scalars patch_shape: Tuple[int, ...], - dtype: np.dtype = np.float32 + dtype: np.dtype = np.float32 # pyrefly: ignore[bad-function-definition] ) -> np.ndarray: """Returns a convolutional kernel that extracts patches.""" patch_size = np.prod(patch_shape) - kernel_2d = np.eye(patch_size, dtype=dtype) + kernel_2d = np.eye(patch_size, dtype=dtype) # pyrefly: ignore[no-matching-overload] kernel_shape = (patch_size, 1) + patch_shape kernel_nd = kernel_2d.reshape(kernel_shape) return np.moveaxis(kernel_nd, (0, 1), (-1, -2)) diff --git a/jax_cfd/ml/networks.py b/jax_cfd/ml/networks.py index f8b2a0d..2efd3ac 100644 --- a/jax_cfd/ml/networks.py +++ b/jax_cfd/ml/networks.py @@ -43,12 +43,12 @@ def split_to_aligned_field( boundary_conditions = tuple( boundaries.periodic_boundary_conditions(grid.ndim) for _ in range(grid.ndim)) - network_offsets = network_offsets or data_offsets + network_offsets = network_offsets or data_offsets # pyrefly: ignore[bad-assignment] def process(inputs): split_inputs = array_utils.split_axis(inputs, -1) output = tuple( grids.GridVariable(grids.GridArray(x, offset, grid), bc) for x, offset, - bc in zip(split_inputs, network_offsets, boundary_conditions)) + bc in zip(split_inputs, network_offsets, boundary_conditions)) # pyrefly: ignore[bad-argument-type] output = tuple( interpolation.linear(x, offset) for x, offset in zip(output, data_offsets)) @@ -70,13 +70,13 @@ def interpolate_gridvar( data_offsets = physics_specs.combo_offsets() else: data_offsets = grid.cell_faces - final_offsets = final_offsets or data_offsets + final_offsets = final_offsets or data_offsets # pyrefly: ignore[bad-assignment] def process(inputs): - inputs = process_fn(inputs) + inputs = process_fn(inputs) # pyrefly: ignore[not-callable] inputs = tuple( interpolation.linear(x, offset) - for x, offset in zip(inputs, final_offsets)) + for x, offset in zip(inputs, final_offsets)) # pyrefly: ignore[bad-argument-type] return inputs return hk.to_module(process)() @@ -136,7 +136,7 @@ def stack_aligned_field_with_neighbors( del dt, physics_specs # unused. shifts = [i for i in np.arange(-n_neighbors, n_neighbors + 1) if i != 0] shifts_and_axis = list(itertools.product(shifts, np.arange(grid.ndim))) - shifts_and_axis.append([0, 0]) + shifts_and_axis.append([0, 0]) # pyrefly: ignore[bad-argument-type] def process(inputs): inputs = tuple(jnp.expand_dims(x.data, axis=-1) for x in inputs) diff --git a/jax_cfd/ml/physics_specifications.py b/jax_cfd/ml/physics_specifications.py index 97f945f..958ba96 100644 --- a/jax_cfd/ml/physics_specifications.py +++ b/jax_cfd/ml/physics_specifications.py @@ -19,7 +19,7 @@ @gin.configurable def get_physics_specs(physics_specs_cls=gin.REQUIRED): """Returns an instance of `physics_specs_cls`, configured by gin.""" - return physics_specs_cls() + return physics_specs_cls() # pyrefly: ignore[not-callable] @gin.register diff --git a/jax_cfd/ml/tiling.py b/jax_cfd/ml/tiling.py index 1c63803..7c44141 100644 --- a/jax_cfd/ml/tiling.py +++ b/jax_cfd/ml/tiling.py @@ -165,7 +165,7 @@ def apply_convolution( # TODO(shoyer): replace this with some sensible heuristic layout = (1,) * len(padding) tiled = space_to_batch(inputs, layout) - padded = halo_exchange_pad(tiled, layout, padding) + padded = halo_exchange_pad(tiled, layout, padding) # pyrefly: ignore[bad-argument-type] convolved = conv(padded) output = batch_to_space(convolved, layout) return output diff --git a/jax_cfd/ml/towers.py b/jax_cfd/ml/towers.py index d9d1631..8b5676f 100644 --- a/jax_cfd/ml/towers.py +++ b/jax_cfd/ml/towers.py @@ -68,7 +68,7 @@ def mirror_convolution( @gin.register def fixed_scale(inputs: Array, axes: Tuple[int, ...], - rescaled_one: float = gin.REQUIRED) -> Array: + rescaled_one: float = gin.REQUIRED) -> Array: # pyrefly: ignore[bad-function-definition] """Linearly scales `inputs` such that `1` maps to `rescaled_one`.""" del axes # unused. return inputs * rescaled_one @@ -78,7 +78,7 @@ def fixed_scale(inputs: Array, def fixed_scale_gridvar( inputs: Array, axes: Tuple[int, ...], - rescaled_one: float = gin.REQUIRED + rescaled_one: float = gin.REQUIRED # pyrefly: ignore[bad-function-definition] ) ->Array: """Linearly scales `inputs` such that `1` maps to `rescaled_one`.""" del axes # unused. @@ -89,8 +89,8 @@ def fixed_scale_gridvar( def scale_to_range( inputs: Array, axes: Tuple[int, ...], - min_value: float = gin.REQUIRED, - max_value: float = gin.REQUIRED, + min_value: float = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + max_value: float = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] ) -> Array: """Dynamically scales `inputs` to be in `[min_value, max_value]` range. @@ -257,7 +257,7 @@ def forward_pass(inputs): for num_channels, kernel_shape, rate, stride in conv_args: components.append(conv_module(num_channels, kernel_shape, ndim, rate=rate, stride=stride)) - components.append(nonlinearity) + components.append(nonlinearity) # pyrefly: ignore[bad-argument-type] components.append(conv_module(num_output_channels, output_kernel_shape, ndim, rate=output_rate, stride=output_stride)) components.append(functools.partial(output_scale_fn, axes=ndim_axes)) diff --git a/jax_cfd/ml/viscosities.py b/jax_cfd/ml/viscosities.py index 06b7421..c15ed97 100644 --- a/jax_cfd/ml/viscosities.py +++ b/jax_cfd/ml/viscosities.py @@ -63,7 +63,7 @@ def interpolate( bc = boundaries.periodic_boundary_conditions(grid.ndim) c_bc = grids.GridVariable(c, bc) interp_var = interpolate_module(grid, dt, physics_specs)( - c_bc, offset, v, dt) + c_bc, offset, v, dt) # pyrefly: ignore[bad-argument-type] return interp_var.array def viscosity_fn( @@ -122,7 +122,7 @@ def interpolate( bc = boundaries.periodic_boundary_conditions(grid.ndim) c_bc = grids.GridVariable(c, bc) interp_var = interpolate_module(grid, dt, physics_specs)( - c_bc, offset, v, dt) + c_bc, offset, v, dt) # pyrefly: ignore[bad-argument-type] return interp_var.array def viscosity_fn( diff --git a/jax_cfd/spectral/equations.py b/jax_cfd/spectral/equations.py index a8f9b63..85dbe5d 100644 --- a/jax_cfd/spectral/equations.py +++ b/jax_cfd/spectral/equations.py @@ -96,7 +96,7 @@ def explicit_terms(self, state): uhat, t = state dudx = self.two_pi_i_k * uhat - f = self._forcing_fn(t) + f = self._forcing_fn(t) # pyrefly: ignore[not-callable] fhat = jnp.fft.rfft(f) advection = - self.rfft(self.irfft(uhat) * self.irfft(dudx)) @@ -175,7 +175,7 @@ def explicit_terms(self, vorticity_hat): terms = advection_hat if self.forcing_fn is not None: - fx, fy = self._forcing_fn_with_grid((_get_grid_variable(vx, self.grid), + fx, fy = self._forcing_fn_with_grid((_get_grid_variable(vx, self.grid), # pyrefly: ignore[not-callable] _get_grid_variable(vy, self.grid))) fx_hat, fy_hat = jnp.fft.rfft2(fx.data), jnp.fft.rfft2(fy.data) terms += spectral_utils.spectral_curl_2d((self.kx, self.ky), diff --git a/jax_cfd/spectral/time_stepping.py b/jax_cfd/spectral/time_stepping.py index 6a5720b..8772d69 100644 --- a/jax_cfd/spectral/time_stepping.py +++ b/jax_cfd/spectral/time_stepping.py @@ -241,8 +241,8 @@ def step_fn(y0): g[0] = G(y0) for i in range(1, num_steps): - ex_terms = dt * sum(a_ex[i-1][j] * f[j] for j in range(i) if a_ex[i-1][j]) - im_terms = dt * sum(a_im[i-1][j] * g[j] for j in range(i) if a_im[i-1][j]) + ex_terms = dt * sum(a_ex[i-1][j] * f[j] for j in range(i) if a_ex[i-1][j]) # pyrefly: ignore[unsupported-operation] + im_terms = dt * sum(a_im[i-1][j] * g[j] for j in range(i) if a_im[i-1][j]) # pyrefly: ignore[unsupported-operation] Y_star = y0 + ex_terms + im_terms Y = G_inv(Y_star, dt * a_im[i-1][i]) if any(a_ex[j][i] for j in range(i, num_steps - 1)) or b_ex[i]: @@ -250,8 +250,8 @@ def step_fn(y0): if any(a_im[j][i] for j in range(i, num_steps - 1)) or b_im[i]: g[i] = G(Y) - ex_terms = dt * sum(b_ex[j] * f[j] for j in range(num_steps) if b_ex[j]) - im_terms = dt * sum(b_im[j] * g[j] for j in range(num_steps) if b_im[j]) + ex_terms = dt * sum(b_ex[j] * f[j] for j in range(num_steps) if b_ex[j]) # pyrefly: ignore[unsupported-operation] + im_terms = dt * sum(b_im[j] * g[j] for j in range(num_steps) if b_im[j]) # pyrefly: ignore[unsupported-operation] y_next = y0 + ex_terms + im_terms return y_next