Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,27 @@ jobs:
with:
file: ./coverage.xml
fail_ci_if_error: false

type-check:
runs-on: "ubuntu-22.04"
continue-on-error: true
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason this was not added to the pre-commit hook and requirements/dev.txt file? Was it for allowing to continue when there is an error? Seems like that is where most of the code checkers are and it would also mean anyone who is use the pre-commit hooks will be able to run the type checker as well.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we can definitely add them there. This was originally written to get testing bootstrapped on a machine other than my local one but we should make sure it's easily reproducible everywhere.

The continue-on-error was to not block PRs just because there was a typing issue. Whether we want to keep that, I'm indifferent on. What I wanted to avoid was people adding #ignore or making the types potentially worse to get around the typechecker.

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip'
cache-dependency-path: setup.cfg

- name: install
run: |
pip install --upgrade pip wheel
pip install -r "requirements/latest.txt"
pip install pyright

- name: pyright
run: pyright adlfs
6 changes: 6 additions & 0 deletions adlfs/_version.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: str
__version__: str
version_tuple: tuple[int | str, ...]
__version_tuple__: tuple[int | str, ...]
commit_id: str | None
__commit_id__: str | None
Comment on lines +1 to +6
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This covers the _version.py generated by setuptools_scm at installation time. That doesn't exist in the repository so CI complains it can't resolve typing from the utils method. You can find a similar setup in Black to handle this.

119 changes: 69 additions & 50 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import logging
import os
import re
import typing
import warnings
import weakref
from collections import defaultdict
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from glob import has_magic
from typing import Optional, Tuple
from typing import Any, Literal
from uuid import uuid4

from azure.core.exceptions import (
Expand Down Expand Up @@ -105,7 +105,7 @@ def get_running_loop():
return loop


def _coalesce_version_id(*args) -> Optional[str]:
def _coalesce_version_id(*args) -> str | None:
"""Helper to coalesce a list of version_ids down to one"""
version_ids = set(args)
if None in version_ids:
Expand All @@ -123,10 +123,10 @@ def _coalesce_version_id(*args) -> Optional[str]:

def _create_aio_blob_service_client(
account_url: str,
location_mode: Optional[str] = None,
credential: Optional[str] = None,
location_mode: str | None = None,
credential: str | None = None,
) -> AIOBlobServiceClient:
service_client_kwargs = {
service_client_kwargs: dict[str, Any] = {
"account_url": account_url,
"user_agent": _USER_AGENT,
}
Comment on lines +129 to 132
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be explicitly typed because values are mutated later to more than the inferred type of dict[str, str]. This is the most sensible future proof option since it's kwargs. If we really want to be strict, we can enumerate all of the current types but that may need to be updated in the future.

Expand Down Expand Up @@ -264,30 +264,30 @@ class AzureBlobFileSystem(AsyncFileSystem):

def __init__(
self,
account_name: str = None,
account_key: str = None,
connection_string: str = None,
credential: str = None,
sas_token: str = None,
account_name: str | None = None,
account_key: str | None = None,
connection_string: str | None = None,
credential: str | None = None,
sas_token: str | None = None,
request_session=None,
socket_timeout=_SOCKET_TIMEOUT_DEFAULT,
blocksize: int = _DEFAULT_BLOCK_SIZE,
client_id: str = None,
client_secret: str = None,
tenant_id: str = None,
anon: bool = None,
client_id: str | None = None,
client_secret: str | None = None,
tenant_id: str | None = None,
anon: bool | None = None,
location_mode: str = "primary",
loop=None,
asynchronous: bool = False,
default_fill_cache: bool = True,
default_cache_type: str = "bytes",
version_aware: bool = False,
assume_container_exists: Optional[bool] = None,
max_concurrency: Optional[int] = None,
timeout: Optional[int] = None,
connection_timeout: Optional[int] = None,
read_timeout: Optional[int] = None,
account_host: str = None,
assume_container_exists: bool | None = None,
max_concurrency: int | None = None,
timeout: int | None = None,
connection_timeout: int | None = None,
read_timeout: int | None = None,
account_host: str | None = None,
**kwargs,
):
self.kwargs = kwargs.copy()
Expand Down Expand Up @@ -386,13 +386,15 @@ def __init__(
weakref.finalize(self, sync, self.loop, close_credential, self)

if max_concurrency is None:
batch_size = _get_batch_size()
batch_size: int = _get_batch_size() # type: ignore[assignment]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know this is an number in all cases (we do numerical comparison on the next line) but _get_batch_size() can technically return None at the fsspec layer. This declares the type we know it to be and the code is built around.

We could add special handling for None if we want, I just wasn't able to find a way to actually trigger it for adlfs.

if batch_size > 0:
max_concurrency = batch_size
else:
max_concurrency = 1
Comment on lines +392 to +393
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is resolving a configuration edge case. self.max_concurrency is the fallback if max_concurrency is None. That will be an int in all cases except if the operating system returns a batch_size < 1. At that point we never initialize the value correctly, and the SDK errors out because it didn't receive an int.

This change moves us to single request throughput (also the SDK default) if we can't determine a batch_size from the OS. That is what would have been intended by None (use the default) before.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little hesitant to change this fallback value as part of supporting type checking, especially if type checking can pass without this change. Mainly I would want to understand the intention of a return value of zero or negative from _get_batch_size() and compare with other fsspec implementations before committing to how it is interpreted.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR open with the Azure SDK to hopefully help fix this issue. The APIs are inconsistently typed, some allow None, others don't. A different subset actually works with None (this is one of them).

Once we get that aligned, we should be fine to remove this piece.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We got 3 PRs merged upstream. Once the next Python SDK release goes out, we should be able to remove this.

self.max_concurrency = max_concurrency

@classmethod
def _strip_protocol(cls, path: str):
def _strip_protocol(cls, path: str) -> str:
"""
Remove the protocol from the input path

Expand All @@ -407,7 +409,7 @@ def _strip_protocol(cls, path: str):
Returns a path without the protocol
"""
if isinstance(path, list):
return [cls._strip_protocol(p) for p in path]
return [cls._strip_protocol(p) for p in path] # type: ignore[return-value]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already explicitly require path to be a str. The upstream fsspec version of this also only uses a str but has this code path in the event we somehow get a list. This block keeps parity with upstream but isn't possible through the current public API for adlfs.

We can introduce a change that records a return type of str | list[str] but it significantly complicates the typing story without providing more accuracy.


STORE_SUFFIX = ".dfs.core.windows.net"
logger.debug(f"_strip_protocol for {path}")
Expand Down Expand Up @@ -473,6 +475,16 @@ def _get_credential_from_service_principal(self):
-------
Tuple of (Async Credential, Sync Credential).
"""
if (
self.tenant_id is None
or self.client_id is None
or self.client_secret is None
):
raise ValueError(
"tenant_id, client_id, and client_secret must all be provided "
"when authenticating with a service principal."
)
Comment on lines +478 to +486
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an existing failure mode. You can pass just self.client_id to trigger this credential path, but it will always fail if all 3 values aren't set. This enforces that requirement on our side rather than getting various errors from the SDK that are more difficult to debug.


from azure.identity import ClientSecretCredential
from azure.identity.aio import (
ClientSecretCredential as AIOClientSecretCredential,
Expand Down Expand Up @@ -573,7 +585,7 @@ def do_connect(self):

def split_path(
self, path, delimiter="/", return_container: bool = False, **kwargs
) -> Tuple[str, str, Optional[str]]:
) -> tuple[str, str, str | None]:
"""
Normalize ABFS path string into bucket and key.

Expand Down Expand Up @@ -708,7 +720,7 @@ async def _ls_blobs(
path: str,
delimiter: str = "/",
return_glob: bool = False,
version_id: Optional[str] = None,
version_id: str | None = None,
versions: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -799,7 +811,7 @@ async def _ls(
invalidate_cache: bool = False,
delimiter: str = "/",
return_glob: bool = False,
version_id: Optional[str] = None,
version_id: str | None = None,
versions: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -867,7 +879,7 @@ async def _details(
delimiter="/",
return_glob: bool = False,
target_path="",
version_id: Optional[str] = None,
version_id: str | None = None,
versions: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -1195,9 +1207,9 @@ def makedir(self, path, exist_ok=False):

async def _rm(
self,
path: typing.Union[str, typing.List[str]],
path: str | list[str],
recursive: bool = False,
maxdepth: typing.Optional[int] = None,
maxdepth: int | None = None,
delimiter: str = "/",
expand_path: bool = True,
**kwargs,
Expand Down Expand Up @@ -1256,9 +1268,7 @@ async def _rm(

rm = sync_wrapper(_rm)

async def _rm_files(
self, container_name: str, file_paths: typing.Iterable[str], **kwargs
):
async def _rm_files(self, container_name: str, file_paths: Iterable[str], **kwargs):
"""
Delete the given file(s)

Expand Down Expand Up @@ -1322,8 +1332,8 @@ async def _rm_file(self, path: str, **kwargs):
self.invalidate_cache(self._parent(path))

async def _separate_directory_markers_for_non_empty_directories(
self, file_paths: typing.Iterable[str]
) -> typing.Tuple[typing.List[str], typing.List[str]]:
self, file_paths: Iterable[str]
) -> tuple[list[str], list[str]]:
"""
Distinguish directory markers of non-empty directories from files and directory markers for empty directories.
A directory marker is an empty blob who's name is the path of the directory.
Expand Down Expand Up @@ -1635,6 +1645,12 @@ async def _url(
account_name = self.account_name
account_key = self.account_key

if account_name is None:
raise ValueError(
"account_name is required to generate a SAS URL. "
"Provide account_name or include AccountName in the connection string."
)
Comment on lines +1648 to +1652
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same deal about downstream requirements. If account_name is None, the string concatenation blows up because you can't add None to a string. This will raise a clearer error.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More for my curiosity, were you able to raise this new exception purely through setting values in the initializer? I was only able to reach it by mutating the account_name property after the file system instantiation.


sas_token = generate_blob_sas(
account_name=account_name,
container_name=container_name,
Expand All @@ -1653,8 +1669,10 @@ async def _url(
url = f"{bc.url}?{sas_token}"
return url

def expand_path(self, path, recursive=False, maxdepth=None, skip_noexist=True):
return sync(
def expand_path(
self, path, recursive=False, maxdepth=None, skip_noexist=True
) -> list[str]:
return sync( # type: ignore[return-value]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sync is defined in fsspec without a return type. It's effectively acting as a proxy from async -> sync so we could likely do something with generics for this. The only issue is the best ways to do that require 3.12+.

For the time being, we know this is the same return type, so it's ignored.

self.loop, self._expand_path, path, recursive, maxdepth, skip_noexist
)

Expand Down Expand Up @@ -1887,12 +1905,12 @@ def _open(
self,
path: str,
mode: str = "rb",
block_size: int = None,
block_size: int | None = None,
autocommit: bool = True,
cache_options: dict = {},
cache_type="readahead",
metadata=None,
version_id: Optional[str] = None,
version_id: str | None = None,
**kwargs,
):
"""Open a file on the datalake, or a block blob
Expand Down Expand Up @@ -1954,12 +1972,12 @@ def __init__(
fs: AzureBlobFileSystem,
path: str,
mode: str = "rb",
block_size="default",
block_size: int | Literal["default"] | None = "default",
autocommit: bool = True,
cache_type: str = "bytes",
cache_options: dict = {},
metadata=None,
version_id: Optional[str] = None,
version_id: str | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -2017,9 +2035,10 @@ def __init__(

self.loop = self._get_loop()
self.container_client = self._get_container_client()
self.blocksize = (
self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size
)
if block_size == "default" or block_size is None:
self.blocksize: int = self.DEFAULT_BLOCK_SIZE
else:
self.blocksize = block_size
Comment on lines +2038 to +2041
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a limitation of pyright. The code doesn't have any notable performance changes, it just allows the type checker to reason. pyright isn't currently able to extrapolate the in check to equivalence.

self.loc = 0
self.autocommit = autocommit
self.end = None
Expand Down Expand Up @@ -2127,9 +2146,9 @@ def connect_client(self):
"""
try:
if hasattr(self.fs, "account_host"):
self.fs.account_url: str = f"https://{self.fs.account_host}"
self.fs.account_url = f"https://{self.fs.account_host}"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't valid type definitions because they're modifying self.fs not a member of the class directly. The correct typing already lives on the fs implementation.

else:
self.fs.account_url: str = (
self.fs.account_url = (
f"https://{self.fs.account_name}.blob.core.windows.net"
)

Expand Down Expand Up @@ -2164,7 +2183,7 @@ def connect_client(self):
f"Unable to fetch container_client with provided params for {e}!!"
) from e

async def _async_fetch_range(self, start: int, end: int = None, **kwargs):
async def _async_fetch_range(self, start: int, end: int | None = None, **kwargs):
"""
Download a chunk of data specified by start and end

Expand Down Expand Up @@ -2221,7 +2240,7 @@ async def _stage_block(self, data, start, end, block_id, semaphore):
async with self.container_client.get_blob_client(blob=self.blob) as bc:
await bc.stage_block(
block_id=block_id,
data=data[start:end],
data=data[start:end], # type: ignore[arg-type]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked with the Azure Python SDK folks and am going to open a PR about this. The current definition doesn't support memoryview or bytearray which are both valid inputs. Once that's fixed there, this ignore is no longer needed.

length=end - start,
)
return block_id
Expand Down Expand Up @@ -2301,7 +2320,7 @@ async def _async_upload_chunk(self, final: bool = False, **kwargs):
await bc.upload_blob(
data=data,
length=length,
blob_type=BlobType.AppendBlob,
blob_type=BlobType.APPENDBLOB,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AppendBlob is an alias pointing to this Enum but isn't explicitly declared as a member. That's leading the type checker to infer it's Unknown. We'll use the Enum directly instead of the alias. It's semantically identical.

metadata=self.metadata,
)
else:
Expand Down Expand Up @@ -2329,6 +2348,6 @@ def __getstate__(self):
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.__dict__.update(state) # type: ignore[reportAttributeAccessIssue]
self.loop = self._get_loop()
self.container_client = self._get_container_client()
8 changes: 3 additions & 5 deletions adlfs/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

try:
from ._version import version as __version__ # type: ignore[import]
from ._version import version_tuple # type: ignore[import]
Expand All @@ -8,7 +6,7 @@
version_tuple = (0, 0, __version__) # type: ignore[assignment]


def match_blob_version(blob, version_id: Optional[str]):
def match_blob_version(blob, version_id: str | None):
blob_version_id = blob.get("version_id")
return (
version_id is None
Expand All @@ -20,7 +18,7 @@ async def filter_blobs(
blobs,
target_path,
delimiter="/",
version_id: Optional[str] = None,
version_id: str | None = None,
versions: bool = False,
):
"""
Expand Down Expand Up @@ -50,7 +48,7 @@ async def filter_blobs(
return finalblobs


async def get_blob_metadata(container_client, path, version_id: Optional[str] = None):
async def get_blob_metadata(container_client, path, version_id: str | None = None):
async with container_client.get_blob_client(path) as bc:
properties = await bc.get_blob_properties(version_id=version_id)
if "metadata" in properties.keys():
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,11 @@ include = ["adlfs*"]
exclude = ["tests"]
namespaces = false

[tool.pyright]
include = ["adlfs"]
exclude = ["adlfs/tests", "adlfs/gen1.py", "build"]
pythonVersion = "3.10"
typeCheckingMode = "basic"

[tool.isort]
profile = "black"