diff --git a/boot.py b/boot.py index 4978a0342..50fb3cf21 100644 --- a/boot.py +++ b/boot.py @@ -17,9 +17,11 @@ from .plugin.configuration import LspDisableLanguageServerInProjectCommand from .plugin.configuration import LspEnableLanguageServerGloballyCommand from .plugin.configuration import LspEnableLanguageServerInProjectCommand +from .plugin.core.aio import run_coroutine_threadsafe from .plugin.core.constants import ST_VERSION from .plugin.core.css import load as load_css from .plugin.core.open import g_opening_files +from .plugin.core.open import get_opening_files_lock from .plugin.core.panels import PanelName from .plugin.core.registry import LspCheckApplicableCommand from .plugin.core.registry import LspNextDiagnosticCommand @@ -88,10 +90,18 @@ from .plugin.tooling import LspParseVscodePackageJson from .plugin.tooling import LspTroubleshootServerCommand from typing import Any +from typing import TYPE_CHECKING import os import sublime +import sublime_aio import sublime_plugin +if TYPE_CHECKING: + import asyncio + +# Uncomment to see all invocations that are marked @deprecated in the Console. +# warnings.simplefilter('always', DeprecationWarning) + __all__ = ( "DocumentSyncListener", "Listener", @@ -219,14 +229,14 @@ def show_warning() -> None: def plugin_unloaded() -> None: _unregister_all_plugins() - windows.disable() + run_coroutine_threadsafe(windows.disable()) unload_settings() -class Listener(sublime_plugin.EventListener): +class Listener(sublime_aio.EventListener): - def on_exit(self) -> None: - kill_all_subprocesses() + async def on_exit(self) -> None: + await kill_all_subprocesses() def on_load_project_async(self, window: sublime.Window) -> None: if manager := windows.lookup(window): @@ -255,27 +265,27 @@ def on_pre_move(self, view: sublime.View) -> None: sublime.set_timeout_async(listener.on_post_move_window_async, 1) return - def on_load(self, view: sublime.View) -> None: + async def on_load(self, view: sublime.View) -> None: file_name = view.file_name() if not file_name: return - for fn in g_opening_files: - if fn == file_name or os.path.samefile(fn, file_name): - # Remove it from the pending opening files, and resolve the promise. - g_opening_files.pop(fn)[1](view) - break + if future := await self._find_opening_file_future(file_name): + future.set_result(view) - def on_pre_close(self, view: sublime.View) -> None: + async def on_pre_close(self, view: sublime.View) -> None: file_name = view.file_name() if not file_name: return - for fn in g_opening_files: - if fn == file_name or os.path.samefile(fn, file_name): - tup = g_opening_files.pop(fn, None) # noqa: B909 - if tup: - # The view got closed before it finished loading. This can happen. - tup[1](None) - break + if future := await self._find_opening_file_future(file_name): + # The view got closed before it finished loading. This can happen. + future.set_result(None) + + async def _find_opening_file_future(self, file_name: str) -> asyncio.Future[sublime.View | None] | None: + async with get_opening_files_lock(): + for fn in g_opening_files: + if fn == file_name or os.path.samefile(fn, file_name): # noqa: ASYNC240 + return g_opening_files.pop(fn, None) + return None def on_post_window_command(self, window: sublime.Window, command_name: str, args: dict[str, Any] | None) -> None: if command_name == "show_panel": diff --git a/dependencies.json b/dependencies.json index 0b7b7aae1..c59fc8206 100644 --- a/dependencies.json +++ b/dependencies.json @@ -4,6 +4,7 @@ "bracex", "mdpopups", "orjson", + "sublime_aio", "typing_extensions", "wcmatch" ] diff --git a/plugin/api.py b/plugin/api.py index 0c941007c..a9c984ccc 100644 --- a/plugin/api.py +++ b/plugin/api.py @@ -1,8 +1,13 @@ from __future__ import annotations +from ..protocol import ConfigurationItem +from ..protocol import DocumentUri +from ..protocol import ExecuteCommandParams from ..protocol import LSPAny from .core.constants import ST_STORAGE_PATH from .core.logging import exception_log +from .core.protocol import Notification +from .core.protocol import Request from .core.protocol import Response from .core.settings import client_configs from .core.types import method2attr @@ -14,6 +19,7 @@ from functools import wraps from pathlib import Path from typing import Any +from typing import Awaitable from typing import Callable from typing import Final from typing import final @@ -221,31 +227,33 @@ def decorator(func: Callable[[Any, P], None]) -> Callable[[Any, P], None]: def request_handler( method: str -) -> Callable[[Callable[[Any, P], Promise[R]]], Callable[[Any, P, int], Promise[Response[R]]]]: +) -> Callable[[Callable[[Any, P], Awaitable[R]]], Callable[[Any, P, int], Awaitable[Response[R]]]]: """ - Decorator to mark a method as a handler for a specific LSP request. + Decorator to mark a coroutine method as a handler for a specific LSP request. Usage: ```py @request_handler('eslint/openDoc') - def on_open_doc(self, params: TextDocumentIdentifier) -> Promise[bool]: + async def on_open_doc(self, params: TextDocumentIdentifier) -> bool: ... ``` - The decorated method will be called with the request parameters whenever the specified - request is received from the language server. The method must return a Promise that resolves - to the response value. The framework will automatically send it back to the server. + The decorated coroutine method will be called with the request parameters whenever the specified + request is received from the language server. The coroutine method must return a response value. + The framework will automatically send it back to the server. + + An older, but backwards-compatible way to define a request handler is by defining a function that returns a Promise. + While that works, the advice is to define a coroutine function. :param method: The LSP request method name (e.g., 'eslint/openDoc'). - :returns: A decorator that registers the function as a request handler. + :returns: A decorator that registers the coroutine function as a request handler. """ - def decorator(func: Callable[[Any, P], Promise[R]]) -> Callable[[Any, P, int], Promise[Response[R]]]: + def decorator(func: Callable[[Any, P], Awaitable[R]]) -> Callable[[Any, P, int], Awaitable[Response[R]]]: @wraps(func) - def wrapper(self: Any, params: P, request_id: int) -> Promise[Response[Any]]: - promise = func(self, params) - return promise.then(lambda result: Response(request_id, result)) + async def wrapper(self: Any, params: P, request_id: int) -> Response[Any]: + return Response(request_id, await func(self, params)) setattr(wrapper, HANDLER_MARKER, method) return wrapper @@ -390,6 +398,9 @@ def plugin_unloaded() -> None: Use this as your directory to install server files. Its path is `$DATA/Package Storage/`. """ + use_asyncio: bool = False + """Set to `true` to make LSP use `async def` variants.""" + @classmethod @final def register(cls) -> None: @@ -464,6 +475,19 @@ def on_pre_start_async(cls, context: OnPreStartContext) -> None: """ pass + @classmethod + async def on_pre_start(cls, context: OnPreStartContext) -> None: + """ + Async version of on_pre_start_async. + + Attempt to use non-blocking functionality for downloading binaries and running subprocesses in order to not + block the asyncio thread. + + :param context: The startup context. `context.configuration`, `context.variables` and + `context.working_directory` can be mutated to influence how the server is launched. + """ + pass + def __init__(self, weaksession: ref[Session]) -> None: """ Constructs a new instance. @@ -491,6 +515,10 @@ def on_initialized_async(self) -> None: """ pass + async def on_initialized(self) -> None: + """Async version of `on_initialize_async`.""" + pass + def on_pre_send_request_async(self, request: ClientRequest, view: sublime.View | None) -> None: """ Notifies about a request that is about to be sent to the language server. diff --git a/plugin/code_actions.py b/plugin/code_actions.py index f69bbfb21..33226ac65 100644 --- a/plugin/code_actions.py +++ b/plugin/code_actions.py @@ -5,6 +5,10 @@ from ..protocol import CodeActionParams from ..protocol import Command from ..protocol import Diagnostic +from ..protocol import LSPAny +from .core.aio import call_soon_threadsafe +from .core.aio import run_coroutine_threadsafe +from .core.logging import trace from .core.promise import Promise from .core.protocol import Error from .core.protocol import Request @@ -23,12 +27,14 @@ from functools import partial from typing import Any from typing import cast +from typing import Coroutine from typing import final from typing import List from typing import Tuple from typing import TYPE_CHECKING from typing import Union from typing_extensions import override +import asyncio import sublime if TYPE_CHECKING: @@ -36,7 +42,6 @@ from .core.sessions import SessionBufferProtocol from collections.abc import Callable from collections.abc import Generator - from collections.abc import Iterator from typing_extensions import TypeGuard @@ -65,9 +70,9 @@ def is_quickfix(action: Command | CodeAction) -> bool: def filter_quickfix_actions( - only_with_diagnostics: bool, response: list[Command | CodeAction] | Error | None + only_with_diagnostics: bool, response: list[Command | CodeAction] | BaseException | None ) -> list[Command | CodeAction]: - if isinstance(response, Error) or not response: + if isinstance(response, BaseException) or not response: return [] if only_with_diagnostics: # If there are multiple diagnostics for the region, in the hover popup we can only use those code actions which @@ -202,6 +207,7 @@ def on_response( sb: SessionBufferProtocol, response: Error | list[CodeActionOrCommand] | None ) -> CodeActionsByConfigName: actions = [] + trace(sb=sb, response=response) if response and not isinstance(response, Error): # Filter actions returned from the session so that only matching kinds are collected. # Since older servers don't support the "context.only" property, those will return all @@ -209,10 +215,12 @@ def on_response( session_kinds = get_session_kinds(sb) matching_kinds = get_matching_kinds(code_actions, session_kinds) actions = [a for a in response if a.get('kind') in matching_kinds and not a.get('disabled')] + trace(session_kinds=session_kinds, matching_kinds=matching_kinds, actions=actions) return (sb.session.config.name, actions) for sb in listener.session_buffers_async('codeActionProvider'): matching_kinds = get_matching_kinds(code_actions, get_session_kinds(sb)) + trace(code_actions=code_actions, matching_kinds=matching_kinds) for kind in matching_kinds: listener.purge_changes_async() # Pull for diagnostics to ensure that server computes them before receiving code action request. @@ -281,35 +289,21 @@ def get_code_action_kinds(cls, view: sublime.View) -> dict[str, bool]: } @override - def run_async(self) -> None: - super().run_async() - view = self._task_runner.view + async def run(self) -> None: + await super().run() + trace() + view = self._text_command.view code_action_kinds = self.get_code_action_kinds(view) - request_iterator = actions_manager.request_on_save_or_format_async(view, code_action_kinds) - self._process_next_request(request_iterator) - - def _process_next_request(self, request_iterator: Iterator[Promise[CodeActionsByConfigName]]) -> None: - if self._cancelled: - return - if request := next(request_iterator, None): - request.then(lambda response: self._handle_response_async(response, request_iterator)) - else: - self._on_complete() - - def _handle_response_async( - self, response: CodeActionsByConfigName, request_iterator: Iterator[Promise[CodeActionsByConfigName]] - ) -> None: - if self._cancelled: - return - view = self._task_runner.view - tasks: list[Promise[None]] = [] - config_name, code_actions = response - session = self._task_runner.session_by_name(config_name, 'codeActionProvider') - if session and code_actions: - tasks.extend([ - session.run_code_action_async(action, progress=False, view=view) for action in code_actions - ]) - Promise.all(tasks).then(lambda _: self._process_next_request(request_iterator)) + tasks: list[Coroutine[None, None, LSPAny]] = [] + for request in actions_manager.request_on_save_or_format_async(view, code_action_kinds): + config_name, code_actions = await request + trace(code_actions=code_actions) + if code_actions and (session := self._text_command.session_by_name(config_name, 'codeActionProvider')): + tasks.extend( + session.run_code_action(action, progress=False, view=self._text_command.view) + for action in code_actions + ) + await asyncio.gather(*tasks) @final @@ -386,9 +380,9 @@ def run( if code_actions_by_config: self._handle_code_actions(code_actions_by_config, run_first=True) return - self._run_async(only_kinds) + run_coroutine_threadsafe(self._run(only_kinds)) - def _run_async(self, only_kinds: list[str | CodeActionKind] | None = None) -> None: + async def _run(self, only_kinds: list[str | CodeActionKind] | None = None) -> None: view = self.view region = first_selection_region(view) if region is None: @@ -427,13 +421,13 @@ def _handle_select(self, index: int, actions: list[tuple[ConfigName, CodeActionO if index == -1: return - def run_async() -> None: + async def run() -> None: config_name, action = actions[index] if session := self.session_by_name(config_name): - session.run_code_action_async(action, progress=True, view=self.view) \ - .then(lambda response: self._handle_response_async(config_name, response)) + response = await session.run_code_action(action, progress=True, view=self.view) + self._handle_response_async(config_name, response) - sublime.set_timeout_async(run_async) + run_coroutine_threadsafe(run()) def _handle_response_async(self, session_name: str, response: Any) -> None: if isinstance(response, Error): @@ -463,7 +457,7 @@ def is_enabled(self, index: int, event: dict | None = None) -> bool: def is_visible(self, index: int, event: dict | None = None) -> bool: if index == -1: if self._has_session(event): - sublime.set_timeout_async(partial(self._request_menu_actions_async, event)) + call_soon_threadsafe(partial(self._request_menu_actions_async, event)) return False return index < len(self.actions_cache) and self._is_cache_valid(event) @@ -488,14 +482,14 @@ def want_event(self) -> bool: return True def run(self, index: int, event: dict | None = None) -> None: - sublime.set_timeout_async(partial(self.run_async, index, event)) + run_coroutine_threadsafe(self._run(index, event)) - def run_async(self, index: int, event: dict | None) -> None: + async def _run(self, index: int, event: dict | None) -> None: if self._is_cache_valid(event): config_name, action = self.actions_cache[index] if session := self.session_by_name(config_name): - session.run_code_action_async(action, progress=True, view=self.view) \ - .then(lambda response: self._handle_response_async(config_name, response)) + response = await session.run_code_action(action, progress=True, view=self.view) + self._handle_response_async(config_name, response) def _handle_response_async(self, session_name: str, response: Any) -> None: if isinstance(response, Error): diff --git a/plugin/code_lens.py b/plugin/code_lens.py index 2d609609a..56fa59083 100644 --- a/plugin/code_lens.py +++ b/plugin/code_lens.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .core.aio import call_soon_threadsafe from .core.constants import CODE_LENS_ENABLED_KEY from .core.protocol import Error from .core.protocol import ResolvedCodeLens @@ -7,7 +8,6 @@ from .core.registry import LspWindowCommand from .core.registry import windows from .core.views import range_to_region -from functools import partial from typing import cast from typing import TYPE_CHECKING from typing_extensions import TypeGuard @@ -128,7 +128,7 @@ def is_checked(self) -> bool: def run(self) -> None: enable = not self.is_checked() self.window.settings().set(CODE_LENS_ENABLED_KEY, enable) - sublime.set_timeout_async(partial(self._update_views_async, enable)) + call_soon_threadsafe(self._update_views_async, enable) def _update_views_async(self, enable: bool) -> None: window_manager = windows.lookup(self.window) diff --git a/plugin/color.py b/plugin/color.py index 4f861bf08..fa9b04ecc 100644 --- a/plugin/color.py +++ b/plugin/color.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.edit import apply_text_edits from .core.protocol import Request from .core.registry import LspTextCommand @@ -27,7 +28,7 @@ def run(self, edit: sublime.Edit, color_information: ColorInformation) -> None: 'color': color_information['color'], 'range': self._range } - session.send_request_async(Request.colorPresentation(params, self.view), self._handle_response_async) + session.send_request(Request.colorPresentation(params, self.view), self._handle_response_async) def want_event(self) -> bool: return False @@ -60,4 +61,8 @@ def _on_select(self, index: int) -> None: if index > -1: color_pres = self._filtered_response[index] text_edit = color_pres.get('textEdit') or {'range': self._range, 'newText': color_pres['label']} - apply_text_edits(self.view, [text_edit], label="Change Color Format", required_view_version=self._version) + run_coroutine_threadsafe( + apply_text_edits( + self.view, [text_edit], label="Change Color Format", required_view_version=self._version + ) + ) diff --git a/plugin/completion.py b/plugin/completion.py index 24f2122c8..a0bbbd623 100644 --- a/plugin/completion.py +++ b/plugin/completion.py @@ -14,6 +14,7 @@ from ..protocol import MarkupKind from ..protocol import Range from ..protocol import TextEdit +from .core.aio import run_coroutine_threadsafe from .core.constants import COMPLETION_KINDS from .core.constants import MarkdownLangMap from .core.edit import apply_text_edits @@ -40,7 +41,6 @@ from typing import Union from typing_extensions import TypeAlias from typing_extensions import TypeGuard -import functools import html import sublime import weakref @@ -292,19 +292,18 @@ def _get_userpref_flags(self) -> sublime.AutoCompleteFlags: class LspResolveDocsCommand(LspTextCommand): def run(self, edit: sublime.Edit, index: int, session_name: str, event: dict | None = None) -> None: + run_coroutine_threadsafe(self._run(index, session_name, event)) - def run_async() -> None: - items, item_defaults = LspSelectCompletionCommand.completions[session_name] - item = completion_with_defaults(items[index], item_defaults) - if session := self.session_by_name(session_name, 'completionProvider.resolveProvider'): - request = Request.resolveCompletionItem(item, self.view) - language_map = session.markdown_language_id_to_st_syntax_map() - handler = functools.partial(self._handle_resolve_response_async, language_map) - session.send_request_async(request, handler) - else: - self._handle_resolve_response_async(None, item) - - sublime.set_timeout_async(run_async) + async def _run(self, index: int, session_name: str, event: dict | None = None) -> None: + items, item_defaults = LspSelectCompletionCommand.completions[session_name] + item = completion_with_defaults(items[index], item_defaults) + if session := self.session_by_name(session_name, 'completionProvider.resolveProvider'): + language_map = session.markdown_language_id_to_st_syntax_map() + item = await session.request(Request.resolveCompletionItem(item, self.view)) + # TODO: why do we only pass the language_map when the langserver is a resolveProvider? + self._handle_resolve_response_async(language_map, item) + else: + self._handle_resolve_response_async(None, item) def _handle_resolve_response_async(self, language_map: MarkdownLangMap | None, item: CompletionItem) -> None: detail = "" @@ -385,33 +384,27 @@ def run(self, edit: sublime.Edit, index: int, session_name: str) -> None: self.view.run_command("insert_snippet", {"contents": new_text}) else: self.view.run_command("insert", {"characters": new_text}) - # TODO: this should all run from the worker thread - session = self.session_by_name(session_name, 'completionProvider.resolveProvider') - additional_text_edits = item.get('additionalTextEdits') - if session and not additional_text_edits: - session.send_request_async( - Request.resolveCompletionItem(item, self.view), - functools.partial(self._on_resolved_async, session_name)) - else: - self._on_resolved(session_name, item) + run_coroutine_threadsafe(self._run(session_name, item)) - def want_event(self) -> bool: - return False - - def _on_resolved_async(self, session_name: str, item: CompletionItem) -> None: - sublime.set_timeout(functools.partial(self._on_resolved, session_name, item)) - - def _on_resolved(self, session_name: str, item: CompletionItem) -> None: - if additional_edits := item.get('additionalTextEdits', []): - apply_text_edits(self.view, additional_edits) + async def _run(self, session_name: str, item: CompletionItem) -> None: + session = self.session_by_name(session_name, 'completionProvider.resolveProvider') + if session and not item.get('additionalTextEdits'): + try: + item = await session.request(Request.resolveCompletionItem(item, self.view)) + except Error as error: + debug("Error resolving completion item:", error) + if additional_edits := item.get('additionalTextEdits'): + await apply_text_edits(self.view, additional_edits) if command := item.get("command"): debug(f'Running server command "{command}" for view {self.view.id()}') - args = { + self.view.run_command("lsp_execute", { "command_name": command["command"], "command_args": command.get("arguments"), "session_name": session_name - } - self.view.run_command("lsp_execute", args) + }) + + def want_event(self) -> bool: + return False def _translated_regions(self, edit_region: sublime.Region) -> Generator[sublime.Region, None, None]: selection = self.view.sel() diff --git a/plugin/configuration.py b/plugin/configuration.py index 13017c5b9..d02631fc4 100644 --- a/plugin/configuration.py +++ b/plugin/configuration.py @@ -1,10 +1,10 @@ from __future__ import annotations +from .core.aio import call_soon_threadsafe from .core.registry import windows from .core.settings import client_configs from functools import partial from typing import TYPE_CHECKING -import sublime import sublime_plugin if TYPE_CHECKING: @@ -42,7 +42,7 @@ def _on_done(self, wm: WindowManager, index: int) -> None: if index == -1: return config_name = self._items[index] - sublime.set_timeout_async(lambda: wm.enable_config_async(config_name)) + call_soon_threadsafe(wm.enable_config_async, config_name) class LspDisableLanguageServerGloballyCommand(sublime_plugin.WindowCommand): @@ -80,4 +80,4 @@ def _on_done(self, wm: WindowManager, index: int) -> None: if index == -1: return config_name = self._items[index] - sublime.set_timeout_async(lambda: wm.disable_config_async(config_name)) + call_soon_threadsafe(wm.disable_config_async, config_name) diff --git a/plugin/core/aio.py b/plugin/core/aio.py new file mode 100644 index 000000000..fc0adae6f --- /dev/null +++ b/plugin/core/aio.py @@ -0,0 +1,235 @@ +"""Functionality wrapping asyncio, sublime_aio, and interaction nuances with Sublime Text.""" + +from __future__ import annotations + +from .logging import debug +from .logging import exception_log +from typing import Any +from typing import AsyncIterator +from typing import Callable +from typing import Coroutine +from typing import Protocol +from typing import TYPE_CHECKING +from typing import TypeVar +import asyncio +import concurrent.futures +import contextlib +import sublime +import sublime_aio +import sys +import threading + +if TYPE_CHECKING: + from contextvars import Context + + +class SupportsAclose(Protocol): + async def aclose(self) -> None: ... + + +T = TypeVar("T") +S = TypeVar("S", bound="SupportsAclose") + + +# `async with aclosing(stream(...))`. This function in the contextlib module is available since python 3.10, but we also +# need to support python 3.8. +# See: https://docs.python.org/3/library/contextlib.html#contextlib.aclosing +if sys.version_info >= (3, 10, 0): + aclosing = contextlib.aclosing +else: + + @contextlib.asynccontextmanager + async def aclosing(thing: S) -> AsyncIterator[S]: + try: + yield thing + finally: + await thing.aclose() + + +_futures: set[concurrent.futures.Future] = set() + + +def run_coroutine_threadsafe(coroutine: Coroutine[object, object, T]) -> concurrent.futures.Future[T]: + """ + Start the execution of a coroutine in the asyncio thread, from any thread. + + When you are certain you are already in the asyncio thread, then there are better ways to start a coroutine from a + "blocking" ("non-async") function. One way is to use + [asyncio.create_task](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_task). However, + asyncio.create_task has the caveat that the returned [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task) + object must be kept alive somewhere. If you don't care about keeping tasks associated to coroutines alive, then + inherit from the `TaskContainer` mixin class and use its `create_task` method. + + A big caveat: coroutines started this way do not print their exceptions when an exception occurs in the coroutine. + To handle this, call `.add_done_callback` on the returned `Future` object. + """ + future: concurrent.futures.Future[T] = sublime_aio.run_coroutine(coroutine) # type: ignore + + def on_done(fut: concurrent.futures.Future[T]) -> None: + _futures.discard(fut) + if not fut.cancelled() and (ex := fut.exception()): + exception_log("coroutine finished with exception", ex) + + future.add_done_callback(on_done) + _futures.add(future) + return future + + +def call_soon_threadsafe(f: Callable[..., Any], *args: Any, context: Context | None = None) -> asyncio.Handle: + """Invoke a function in the asyncio thread, from any thread.""" + return sublime_aio.call_soon_threadsafe(f, *args, context=context) + + +class _Executor(concurrent.futures.Executor): + """ + An Executor that wraps sublime.set_timeout(_async). + + Use in combination with an asyncio loop: + + ```python + from LSP.core.aio import executor_main, executor_async + + + def some_blocking_function_that_interacts_with_gui() -> int: + window = sublime.current_window() + return 42 + + + async def foo() -> int: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(executor_main, some_blocking_function_that_interacts_with_gui) + return result + + + def some_cpu_heavy_function() -> int: + time.sleep(1) + return 42 + + + async def bar() -> int: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(executor_async, some_cpu_heavy_function) + return result + ``` + """ + + def __init__(self, dispatch_func: Callable[[Callable[..., Any]], Any]) -> None: + self._dispatch_func = dispatch_func + self._running = 0 + self._shuttingdown = False + self._lock = threading.Lock() + self._cv = threading.Condition(self._lock) + + def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future: + if self._shuttingdown: + raise RuntimeError("Executor is shutting down") + future: concurrent.futures.Future = concurrent.futures.Future() + with self._cv: + self._running += 1 + + def run() -> None: + try: + future.set_result(fn(*args, **kwargs)) + except BaseException as ex: + future.set_exception(ex) + with self._cv: + self._running -= 1 + if self._running == 0: + self._cv.notify() + + self._dispatch_func(run) + return future + + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + self._shuttingdown = True + if wait: + with self._cv: + self._cv.wait_for(lambda: self._running == 0) + + +executor_main = _Executor(sublime.set_timeout) +"""Executor instance that runs functions on the Sublime Text main (GUI) thread.""" + +executor_async = _Executor(sublime.set_timeout_async) +"""Executor instance that runs functions on the Sublime Text "async" thread.""" + + +async def next_frame() -> None: + """Wait until (at least one) UI frame has passed.""" + + def noop() -> None: + pass + + await asyncio.get_running_loop().run_in_executor(executor_main, noop) + + +async def gather_and_flatten_exceptions(*coros: Coroutine[Any, Any, list[Exception]]) -> list[Exception]: + """ + Takes a list of coroutines, runs them concurrently using asyncio.gather, collects all exceptions, and returns a + flattened list of Exceptions that occurred for each coroutine. BaseExceptions are filtered out. + """ + exceptions: list[Exception] = [] + items: list[BaseException | list[Exception]] = await asyncio.gather(*coros, return_exceptions=True) + for item in items: + # Only keep exceptions derived from Exception. Exceptions derived from BaseException, but not derived from + # Exception are things like asyncio.CancelledError or SystemExit and should be ignored. + if isinstance(item, Exception): + exceptions.append(item) + elif isinstance(item, list): + exceptions.extend(item) + return exceptions + + +class TaskContainer: + """ + A [mixin class](https://en.wikipedia.org/wiki/Mixin) for adding "fire-and-forget" functionality to a class for + starting coroutines. + + Note: don't forget to call `super().__init__()` when using this class. + + Ensure the `cancel_all_tasks` async function is ran before this class is destroyed. + """ + + def __init__(self) -> None: + self._tasks: set[asyncio.Task] = set() + + def __del__(self) -> None: + if self._tasks: + debug("WARNING: TaskContainer is destroyed but there are still tasks running!") + + async def cancel_all_tasks(self) -> list[Exception]: + return [x for x in await asyncio.gather(*self._tasks, return_exceptions=True) if isinstance(x, Exception)] + + def create_task(self, coro: Coroutine[object, object, object], /, **kwargs: Any) -> asyncio.Task: + """ + Spawn a new coroutine, to be run in the background. Not thread-safe. Must be invoked from the asyncio thread. + + First argument is the coroutine object, the named arguments are exactly the ones from asyncio.create_task. + + This method saves a strong reference to the spawned task, unlike asyncio. + Moreover, this method will print any exception that occured during the exception of the coroutine, if any. + """ + task = asyncio.create_task(coro, **kwargs) + tasks = self._tasks + tasks.add(task) + + def on_done(t: asyncio.Task) -> None: + tasks.discard(t) + if t.cancelled(): + return + if ex := t.exception(): + exception_log(f"Task {t.get_name()} finished with exception", ex) + + task.add_done_callback(on_done) + return task + + def create_task_threadsafe(self, coro: Coroutine[object, object, object], /, **kwargs: Any) -> None: + """ + Spawn a new coroutine, to be run in the background. Thread-safe. + + First argument is the coroutine object, the named arguments are exactly the ones from asyncio.create_task. + + This method saves a strong reference to the spawned task, unlike asyncio. + Moreover, this method will print any exception that occured during the exception of the coroutine, if any. + """ + call_soon_threadsafe(lambda: self.create_task(coro, **kwargs)) diff --git a/plugin/core/edit.py b/plugin/core/edit.py index 894371071..5c8888412 100644 --- a/plugin/core/edit.py +++ b/plugin/core/edit.py @@ -10,8 +10,8 @@ from ...protocol import TextDocumentEdit from ...protocol import TextEdit from ...protocol import WorkspaceEdit +from .aio import next_frame from .logging import debug -from .promise import Promise from .protocol import UINT_MAX from typing import Dict from typing import List @@ -89,22 +89,23 @@ def parse_lsp_position(position: Position) -> tuple[int, int]: return position['line'], min(UINT_MAX, position['character']) -def apply_text_edits( +async def apply_text_edits( view: sublime.View, edits: Sequence[TextEdit | AnnotatedTextEdit | SnippetTextEdit], *, label: str | None = None, process_placeholders: bool = False, required_view_version: int | None = None -) -> Promise[sublime.View | None]: +) -> sublime.View | None: if not edits: - return Promise.resolve(view) + return view if not view.is_valid(): print('LSP: ignoring edits due to view not being open') - return Promise.resolve(None) + return None if process_placeholders: # TODO: remove rust-analyzer specific handling for placeholders in TextEdit, because SnippetTextEdit is now part # of the LSP specs. + # TODO: Communicate results back. view.run_command( 'lsp_apply_document_edit', { @@ -115,16 +116,20 @@ def apply_text_edits( } ) elif required_view_version is None or required_view_version == view.change_count(): + # TODO: Communicate results back. view.run_command('lsp_apply_text_document_edit', {'edits': edits, 'label': label}) # Resolving from the next message loop iteration guarantees that the edits have already been applied in the main # thread, and that we've received view changes in the asynchronous thread. - return Promise(lambda resolve: sublime.set_timeout_async(lambda: resolve(view if view.is_valid() else None))) + await next_frame() + return view if view.is_valid() else None def show_summary_message( - window: sublime.Window, result: ApplyWorkspaceEditResult, summary: WorkspaceEditSummary + window: sublime.Window, result: ApplyWorkspaceEditResult, summary: WorkspaceEditSummary | BaseException ) -> None: - if result['applied']: + if isinstance(summary, BaseException): + message = f"Error: {summary}" + elif result['applied']: message = f"Applied {summary['total_changes']} changes in {summary['edited_files']} files" else: message = "Error while applying WorkspaceEdit" diff --git a/plugin/core/logging.py b/plugin/core/logging.py index b8d0c4147..c16932946 100644 --- a/plugin/core/logging.py +++ b/plugin/core/logging.py @@ -3,6 +3,7 @@ from .constants import ST_PACKAGES_PATH from typing import Any import inspect +import threading import traceback log_debug = False @@ -19,7 +20,7 @@ def debug(*args: Any) -> None: printf(*args) -def trace() -> None: +def trace(print_callstack: bool = False, **values: Any) -> None: current_frame = inspect.currentframe() if current_frame is None: debug("TRACE (unknown frame)") @@ -27,15 +28,25 @@ def trace() -> None: previous_frame = current_frame.f_back file_name, line_number, function_name, _, _ = inspect.getframeinfo(previous_frame) # type: ignore file_name = file_name[len(ST_PACKAGES_PATH) + len("/LSP/"):] - debug(f"TRACE {function_name:<32} {file_name}:{line_number}") + debug(f"TRACE {threading.current_thread().name:<16} {function_name:<32} {file_name}:{line_number}") + if print_callstack: + for frame in traceback.extract_stack(): + debug(f"TRACE {frame.filename}:{frame.lineno} in {frame.name}") + for k, v in values.items(): + debug(f"TRACE {k}={v}") -def exception_log(message: str, ex: Exception) -> None: +def exception_log(message: str, ex: BaseException) -> None: print(message) ex_traceback = ex.__traceback__ print(''.join(traceback.format_exception(ex.__class__, ex, ex_traceback))) +def exceptions_log(message: str, exs: list[Exception]) -> None: + for ex in exs: + exception_log(message, ex) + + def printf(*args: Any, prefix: str = 'LSP') -> None: """Print args to the console, prefixed by the plugin name.""" print(prefix + ":", *args) diff --git a/plugin/core/open.py b/plugin/core/open.py index ae4c9d312..d62ce4193 100644 --- a/plugin/core/open.py +++ b/plugin/core/open.py @@ -1,17 +1,17 @@ from __future__ import annotations +from .aio import executor_main from .constants import ST_PACKAGES_PATH from .constants import ST_PLATFORM from .constants import ST_VERSION from .logging import exception_log -from .promise import Promise -from .promise import ResolveFunc from .protocol import UINT_MAX from .url import parse_uri from .views import range_to_region from typing import TYPE_CHECKING from urllib.parse import unquote from urllib.parse import urlparse +import asyncio import os import re import sublime @@ -23,10 +23,18 @@ from ...protocol import DocumentUri from ...protocol import Range -g_opening_files: dict[str, tuple[Promise[sublime.View | None], ResolveFunc[sublime.View | None]]] = {} +g_opening_files: dict[str, asyncio.Future[sublime.View | None]] = {} +g_opening_files_lock: asyncio.Lock | None = None FRAGMENT_PATTERN = re.compile(r'^L?(\d+)(?:,(\d+))?(?:-L?(\d+)(?:,(\d+))?)?') +def get_opening_files_lock() -> asyncio.Lock: + global g_opening_files_lock + if not g_opening_files_lock: + g_opening_files_lock = asyncio.Lock() + return g_opening_files_lock + + def lsp_range_from_uri_fragment(fragment: str) -> Range | None: if match := FRAGMENT_PATTERN.match(fragment): selection: Range = {'start': {'line': 0, 'character': 0}, 'end': {'line': 0, 'character': 0}} @@ -48,21 +56,16 @@ def lsp_range_from_uri_fragment(fragment: str) -> Range | None: return None -def open_file_uri( +async def open_file_uri( window: sublime.Window, uri: DocumentUri, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, group: int = -1 -) -> Promise[sublime.View | None]: +) -> sublime.View | None: decoded_uri = unquote(uri) # decode percent-encoded characters - open_promise = open_file(window, decoded_uri, flags, group) - if fragment := urlparse(decoded_uri).fragment: - if selection := lsp_range_from_uri_fragment(fragment): - return open_promise.then(lambda view: _select_and_center(view, selection)) - return open_promise - - -def _select_and_center(view: sublime.View | None, r: Range) -> sublime.View | None: + view = await open_file(window, decoded_uri, flags, group) if view: - return center_selection(view, r) - return None + if fragment := urlparse(decoded_uri).fragment: + if selection := lsp_range_from_uri_fragment(fragment): + center_selection(view, selection) + return view def _return_existing_view(flags: int, existing_view_group: int, active_group: int, specified_group: int) -> bool: @@ -84,52 +87,62 @@ def _find_open_file(window: sublime.Window, fname: str, group: int = -1) -> subl return window.find_open_file(fname, group) if ST_VERSION >= 4136 else window.find_open_file(fname) -def open_file( +async def open_file( window: sublime.Window, uri: DocumentUri, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, group: int = -1 -) -> Promise[sublime.View | None]: +) -> sublime.View | None: """ - Open a file asynchronously. - It is only safe to call this function from the UI thread. + Open a file and wait for it to be done loading. The provided uri MUST be a file URI. """ + future: asyncio.Future[sublime.View | None] | None = None file = parse_uri(uri)[1] - # window.open_file brings the file to focus if it's already opened, which we don't want (unless it's supposed - # to open as a separate view). - view = _find_open_file(window, file) - if view and _return_existing_view(flags, window.get_view_index(view)[0], window.active_group(), group): - return Promise.resolve(view) - - was_already_open = view is not None - if not was_already_open and not os.path.isfile(file): - # window.open_file creates a new view with empty content if the path from the given URI doesn't exist as a file - # on disk, but we don't want that here. If the language server wants to create a new file for a given URI, it - # must use the CreateFile resource operation in a WorkspaceEdit. - return Promise.resolve(None) - view = window.open_file(file, flags, group) - if not view.is_loading(): - if was_already_open and (flags & sublime.NewFileFlags.SEMI_TRANSIENT): - # workaround bug https://github.com/sublimehq/sublime_text/issues/2411 where transient view might not get - # its view listeners initialized. - sublime_plugin.check_view_event_listeners(view) # type: ignore - # It's already loaded. Possibly already open in a tab. - return Promise.resolve(view) - - # Is the view opening right now? Then return the associated unresolved promise - for fn, value in g_opening_files.items(): - if fn == file or os.path.samefile(fn, file): - # Return the unresolved promise. A future on_load event will resolve the promise. - return value[0] - - # Prepare a new promise to be resolved by a future on_load event (see the event listener in main.py) - def fullfill(resolve: ResolveFunc[sublime.View | None]) -> None: - # Save the promise in the first element of the tuple -- except we cannot yet do that here - g_opening_files[file] = (None, resolve) # type: ignore - - promise = Promise(fullfill) - tup = g_opening_files[file] - # Save the promise in the first element of the tuple so that the for-loop above can return it - g_opening_files[file] = (promise, tup[1]) - return promise + async with get_opening_files_lock(): + # Is the view opening right now? Then return the associated unresolved future + for fn, fut in g_opening_files.items(): + if fn == file or os.path.samefile(fn, file): # noqa ASYNC240 + # Return the unresolved future. A future on_load event will resolve the future. + future = fut + break + if future is None: + loop = asyncio.get_running_loop() + future = loop.create_future() + + def resolve_right_now(view: sublime.View | None) -> None: + future.set_result(view) + + def resolve_later() -> None: + g_opening_files[file] = future + + def on_main_thread() -> None: + + # window.open_file brings the file to focus if it's already opened, which we don't want (unless it's + # supposed to open as a separate view). + view = _find_open_file(window, file) + if view and _return_existing_view(flags, window.get_view_index(view)[0], window.active_group(), group): + loop.call_soon_threadsafe(lambda: resolve_right_now(view)) + return + + was_already_open = view is not None + if not was_already_open and not os.path.isfile(file): + # window.open_file creates a new view with empty content if the path from the given URI doesn't + # exist as a file on disk, but we don't want that here. If the language server wants to create a new + # file for a given URI, it must use the CreateFile resource operation in a WorkspaceEdit. + loop.call_soon_threadsafe(lambda: resolve_right_now(view)) + return + + view = window.open_file(file, flags, group) + if not view.is_loading(): + if was_already_open and (flags & sublime.NewFileFlags.SEMI_TRANSIENT): + # workaround bug https://github.com/sublimehq/sublime_text/issues/2411 where transient view + # might not get its view listeners initialized. + sublime_plugin.check_view_event_listeners(view) # type: ignore + # It's already loaded. Possibly already open in a tab. + loop.call_soon_threadsafe(lambda: resolve_right_now(view)) + + loop.call_soon_threadsafe(resolve_later) + + await loop.run_in_executor(executor_main, on_main_thread) + return await future def open_resource(window: sublime.Window, uri: DocumentUri, group: int = -1) -> sublime.View | None: diff --git a/plugin/core/promise.py b/plugin/core/promise.py index ac5d1574d..7b2fa5acd 100644 --- a/plugin/core/promise.py +++ b/plugin/core/promise.py @@ -1,14 +1,23 @@ from __future__ import annotations +from typing import Any from typing import Callable +from typing import Generator from typing import Generic from typing import Protocol from typing import Tuple +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +import asyncio import functools +import inspect import threading +if TYPE_CHECKING: + from collections.abc import Coroutine + + T = TypeVar('T') S = TypeVar('S') TExecutor = TypeVar('TExecutor') @@ -106,6 +115,28 @@ def __call__(self, resolver: ResolveFunc[TExecutor]) -> None: assert callable(executor.resolver) return promise, executor.resolver + @staticmethod + def wrap_task(task: asyncio.Task[T]) -> Promise[T | BaseException]: + """Wrap a task in a Promise. The Promise resolves when the task is done.""" + + def executor(resolve: ResolveFunc[T | BaseException]) -> None: + + def on_done(t: asyncio.Task[T]) -> None: + if ex := t.exception(): + resolve(ex) + else: + resolve(t.result()) + + setattr(on_done, "_strong_task_ref", task) + task.add_done_callback(on_done) + + return Promise(executor) + + @staticmethod + def wrap_coroutine(coro: Coroutine[None, None, T]) -> Promise[T | BaseException]: + """Wrap a coroutine object in a Promise. The Promise resolves when the coroutine is done.""" + return Promise.wrap_task(asyncio.create_task(coro)) + # Could also support passing plain S. @staticmethod def all(promises: list[Promise[S]]) -> Promise[list[S]]: @@ -207,6 +238,22 @@ def async_wrapper(resolve_fn: ResolveFunc[TResult]) -> None: return Promise(sync_wrapper) return Promise(async_wrapper) + def __await__(self) -> Generator[Any, None, T]: + """You can `await` a Promise.""" + loop = asyncio.get_running_loop() + future = loop.create_future() + with self.mutex: + if self.resolved: + future.set_result(self.value) + else: + + def resolve_callback(resolve_value: T) -> None: + # We don't know from which thread we are resolving, so use call_soon_threadsafe. + loop.call_soon_threadsafe(functools.partial(future.set_result, resolve_value)) + + self.callbacks.append(resolve_callback) + return future.__await__() + def _do_resolve(self, new_value: T) -> None: # No need to block as we can't change from resolved to unresolved. if self.resolved: @@ -215,6 +262,8 @@ def _do_resolve(self, new_value: T) -> None: self.resolved = True self.value = new_value for callback in self.callbacks: + if inspect.iscoroutine(callback) or inspect.iscoroutinefunction(callback): + raise RuntimeError("Cannot await a coroutine in a Promise.then") callback(new_value) def _add_callback(self, callback: ResolveFunc[T]) -> None: diff --git a/plugin/core/protocol.py b/plugin/core/protocol.py index a4f6b9829..92f94a5a1 100644 --- a/plugin/core/protocol.py +++ b/plugin/core/protocol.py @@ -270,9 +270,9 @@ def documentDiagnostic( @classmethod def workspaceDiagnostic( - cls, params: WorkspaceDiagnosticParams, on_partial_result: Callable[[WorkspaceDiagnosticReport], None] + cls, params: WorkspaceDiagnosticParams ) -> Request[WorkspaceDiagnosticParams, WorkspaceDiagnosticReport]: - return Request('workspace/diagnostic', params, on_partial_result=on_partial_result) + return Request('workspace/diagnostic', params) @classmethod def shutdown(cls) -> Request[None, None]: diff --git a/plugin/core/registry.py b/plugin/core/registry.py index d27e66e5a..48bf9d84f 100644 --- a/plugin/core/registry.py +++ b/plugin/core/registry.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .aio import run_coroutine_threadsafe from .views import first_selection_region from .views import get_uri_and_position_from_location from .views import MissingUriError @@ -197,23 +198,19 @@ def run( flags |= sublime.NewFileFlags.ADD_TO_SELECTION | sublime.NewFileFlags.SEMI_TRANSIENT | sublime.NewFileFlags.CLEAR_TO_RIGHT # noqa: E501 elif 'shift' in modifier_keys: flags |= sublime.NewFileFlags.ADD_TO_SELECTION | sublime.NewFileFlags.SEMI_TRANSIENT - sublime.set_timeout_async(lambda: self._run_async(location, session_name, flags, group)) + run_coroutine_threadsafe(self._run(location, session_name, flags, group)) def want_event(self) -> bool: return True - def _run_async( + async def _run( self, location: Location | LocationLink, session_name: str | None, flags: sublime.NewFileFlags, group: int ) -> None: if session := self.session_by_name(session_name) if session_name else self.session(): - session.open_location_async(location, flags, group) \ - .then(lambda view: self._handle_continuation(location, view is not None)) - - def _handle_continuation(self, location: Location | LocationLink, success: bool) -> None: - if not success: - uri, _ = get_uri_and_position_from_location(location) - message = f"Failed to open {uri}" - sublime.status_message(message) + if not await session.open_location(location, flags, group): + uri, _ = get_uri_and_position_from_location(location) + message = f"Failed to open {uri}" + sublime.status_message(message) class LspRestartServerCommand(LspTextCommand): @@ -236,21 +233,18 @@ def want_event(self) -> bool: def restart_server(self, wm: WindowManager, index: int) -> None: if index == -1: return - - def run_async() -> None: - wm.restart_sessions_async([self._config_names[index]]) - - sublime.set_timeout_async(run_async) + # TODO: handle exception list? + run_coroutine_threadsafe(wm.restart_sessions([self._config_names[index]])) class LspCheckApplicableCommand(sublime_plugin.TextCommand): def run(self, edit: sublime.Edit, session_name: str) -> None: - sublime.set_timeout_async(lambda: self._run_async(session_name)) + run_coroutine_threadsafe(self._run(session_name)) - def _run_async(self, session_name: str) -> None: + async def _run(self, session_name: str) -> None: if wm := windows.lookup(self.view.window()): - wm.recheck_is_applicable_async(self.view, session_name) + await wm.recheck_is_applicable(self.view, session_name) def navigate_diagnostics(view: sublime.View, point: int | None, forward: bool = True) -> None: diff --git a/plugin/core/sessions.py b/plugin/core/sessions.py index f1c3a8f86..32519a9d2 100644 --- a/plugin/core/sessions.py +++ b/plugin/core/sessions.py @@ -77,7 +77,6 @@ from ...protocol import WorkDoneProgressReport from ...protocol import WorkspaceClientCapabilities from ...protocol import WorkspaceDiagnosticParams -from ...protocol import WorkspaceDiagnosticReport from ...protocol import WorkspaceDocumentDiagnosticReport from ...protocol import WorkspaceEdit from ...protocol import WorkspaceFolder as LspWorkspaceFolder @@ -91,6 +90,11 @@ from ..diagnostics import DiagnosticsStorage from ..diagnostics import WORKSPACE_DIAGNOSTICS_RETRIGGER_DELAY from ..locationpicker import LocationPicker +from .aio import aclosing +from .aio import call_soon_threadsafe +from .aio import executor_main +from .aio import gather_and_flatten_exceptions +from .aio import TaskContainer from .constants import ChangeEventAction from .constants import MarkdownLangMap from .constants import MARKO_MD_PARSER_VERSION @@ -110,6 +114,7 @@ from .file_watcher import lsp_watch_kind_to_file_watcher_event_types from .logging import debug from .logging import exception_log +from .logging import exceptions_log from .logging import printf from .open import center_selection from .open import open_externally @@ -138,7 +143,6 @@ from .types import Capabilities from .types import ClientConfig from .types import ClientStates -from .types import debounced from .types import diff from .types import DocumentSelectorMatcher from .types import method2attr @@ -159,21 +163,22 @@ from enum import IntFlag from functools import lru_cache from functools import partial -from operator import itemgetter from typing import Any from typing import Callable from typing import cast from typing import Generator +from typing import Generic from typing import Literal from typing import overload from typing import Protocol from typing import TYPE_CHECKING from typing import TypeVar -from typing import Union +from typing_extensions import deprecated from typing_extensions import TypeAlias from typing_extensions import TypeGuard from urllib.parse import urldefrag from weakref import WeakSet +import asyncio import itertools import mdpopups import os @@ -274,13 +279,12 @@ def should_ignore_diagnostics(self, uri: DocumentUri, configuration: ClientConfi # Mutators @abstractmethod - def start_async(self, configuration: ClientConfig, initiating_view: sublime.View) -> None: + async def start(self, config: ClientConfig, listener: AbstractViewListener) -> Session | None: """ - Start a new Session with the given configuration. The initiating view is the view that caused this method to + Start a new Session with the given configuration. The listener is the listener that caused this method to be called. - A normal flow of calls would be start -> on_post_initialize -> do language server things -> on_post_exit. - However, it is possible that the subprocess cannot start, in which case on_post_initialize will never be called. + Returns the initialized Session object, or None if nothing was started. """ raise NotImplementedError @@ -291,20 +295,18 @@ def on_diagnostics_updated(self) -> None: # Event callbacks @abstractmethod - def on_post_exit_async(self, session: Session, exit_code: int, exception: Exception | None) -> None: + async def on_post_exit(self, session: Session, exit_code: int, exception: Exception | None) -> None: """The given Session has stopped with the given exit code.""" raise NotImplementedError @abstractmethod - def handle_message_request( + async def handle_message_request( self, config_name: str, params: ShowMessageRequestParams - ) -> Promise[MessageActionItem | None]: + ) -> MessageActionItem | None: ... @abstractmethod - def handle_show_message( - self, config_name: str, params: ShowMessageParams - ) -> Promise[MessageActionItem | None]: + def handle_show_message(self, config_name: str, params: ShowMessageParams) -> None: ... @abstractmethod @@ -672,7 +674,7 @@ def on_capability_removed_async(self, registration_id: str, discarded_capabiliti def has_capability_async(self, capability_path: str) -> bool: ... - def shutdown_async(self) -> None: + async def shutdown(self) -> list[Exception]: ... def present_diagnostics_async(self, is_view_visible: bool) -> None: @@ -802,7 +804,17 @@ def request_code_actions_async( diagnostics: list[Diagnostic], kinds: list[str | CodeActionKind] | None = ..., trigger_kind: CodeActionTriggerKind = ... - ) -> Promise[list[Command | CodeAction] | Error | None]: + ) -> Promise[list[Command | CodeAction] | BaseException | None]: + ... + + async def request_code_actions( + self, + view: sublime.View, + region: sublime.Region, + diagnostics: list[Diagnostic], + kinds: list[str | CodeActionKind] | None = ..., + trigger_kind: CodeActionTriggerKind = ... + ) -> list[Command | CodeAction] | Error | None: ... def do_code_lenses_async(self, view: sublime.View) -> None: @@ -849,7 +861,7 @@ def on_session_initialized_async(self, session: Session) -> None: raise NotImplementedError @abstractmethod - def on_session_shutdown_async(self, session: Session) -> None: + async def on_session_shutdown(self, session: Session) -> list[Exception]: raise NotImplementedError @abstractmethod @@ -948,6 +960,85 @@ def incoming_notification(self, method: str, params: Any, unhandled: bool) -> No pass +class CancellableRequest(Generic[R]): + """A request that is cancellable.""" + + _id: int + _weaksession: weakref.ref[Session] + + def __init__(self, req_id: int, session: Session) -> None: + self._id = req_id + self._weaksession = weakref.ref(session) + + def cancel(self) -> None: + """Cancel this request.""" + if self._id != 0: + if session := self._weaksession(): + session.cancel_request_async(self._id) + self._id = 0 + + @property + def id(self) -> int: + """Get the request ID.""" + return self._id + + +class CancellableInflightRequest(CancellableRequest[R]): + """A request that is in flight. The result can be awaited.""" + + _future: asyncio.Future[R] + + def __init__(self, future: asyncio.Future[R], req_id: int, session: Session) -> None: + super().__init__(req_id, session) + self._future = future + + def __await__(self) -> Generator[Any, None, R]: + """ + You can `await` the response of an in-flight request. + However, note that immediately awaiting this object prevents you from ever canceling it. + When the language server replies with an error, an exception of type protocol.Error is raised. + """ + return self._future.__await__() + + +class CancellableInflightStreamingRequest(CancellableRequest[R]): + """ + A streaming request that is in flight. + Use `async for` syntax to asynchronously stream the partial results. + Only requests for which the partial results are of type list[...] can work with this class. + An empty list signals the end of the stream. So the class knows when to signal the end of the `async for` loop. + """ + + def __init__(self, req_id: int, session: Session) -> None: + super().__init__(req_id, session) + self._queue: asyncio.Queue[R | Error | None] = asyncio.Queue() + + def on_partial_result(self, response: R) -> None: + # Note: R should have type list[...]. If an empty list is returned, then this signals the end of the partial + # result stream. In that case we put `None` in the queue. + self._queue.put_nowait(response or None) + + def on_error(self, error: ResponseError) -> None: + self._queue.put_nowait(Error.from_lsp(error)) + + def __aiter__(self) -> CancellableInflightStreamingRequest: + """Stream partial results using the `async for` syntax.""" + return self + + async def __anext__(self) -> R: + """Get the next partial result.""" + item = await self._queue.get() + if item is None: + raise StopAsyncIteration + if isinstance(item, Error): + raise item + return item + + async def aclose(self) -> None: + # See: https://docs.python.org/3/library/asyncio-dev.html#close-asynchronous-generators-explicitly + self._queue.put_nowait(None) + + def print_to_status_bar(error: ResponseError) -> None: sublime.status_message(error["message"]) @@ -991,7 +1082,7 @@ def check_applicable(self, sb: SessionBufferProtocol, *, suppress_requests: bool _PARTIAL_RESULT_PROGRESS_PREFIX = "$ublime-partial-result-progress-" -class Session(APIHandler, TransportCallbacks): +class Session(APIHandler, TransportCallbacks, TaskContainer): def __init__(self, manager: Manager, logger: Logger, workspace_folders: list[WorkspaceFolder], config: ClientConfig, plugin_class: type[AbstractPlugin | LspPlugin] | None, @@ -1009,11 +1100,9 @@ def __init__(self, manager: Manager, logger: Logger, workspace_folders: list[Wor self.capabilities = Capabilities() self.diagnostics = DiagnosticsStorage() self.diagnostics_result_ids: dict[tuple[DocumentUri, DiagnosticsIdentifier], str | None] = {} - self.workspace_diagnostics_pending_responses: dict[DiagnosticsIdentifier, int | None] = {} + self.workspace_diagnostics_pending_responses: dict[DiagnosticsIdentifier, CancellableInflightStreamingRequest | None] = {} # noqa: E501 self.exiting = False self._registrations: dict[str, _RegistrationData] = {} - self._init_callback: InitCallback | None = None - self._initialize_error: tuple[int, Exception | None] | None = None self._views_opened = 0 self._variables: dict[str, str] = {} self._workspace_folders = workspace_folders @@ -1029,6 +1118,7 @@ def __init__(self, manager: Manager, logger: Logger, workspace_folders: list[Wor self._semantic_tokens_map = get_semantic_tokens_map(config.semantic_tokens) self._is_executing_refactoring_command = False self._logged_unsupported_commands: set[str] = set() + self._maybe_end_task: asyncio.Task | None = None super().__init__() # TODO: Create an assurance that the API doesn't change here as it can be used by plugins. @@ -1047,11 +1137,33 @@ def register_session_view_async(self, sv: SessionViewProtocol) -> None: for status_key, message in self._status_messages.items(): sv.view.set_status(status_key, message) - def unregister_session_view_async(self, sv: SessionViewProtocol) -> None: + async def unregister_session_view(self, sv: SessionViewProtocol) -> None: self._session_views.discard(sv) - if not self._session_views: - current_count = self._views_opened - debounced(self.end_async, 3000, lambda: self._views_opened == current_count, async_thread=True) + if self._session_views: + return + current_count = self._views_opened + + async def maybe_end() -> None: + await asyncio.sleep(3) + if self._views_opened == current_count: + exceptions_log(f"Exception while stopping {self.config.name}", await self.end()) + self._maybe_end_task = None + + if self.exiting: + # This means we're really ending, just return and let this object shutdown. + return + # If we're at this point, then we are certain the `end()` method isn't running. Maybe there is an existing + # `maybe_end` task running, in which case, at this point, we are certain it's sleeping. + if self._maybe_end_task: + # Cancel the sleep. + if self._maybe_end_task.cancel(): + try: + await self._maybe_end_task + except asyncio.CancelledError: + pass + # The maybe_end task is special. Inside of the `end()` method, we call `cancel_all_tasks()`. If this + # maybe_end task is part of that task list, then it will itself also be cancelled, which we don't want. + self._maybe_end_task = asyncio.create_task(maybe_end()) def session_views_async(self) -> Generator[SessionViewProtocol, None, None]: """It is only safe to iterate over this in the async thread.""" @@ -1094,7 +1206,7 @@ def unregister_session_buffer_async(self, sb: SessionBufferProtocol) -> None: self._session_buffers.discard(sb) def session_buffers_async(self) -> Generator[SessionBufferProtocol, None, None]: - """It is only safe to iterate over this in the async thread.""" + """It is only safe to iterate over this in the asyncio thread.""" yield from self._session_buffers def get_session_buffer_for_uri_async(self, uri: DocumentUri) -> SessionBufferProtocol | None: @@ -1244,24 +1356,23 @@ def update_folders(self, folders: list[WorkspaceFolder]) -> None: else: self._workspace_folders = folders[:1] - def initialize_async( + async def initialize( self, variables: dict[str, str], working_directory: str | None, - transport: TransportWrapper, - init_callback: InitCallback - ) -> None: + transport: TransportWrapper + ) -> InitializeResult: + loop = asyncio.get_running_loop() if self._plugin_class and issubclass(self._plugin_class, LspPlugin): self._plugin = self._plugin_class(weakref.ref(self)) self.transport = transport self.working_directory = working_directory - self._variables = variables - params = get_initialize_params(self._variables, self._workspace_folders, self.config) - self._init_callback = init_callback - self.send_request_async( - Request.initialize(params), self._handle_initialize_success, self._handle_initialize_error) - - def _handle_initialize_success(self, result: InitializeResult) -> None: + params = get_initialize_params(variables, self._workspace_folders, self.config) + try: + result = await self.request(Request.initialize(params)) + except: + await self.end() # ignore exceptions + raise capabilities = result['capabilities'] self.capabilities.assign(capabilities) if self._workspace_folders and not self._supports_workspace_folders(): @@ -1274,15 +1385,18 @@ def _handle_initialize_success(self, result: InitializeResult) -> None: # Handle it now and use fake request ID since it shouldn't matter. if issubclass(self._plugin_class, AbstractPlugin): self._plugin = self._plugin_class(weakref.ref(self)) - self._plugin.on_server_response_async('initialize', Response[InitializeResult](-1, result)) - self.send_notification(Notification.initialized()) + self._plugin.on_server_response_async('initialize', Response(-1, result)) + await self.notify(Notification.initialized()) if self._plugin and isinstance(self._plugin, LspPlugin): - self._plugin.on_initialized_async() + if self._plugin.use_asyncio: + await self._plugin.on_initialized() + else: + self._plugin.on_initialized_async() self._maybe_send_did_change_configuration() if execute_commands := self.get_capability('executeCommandProvider.commands'): debug(f"{self.config.name}: Supported execute commands: {execute_commands}") if code_action_kinds := self.get_capability('codeActionProvider.codeActionKinds'): - debug(f'{self.config.name}: supported code action kinds: {code_action_kinds}') + debug(f'{self.config.name}: Supported code action kinds: {code_action_kinds}') if semantic_token_types := self.get_capability('semanticTokensProvider.legend.tokenTypes'): debug(f'{self.config.name}: Supported semantic token types: {semantic_token_types}') if semantic_token_modifiers := self.get_capability('semanticTokensProvider.legend.tokenModifiers'): @@ -1295,15 +1409,8 @@ def _handle_initialize_success(self, result: InitializeResult) -> None: ignores = config.get('ignores') or self._get_global_ignore_globs(folder.path) watcher = self._watcher_impl.create(folder.path, patterns, events, ignores, self) self._static_file_watchers.append(watcher) - if self._init_callback: - self._init_callback(self, False) - self._init_callback = None - self.do_workspace_diagnostics_async() - - def _handle_initialize_error(self, result: ResponseError) -> None: - self._initialize_error = (result.get('code', -1), Exception(result.get('message', 'Error initializing server'))) - # Init callback called after transport is closed to avoid pre-mature GC of Session. - self.end_async() + loop.call_soon(self.do_workspace_diagnostics_async) + return result def _get_global_ignore_globs(self, root_path: str) -> list[str]: folder_exclude_patterns = cast('list[str]', globalprefs().get('folder_exclude_patterns')) @@ -1337,46 +1444,41 @@ def _get_resolved_settings(self) -> dict[str, Any]: self._plugin.on_settings_changed(self.config.settings) return self.config.settings.get_resolved(self._variables) - def execute_command( + async def execute_command( self, command: ExecuteCommandParams, *, progress: bool = False, view: sublime.View | None = None, is_refactoring: bool = False, - ) -> Promise[R | Error | None]: # pyright: ignore[reportInvalidTypeVarUse] - """Run a command from any thread. Your .then() continuations will run in Sublime's worker thread.""" + ) -> LSPAny: + """Run a command from the asyncio thread.""" command_name = command['command'] if self._plugin: if isinstance(self._plugin, LspPlugin): if command_handler := self._plugin.get_command_handler(command_name): - return command_handler(command.get('arguments')) + return await command_handler(command.get('arguments')) else: - task: PackagedTask[R | Error | None] = Promise.packaged_task() + task: PackagedTask[LSPAny | Error | None] = Promise.packaged_task() promise, resolve = task if self._plugin.on_pre_server_command(command, lambda: resolve(None)): - return promise - resolve(None) + return cast("LSPAny", await promise) # Handle VSCode-specific command for triggering AC/sighelp if command_name == "editor.action.triggerSuggest" and view: # Triggered from set_timeout as suggestions popup doesn't trigger otherwise. sublime.set_timeout(lambda: view.run_command("auto_complete")) - return Promise.resolve(None) + return None if command_name == "editor.action.triggerParameterHints" and view: - - def run_async() -> None: - session_view = self.session_view_for_view_async(view) - if not session_view: - return - listener = session_view.listener() - if not listener: - return - listener.do_signature_help_async(SignatureHelpTriggerKind.Invoked) - - sublime.set_timeout_async(run_async) - return Promise.resolve(None) + session_view = self.session_view_for_view_async(view) + if not session_view: + return None + listener = session_view.listener() + if not listener: + return None + listener.do_signature_help_async(SignatureHelpTriggerKind.Invoked) + return None # Handle VSCode-specific command which is often used for "References" code lenses if command_name == "editor.action.showReferences" and view: if (arguments := command.get('arguments')) and len(arguments) == 3: if references := cast('list[Location]', arguments[2]): if len(references) == 1: - self.open_location_async(references[0]) + await self.open_location(references[0]) else: view_uri = uri_from_view(view) locations = sorted( @@ -1388,25 +1490,37 @@ def run_async() -> None: ) ) LocationPicker(view, self, locations, side_by_side=False) - return Promise.resolve(None) - request = Request[ExecuteCommandParams, Union[R, None]].executeCommand(command, progress=progress) - execute_command_promise = self.send_request_task(request) + return None + future = self.request(Request.executeCommand(command, progress=progress)) if is_refactoring: self._is_executing_refactoring_command = True - execute_command_promise.then(lambda _: self._reset_is_executing_refactoring_command()) - return execute_command_promise + try: + return await future + finally: + self._is_executing_refactoring_command = False + return await future - def _reset_is_executing_refactoring_command(self) -> None: - self._is_executing_refactoring_command = False + @deprecated("use Session.execute_command instead") + def execute_command_async( + self, + command: ExecuteCommandParams, + *, + progress: bool = False, + view: sublime.View | None = None, + is_refactoring: bool = False, + ) -> Promise[LSPAny | BaseException]: + return Promise.wrap_task( + self.create_task(self.execute_command(command, progress=progress, view=view, is_refactoring=is_refactoring)) + ) def check_log_unsupported_command(self, command: str) -> None: if userprefs().log_debug and command not in self._logged_unsupported_commands: self._logged_unsupported_commands.add(command) debug(f'{self.config.name}: unsupported command: {command}') - def run_code_action_async( + async def run_code_action( self, code_action: Command | CodeAction, progress: bool, view: sublime.View | None = None - ) -> Promise[None]: + ) -> LSPAny: command = code_action.get("command") if isinstance(command, str): code_action = cast('Command', code_action) @@ -1416,134 +1530,153 @@ def run_code_action_async( if isinstance(arguments, list): command_params['arguments'] = arguments is_refactoring = kind_contains_other_kind(CodeActionKind.Refactor, code_action.get('kind', '')) - return self.execute_command(command_params, progress=progress, view=view, is_refactoring=is_refactoring) \ - .then(lambda _: None) + return await self.execute_command( + command_params, progress=progress, view=view, is_refactoring=is_refactoring + ) # At this point it cannot be a command anymore, it has to be a proper code action. # A code action can have an edit and/or command. Note that it can have *both*. In case both are present, we # must apply the edits before running the command. code_action = cast('CodeAction', code_action) - return self._maybe_resolve_code_action(code_action, view) \ - .then(lambda code_action: self._apply_code_action_async(code_action, view)) + code_action = await self._maybe_resolve_code_action(code_action, view) + return await self._apply_code_action(code_action, view) - def try_open_uri_async( + @deprecated("use Session.run_code_action instead") + def run_code_action_async( + self, code_action: Command | CodeAction, progress: bool, view: sublime.View | None = None + ) -> Promise[BaseException | None]: + return Promise.wrap_coroutine(self.run_code_action(code_action, progress, view)) + + async def try_open_uri( self, uri: DocumentUri, r: Range | None = None, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, - group: int = -1 - ) -> Promise[sublime.View | None] | None: + group: int = -1, + ) -> sublime.View | Literal[False] | None: + """ + Try to open an URI. + + If the URI has the file: scheme, opens the file in a tab. + If the URI has the res: scheme, opens the Sublime resource file in a tab. + If the URI has the untitled: scheme, opens a scratch tab. + Otherwise, if there's a plugin attached, delegates to the plugin. + If the plugin does not handle the URI scheme, returns the constant boolean False. + If the URI can be opened, returns an optional sublime.View. + """ if uri.startswith("file:"): - return self._open_file_uri_async(uri, r, flags, group) + return await self._open_file_uri(uri, r, flags, group) # Try to find a pre-existing session-buffer if sb := self.get_session_buffer_for_uri_async(uri): view = sb.get_view_in_group(group) self.window.focus_view(view) if r: center_selection(view, r) - return Promise.resolve(view) + return view if uri.startswith('res:'): - return self._open_res_uri_async(uri, r, group) + return await self._open_res_uri(uri, r, group) + loop = asyncio.get_running_loop() if uri.startswith('untitled:'): # VSCode specific URI scheme for unsaved buffers - flags &= sublime.NewFileFlags.TRANSIENT | sublime.NewFileFlags.ADD_TO_SELECTION - if name := uri[len('untitled:'):]: - # Check if there is a pre-existing unsaved buffer with the given name - for view in self.window.views(): - if view.file_name() is None and view.name() == name: - self.window.focus_view(view) - return Promise.resolve(view) + + def open_untitled_buffer(flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE) -> sublime.View: + flags &= sublime.NewFileFlags.TRANSIENT | sublime.NewFileFlags.ADD_TO_SELECTION + if name := uri[len('untitled:') :]: + # Check if there is a pre-existing unsaved buffer with the given name + for view in self.window.views(): + if view.file_name() is None and view.name() == name: + self.window.focus_view(view) + return view + view = self.window.new_file(flags) + view.set_scratch(True) + view.set_name(name) + return view view = self.window.new_file(flags) view.set_scratch(True) - view.set_name(name) - return Promise.resolve(view) - view = self.window.new_file(flags) - view.set_scratch(True) - return Promise.resolve(view) - # There is no pre-existing session-buffer, so we have to go through the plugin's URI handler. + return view + + return await loop.run_in_executor(executor_main, open_untitled_buffer, flags) + # There is no pre-existing session-buffer, so we have to go through AbstractPlugin.on_open_uri_async. if self._plugin: if isinstance(self._plugin, LspPlugin): scheme, _ = parse_uri(uri) if handler := self._plugin.get_uri_handler(scheme): - return handler(uri, flags).then(lambda sheet: self._on_sheet_opened(sheet, uri, r)) + sheet = await handler(uri, flags) + return self._on_sheet_opened(sheet, uri, r) else: - return self._open_uri_with_plugin_async(self._plugin, uri, r, flags, group) - return None + return await self._open_uri_with_plugin(self._plugin, uri, r, flags, group) + return False - def open_uri_async( + async def open_uri( self, uri: DocumentUri, r: Range | None = None, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, group: int = -1 - ) -> Promise[sublime.View | None]: - promise = self.try_open_uri_async(uri, r, flags, group) - return Promise.resolve(None) if promise is None else promise - - def _open_file_uri_async( + ) -> sublime.View | None: + """Open a URI. If the URI can't be opened, raises RuntimeError.""" + result = await self.try_open_uri(uri, r, flags, group) + if result is False: + raise RuntimeError(f"unable to open URI {uri}") + return result + + async def _open_file_uri( self, uri: DocumentUri, r: Range | None = None, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, group: int = -1 - ) -> Promise[sublime.View | None]: - result: PackagedTask[sublime.View | None] = Promise.packaged_task() + ) -> sublime.View | None: + view = await open_file(self.window, uri, flags, group) + if view and r: + center_selection(view, r) + return view - def handle_continuation(view: sublime.View | None) -> None: - if view and r: - center_selection(view, r) - sublime.set_timeout_async(lambda: result[1](view)) - - sublime.set_timeout(lambda: open_file(self.window, uri, flags, group).then(handle_continuation)) - return result[0] - - def _open_res_uri_async( + async def _open_res_uri( self, uri: DocumentUri, r: Range | None = None, group: int = -1 - ) -> Promise[sublime.View | None]: + ) -> sublime.View | None: - def continue_on_main_thread() -> None: + def continue_on_main_thread() -> sublime.View | None: view = open_resource(self.window, uri, group) if view and r: sublime.set_timeout(partial(center_selection, view, r)) - sublime.set_timeout_async(lambda: result[1](view)) + return view - result: PackagedTask[sublime.View | None] = Promise.packaged_task() - sublime.set_timeout(continue_on_main_thread) - return result[0] + return await asyncio.get_running_loop().run_in_executor(executor_main, continue_on_main_thread) - def _open_uri_with_plugin_async( + async def _open_uri_with_plugin( self, plugin: AbstractPlugin, uri: DocumentUri, r: Range | None, flags: sublime.NewFileFlags, group: int, - ) -> Promise[sublime.View | None] | None: + ) -> sublime.View | Literal[False] | None: # I cannot type-hint an unpacked tuple pair: PackagedTask[tuple[str, str, str]] = Promise.packaged_task() promise, resolve = pair # It'd be nice to have automatic tuple unpacking continuations callback = lambda a, b, c: resolve((a or 'untitled', b, c)) # noqa: E731 if plugin.on_open_uri_async(uri, callback): - return promise.then(lambda tup: self.open_scratch_buffer(*tup, flags, group)) \ - .then(lambda view: self._on_sheet_opened(view.sheet(), uri, r)) + title, content, syntax = await promise + if view := await self.open_scratch_buffer(title, content, syntax, flags, group): + return self._on_sheet_opened(view.sheet(), uri, r) + return None # resolve unused promise resolve(('', '', '')) - return None + return False - def open_scratch_buffer( + async def open_scratch_buffer( self, title: str, content: str, syntax: str, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, group: int = -1, - ) -> Promise[sublime.View]: - task: PackagedTask[sublime.View] = Promise.packaged_task() - promise, resolve = task + ) -> sublime.View | None: - def continue_on_main_thread() -> None: + def continue_on_main_thread() -> sublime.View | None: if group > -1: self.window.focus_group(group) view = self.window.new_file(syntax=syntax, flags=flags) @@ -1553,10 +1686,9 @@ def continue_on_main_thread() -> None: view.set_name(title) view.run_command("append", {"characters": content}) view.set_read_only(True) - resolve(view) + return view - sublime.set_timeout(continue_on_main_thread) - return promise + return await asyncio.get_running_loop().run_in_executor(executor_main, continue_on_main_thread) def _on_sheet_opened(self, sheet: sublime.Sheet | None, uri: DocumentUri, r: Range | None) -> sublime.View | None: if sheet and (view := sheet.view()): @@ -1567,16 +1699,16 @@ def _on_sheet_opened(self, sheet: sublime.Sheet | None, uri: DocumentUri, r: Ran return view return None - def open_location_async( + async def open_location( self, location: Location | LocationLink, flags: sublime.NewFileFlags = sublime.NewFileFlags.NONE, group: int = -1 - ) -> Promise[sublime.View | None]: + ) -> sublime.View | None: uri, r = get_uri_and_range_from_location(location) - return self.open_uri_async(uri, r, flags, group) + return await self.open_uri(uri, r, flags, group) - def notify_plugin_on_session_buffer_change(self, session_buffer: SessionBufferProtocol) -> None: + def notify_plugin_on_session_buffer_change_async(self, session_buffer: SessionBufferProtocol) -> None: if not self._plugin: return if isinstance(self._plugin, LspPlugin): @@ -1584,9 +1716,9 @@ def notify_plugin_on_session_buffer_change(self, session_buffer: SessionBufferPr else: self._plugin.on_session_buffer_changed_async(session_buffer) - def _maybe_resolve_code_action( + async def _maybe_resolve_code_action( self, code_action: CodeAction, view: sublime.View | None - ) -> Promise[CodeAction | Error]: + ) -> CodeAction: if "edit" not in code_action: has_capability = self.has_capability("codeActionProvider.resolveProvider") if not has_capability and view: @@ -1594,77 +1726,68 @@ def _maybe_resolve_code_action( has_capability = session_view.has_capability_async("codeActionProvider.resolveProvider") if has_capability: # We must first resolve the command and edit properties, because they can potentially be absent. - request = Request("codeAction/resolve", code_action) - return self.send_request_task(request) - return Promise.resolve(code_action) + return await self.request(Request("codeAction/resolve", code_action)) + return code_action - def _apply_code_action_async( - self, code_action: CodeAction | Error | None, view: sublime.View | None - ) -> Promise[None]: + async def _apply_code_action(self, code_action: CodeAction | Error | None, view: sublime.View | None) -> None: if not code_action: - return Promise.resolve(None) + return if isinstance(code_action, Error): - # TODO: our promise must be able to handle exceptions (or, wait until we can use coroutines) + # TODO: do something with the error? self.window.status_message(f"Failed to apply code action: {code_action}") - return Promise.resolve(None) + return title = code_action['title'] edit = code_action.get("edit") is_refactoring = kind_contains_other_kind(CodeActionKind.Refactor, code_action.get('kind', '')) - promise = self.apply_workspace_edit_async(edit, label=title, is_refactoring=is_refactoring) \ - .then(lambda _: None) if edit else Promise.resolve(None) - command = code_action.get("command") - if command is not None: + if edit: + await self.apply_workspace_edit(edit, label=title, is_refactoring=is_refactoring) + if command := code_action.get("command"): execute_command: ExecuteCommandParams = { "command": command["command"], } arguments = command.get("arguments") if arguments is not None: execute_command['arguments'] = arguments - return promise \ - .then(lambda _: self.execute_command(execute_command, progress=False, view=view, - is_refactoring=is_refactoring)) \ - .then(lambda _: None) - return promise + await self.execute_command(execute_command, progress=False, view=view, is_refactoring=is_refactoring) - def apply_document_changes_async( + async def apply_document_changes( self, document_changes: list[TextDocumentEdit | CreateFile | RenameFile | DeleteFile], change_annotations: dict[ChangeAnnotationIdentifier, ChangeAnnotation], *, label: str | None = None, is_refactoring: bool = False - ) -> Promise[ApplyWorkspaceEditResult]: + ) -> ApplyWorkspaceEditResult: active_sheet = self.window.active_sheet() selected_sheets = self.window.selected_sheets() auto_save = userprefs().refactoring_auto_save if is_refactoring else 'never' index = 0 # Assuming 0-based indexing for the ApplyWorkspaceEditResult.faildedChange value - promise = self._apply_document_changes_recursive_async( + result = await self._apply_document_changes_recursive( document_changes, change_annotations, index, label, auto_save) - promise \ - .then(lambda _: self._set_selected_sheets(selected_sheets)) \ - .then(lambda _: self._set_focused_sheet(active_sheet)) - return promise + self._set_selected_sheets(selected_sheets) + self._set_focused_sheet(active_sheet) + return result - def _apply_document_changes_recursive_async( + async def _apply_document_changes_recursive( self, document_changes: list[TextDocumentEdit | CreateFile | RenameFile | DeleteFile], change_annotations: dict[ChangeAnnotationIdentifier, ChangeAnnotation], index: int, label: str | None, auto_save: str - ) -> Promise[ApplyWorkspaceEditResult]: + ) -> ApplyWorkspaceEditResult: - def apply_text_document_edit( + async def apply_text_document_edit( view: sublime.View | None, uri: DocumentUri, edits: list[TextEdit | AnnotatedTextEdit | SnippetTextEdit], version: int | None, view_state_actions: ViewStateActions - ) -> Promise[str | None]: + ) -> str | None: if not view: - return Promise.resolve(f'Failed to open URI {uri}') + return f'Failed to open URI {uri}' if version is not None and version != (change_count := view.change_count()): - return Promise.resolve(f'Document version for URI {uri} is {change_count}, but required {version}') + return f'Document version for URI {uri} is {change_count}, but required {version}' for edit in edits: # Use more specific label for this particular TextDocumentEdit if available if annotation_id := edit.get('annotationId'): @@ -1673,70 +1796,70 @@ def apply_text_document_edit( else: edit_label = label view.run_command('lsp_apply_text_document_edit', {'edits': edits, 'label': edit_label}) - promise = Promise(lambda resolve: sublime.set_timeout_async(lambda: resolve(None))) if view and view_state_actions: - return promise.then(lambda _: self._set_view_state(view_state_actions, view)) # pyright: ignore[reportReturnType] - return promise + await self._set_view_state(view_state_actions, view) + return None - def _continue(failure_reason: str | None) -> Promise[ApplyWorkspaceEditResult]: + async def _continue(failure_reason: str | None) -> ApplyWorkspaceEditResult: if failure_reason: printf(f'Error while applying WorkspaceEdit: {failure_reason}') - return Promise.resolve({ + return { 'applied': False, 'failureReason': failure_reason, 'failedChange': index - }) - return self._apply_document_changes_recursive_async( + } + return await self._apply_document_changes_recursive( document_changes, change_annotations, index + 1, label, auto_save) try: document_change = document_changes.pop(0) except IndexError: # All document changes were handled - return Promise.resolve({'applied': True}) + return {'applied': True} if is_text_document_edit(document_change): text_document = document_change['textDocument'] uri = text_document['uri'] version = text_document['version'] view_state_actions = self._get_view_state_actions(uri, auto_save) - return self.open_uri_async(uri).then( - lambda view: apply_text_document_edit(view, uri, document_change['edits'], version, view_state_actions) - ).then(_continue) + view = await self.open_uri(uri) + failure_reason = await apply_text_document_edit( + view, uri, document_change['edits'], version, view_state_actions) + return await _continue(failure_reason) if is_create_file(document_change): # TODO: add support for ResourceOperationKind.Create - return Promise.resolve({ + return { 'applied': False, 'failureReason': 'CreateFile not yet supported by client', 'failedChange': index - }) + } if is_rename_file(document_change): # TODO: add support for ResourceOperationKind.Rename - return Promise.resolve({ + return { 'applied': False, 'failureReason': 'RenameFile not yet supported by client', 'failedChange': index - }) + } if is_delete_file(document_change): # TODO: add support for ResourceOperationKind.Delete - return Promise.resolve({ + return { 'applied': False, 'failureReason': 'DeleteFile not yet supported by client', 'failedChange': index - }) + } # Should be unreachable, but must return value on all code paths to satisfy type checker - return Promise.resolve({ + return { 'applied': False, 'failureReason': 'Unknown document change type', 'failedChange': index - }) + } - def apply_workspace_edit_async( + async def apply_workspace_edit( self, edit: WorkspaceEdit, *, label: str | None = None, is_refactoring: bool = False - ) -> Promise[tuple[ApplyWorkspaceEditResult, WorkspaceEditSummary]]: + ) -> tuple[ApplyWorkspaceEditResult, WorkspaceEditSummary]: """ - Apply a WorkspaceEdit, and return a promise that resolves on the async thread again after the edits have been - applied. The resolved promise contains the ApplyWorkspaceEditResult and a summary of the changes in the - WorkspaceEdit. + Apply a WorkspaceEdit. + + Returns the ApplyWorkspaceEditResult and a summary of the changes in the WorkspaceEdit. """ document_changes = edit.get('documentChanges', []) if not document_changes: @@ -1762,12 +1885,31 @@ def apply_workspace_edit_async( summary['renamed_files'] += 1 elif is_delete_file(document_change): summary['deleted_files'] += 1 - return self.apply_document_changes_async( - document_changes, - change_annotations, - label=label, - is_refactoring=is_refactoring or self._is_executing_refactoring_command - ).then(lambda result: (result, summary)) + return ( + await self.apply_document_changes( + document_changes, + change_annotations, + label=label, + is_refactoring=is_refactoring or self._is_executing_refactoring_command + ), + summary + ) + + @deprecated("use Session.apply_workspace_edit instead") + def apply_workspace_edit_async( + self, edit: WorkspaceEdit, *, label: str | None = None, is_refactoring: bool = False + ) -> Promise[tuple[ApplyWorkspaceEditResult, WorkspaceEditSummary]]: + + def ignore_exception( + x: tuple[ApplyWorkspaceEditResult, WorkspaceEditSummary] | BaseException, + ) -> tuple[ApplyWorkspaceEditResult, WorkspaceEditSummary]: + if isinstance(x, BaseException): + return cast('ApplyWorkspaceEditResult', {}), cast('WorkspaceEditSummary', {}) + return x + + return Promise.wrap_task( + self.create_task(self.apply_workspace_edit(edit, label=label, is_refactoring=is_refactoring)) + ).then(ignore_exception) def _get_view_state_actions(self, uri: DocumentUri, auto_save: str) -> ViewStateActions: """ @@ -1866,9 +2008,9 @@ def do_workspace_diagnostics_async(self) -> None: # The server is probably leaving the request open intentionally, in order to continuously stream updates # via $/progress notifications. continue - self._do_workspace_diagnostics_async(identifier) + self.create_task(self._do_workspace_diagnostics(identifier)) - def _do_workspace_diagnostics_async(self, identifier: DiagnosticsIdentifier) -> None: + async def _do_workspace_diagnostics(self, identifier: DiagnosticsIdentifier) -> None: previous_result_ids: list[PreviousResultId] = [ {'uri': uri, 'value': result_id} for (uri, id_), result_id in self.diagnostics_result_ids.items() if id_ == identifier and result_id is not None @@ -1876,47 +2018,41 @@ def _do_workspace_diagnostics_async(self, identifier: DiagnosticsIdentifier) -> params: WorkspaceDiagnosticParams = {'previousResultIds': previous_result_ids} if identifier is not None: params['identifier'] = identifier - self.workspace_diagnostics_pending_responses[identifier] = self.send_request_async( - Request.workspaceDiagnostic( - params, - on_partial_result=partial(self._on_workspace_diagnostics_async, identifier, reset_pending_response=False)), # noqa: E501 - partial(self._on_workspace_diagnostics_async, identifier), - partial(self._on_workspace_diagnostics_error_async, identifier) - ) - def _on_workspace_diagnostics_async( - self, - identifier: DiagnosticsIdentifier, - response: WorkspaceDiagnosticReport, - *, - reset_pending_response: bool = True - ) -> None: - if reset_pending_response: + self.workspace_diagnostics_pending_responses[identifier] = inflight_request = self.stream( + Request.workspaceDiagnostic(params) + ) + try: + async with aclosing(inflight_request) as stream: + async for partial_response in stream: + for diagnostic_report in partial_response['items']: + uri = normalize_uri(diagnostic_report['uri']) + version = diagnostic_report['version'] + # Skip if outdated + if ( + isinstance(version, int) + and (session_buffer := self.get_session_buffer_for_uri_async(uri)) + and version < session_buffer.last_synced_version + ): + continue + self.diagnostics_result_ids[(uri, identifier)] = diagnostic_report.get('resultId') + if is_workspace_full_document_diagnostic_report(diagnostic_report): + self.handle_diagnostics_async(uri, identifier, version, diagnostic_report['items']) + self.workspace_diagnostics_pending_responses[identifier] = None + except Error as e: + if e.code == LSPErrorCodes.ServerCancelled: + if is_diagnostic_server_cancellation_data(e.data) and e.data['retriggerRequest']: + # Retrigger the request after a short delay, but don't reset the pending response variable for this + # moment, to prevent new requests of this type in the meanwhile. The delay is used in order to + # prevent infinite cycles of cancel -> retrigger, in case the server is busy. + + async def retry_later() -> None: + await asyncio.sleep(WORKSPACE_DIAGNOSTICS_RETRIGGER_DELAY / 1000.0) + await self._do_workspace_diagnostics(identifier) + + self.create_task(retry_later()) + return self.workspace_diagnostics_pending_responses[identifier] = None - for diagnostic_report in response['items']: - uri = normalize_uri(diagnostic_report['uri']) - version = diagnostic_report['version'] - # Skip if outdated - if isinstance(version, int) and (session_buffer := self.get_session_buffer_for_uri_async(uri)) and \ - version < session_buffer.last_synced_version: - continue - self.diagnostics_result_ids[(uri, identifier)] = diagnostic_report.get('resultId') - if is_workspace_full_document_diagnostic_report(diagnostic_report): - self.handle_diagnostics_async(uri, identifier, version, diagnostic_report['items']) - - def _on_workspace_diagnostics_error_async(self, identifier: DiagnosticsIdentifier, error: ResponseError) -> None: - if error['code'] == LSPErrorCodes.ServerCancelled: - data = error.get('data') - if is_diagnostic_server_cancellation_data(data) and data['retriggerRequest']: - # Retrigger the request after a short delay, but don't reset the pending response variable for this - # moment, to prevent new requests of this type in the meanwhile. The delay is used in order to prevent - # infinite cycles of cancel -> retrigger, in case the server is busy. - sublime.set_timeout_async( - lambda: self._do_workspace_diagnostics_async(identifier), - WORKSPACE_DIAGNOSTICS_RETRIGGER_DELAY - ) - return - self.workspace_diagnostics_pending_responses[identifier] = None # --- workspace/didChangeConfiguration ----------------------------------------------------------------------------- @@ -1931,10 +2067,10 @@ def on_server_settings_changed(self, settings: DottedDict) -> None: # --- server request handlers -------------------------------------------------------------------------------------- @request_handler('window/showMessageRequest') - def on_window_show_message_request(self, params: ShowMessageRequestParams) -> Promise[MessageActionItem | None]: + async def on_window_show_message_request(self, params: ShowMessageRequestParams) -> MessageActionItem | None: if mgr := self.manager(): - return mgr.handle_message_request(self.config.name, params) - return Promise.resolve(None) + return await mgr.handle_message_request(self.config.name, params) + return None @notification_handler('window/showMessage') def on_window_show_message(self, params: ShowMessageParams) -> None: @@ -1947,11 +2083,11 @@ def on_window_log_message(self, params: LogMessageParams) -> None: mgr.handle_log_message(self.config.name, params) @request_handler('workspace/workspaceFolders') - def on_workspace_workspace_folders(self, _: None) -> Promise[list[LspWorkspaceFolder]]: - return Promise.resolve([wf.to_lsp() for wf in self._workspace_folders]) + async def on_workspace_workspace_folders(self, _: None) -> list[LspWorkspaceFolder]: + return [wf.to_lsp() for wf in self._workspace_folders] @request_handler('workspace/configuration') - def on_workspace_configuration(self, params: ConfigurationParams) -> Promise[list[LSPAny]]: + async def on_workspace_configuration(self, params: ConfigurationParams) -> list[LSPAny]: items: list[LSPAny] = [] requested_items = params.get("items") or [] for requested_item in requested_items: @@ -1960,17 +2096,20 @@ def on_workspace_configuration(self, params: ConfigurationParams) -> Promise[lis items.append(self._plugin.on_workspace_configuration(requested_item, configuration)) else: items.append(configuration) - return Promise.resolve(sublime.expand_variables(items, self._variables)) + return sublime.expand_variables(items, self._variables) @request_handler('workspace/applyEdit') - def on_workspace_apply_edit(self, params: ApplyWorkspaceEditParams) -> Promise[ApplyWorkspaceEditResult]: - is_refactoring = metadata.get('isRefactoring', False) if (metadata := params.get('metadata')) else False - return self.apply_workspace_edit_async( - params['edit'], label=params.get('label'), is_refactoring=is_refactoring - ).then(itemgetter(0)) + async def on_workspace_apply_edit(self, params: ApplyWorkspaceEditParams) -> ApplyWorkspaceEditResult: + return ( + await self.apply_workspace_edit( + params['edit'], + label=params.get('label'), + is_refactoring=metadata.get('isRefactoring', False) if (metadata := params.get('metadata')) else False, + ) + )[0] @request_handler('workspace/codeLens/refresh') - def on_workspace_code_lens_refresh(self, _: None) -> Promise[None]: + async def on_workspace_code_lens_refresh(self, _: None) -> None: def continue_after_response() -> None: visible_session_buffers, not_visible_session_buffers = self.session_buffers_by_visibility() @@ -1979,11 +2118,10 @@ def continue_after_response() -> None: for session_buffer in not_visible_session_buffers: session_buffer.set_pending_refresh(RequestFlags.CODE_LENS) - sublime.set_timeout_async(continue_after_response) - return Promise.resolve(None) + asyncio.get_running_loop().call_soon(continue_after_response) @request_handler('workspace/semanticTokens/refresh') - def on_workspace_semantic_tokens_refresh(self, _: None) -> Promise[None]: + async def on_workspace_semantic_tokens_refresh(self, _: None) -> None: def continue_after_response() -> None: visible_session_buffers, not_visible_session_buffers = self.session_buffers_by_visibility() @@ -1995,11 +2133,10 @@ def continue_after_response() -> None: for session_buffer in not_visible_session_buffers: session_buffer.set_pending_refresh(RequestFlags.SEMANTIC_TOKENS) - sublime.set_timeout_async(continue_after_response) - return Promise.resolve(None) + asyncio.get_running_loop().call_soon(continue_after_response) @request_handler('workspace/inlayHint/refresh') - def on_workspace_inlay_hint_refresh(self, _: None) -> Promise[None]: + async def on_workspace_inlay_hint_refresh(self, _: None) -> None: def continue_after_response() -> None: visible_session_buffers, not_visible_session_buffers = self.session_buffers_by_visibility() @@ -2011,13 +2148,11 @@ def continue_after_response() -> None: for session_buffer in not_visible_session_buffers: session_buffer.set_pending_refresh(RequestFlags.INLAY_HINT) - sublime.set_timeout_async(continue_after_response) - return Promise.resolve(None) + asyncio.get_running_loop().call_soon(continue_after_response) @request_handler('workspace/diagnostic/refresh') - def on_workspace_diagnostic_refresh(self, _: None) -> Promise[None]: - sublime.set_timeout_async(self._refresh_diagnostics) - return Promise.resolve(None) + async def on_workspace_diagnostic_refresh(self, _: None) -> None: + self._refresh_diagnostics() def _refresh_diagnostics(self) -> None: visible_session_buffers, not_visible_session_buffers = self.session_buffers_by_visibility() @@ -2052,6 +2187,8 @@ def clear_diagnostics_for_uri(self, uri: DocumentUri) -> None: if mgr := self.manager(): mgr.on_diagnostics_updated() + # Keep this request handler as backwards-compatible method that returns a Promise, to ensure Promises keep working + # for now. @request_handler('client/registerCapability') def on_client_register_capability(self, params: RegistrationParams) -> Promise[None]: new_diagnostics_provider = False @@ -2088,7 +2225,7 @@ def on_client_register_capability(self, params: RegistrationParams) -> Promise[N inform = partial(sv.on_capability_added_async, registration_id, capability_path, options) # Inform only after the response is sent, otherwise we might start doing requests for capabilities # which are technically not yet done registering. - sublime.set_timeout_async(inform) + asyncio.get_running_loop().call_soon(inform) if capability_path == "didChangeWatchedFilesProvider": capability_options = cast('DidChangeWatchedFilesRegistrationOptions', options) self.register_file_system_watchers(registration_id, capability_options['watchers']) @@ -2099,11 +2236,11 @@ def continue_after_response() -> None: if new_workspace_diagnostics_provider: self.do_workspace_diagnostics_async() - sublime.set_timeout_async(continue_after_response) + asyncio.get_running_loop().call_soon(continue_after_response) return Promise.resolve(None) @request_handler('client/unregisterCapability') - def on_client_unregister_capability(self, params: UnregistrationParams) -> Promise[None]: + async def on_client_unregister_capability(self, params: UnregistrationParams) -> None: unregistrations = params["unregisterations"] # typo in the official specification for unregistration in unregistrations: registration_id = unregistration["id"] @@ -2122,7 +2259,6 @@ def on_client_unregister_capability(self, params: UnregistrationParams) -> Promi if isinstance(discarded, dict): for sv in self.session_views_async(): sv.on_capability_removed_async(registration_id, discarded) - return Promise.resolve(None) def register_file_system_watchers(self, registration_id: str, watchers: list[FileSystemWatcher]) -> None: if not self._watcher_impl: @@ -2154,7 +2290,7 @@ def unregister_file_system_watchers(self, registration_id: str) -> None: file_watcher.destroy() @request_handler('window/showDocument') - def on_window_show_document(self, params: ShowDocumentParams) -> Promise[ShowDocumentResult]: + async def on_window_show_document(self, params: ShowDocumentParams) -> ShowDocumentResult: uri = params.get("uri") def success(b: bool | sublime.View | None) -> ShowDocumentResult: @@ -2167,14 +2303,13 @@ def success(b: bool | sublime.View | None) -> ShowDocumentResult: return ({"success": b}) if params.get("external"): - return Promise.resolve(success(open_externally(uri))) + return success(open_externally(uri)) # TODO: ST API does not allow us to say "do not focus this new view" - return self.open_uri_async(uri, params.get("selection")).then(success) + return success(await self.try_open_uri(uri, params.get("selection"))) @request_handler('window/workDoneProgress/create') - def on_window_work_done_progress_create(self, params: WorkDoneProgressCreateParams) -> Promise[None]: + async def on_window_work_done_progress_create(self, params: WorkDoneProgressCreateParams) -> None: self._progress[params['token']] = None - return Promise.resolve(None) def _invoke_views(self, request: Request[Any, Any], method: str, *args: Any) -> None: if request.view: @@ -2222,7 +2357,7 @@ def on_progress(self, params: ProgressParams) -> None: token = str(token) request_id = int(token[len(_WORK_DONE_PROGRESS_PREFIX):]) request = self._response_handlers[request_id][0] - self._invoke_views(request, "on_request_progress", request_id, params) + lambda: self._invoke_views(request, "on_request_progress", request_id, params) except (TypeError, IndexError, ValueError, KeyError): # The parse failed so possibility (1) is apparently not applicable. At this point we may still be # dealing with possibility (2). @@ -2253,16 +2388,16 @@ def on_progress(self, params: ProgressParams) -> None: # --- shutdown dance ----------------------------------------------------------------------------------------------- - def end_async(self) -> None: - # TODO: Ensure this function is called only from the async thread + async def end(self) -> list[Exception]: if self.exiting: - return + return [] self.exiting = True if self._plugin: self._plugin.on_session_end_async(None, None) self._plugin = None - for sv in self.session_views_async(): - self.shutdown_session_view_async(sv) + exceptions = await gather_and_flatten_exceptions( + *(self.shutdown_session_view(sv) for sv in self.session_views_async()) + ) self.capabilities.clear() self._registrations.clear() for watcher in self._static_file_watchers: @@ -2272,18 +2407,21 @@ def end_async(self) -> None: watcher.destroy() self._dynamic_file_watchers = {} self.state = ClientStates.STOPPING - self.send_request_async(Request.shutdown(), self._handle_shutdown_result, self._handle_shutdown_result) - - def shutdown_session_view_async(self, session_view: SessionViewProtocol) -> None: + exceptions.extend(await self.cancel_all_tasks()) + try: + await self.request(Request.shutdown()) + except Exception as shutdown_exception: + exceptions.append(shutdown_exception) + finally: + await self.exit() + return exceptions + + async def shutdown_session_view(self, session_view: SessionViewProtocol) -> list[Exception]: for status_key in self._status_messages: session_view.view.erase_status(status_key) - session_view.shutdown_async() + return await session_view.shutdown() - def _handle_shutdown_result(self, _: Any) -> None: - self._progress.clear() - self.exit() - - def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: + async def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: self.exiting = True self.state = ClientStates.STOPPING self.transport = None @@ -2291,58 +2429,126 @@ def on_transport_close(self, exit_code: int, exception: Exception | None) -> Non if self._plugin: self._plugin.on_session_end_async(exit_code, exception) self._plugin = None - if self._initialize_error: - # Override potential exit error with a saved one. - exit_code, exception = self._initialize_error if mgr := self.manager(): - if self._init_callback: - self._init_callback(self, True) - self._init_callback = None - mgr.on_post_exit_async(self, exit_code, exception) + await mgr.on_post_exit(self, exit_code, exception) # --- RPC message handling ---------------------------------------------------------------------------------------- - def send_request_async( - self, - request: Request[P, R], - on_result: Callable[[R], None], - on_error: Callable[[ResponseError], None] | None = None - ) -> int: - """You must call this method from Sublime's worker thread. Callbacks will run in Sublime's worker thread.""" + def request(self, r: Request[P, R]) -> CancellableInflightRequest[R]: + """ + Make a request to the language server. + + You must call this method from the asyncio thread. + + ```py + try: + result = await session.request(Request(...)) + print(result) + except Error as error: + print(error.code) + ``` + """ self.request_id += 1 request_id = self.request_id - if request.progress and isinstance(request.params, dict): - request.params["workDoneToken"] = _WORK_DONE_PROGRESS_PREFIX + str(request_id) - if request.on_partial_result and isinstance(request.params, dict): - request.params["partialResultToken"] = _PARTIAL_RESULT_PROGRESS_PREFIX + str(request_id) - on_error = on_error or (lambda _: None) - self._response_handlers[request_id] = (request, on_result, on_error) - self._invoke_views(request, "on_request_started_async", request_id, request) + loop = asyncio.get_running_loop() + future = loop.create_future() + result = CancellableInflightRequest(future, request_id, self) + if r.progress and isinstance(r.params, dict): + r.params["workDoneToken"] = _WORK_DONE_PROGRESS_PREFIX + str(request_id) + if r.on_partial_result and isinstance(r.params, dict): + r.params["partialResultToken"] = _PARTIAL_RESULT_PROGRESS_PREFIX + str(request_id) + self._response_handlers[request_id] = (r, future.set_result, lambda x: future.set_exception(Error.from_lsp(x))) + self._invoke_views(r, "on_request_started_async", request_id, r) if self._plugin and isinstance(self._plugin, AbstractPlugin): - self._plugin.on_pre_send_request_async(request_id, request) + self._plugin.on_pre_send_request_async(request_id, r) elif self._plugin: - client_request = cast('ClientRequest', cast('object', {'method': request.method, 'params': request.params})) - self._plugin.on_pre_send_request_async(client_request, request.view) - request.params = cast('P', client_request['params']) - self._logger.outgoing_request(request_id, request.method, request.params) - self.send_payload(request.to_payload(request_id)) - return request_id + client_request = cast('ClientRequest', cast('object', {'method': r.method, 'params': r.params})) + self._plugin.on_pre_send_request_async(client_request, r.view) + r.params = cast('P', client_request['params']) + self._logger.outgoing_request(request_id, r.method, r.params) + self.create_task(self.send_payload(r.to_payload(request_id))) + return result + + def stream(self, r: Request[P, R]) -> CancellableInflightStreamingRequest[R]: + """ + Stream partial results from the language server. + + You must call this method from the asyncio thread. + Use in combination with `async for` syntax: + + ```py + try: + async .core.aio.aclosing(session.stream(Request(...))) as stream: + async for partial_result in stream: + print(partial_result) + except Error as error: + print(error.code) + ``` + """ + self.request_id += 1 + request_id = self.request_id + result = CancellableInflightStreamingRequest(request_id, self) + if not isinstance(r.params, dict): + raise TypeError("request should have dict params") + if r.progress: + r.params["workDoneToken"] = _WORK_DONE_PROGRESS_PREFIX + str(request_id) + r.params["partialResultToken"] = _PARTIAL_RESULT_PROGRESS_PREFIX + str(request_id) + r.on_partial_result = result.on_partial_result + self._response_handlers[request_id] = (r, result.on_partial_result, result.on_error) + self._invoke_views(r, "on_request_started_async", request_id, r) + if self._plugin and isinstance(self._plugin, AbstractPlugin): + self._plugin.on_pre_send_request_async(request_id, r) + elif self._plugin: + client_request = cast('ClientRequest', cast('object', {'method': r.method, 'params': r.params})) + self._plugin.on_pre_send_request_async(client_request, r.view) + r.params = cast('P', client_request['params']) + self._logger.outgoing_request(request_id, r.method, r.params) + self.create_task(self.send_payload(r.to_payload(request_id))) + return result + + @deprecated("use Session.request or Session.stream instead") + def send_request_async( + self, + request: Request[P, R], + on_result: Callable[[R], None], + on_error: Callable[[ResponseError], None] | None = None + ) -> int: + """You must call this method from the asyncio loop thread. Callbacks will run in the asyncio thread.""" + result = self.request(request) + + def on_done(future: asyncio.Future[R]) -> None: + if future.cancelled(): + return + if ex := future.exception(): + if callable(on_error) and isinstance(ex, Error): + on_error(ex.to_lsp()) + return + exception_log("Response error is ignored", ex) + return + on_result(future.result()) + + result._future.add_done_callback(on_done) + return result.id + + @deprecated("use Session.request or Session.stream instead") def send_request( - self, - request: Request[P, R], - on_result: Callable[[R], None], - on_error: Callable[[ResponseError], None] | None = None, + self, + request: Request[P, R], + on_result: Callable[[R], None], + on_error: Callable[[ResponseError], None] | None = None, ) -> None: - """You can call this method from any thread. Callbacks will run in Sublime's worker thread.""" - sublime.set_timeout_async(partial(self.send_request_async, request, on_result, on_error)) + """You can call this method from any thread. Callbacks will run in the asyncio thread.""" + call_soon_threadsafe(lambda: self.send_request_async(request, on_result, on_error)) + @deprecated("use Session.request or Session.stream instead") def send_request_task(self, request: Request[P, R]) -> Promise[R | Error]: task: PackagedTask[Any] = Promise.packaged_task() promise, resolver = task - self.send_request_async(request, resolver, lambda x: resolver(Error.from_lsp(x))) + self.send_request(request, resolver, lambda x: resolver(Error.from_lsp(x))) return promise + @deprecated("use Session.request or Session.stream instead") def send_request_task_2(self, request: Request[P, R]) -> tuple[Promise[R | Error], int]: task: PackagedTask[R | Error] = Promise.packaged_task() promise, resolver = task @@ -2357,7 +2563,8 @@ def cancel_request_async(self, request_id: int) -> None: self._invoke_views(request, "on_request_canceled_async", request_id) self._response_handlers[request_id] = (request, lambda *args: None, lambda *args: None) - def send_notification(self, notification: Notification[P]) -> None: + async def notify(self, notification: Notification[P]) -> None: + """Send a notification to the server.""" if self._plugin and isinstance(self._plugin, AbstractPlugin): self._plugin.on_pre_send_notification_async(notification) elif self._plugin: @@ -2366,29 +2573,37 @@ def send_notification(self, notification: Notification[P]) -> None: self._plugin.on_pre_send_notification_async(client_notification) notification.params = cast('P', client_notification['params']) self._logger.outgoing_notification(notification.method, notification.params) - self.send_payload(notification.to_payload()) + await self.send_payload(notification.to_payload()) + + def send_notification_async(self, notification: Notification[P]) -> None: + """Send a notification to the server. Not thread safe. Must be called from the asyncio thread.""" + self.create_task(self.notify(notification)) + + def send_notification(self, notification: Notification[P]) -> None: + """Send a notification to the server. Thread safe. Can be called from any thread.""" + self.create_task_threadsafe(self.notify(notification)) - def send_response(self, response: Response[P]) -> None: + async def send_response(self, response: Response[P]) -> None: self._logger.outgoing_response(response.request_id, response.result) - self.send_payload(response.to_payload()) + await self.send_payload(response.to_payload()) - def send_error_response(self, request_id: int | str, error: Error) -> None: + async def send_error_response(self, request_id: int | str, error: Error) -> None: self._logger.outgoing_error_response(request_id, error) - self.send_payload({'jsonrpc': '2.0', 'id': request_id, 'error': error.to_lsp()}) + await self.send_payload({'jsonrpc': '2.0', 'id': request_id, 'error': error.to_lsp()}) - def exit(self) -> None: - self.send_notification(Notification.exit()) + async def exit(self) -> None: + await self.notify(Notification.exit()) if self.transport: - self.transport.close() + await self.transport.close() self.transport = None - def send_payload(self, payload: JSONRPCMessage) -> None: + async def send_payload(self, payload: JSONRPCMessage) -> None: try: - self.transport.send(payload) # pyright: ignore[reportOptionalMemberAccess] + await self.transport.send(payload) # pyright: ignore[reportOptionalMemberAccess] except AttributeError: pass - def deduce_payload( + async def deduce_payload( self, payload: JSONRPCMessage ) -> tuple[Callable | None, Any, str | int | None, str | None, str | None]: @@ -2400,7 +2615,7 @@ def deduce_payload( req_id = payload["id"] self._logger.incoming_request(req_id, method, result) if handler is None: - self.send_error_response(req_id, Error(ErrorCodes.MethodNotFound, method)) + await self.send_error_response(req_id, Error(ErrorCodes.MethodNotFound, method)) else: return (handler, result, req_id, "request", method) else: @@ -2434,31 +2649,30 @@ def deduce_payload( debug("Unknown payload type: ", payload) # pyright: ignore[reportUnreachable] return (None, None, None, None, None) - def on_payload(self, payload: JSONRPCMessage) -> None: - handler, result, req_id, typestr, method = self.deduce_payload(payload) + async def on_payload(self, payload: JSONRPCMessage) -> None: + handler, result, req_id, typestr, method = await self.deduce_payload(payload) if handler: - result_promise: Promise[Response[Any]] | None = None try: if req_id is None: - # notification or response + # server notification or (response to) client request handler(result) else: - # request + # server request try: - result_promise = cast('Promise[Response[Any]] | None', handler(result, req_id)) + await self.send_response( + self._handle_plugin_on_pre_send_response_async( + method, result, await handler(result, req_id) + ) + ) except Error as err: - self.send_error_response(req_id, err) - return + await self.send_error_response(req_id, err) except Exception as ex: - self.send_error_response(req_id, Error.from_exception(ex)) + await self.send_error_response(req_id, Error.from_exception(ex)) raise except Exception as err: exception_log(f"Error handling {typestr}", err) - return - if isinstance(result_promise, Promise): - result_promise \ - .then(lambda r: self._handle_plugin_on_pre_send_response_async(method, result, r)) \ - .then(self.send_response) + else: + debug("no handler found for payload:", payload) def _handle_plugin_on_pre_send_response_async( self, method: str | None, params: Any, response: Response[Any] @@ -2471,7 +2685,7 @@ def _handle_plugin_on_pre_send_response_async( def response_handler( self, response_id: str | int, response: JSONRPCMessage ) -> tuple[Callable[[ResponseError], None], str | None, Any, bool]: - matching_handler = self._response_handlers.pop(response_id) + matching_handler = self._response_handlers.pop(response_id, None) if not matching_handler: error = {"code": ErrorCodes.InvalidParams, "message": f"unknown response ID {response_id}"} return (print_to_status_bar, None, error, True) diff --git a/plugin/core/signature_help.py b/plugin/core/signature_help.py index 9454b5842..426724eb1 100644 --- a/plugin/core/signature_help.py +++ b/plugin/core/signature_help.py @@ -3,6 +3,7 @@ from ...protocol import SignatureHelp from ...protocol import SignatureHelpTriggerKind from ...protocol import SignatureInformation +from .aio import call_soon_threadsafe from .logging import debug from .registry import LspTextCommand from .views import FORMAT_MARKUP_CONTENT @@ -13,10 +14,10 @@ from typing import TypedDict import html import re -import sublime if TYPE_CHECKING: from .constants import MarkdownLangMap + import sublime class SignatureHelpStyle(TypedDict): @@ -45,7 +46,7 @@ def want_event(self) -> bool: def run(self, _: sublime.Edit) -> None: if listener := self.get_listener(): - sublime.set_timeout_async(lambda: listener.do_signature_help_async(SignatureHelpTriggerKind.Invoked)) + call_soon_threadsafe(listener.do_signature_help_async, SignatureHelpTriggerKind.Invoked) class SigHelp: diff --git a/plugin/core/transports.py b/plugin/core/transports.py index 40b11c416..ce0d06e45 100644 --- a/plugin/core/transports.py +++ b/plugin/core/transports.py @@ -3,34 +3,27 @@ from .constants import ST_PLATFORM from .logging import debug from .logging import exception_log -from .promise import PackagedTask -from .promise import Promise from abc import ABC from abc import abstractmethod -from contextlib import closing -from functools import partial -from queue import Queue from typing import Any from typing import Callable from typing import final -from typing import IO from typing import TYPE_CHECKING from typing_extensions import override +import asyncio +import asyncio.subprocess import contextlib -import http.client import json import os import shutil import socket import sublime import subprocess -import threading import time import weakref if TYPE_CHECKING: from .protocol import JSONRPCMessage - from io import BufferedIOBase try: import orjson @@ -48,7 +41,7 @@ class StopLoopError(Exception): class TransportConfig(ABC): - """The object that does the actual RPC communication.""" + """Config object that can start the transport.""" @staticmethod def resolve_launch_config( @@ -56,6 +49,10 @@ def resolve_launch_config( env: dict[str, str] | None, variables: dict[str, str], ) -> LaunchConfig: + """ + Given the state of this transport configuration, and the provided command/env/vars, create a small object + that has resolved all variables to a concrete command to run. + """ command = sublime.expand_variables(command, variables) command = [os.path.expanduser(arg) for arg in command] resolved_env = os.environ.copy() @@ -68,7 +65,7 @@ def resolve_launch_config( return LaunchConfig(command, resolved_env) @abstractmethod - def start( + async def start( self, command: list[str] | None, env: dict[str, str] | None, @@ -76,6 +73,7 @@ def start( variables: dict[str, str], callbacks: TransportCallbacks, ) -> TransportWrapper: + """Start a communication channel with the language server.""" raise NotImplementedError @@ -87,7 +85,7 @@ class StdioTransportConfig(TransportConfig): """ @override - def start( + async def start( self, command: list[str] | None, env: dict[str, str] | None, @@ -97,7 +95,8 @@ def start( ) -> TransportWrapper: if not command: raise RuntimeError('missing "command" to start a child process for running the language server') - process = TransportConfig.resolve_launch_config(command, env, variables).start( + launch = TransportConfig.resolve_launch_config(command, env, variables) + process = await launch.start( cwd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, @@ -107,8 +106,9 @@ def start( raise Exception('Failed to create transport config due to not being able to pipe stdio') return TransportWrapper( callback_object=callbacks, - transport=FileObjectTransport(encode_json, decode_json, process.stdout, process.stdin), + transport=StreamTransport(encode_json, decode_json, process.stdout, process.stdin), process=process, + process_args=launch.command, error_reader=ErrorReader(callbacks, process.stderr), ) @@ -129,7 +129,7 @@ def __init__(self, port: int | None) -> None: raise RuntimeError("invalid port number") @override - def start( + async def start( self, command: list[str] | None, env: dict[str, str] | None, @@ -138,12 +138,14 @@ def start( callbacks: TransportCallbacks, ) -> TransportWrapper: port = _add_and_resolve_port_variable(variables, self._port) + launch: LaunchConfig | None = None if command: - process = TransportConfig.resolve_launch_config(command, env, variables).start( + launch = TransportConfig.resolve_launch_config(command, env, variables) + process = await launch.start( cwd, - stdout=subprocess.PIPE, - stdin=subprocess.DEVNULL, - stderr=subprocess.STDOUT, + stdout=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.STDOUT, ) if not process.stdout: raise Exception('Failed to create transport config due to not being able to pipe stdout') @@ -151,27 +153,37 @@ def start( else: process = None error_reader = None - return TransportWrapper( - callback_object=callbacks, - transport=SocketTransport(encode_json, decode_json, self._connect(port)), - process=process, - error_reader=error_reader, - ) - - def _connect(self, port: int) -> socket.socket: start_time = time.time() - while time.time() - start_time < TCP_CONNECT_TIMEOUT: + current_time = start_time + delta = 0 + while delta < TCP_CONNECT_TIMEOUT: + time_left = TCP_CONNECT_TIMEOUT - delta try: - return socket.create_connection(('localhost', port)) + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host='127.0.0.1', port=port), timeout=time_left + ) + return TransportWrapper( + callback_object=callbacks, + transport=StreamTransport(encode_json, decode_json, reader, writer), + process=process, + process_args=launch.command if launch else None, + error_reader=error_reader, + ) except ConnectionRefusedError: - pass - raise RuntimeError("failed to connect") + # Can happen when the language server is still starting. Just wait a bit and retry. + await asyncio.sleep(TCP_CONNECT_TIMEOUT / 10) + except TimeoutError: + # We passed the TCP_CONNECT_TIMEOUT and the process didn't respond. + break + current_time = time.time() + delta = current_time - start_time + raise RuntimeError(f"Failed to connect to TCP port {port}") class TcpServerTransportConfig(TransportConfig): """ Transport for communicating to a language server over TCP. The difference, however, is that this transport will - start a TCP listener socket accepting new TCP cliet connections. Once a client connects to this text editor acting + start a TCP listener socket accepting new TCP client connections. Once a client connects to this text editor acting as the TCP server, we'll assume it's the language server we just launched. As such, this tranport requires a "command" for starting the language server subprocess. """ @@ -182,7 +194,7 @@ def __init__(self, port: int | None) -> None: raise RuntimeError("invalid port number") @override - def start( + async def start( self, command: list[str] | None, env: dict[str, str] | None, @@ -194,142 +206,143 @@ def start( raise RuntimeError('missing "command" to start a child process for running the language server') port = _add_and_resolve_port_variable(variables, self._port) launch = TransportConfig.resolve_launch_config(command, env, variables) - listener_socket = socket.socket() - listener_socket.bind(('localhost', port)) - listener_socket.settimeout(TCP_CONNECT_TIMEOUT) - listener_socket.listen(1) - process_task: PackagedTask[subprocess.Popen[bytes] | None] = Promise.packaged_task() - process_promise, resolve_process = process_task - - # We need to be able to start the process while also awaiting a client connection. - def start_in_background() -> None: - # Sleep for one second, because the listener socket needs to be in the "accept" state before starting the - # subprocess. This is hacky, and will get better when we can use asyncio. - time.sleep(1) - resolve_process(launch.start( - cwd, stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)) - - thread = threading.Thread(target=start_in_background) - thread.start() - with closing(listener_socket): - # Await one client connection (blocking!) - sock, _ = listener_socket.accept() - thread.join() - process = process_promise.value - if not process: - raise Exception('Failed to create transport config from separate thread.') - if not process.stderr: - raise Exception('Failed to create transport config due to not being able to pipe stderr') - error_reader = ErrorReader(callbacks, process.stderr) - return TransportWrapper( - callback_object=callbacks, - transport=SocketTransport(encode_json, decode_json, sock), - process=process, - error_reader=error_reader, - ) + + class ClientConnectedCallback: + def __init__(self) -> None: + self.cv = asyncio.Condition() + self.wrapper: TransportWrapper | None = None + self.process: asyncio.subprocess.Process | None = None + self.error_reader: ErrorReader | None = None + + async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + async with self.cv: + transport = StreamTransport(encode_json, decode_json, reader, writer) + self.wrapper = TransportWrapper(callbacks, transport, self.process, command, self.error_reader) + self.cv.notify() + + callback = ClientConnectedCallback() + async with callback.cv: + server = await asyncio.start_server(callback, host='127.0.0.1', port=port, family=socket.AF_INET) + try: + await server.start_serving() + process = await launch.start( + cwd, + stdout=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.STDOUT, + ) + assert process.stdout + callback.process = process + callback.error_reader = ErrorReader(callbacks, process.stdout) + try: + await asyncio.wait_for(callback.cv.wait(), timeout=TCP_CONNECT_TIMEOUT) + except Exception: + process.kill() + await process.wait() + raise + finally: + server.close() + await server.wait_closed() + assert callback.wrapper + return callback.wrapper # --- Transports ------------------------------------------------------------------------------------------------------- class TransportCallbacks: - def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: ... + async def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: ... - def on_payload(self, payload: JSONRPCMessage) -> None: ... + async def on_payload(self, payload: JSONRPCMessage) -> None: ... def on_stderr_message(self, message: str) -> None: ... class Transport(ABC): - def __init__( - self, - encoder: Callable[[JSONRPCMessage], bytes], - decoder: Callable[[bytes], JSONRPCMessage] - ) -> None: + def __init__(self, encoder: Callable[[JSONRPCMessage], bytes], decoder: Callable[[bytes], JSONRPCMessage]) -> None: self._encoder = encoder self._decoder = decoder @abstractmethod - def read(self) -> JSONRPCMessage | None: + async def read(self) -> JSONRPCMessage | None: raise NotImplementedError @abstractmethod - def write(self, payload: JSONRPCMessage) -> None: + async def write(self, payload: JSONRPCMessage) -> None: raise NotImplementedError @abstractmethod - def write_bytes(self, payload: bytes) -> None: + async def write_bytes(self, payload: bytes) -> None: raise NotImplementedError @abstractmethod - def close(self) -> None: + async def close(self) -> None: raise NotImplementedError -class FileObjectTransport(Transport): +async def parse_headers(reader: asyncio.StreamReader) -> dict[str, str]: + headers: dict[str, str] = {} + try: + headers_bytes = (await reader.readuntil(b'\r\n\r\n')).decode("ascii").rstrip() + for line in headers_bytes.split("\r\n"): + key, value = line.split(":", 1) + headers[key.lower()] = value + except asyncio.exceptions.IncompleteReadError: + # May happen when shutting down. parse_content_length will then return None, + # which will cause the read loop to stop. + pass + return headers + + +async def parse_content_length(reader: asyncio.StreamReader) -> int | None: + headers = await parse_headers(reader) + content_length = headers.get("content-length") + return int(content_length) if content_length else None + + +class StreamTransport(Transport): def __init__( self, encoder: Callable[[JSONRPCMessage], bytes], decoder: Callable[[bytes], JSONRPCMessage], - reader: IO[bytes] | BufferedIOBase, - writer: IO[bytes] | BufferedIOBase, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, ) -> None: super().__init__(encoder, decoder) self._reader = reader self._writer = writer @override - def read(self) -> JSONRPCMessage: - headers: http.client.HTTPMessage | None = None - try: - headers = http.client.parse_headers(self._reader) - content_length = headers.get("Content-Length") - if not isinstance(content_length, str): - raise TypeError("Missing Content-Length header") - body = self._reader.read(int(content_length)) - except TypeError as ex: - if str(headers) == "\n": - # Expected on process stopping. Gracefully stop the transport. - raise StopLoopError from None - # Propagate server's output to the UI. - raise Exception(f"Unexpected payload in server's stdout:\n\n{headers}") from ex + async def read(self) -> JSONRPCMessage: + content_length = await parse_content_length(self._reader) + if content_length is None: + raise StopLoopError + body = await self._reader.readexactly(content_length) try: return self._decoder(body) except Exception as ex: raise Exception(f"JSON decode error: {ex}") from ex @override - def write(self, payload: JSONRPCMessage) -> None: + async def write(self, payload: JSONRPCMessage) -> None: body = self._encoder(payload) self._writer.writelines((f"Content-Length: {len(body)}\r\n\r\n".encode("ascii"), body)) - self._writer.flush() + try: + await self._writer.drain() + except ConnectionResetError: + # Can happen when the lang server is shut down or the connection is severed in some way. Just return, + # there's other logic that will make the transport shut down. + pass @override - def write_bytes(self, payload: bytes) -> None: + async def write_bytes(self, payload: bytes) -> None: self._writer.write(payload) - self._writer.flush() + await self._writer.drain() @override - def close(self) -> None: + async def close(self) -> None: self._writer.close() - self._reader.close() - - -class SocketTransport(FileObjectTransport): - def __init__( - self, - encoder: Callable[[JSONRPCMessage], bytes], - decoder: Callable[[bytes], JSONRPCMessage], - sock: socket.socket - ) -> None: - reader_writer_pair = sock.makefile("rwb") - super().__init__(encoder, decoder, reader_writer_pair, reader_writer_pair) - self._socket = sock - - @override - def close(self) -> None: - super().close() - self._socket.close() + await self._writer.wait_closed() # --- TransportWrapper ------------------------------------------------------------------------------------------------- @@ -348,132 +361,104 @@ def __init__( self, callback_object: TransportCallbacks, transport: Transport, - process: subprocess.Popen[bytes] | None, + process: asyncio.subprocess.Process | None, + process_args: list[str] | None, error_reader: ErrorReader | None, ) -> None: - self._closed = False self._callback_object = weakref.ref(callback_object) - self._transport = transport + self._transport: Transport | None = transport self._process = process - self._error_reader = error_reader - self._reader_thread = threading.Thread(target=self._read_loop) - self._writer_thread = threading.Thread(target=self._write_loop) - self._send_queue: Queue[JSONRPCMessage | bytes | None] = Queue(0) - self._reader_thread.start() - self._writer_thread.start() + self._process_args = process_args + self._error_reader: ErrorReader | None = error_reader + self._task = asyncio.get_running_loop().create_task(self._read_loop()) @property - def process_args(self) -> Any: - return self._process.args if self._process else None - - def send(self, payload: JSONRPCMessage) -> None: - self._send_queue.put_nowait(payload) - - def send_bytes(self, payload: bytes) -> None: - self._send_queue.put_nowait(payload) - - def close(self) -> None: - if not self._closed: - self._closed = True - self._send_queue.put_nowait(None) - _join_thread(self._writer_thread) - _join_thread(self._reader_thread) - if self._error_reader: - self._error_reader.on_transport_close() - self._error_reader = None - if self._transport: - self._transport.close() - self._transport = None - - def _read_loop(self) -> None: - exception = None + def process_args(self) -> list[str] | None: + """ + The arguments for the process launched by this wrapper, or None if there is no process launched (such as with + a remote TCP/websocket connection). + """ + return self._process_args + + async def send(self, payload: JSONRPCMessage) -> None: + if self._transport: + await self._transport.write(payload) + + async def send_bytes(self, payload: bytes) -> None: + if self._transport: + await self._transport.write_bytes(payload) + + async def close(self) -> None: + if self._error_reader: + self._error_reader.on_transport_close() + self._error_reader = None + if self._transport: + await self._transport.close() + self._transport = None + + async def _read_loop(self) -> None: + exception: Exception | None = None try: while self._transport: - if (payload := self._transport.read()) is None: + if (payload := await self._transport.read()) is None: continue - - def invoke(p: JSONRPCMessage) -> None: - if self._closed: - return - if callback_object := self._callback_object(): - callback_object.on_payload(p) - - sublime.set_timeout_async(partial(invoke, payload)) - except (AttributeError, BrokenPipeError, StopLoopError): + if callback_object := self._callback_object(): + await callback_object.on_payload(payload) + except (AttributeError, BrokenPipeError, StopLoopError, TypeError): + # TypeError happens when `callback_object` becomes None. + # It can become `None` even when the if-condition above that passes. pass except Exception as ex: + exception_log("unexpected exception while stopping transport", ex) exception = ex - if exception: - self._end(exception) - else: - self._send_queue.put_nowait(None) - - def _end(self, exception: Exception | None) -> None: - exit_code = 0 + exit_code: int | None = None if self._process: if not exception: try: # Allow the process to stop itself. - exit_code = self._process.wait(1) - except (AttributeError, ProcessLookupError, subprocess.TimeoutExpired): + exit_code = await asyncio.wait_for(self._process.wait(), timeout=1) + except (AttributeError, ProcessLookupError, asyncio.TimeoutError): pass - if self._process.poll() is None: + if exit_code is None: try: # The process didn't stop itself. Terminate! self._process.kill() # still wait for the process to die, or zombie processes might be the result # Ignore the exit code in this case, it's going to be something non-zero because we sent SIGKILL. - self._process.wait() + await self._process.wait() except (AttributeError, ProcessLookupError): pass except Exception as ex: exception = ex # TODO: Old captured exception is overwritten - - def invoke() -> None: - callback_object = self._callback_object() - if callback_object: - callback_object.on_transport_close(exit_code, exception) - - sublime.set_timeout_async(invoke) - self.close() - - def _write_loop(self) -> None: - exception: Exception | None = None - try: - while self._transport: - if (d := self._send_queue.get()) is None: - break - if isinstance(d, bytes): - self._transport.write_bytes(d) - else: - self._transport.write(d) - except (BrokenPipeError, AttributeError): - pass - except Exception as ex: - exception = ex - self._end(exception) + if callback_object := self._callback_object(): + await callback_object.on_transport_close(exit_code or 0, exception) + await self.close() class LaunchConfig: + """Small object that can start a process.""" + __slots__ = ("command", "env") def __init__(self, command: list[str], env: dict[str, str] | None = None) -> None: self.command: list[str] = command self.env: dict[str, str] = env or {} - def start( + async def start( self, cwd: str | None, stdin: int, stdout: int, stderr: int, - ) -> subprocess.Popen[bytes]: + ) -> asyncio.subprocess.Process: + """Start a process.""" startupinfo = _fixup_startup_args(self.command) - return _start_subprocess(self.command, stdin, stdout, stderr, startupinfo, self.env, cwd) + return await _start_subprocess(self.command, stdin, stdout, stderr, startupinfo, self.env, cwd) # --- Utils ------------------------------------------------------------------------------------------------------- + class ErrorReader: """ Relays log messages from a raw stream to a (subclass of) TransportCallbacks. @@ -483,28 +468,28 @@ class ErrorReader: via a socket, while it listens for log messages on the stdout/stderr streams of a spawned child process. """ - def __init__(self, callback_object: TransportCallbacks, reader: IO[bytes]) -> None: + def __init__(self, callback_object: TransportCallbacks, reader: asyncio.StreamReader) -> None: self._callback_object = weakref.ref(callback_object) self._reader = reader - self._thread = threading.Thread(target=self._loop) - self._thread.start() + self._task = asyncio.get_running_loop().create_task(self._loop()) def on_transport_close(self) -> None: self._reader = None - _join_thread(self._thread) + self._task.cancel() - def _loop(self) -> None: + async def _loop(self) -> None: try: while self._reader: - message = self._reader.readline().decode("utf-8", "replace") - if not message: - continue - callback_object = self._callback_object() - if callback_object: + raw = await self._reader.readline() + if not raw: + break + message = raw.decode("utf-8", "replace") + if callback_object := self._callback_object(): callback_object.on_stderr_message(message.rstrip()) else: break - except (BrokenPipeError, AttributeError): + except (BrokenPipeError, AttributeError, asyncio.CancelledError): + # debug(f"exiting from ErrorReader._loop with expected error (which is: {type(ex)}, message: {ex})") pass except Exception as ex: exception_log("unexpected exception type in error reader", ex) @@ -531,21 +516,17 @@ def decode_json(message: bytes) -> JSONRPCMessage: # --- Internal --------------------------------------------------------------------------------------------------------- -g_subprocesses: weakref.WeakSet[subprocess.Popen[bytes]] = weakref.WeakSet() +g_subprocesses: weakref.WeakSet[asyncio.subprocess.Process] = weakref.WeakSet() -def kill_all_subprocesses() -> None: +async def kill_all_subprocesses() -> None: subprocesses = list(g_subprocesses) for p in subprocesses: try: p.kill() except Exception: pass - for p in subprocesses: - try: - p.wait() - except Exception: - pass + await asyncio.gather(*[p.wait() for p in subprocesses]) def _fixup_startup_args(args: list[str]) -> Any: @@ -568,7 +549,7 @@ def _fixup_startup_args(args: list[str]) -> Any: return startupinfo -def _start_subprocess( +async def _start_subprocess( args: list[str], stdin: int, stdout: int, @@ -576,10 +557,10 @@ def _start_subprocess( startupinfo: Any, env: dict[str, str], cwd: str | None, -) -> subprocess.Popen[bytes]: +) -> asyncio.subprocess.Process: debug(f"starting {args} in {cwd or os.getcwd()}") - process = subprocess.Popen( - args=args, + process = await asyncio.create_subprocess_exec( + *args, stdin=stdin, stdout=stdout, stderr=stderr, @@ -603,12 +584,3 @@ def _add_and_resolve_port_variable(variables: dict[str, str], port: int | None) port = _find_free_port() variables["port"] = str(port) return port - - -def _join_thread(t: threading.Thread) -> None: - if t.ident == threading.current_thread().ident: - return - try: - t.join(2) - except TimeoutError as ex: - exception_log(f"failed to join {t.name} thread", ex) diff --git a/plugin/core/types.py b/plugin/core/types.py index 12b27033a..dd83579e0 100644 --- a/plugin/core/types.py +++ b/plugin/core/types.py @@ -8,6 +8,8 @@ from ...protocol import TextDocumentSyncKind from ...protocol import TextDocumentSyncOptions from ...protocol import URI +from .aio import run_coroutine_threadsafe +from .aio import TaskContainer from .collections import DottedDict from .constants import LANGUAGE_IDENTIFIERS from .constants import MarkdownLangMap @@ -40,6 +42,7 @@ from wcmatch.glob import globmatch from wcmatch.glob import GLOBSTAR from wcmatch.glob import IGNORECASE +import asyncio import contextlib import fnmatch import os @@ -159,10 +162,11 @@ def sublime_pattern_to_glob(pattern: str, *, is_directory_pattern: bool, root_pa return glob -def debounced(f: Callable[[], Any], timeout_ms: int = 0, condition: Callable[[], bool] = lambda: True, - async_thread: bool = False) -> None: +def debounced(f: Callable[[], Any], timeout_ms: int = 0, condition: Callable[[], bool] = lambda: True) -> None: """ - Possibly run a function at a later point in time, either on the async thread or on the main thread. + Possibly run a function at a later point in time. Always on the asyncio thread. + + Note: use asyncio.sleep(x) and simple condition checking if you're already running in an `async` function. :param f: The function to possibly run. Its return type is discarded. :param timeout_ms: The time in milliseconds after which to possibly to run the function @@ -171,12 +175,12 @@ def debounced(f: Callable[[], Any], timeout_ms: int = 0, condition: Callable[[], main thread """ - def run() -> None: + async def run() -> None: + await asyncio.sleep(timeout_ms / 1000.0) if condition(): f() - runner = sublime.set_timeout_async if async_thread else sublime.set_timeout - runner(run, timeout_ms) + run_coroutine_threadsafe(run()) class SettingsRegistration: @@ -209,12 +213,11 @@ class DebouncerNonThreadSafe: When calling `debounce()` multiple times, if the time span between calls is shorter than the specified `timeout_ms`, the callback function will only be called once, after `timeout_ms` since the last call. - This implementation is not thread safe. You must ensure that `debounce()` is called from the same thread as - was chosen during initialization through the `async_thread` argument. + This implementation is not thread safe. You must ensure that `debounce()` is called from the asyncio thread. """ - def __init__(self, async_thread: bool) -> None: - self._async_thread = async_thread + def __init__(self, task_container: TaskContainer) -> None: + self._task_container = task_container self._current_id = -1 self._next_id = 0 @@ -222,23 +225,23 @@ def debounce( self, f: Callable[[], None], timeout_ms: int = 0, condition: Callable[[], bool] = lambda: True ) -> None: """ - Possibly run a function at a later point in time on the thread chosen during initialization. + Possibly run a function at a later point in time on the asyncio thread. :param f: The function to possibly run :param timeout_ms: The time in milliseconds after which to possibly to run the function :param condition: The condition that must evaluate to True in order to run the function """ - def run(debounce_id: int) -> None: + async def run(debounce_id: int) -> None: + await asyncio.sleep(timeout_ms / 1000.0) if debounce_id != self._current_id: return if condition(): f() - runner = sublime.set_timeout_async if self._async_thread else sublime.set_timeout current_id = self._current_id = self._next_id self._next_id += 1 - runner(lambda: run(current_id), timeout_ms) + self._task_container.create_task(run(current_id)) def cancel_pending(self) -> None: self._current_id = -1 diff --git a/plugin/core/windows.py b/plugin/core/windows.py index d51785d34..e160e7860 100644 --- a/plugin/core/windows.py +++ b/plugin/core/windows.py @@ -14,6 +14,9 @@ from ..api import LspPlugin from ..api import OnPreStartContext from ..api import PluginStartError +from .aio import call_soon_threadsafe +from .aio import gather_and_flatten_exceptions +from .aio import run_coroutine_threadsafe from .configurations import RETRY_COUNT_TIMEDELTA from .configurations import RETRY_MAX_COUNT from .configurations import WindowConfigChangeListener @@ -21,13 +24,13 @@ from .constants import MESSAGE_TYPE_LEVELS from .logging import debug from .logging import exception_log +from .logging import exceptions_log from .message_request_handler import MessageRequestHandler from .panels import LOG_LINES_LIMIT_SETTING_NAME from .panels import MAX_LOG_LINES_LIMIT_OFF from .panels import MAX_LOG_LINES_LIMIT_ON from .panels import PanelManager from .panels import PanelName -from .promise import Promise from .protocol import Error from .protocol import Point from .sessions import AbstractViewListener @@ -49,7 +52,6 @@ from .workspace import ProjectFolders from .workspace import sorted_workspace_folders from .workspace import WorkspaceFolder -from collections import deque from datetime import datetime from subprocess import CalledProcessError from time import perf_counter @@ -59,6 +61,7 @@ from typing_extensions import override from weakref import ref from weakref import WeakSet +import asyncio import functools import json import sublime @@ -88,12 +91,10 @@ class WindowManager(Manager, WindowConfigChangeListener, ViewStatusHandler): def __init__(self, window: sublime.Window, workspace: ProjectFolders, config_manager: WindowConfigManager) -> None: self._window = window self._config_manager = config_manager + self._start_lock: asyncio.Lock | None = None self._sessions: set[Session] = set() self._workspace = workspace - self._pending_listeners: deque[AbstractViewListener] = deque() self._listeners: WeakSet[AbstractViewListener] = WeakSet() - self._new_listener: AbstractViewListener | None = None - self._new_session: Session | None = None self._panel_code_phantoms: sublime.PhantomSet | None = None self._server_log: list[tuple[str, str]] = [] self.panel_manager: PanelManager | None = PanelManager(self._window) @@ -108,6 +109,7 @@ def __init__(self, window: sublime.Window, workspace: ProjectFolders, config_man self._config_manager.add_change_listener(self) @property + @override def window(self) -> sublime.Window: return self._window @@ -153,9 +155,16 @@ def register_listener_async(self, listener: AbstractViewListener) -> None: # Update workspace folders in case the user have changed those since window was created. # There is no currently no notification in ST that would notify about folder changes. self.update_workspace_folders_async() - self._pending_listeners.appendleft(listener) - if self._new_listener is None: - self._dequeue_listener_async() + for config in self._config_manager.match_view(listener.view, self._workspace.get_workspace_folders()): + if plugin := get_plugin(config.name): + if issubclass(plugin, LspPlugin): + context = IsApplicableContext(config, listener.view, self._workspace.get_workspace_folders()) + if plugin.is_applicable_async(context): + run_coroutine_threadsafe(self.start(config, listener)) + elif plugin.is_applicable(listener.view, config): + run_coroutine_threadsafe(self.start(config, listener)) + else: + run_coroutine_threadsafe(self.start(config, listener)) def unregister_listener_async(self, listener: AbstractViewListener) -> None: self._listeners.discard(listener) @@ -169,13 +178,10 @@ def listener_for_view(self, view: sublime.View) -> AbstractViewListener | None: return listener return None - def recheck_is_applicable_async(self, view: sublime.View, config_name: str) -> None: + async def recheck_is_applicable(self, view: sublime.View, config_name: str) -> None: if not (listener := self.listener_for_view(view)): debug(f'No listener for view {view}') return - if listener == self._new_listener: - debug(f'Already starting relevant sessions for view {view}.') - return scheme = parse_uri(listener.get_uri())[0] if (config := self._config_manager.get_config(config_name)) and config.enabled: is_applicable = config.match_view(view, scheme, self.window, self.workspace_folders) @@ -184,62 +190,11 @@ def recheck_is_applicable_async(self, view: sublime.View, config_name: str) -> N if is_applicable and not session_view: listener.on_session_initialized_async(session) elif not is_applicable and session_view: - session.shutdown_session_view_async(session_view) + exceptions_log("Error", await session.shutdown_session_view(session_view)) elif is_applicable: - self.start_async(config, view) - if self._new_session: - self._sessions.add(self._new_session) - listener.on_session_initialized_async(self._new_session) - self._new_session = None - - def _dequeue_listener_async(self) -> None: - listener: AbstractViewListener | None = None - if self._new_listener is not None: - listener = self._new_listener - # debug("re-checking listener", listener) - self._new_listener = None - else: - try: - listener = self._pending_listeners.pop() - if not listener.view.is_valid(): - # debug("listener", listener, "is no longer valid") - self._dequeue_listener_async() - return - # debug("adding new pending listener", listener) - self._listeners.add(listener) - except IndexError: - # We have handled all pending listeners. - self._new_session = None - return - if self._new_session: - self._sessions.add(self._new_session) - self._publish_sessions_to_listener_async(listener) - if self._new_session: - if not any(self._new_session.session_views_async()): - self._sessions.discard(self._new_session) - self._new_session.end_async() - self._new_session = None - if config := self._needed_config(listener.view): - # debug("found new config for listener", listener) - self._new_listener = listener - self.start_async(config, listener.view) - else: - # debug("no new config found for listener", listener) - self._new_listener = None - self._dequeue_listener_async() - - def _publish_sessions_to_listener_async(self, listener: AbstractViewListener) -> None: - inside_workspace = self._workspace.contains(listener.view) - scheme = parse_uri(listener.get_uri())[0] - for session in self._sessions: - if session.can_handle(listener.view, scheme, capability=None, inside_workspace=inside_workspace): - # debug("registering session", session.config.name, "to listener", listener) - try: - listener.on_session_initialized_async(session) - except Exception as ex: - message = f"failed to register session {session.config.name} to listener {listener}" - exception_log(message, ex) + await self.start(config, listener) + @override def get_session(self, config_name: str, file_path: str | None = None) -> Session | None: if file_path: return self._find_session(config_name, file_path) @@ -258,103 +213,100 @@ def _find_session(self, config_name: str, file_path: str) -> Session | None: return session return None - def _needed_config(self, view: sublime.View) -> ClientConfig | None: - configs = self._config_manager.match_view(view, self._workspace.get_workspace_folders()) - handled = False - file_name = view.file_name() - inside = self._workspace.contains(view) - for config in configs: - handled = False - for session in self._sessions: - if config.name == session.config.name and session.handles_path(file_name, inside): - handled = True - break - if not handled: - if plugin := get_plugin(config.name): - if issubclass(plugin, LspPlugin): - context = IsApplicableContext(config, view, self._workspace.get_workspace_folders()) - if plugin.is_applicable_async(context): - return config - elif plugin.is_applicable(view, config): - return config - else: - return config - return None + @override + async def start(self, config: ClientConfig, listener: AbstractViewListener) -> Session | None: + if not self._start_lock: + self._start_lock = asyncio.Lock() + async with self._start_lock: + file_path = listener.view.file_name() or '' + inside = self._workspace.contains(file_path) + for session in list(self._sessions): + if session.config.name == config.name and session.handles_path(file_path, inside): + # OK, this session is already initialized for this view. + self._listeners.add(listener) + session.config.set_view_status(listener.view, "") + # Do not let an exception in listener.on_session_initialized_async cause a failure in this method. + asyncio.get_running_loop().call_soon(listener.on_session_initialized_async, session) + return session + + config = ClientConfig.from_config(config, {}) + config.set_view_status_handler(self) + loop = asyncio.get_running_loop() - def start_async(self, config: ClientConfig, initiating_view: sublime.View) -> None: - config = ClientConfig.from_config(config, {}) - config.set_view_status_handler(self) - file_path = initiating_view.file_name() or '' - if not self._can_start_config(config.name, file_path): - return - try: - workspace_folders = sorted_workspace_folders(self._workspace.folders, file_path) - plugin_class = get_plugin(config.name) - variables = extract_variables(self._window) - cwd = workspace_folders[0].path if workspace_folders else None - context = OnPreStartContext(config, variables, initiating_view, cwd, workspace_folders) - if plugin_class: - if issubclass(plugin_class, LspPlugin): - config.set_view_status(initiating_view, "installing...") - plugin_class.on_pre_start_async(context) - cwd = context.working_directory - else: - if plugin_class.needs_update_or_installation(): - config.set_view_status(initiating_view, "installing...") - plugin_class.install_or_update() - additional_variables = plugin_class.additional_variables() - if isinstance(additional_variables, dict): - variables.update(additional_variables) - cannot_start_reason = plugin_class.can_start( - self._window, initiating_view, workspace_folders, config) - if cannot_start_reason: - raise PluginStartError(cannot_start_reason) - if new_cwd := plugin_class.on_pre_start(self._window, initiating_view, workspace_folders, config): - cwd = new_cwd - config.set_view_status(initiating_view, "starting...") - session = Session(self, self._create_logger(config.name), workspace_folders, config, plugin_class) - transport = config.create_transport_config().start(config.command, config.env, cwd, variables, session) - if plugin_class and issubclass(plugin_class, AbstractPlugin): - plugin_class.on_post_start(self._window, initiating_view, workspace_folders, config) - config.set_view_status(initiating_view, "initialize") - session.initialize_async( - variables=variables, - transport=transport, - working_directory=cwd, - init_callback=functools.partial(self._on_post_session_initialize, initiating_view) - ) - self._new_session = session - except PluginStartError as ex: - config.erase_view_status(initiating_view) - message = f"cannot start {config.name}: {ex!s}" - self._config_manager.disable_config(config.name, only_for_session=True) - # Continue with handling pending listeners - self._new_session = None - sublime.set_timeout_async(self._dequeue_listener_async) - self._window.status_message(message) - except Exception as e: - message = (f'Failed to start {config.name} - disabling for this window for the duration of the current ' - 'session.\nRe-enable by running "LSP: Enable Language Server In Project" from the Command ' - f'Palette.\n\n--- Error: ---\n{e}') - exception_log(f"Unable to initialize language server for {config.name}", e) - if isinstance(e, CalledProcessError): - print("Server output:\n{}".format(e.output.decode('utf-8', 'replace'))) - self._config_manager.disable_config(config.name, only_for_session=True) - config.erase_view_status(initiating_view) - sublime.message_dialog(message) - # Continue with handling pending listeners - self._new_session = None - sublime.set_timeout_async(self._dequeue_listener_async) - - def _on_post_session_initialize( - self, initiating_view: sublime.View, session: Session, is_error: bool = False - ) -> None: - if is_error: - session.config.erase_view_status(initiating_view) - self._new_listener = None - self._new_session = None - else: - sublime.set_timeout_async(self._dequeue_listener_async) + try: + workspace_folders = sorted_workspace_folders(self._workspace.folders, file_path) + plugin_class = get_plugin(config.name) + variables = extract_variables(self._window) + cwd = workspace_folders[0].path if workspace_folders else None + context = OnPreStartContext(config, variables, listener.view, cwd, workspace_folders) + if plugin_class: + if issubclass(plugin_class, LspPlugin): + config.set_view_status(listener.view, "installing...") + if plugin_class.use_asyncio: + await plugin_class.on_pre_start(context) + else: + await loop.run_in_executor(None, plugin_class.on_pre_start_async, context) + cwd = context.working_directory + else: + if plugin_class.needs_update_or_installation(): + config.set_view_status(listener.view, "installing...") + await loop.run_in_executor(None, plugin_class.install_or_update) + additional_variables = plugin_class.additional_variables() + if isinstance(additional_variables, dict): + variables.update(additional_variables) + cannot_start_reason = plugin_class.can_start( + self._window, listener.view, workspace_folders, config) + if cannot_start_reason: + raise PluginStartError(cannot_start_reason) + if new_cwd := plugin_class.on_pre_start(self._window, listener.view, workspace_folders, config): + cwd = new_cwd + config.set_view_status(listener.view, "starting...") + session = Session(self, self._create_logger(config.name), workspace_folders, config, plugin_class) + transport = await config.create_transport_config().start( + config.command, config.env, cwd, variables, session) + if plugin_class and issubclass(plugin_class, AbstractPlugin): + plugin_class.on_post_start(self._window, listener.view, workspace_folders, config) + except PluginStartError as ex: + config.erase_view_status(listener.view) + message = f"cannot start {config.name}: {ex!s}" + self._config_manager.disable_config(config.name, only_for_session=True) + self._window.status_message(message) + return None + except Exception as e: + message = (f'Failed to start {config.name} - disabling for this window for the duration of the current ' + 'session.\nRe-enable by running "LSP: Enable Language Server In Project" from the Command ' + f'Palette.\n\n--- Error: ---\n{e}') + exception_log(f"Unable to start language server for {config.name}", e) + if isinstance(e, CalledProcessError): + print("Server output:\n{}".format(e.output.decode('utf-8', 'replace'))) + self._config_manager.disable_config(config.name, only_for_session=True) + config.erase_view_status(listener.view) + sublime.message_dialog(message) + return None + + try: + config.set_view_status(listener.view, "initializing...") + await session.initialize(variables=variables, transport=transport, working_directory=cwd) + self._sessions.add(session) + self._listeners.add(listener) + # Do not let an exception in listener.on_session_initialized_async cause a failure in this method. + asyncio.get_running_loop().call_soon(listener.on_session_initialized_async, session) + config.set_view_status(listener.view, "") + except Exception as e: + message = ( + f'Failed to initialize {config.name} - disabling for this window for the duration of the current ' + 'session.\nRe-enable by running "LSP: Enable Language Server In Project" from the Command ' + f'Palette.\n\n--- Error: ---\n{e}' + ) + exception_log(f"Unable to initialize language server for {config.name}", e) + if isinstance(e, CalledProcessError): + print("Server output:\n{}".format(e.output.decode('utf-8', 'replace'))) + self._config_manager.disable_config(config.name, only_for_session=True) + sublime.message_dialog(message) + config.erase_view_status(listener.view) + else: + return session + return None def _create_logger(self, config_name: str) -> Logger: logger_map = { @@ -376,26 +328,32 @@ def _create_logger(self, config_name: str) -> Logger: router_logger.append(logger(self, config_name)) return router_logger - def handle_message_request( + @override + async def handle_message_request( self, config_name: str, params: ShowMessageRequestParams - ) -> Promise[MessageActionItem | None]: + ) -> MessageActionItem | None: if view := self._window.active_view(): - return MessageRequestHandler(view, params, config_name).show() - return Promise.resolve(None) + return await MessageRequestHandler(view, params, config_name).show() + return None - def restart_sessions_async(self, config_names: list[str]) -> None: - self._end_sessions_async(config_names) + async def restart_sessions(self, config_names: list[str]) -> list[Exception]: + exceptions = await self._end_sessions(config_names) listeners = list(self._listeners) self._listeners.clear() for listener in listeners: self.register_listener_async(listener) + return exceptions - def _end_sessions_async(self, config_names: list[str] | None = None) -> None: + async def _end_sessions(self, config_names: list[str] | None = None) -> list[Exception]: + coros = [] for session in list(self._sessions): if config_names is None or session.config.name in config_names: - session.end_async() + debug(f"stopping {session.config.name}") + coros.append(session.end()) self._sessions.discard(session) + return await gather_and_flatten_exceptions(*coros) + @override def get_project_path(self, file_path: str) -> str | None: candidate: str | None = None for folder in self._workspace.folders: @@ -404,6 +362,7 @@ def get_project_path(self, file_path: str) -> str | None: candidate = folder return candidate + @override def should_ignore_diagnostics(self, uri: DocumentUri, configuration: ClientConfig) -> str | None: scheme, path = parse_uri(uri) if scheme != "file": @@ -428,18 +387,26 @@ def should_ignore_diagnostics(self, uri: DocumentUri, configuration: ClientConfi return "matches a project's folder_exclude_patterns" return None - def on_post_exit_async(self, session: Session, exit_code: int, exception: Exception | None) -> None: + @override + async def on_post_exit(self, session: Session, exit_code: int, exception: Exception | None) -> None: + debug(f"{session.config.name} has stopped") self._sessions.discard(session) - for listener in self._listeners: - listener.on_session_shutdown_async(session) + exceptions_log( + "Error shutting down listeners", + await gather_and_flatten_exceptions( + *(listener.on_session_shutdown(session) for listener in self._listeners) + ), + ) if exit_code != 0 or exception: config = session.config restart = self._config_manager.record_crash(config.name, exit_code, exception) if not restart: - msg = (f'The {config.name} server has crashed {RETRY_MAX_COUNT} times in the last ' - f'{int(RETRY_COUNT_TIMEDELTA.total_seconds())} seconds.\n\nYou can try to Restart it or you can ' - 'choose Cancel to disable it for this window for the duration of the current session. ' - 'Re-enable by running "LSP: Enable Language Server In Project" from the Command Palette.') + msg = ( + f'The {config.name} server has crashed {RETRY_MAX_COUNT} times in the last ' + f'{int(RETRY_COUNT_TIMEDELTA.total_seconds())} seconds.\n\nYou can try to Restart it or you can ' + 'choose Cancel to disable it for this window for the duration of the current session. ' + 'Re-enable by running "LSP: Enable Language Server In Project" from the Command Palette.' + ) if exception: msg += f"\n\n--- Error: ---\n{exception}" restart = sublime.ok_cancel_dialog(msg, "Restart") @@ -449,16 +416,15 @@ def on_post_exit_async(self, session: Session, exit_code: int, exception: Except else: self._config_manager.disable_config(config.name, only_for_session=True) - def destroy(self) -> None: - """ - Called **from the main thread** when the plugin unloads. In that case we must destroy all sessions - from the main thread. That could lead to some dict/list being mutated while iterated over, so be careful. - """ - self._end_sessions_async() + async def destroy(self) -> list[Exception]: + """Destroy everything related to this instance.""" + result = await self._end_sessions() if self.panel_manager: self.panel_manager.destroy_output_panels() self.panel_manager = None + return result + @override def handle_log_message(self, config_name: str, params: LogMessageParams) -> None: if not userprefs().log_debug: return @@ -469,6 +435,7 @@ def handle_log_message(self, config_name: str, params: LogMessageParams) -> None if message_type == MessageType.Error: self.window.status_message(f"{config_name}: {message}") + @override def handle_stderr_log(self, config_name: str, message: str) -> None: self.handle_server_message_async(config_name, message) @@ -492,6 +459,7 @@ def is_log_lines_limit_enabled(self) -> bool: panel = self.panel_manager and self.panel_manager.get_panel(PanelName.Log) return bool(panel and panel.settings().get(LOG_LINES_LIMIT_SETTING_NAME, True)) + @override def handle_show_message(self, config_name: str, params: ShowMessageParams) -> None: level = MESSAGE_TYPE_LEVELS[params['type']] message = params['message'] @@ -499,6 +467,7 @@ def handle_show_message(self, config_name: str, params: ShowMessageParams) -> No debug(msg) self.window.status_message(msg) + @override def on_diagnostics_updated(self) -> None: self.total_error_count = 0 self.total_warning_count = 0 @@ -558,7 +527,8 @@ def _update_panel_main_thread(self, characters: str, prephantoms: list[tuple[int def on_configs_changed(self, configs: list[ClientConfig]) -> None: config_names = [config.name for config in configs] - sublime.set_timeout_async(lambda: self.restart_sessions_async(config_names)) + # TODO: handle exception list? + run_coroutine_threadsafe(self.restart_sessions(config_names)) # --- Implements ViewStatusHandler --------------------------------------------------------------------------------- @@ -593,14 +563,11 @@ def enable(self) -> None: for window in sublime.windows(): self.lookup(window) - def disable(self) -> None: + async def disable(self) -> list[Exception]: self._enabled = False - for wm in self._windows.values(): - try: - wm.destroy() - except Exception as ex: - exception_log("failed to destroy window", ex) + exceptions = await gather_and_flatten_exceptions(*(wm.destroy() for wm in self._windows.values())) self._windows = {} + return exceptions def lookup(self, window: sublime.Window | None) -> WindowManager | None: if not self._enabled or not window or not window.is_valid(): @@ -620,7 +587,7 @@ def listener_for_view(self, view: sublime.View) -> AbstractViewListener | None: def discard(self, window: sublime.Window) -> None: if wm := self._windows.pop(window.id(), None): - sublime.set_timeout_async(wm.destroy) + run_coroutine_threadsafe(wm.destroy()) # --- Implements LspSettingsChangeListener ------------------------------------------------------------------------- @@ -637,9 +604,9 @@ def on_userprefs_updated(self) -> None: for wm in self._windows.values(): wm.on_diagnostics_updated() for session in wm.get_sessions(): - sublime.set_timeout_async(session.on_userprefs_changed_async) + call_soon_threadsafe(session.on_userprefs_changed_async) for listener in wm.listeners(): - sublime.set_timeout_async(listener.on_userprefs_changed_async) + call_soon_threadsafe(listener.on_userprefs_changed_async) class RequestTimeTracker: diff --git a/plugin/diagnostics.py b/plugin/diagnostics.py index 90e4626e1..08369dc18 100644 --- a/plugin/diagnostics.py +++ b/plugin/diagnostics.py @@ -160,7 +160,11 @@ def on_color_scheme_changed(self) -> None: self._severity_colors = self._get_severity_colors() def _get_severity_colors(self) -> dict[DiagnosticSeverity, str]: - return { - severity: self._view.style_for_scope(scope)['foreground'] - for severity, scope in DIAGNOSTIC_SEVERITY_SCOPES.items() - } + try: + return { + severity: self._view.style_for_scope(scope)['foreground'] + for severity, scope in DIAGNOSTIC_SEVERITY_SCOPES.items() + } + except KeyError: + # Happens when the view is already closed. + return {} diff --git a/plugin/document_link.py b/plugin/document_link.py index 757a8bfdb..5d3541ec6 100644 --- a/plugin/document_link.py +++ b/plugin/document_link.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.logging import debug from .core.open import open_file_uri from .core.open import open_in_browser @@ -9,7 +10,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..protocol import DocumentLink from ..protocol import URI import sublime @@ -28,24 +28,26 @@ def is_enabled(self, event: dict | None = None, point: int | None = None) -> boo def run(self, edit: sublime.Edit, event: dict | None = None) -> None: if position := get_position(self.view, event): - if session := self.best_session(self.capability, position): - if sv := session.session_view_for_view_async(self.view): - if link := sv.session_buffer.get_document_link_at_point(self.view, position): - if (target := link.get("target")) is not None: - self.open_target(target) - elif session.has_capability("documentLinkProvider.resolveProvider"): - request = Request.resolveDocumentLink(link, self.view) - session.send_request_async(request, self._on_resolved_async) - else: - debug("DocumentLink.target is missing, but the server doesn't support documentLink/resolve") - - def _on_resolved_async(self, response: DocumentLink) -> None: - if target := response.get("target"): - self.open_target(target) - - def open_target(self, target: URI) -> None: + run_coroutine_threadsafe(self._run(position)) + + async def _run(self, position: int) -> None: + if not (session := self.best_session(self.capability, position)): + return + if not (sv := session.session_view_for_view_async(self.view)): + return + if not (link := sv.session_buffer.get_document_link_at_point(self.view, position)): + return + if (target := link.get("target")) is not None: + await self.open_target(target) + elif session.has_capability("documentLinkProvider.resolveProvider"): + if target := (await session.request(Request.resolveDocumentLink(link, self.view))).get("target"): + await self.open_target(target) + else: + debug("DocumentLink.target is missing, but the server doesn't support documentLink/resolve") + + async def open_target(self, target: URI) -> None: if target.startswith("file:"): if window := self.view.window(): - open_file_uri(window, target) + await open_file_uri(window, target) else: open_in_browser(target) diff --git a/plugin/documents.py b/plugin/documents.py index 469e96e1a..4d3d3ddef 100644 --- a/plugin/documents.py +++ b/plugin/documents.py @@ -18,6 +18,10 @@ from .code_actions import filter_quickfix_actions from .code_lens import LspToggleCodeLensesCommand from .completion import QueryCompletionsTask +from .core.aio import call_soon_threadsafe +from .core.aio import gather_and_flatten_exceptions +from .core.aio import run_coroutine_threadsafe +from .core.aio import TaskContainer from .core.constants import ChangeEventAction from .core.constants import CODE_ACTION_ANNOTATION_SCOPE from .core.constants import COMMAND_TO_CHANGE_EVENT_ACTION @@ -32,6 +36,7 @@ from .core.constants import SIGNATURE_HELP_INACTIVE_PARAMETER_SCOPE from .core.constants import ST_VERSION from .core.logging import debug +from .core.logging import exceptions_log from .core.open import open_file_uri from .core.open import open_in_browser from .core.panels import PanelName @@ -72,20 +77,23 @@ from os.path import basename from typing import Any from typing import Callable +from typing import cast from typing import Generator from typing import Iterable from typing import Literal from typing import overload from typing import Sequence from typing import TYPE_CHECKING -from typing import TypeVar from typing_extensions import Concatenate from typing_extensions import override from typing_extensions import ParamSpec from weakref import WeakSet from weakref import WeakValueDictionary +import asyncio +import inspect import itertools import sublime +import sublime_aio import sublime_plugin import weakref import webbrowser @@ -93,24 +101,33 @@ if TYPE_CHECKING: from .core.windows import WindowManager from .session_buffer import SessionBuffer + from collections.abc import Coroutine -P = ParamSpec('P') -R = TypeVar('R') + +P = ParamSpec("P") def requires_session( - func: Callable[Concatenate[DocumentSyncListener, P], R] -) -> Callable[Concatenate[DocumentSyncListener, P], R | None]: - """ - A decorator for the `DocumentSyncListener` event handlers, which immediately returns `None` if there are no - `SessionView`s. - """ + func: Callable[Concatenate[DocumentSyncListener, P], Any], +) -> Callable[Concatenate[DocumentSyncListener, P], Any]: + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self: DocumentSyncListener, *args: P.args, **kwargs: P.kwargs) -> Any: + if not self.session_views_async(): + return None + return await func(self, *args, **kwargs) + + return cast("Callable[Concatenate[DocumentSyncListener, P], Coroutine[Any, Any, Any]]", async_wrapper) + @wraps(func) - def wrapper(self: DocumentSyncListener, *args: P.args, **kwargs: P.kwargs) -> R | None: + def sync_wrapper(self: DocumentSyncListener, *args: P.args, **kwargs: P.kwargs) -> Any: if not self.session_views_async(): return None return func(self, *args, **kwargs) - return wrapper + + return cast("Callable[Concatenate[DocumentSyncListener, P], Any]", sync_wrapper) def is_regular_view(v: sublime.View) -> bool: @@ -160,11 +177,12 @@ def on_text_changed(self, changes: list[sublime.TextChange]) -> None: change_count = view.change_count() frozen_listeners = WeakSet(self.view_listeners) - def notify(action: ChangeEventAction) -> None: - for listener in list(frozen_listeners): - listener.on_text_changed_async(change_count, changes, action) + async def notify(action: ChangeEventAction) -> None: + await asyncio.gather( + *[listener.on_text_changed(change_count, changes, action) for listener in list(frozen_listeners)] + ) - sublime.set_timeout_async(partial(notify, self._last_edit_action)) + run_coroutine_threadsafe(notify(self._last_edit_action)) self._reset_last_edit_action() def on_reload_async(self) -> None: @@ -188,7 +206,7 @@ def __repr__(self) -> str: return f"TextChangeListener({self.buffer.buffer_id})" -class DocumentSyncListener(sublime_plugin.ViewEventListener, AbstractViewListener): +class DocumentSyncListener(sublime_aio.ViewEventListener, AbstractViewListener, TaskContainer): ACTIVE_DIAGNOSTIC = "lsp_active_diagnostic" debounce_time = FEATURES_TIMEOUT @@ -254,7 +272,7 @@ def _cleanup(self) -> None: self._stored_selection = [] self.view.erase_status(AbstractViewListener.TOTAL_ERRORS_AND_WARNINGS_STATUS_KEY) self._clear_highlight_regions() - self._clear_session_views_async() + # run_coroutine_threadsafe(self._clear_session_views()) def _reset(self) -> None: # Have to do this on the main thread, since __init__ and __del__ are invoked on the main thread too @@ -262,8 +280,8 @@ def _reset(self) -> None: self._setup() for session in self.sessions_async(): session.diagnostics.clear_identifiers_cache_for_view(self.view) - # But this has to run on the async thread again - sublime.set_timeout_async(self.on_activated_async) + # But this has to run on the asyncio thread again + run_coroutine_threadsafe(self._activated_impl()) # --- Implements AbstractViewListener ------------------------------------------------------------------------------ @@ -310,15 +328,16 @@ def on_session_initialized_async(self, session: Session) -> None: sb.session.cancel_request_async(request_id) sb.semantic_tokens.pending_response = None - def on_session_shutdown_async(self, session: Session) -> None: + async def on_session_shutdown(self, session: Session) -> list[Exception]: if removed_session := self._session_views.pop(session.config.name, None): - removed_session.on_before_remove() + result = await removed_session.on_before_remove() if not self._session_views: self.view.settings().erase("lsp_active") self._registered = False - else: - # SessionView was likely not created for this config so remove status here. - session.config.erase_view_status(self.view) + return result + # SessionView was likely not created for this config so remove status here. + session.config.erase_view_status(self.view) + return [] def _diagnostics_async( self, allow_stale: bool = False @@ -380,7 +399,7 @@ def session_views_async(self) -> list[SessionView]: return list(self._session_views.values()) @requires_session - def on_text_changed_async( + async def on_text_changed( self, change_count: int, changes: list[sublime.TextChange], action: ChangeEventAction ) -> None: if self.view.is_primary(): @@ -412,28 +431,34 @@ def get_request_flags(self, session: Session) -> RequestFlags: # --- Callbacks from Sublime Text ---------------------------------------------------------------------------------- - def on_load_async(self) -> None: + async def on_load(self) -> None: + await self._on_load_impl() + + async def _on_load_impl(self) -> None: if not self._registered and is_regular_view(self.view): - self._register_async() + self._register() return if initially_folded_kinds := userprefs().initially_folded: if session := self.session_async('foldingRangeProvider'): params: FoldingRangeParams = {'textDocument': text_document_identifier(self.view)} - session.send_request_async( - Request.foldingRange(params, self.view), - partial(self._on_initial_folding_ranges, initially_folded_kinds)) - self.on_activated_async() + self._on_initial_folding_ranges( + initially_folded_kinds, await session.request(Request.foldingRange(params, self.view)) + ) + await self._activated_impl() - def on_post_move_async(self) -> None: + async def on_post_move(self) -> None: if ST_VERSION < 4184: # Already handled in boot.Listener.on_pre_move return self.on_post_move_window_async() - def on_activated_async(self) -> None: + async def on_activated(self) -> None: + await self._activated_impl() + + async def _activated_impl(self) -> None: if self.view.is_loading() or not is_regular_view(self.view): return if not self._registered: - self._register_async() + self._register() session_views = self.session_views_async() if not session_views: return @@ -454,8 +479,8 @@ def on_activated_async(self) -> None: self._do_code_actions_for_selection_async(self.session_buffers_async('codeActionProvider')) @requires_session - def on_selection_modified_async(self) -> None: - first_region, _ = self._update_stored_selection_async() + async def on_selection_modified(self) -> None: + first_region, _ = self._update_stored_selection() if first_region is None: return if not self._is_in_higlighted_region(first_region.b): @@ -463,11 +488,11 @@ def on_selection_modified_async(self) -> None: if userprefs().show_code_actions: self._code_actions_for_selection.clear() self._clear_code_actions_annotation() - self._when_selection_remains_stable_async( - self._on_selection_modified_debounced_async, first_region, after_ms=self.debounce_time) + self._when_selection_remains_stable( + self._on_selection_modified_debounced, first_region, after_ms=self.debounce_time) self._update_diagnostic_in_status_bar_async() - def _on_selection_modified_debounced_async(self) -> None: + def _on_selection_modified_debounced(self) -> None: if userprefs().document_highlight_style: self._do_highlights_async() if userprefs().show_code_actions: @@ -479,7 +504,7 @@ def _on_selection_modified_debounced_async(self) -> None: if plugin := sv.session.plugin: plugin.on_selection_modified_async(sv) - def on_post_save_async(self) -> None: + async def on_post_save(self) -> None: # Re-determine the URI; this time it's guaranteed to be a file because ST can only save files to a real # filesystem. uri = view_to_uri(self.view) @@ -519,11 +544,10 @@ def _toggle_diagnostics_panel_if_needed_async(self) -> None: elif has_relevant_diagnostcs: panel_manager.show_diagnostics_panel_async() - def on_close(self) -> None: + async def on_close(self) -> None: if self._registered and self._manager: - manager = self._manager - sublime.set_timeout_async(lambda: manager.unregister_listener_async(self)) - self._clear_session_views_async() + self._manager.unregister_listener_async(self) + exceptions_log("Exception while closing document", await self._clear_session_views()) def on_query_context(self, key: str, operator: int, operand: Any, match_all: bool) -> bool | None: # You can filter key bindings by the precense of a provider, @@ -570,7 +594,7 @@ def on_hover(self, point: int, hover_zone: int) -> None: if window.settings().get(HOVER_ENABLED_KEY, True): self.view.run_command("lsp_hover", {"point": point}) elif hover_zone == sublime.HoverZone.GUTTER: - sublime.set_timeout_async(partial(self._on_hover_gutter_async, point)) + call_soon_threadsafe(partial(self._on_hover_gutter_async, point)) def _on_hover_gutter_async(self, point: int) -> None: if userprefs().diagnostics_gutter_marker: @@ -609,11 +633,11 @@ def _on_navigate(self, href: str) -> None: if scheme == CODE_ACTION_SCHEME: session_name, version, action = decode_code_action_uri(href) if version == self.view.change_count() and (session := self.session_by_name(session_name)): - sublime.set_timeout_async(lambda: session.run_code_action_async(action, progress=True, view=self.view)) + self.create_task_threadsafe(session.run_code_action(action, progress=True, view=self.view)) self.view.hide_popup() elif scheme == 'file': if window := self.view.window(): - open_file_uri(window, href) + self.create_task(open_file_uri(window, href)) elif scheme.lower() in {"http", "https"} or (not scheme and href.startswith('www.')): open_in_browser(href) @@ -650,7 +674,9 @@ def on_post_text_command(self, command_name: str, args: dict[str, Any] | None) - if format_on_paste and self.session_async("documentRangeFormattingProvider"): self._should_format_on_paste = True elif command_name in {"next_field", "prev_field"} and args is None: - sublime.set_timeout_async(lambda: self.do_signature_help_async(SignatureHelpTriggerKind.ContentChange)) + call_soon_threadsafe( + lambda: self.do_signature_help_async(SignatureHelpTriggerKind.ContentChange) + ) if not self.view.is_popup_visible(): return if self._is_documenation_popup_open and command_name in {"move", "commit_completion", "delete_word", @@ -662,7 +688,7 @@ def on_query_completions(self, prefix: str, locations: list[int]) -> sublime.Com completion_list = sublime.CompletionList() triggered_manually = self._auto_complete_triggered_manually self._auto_complete_triggered_manually = False # reset state for next completion popup - sublime.set_timeout_async( + call_soon_threadsafe( lambda: self._on_query_completions_async(completion_list, locations[0], triggered_manually)) return completion_list @@ -988,13 +1014,13 @@ def reload_async(self) -> None: # --- Private utility methods -------------------------------------------------------------------------------------- - def _when_selection_remains_stable_async(self, f: Callable[[], None], r: sublime.Region, after_ms: int) -> None: - debounced(f, after_ms, partial(self._is_selection_stable_async, r), async_thread=True) + def _when_selection_remains_stable(self, f: Callable[[], None], r: sublime.Region, after_ms: int) -> None: + debounced(f, after_ms, partial(self._is_selection_stable_async, r)) def _is_selection_stable_async(self, region: sublime.Region) -> bool: return bool(self._stored_selection and self._stored_selection[0] == region) - def _register_async(self) -> None: + def _register(self) -> None: buf = self.view.buffer() if not buf: debug("not tracking bufferless view", self.view.id()) @@ -1022,18 +1048,18 @@ def _register_async(self) -> None: for listener in listeners: if isinstance(listener, DocumentSyncListener): debug("also registering", listener) - listener.on_load_async() + self.create_task(listener._on_load_impl()) def _on_view_updated_async(self) -> None: if self._should_format_on_paste: self._should_format_on_paste = False sublime.get_clipboard_async(self._format_on_paste_async) - first_region, _ = self._update_stored_selection_async() + first_region, _ = self._update_stored_selection() if first_region is None: return if userprefs().document_highlight_style: self._clear_highlight_regions() - self._when_selection_remains_stable_async( + self._when_selection_remains_stable( self._do_highlights_async, first_region, after_ms=self.debounce_time) if userprefs().show_signature_help and (selection := self._stored_selection): if self._sighelp: @@ -1050,7 +1076,7 @@ def _on_view_updated_async(self) -> None: if previous_char in triggers: self.do_signature_help_async(SignatureHelpTriggerKind.TriggerCharacter, previous_char) - def _update_stored_selection_async(self) -> tuple[sublime.Region | None, bool]: + def _update_stored_selection(self) -> tuple[sublime.Region | None, bool]: """ Stores the current selection in a variable. Note that due to this function (supposedly) running in the async worker thread of ST, it can happen that the @@ -1105,16 +1131,11 @@ def run_sync() -> None: sublime.set_timeout(run_sync) - def _clear_session_views_async(self) -> None: + async def _clear_session_views(self) -> list[Exception]: session_views = self._session_views - - def clear_async() -> None: - nonlocal session_views - for session_view in session_views.values(): - session_view.on_before_remove() - session_views.clear() - - sublime.set_timeout_async(clear_async) + exceptions = await gather_and_flatten_exceptions(*(s.on_before_remove() for s in session_views.values())) + session_views.clear() + return exceptions def on_userprefs_changed_async(self) -> None: if userprefs().document_highlight_style: diff --git a/plugin/edit.py b/plugin/edit.py index 33a344a50..3ec53ba49 100644 --- a/plugin/edit.py +++ b/plugin/edit.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.constants import ChangeEventAction from .core.edit import is_snippet_text_edit from .core.edit import parse_lsp_position @@ -91,10 +92,14 @@ class LspApplyWorkspaceEditCommand(LspWindowCommand): def run( self, session_name: str, edit: WorkspaceEdit, label: str | None = None, is_refactoring: bool = False + ) -> None: + run_coroutine_threadsafe(self._run(session_name, edit, label, is_refactoring)) + + async def _run( + self, session_name: str, edit: WorkspaceEdit, label: str | None = None, is_refactoring: bool = False ) -> None: if session := self.session_by_name(session_name): - sublime.set_timeout_async( - lambda: session.apply_workspace_edit_async(edit, label=label, is_refactoring=is_refactoring)) + await session.apply_workspace_edit(edit, label=label, is_refactoring=is_refactoring) else: debug('Could not find session', session_name, 'required to apply WorkspaceEdit') diff --git a/plugin/execute_command.py b/plugin/execute_command.py index 76451544b..4d756cc13 100644 --- a/plugin/execute_command.py +++ b/plugin/execute_command.py @@ -1,7 +1,9 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.logging import debug from .core.protocol import Error +from .core.protocol import LSPAny from .core.registry import LspTextCommand from .core.views import first_selection_region from .core.views import offset_to_point @@ -16,6 +18,7 @@ if TYPE_CHECKING: from ..protocol import ExecuteCommandParams + from .core.sessions import Session class LspExecuteCommand(LspTextCommand): @@ -32,15 +35,14 @@ def run(self, params: ExecuteCommandParams = {"command": command_name} if command_args: params["arguments"] = self._expand_variables(command_args) + run_coroutine_threadsafe(self._run(session, command_name, params)) - def handle_response(response: Any) -> None: - assert command_name - if isinstance(response, Error): - self.handle_error_async(response, command_name) - return - self.handle_success_async(response, command_name) - - session.execute_command(params, progress=True, view=self.view).then(handle_response) + async def _run(self, session: Session, command_name: str, params: ExecuteCommandParams) -> None: + try: + result: LSPAny = await session.execute_command(params, progress=True, view=self.view) + self.handle_success_async(result, command_name) + except Error as error: + self.handle_error_async(error, command_name) def handle_success_async(self, result: Any, command_name: str) -> None: """ diff --git a/plugin/folding_range.py b/plugin/folding_range.py index 73fb315e7..5992fa216 100644 --- a/plugin/folding_range.py +++ b/plugin/folding_range.py @@ -69,6 +69,7 @@ def is_visible( point: int | None = None ) -> bool: if not prefetch: + return True # There should be a single empty selection in the view, otherwise this functionality would be misleading selection = self.view.sel() @@ -85,7 +86,7 @@ def is_visible( session = self.best_session(self.capability) if session: params: FoldingRangeParams = {'textDocument': text_document_identifier(self.view)} - session.send_request_async( + session.send_request( Request.foldingRange(params, self.view), partial(self._handle_response_async, view_change_count) ) @@ -156,7 +157,7 @@ def run( pt = selection[0].b if session := self.best_session(self.capability): params: FoldingRangeParams = {'textDocument': text_document_identifier(self.view)} - session.send_request_async( + session.send_request( Request.foldingRange(params, self.view), partial(self._handle_response_manual_async, pt, strict) ) @@ -181,7 +182,7 @@ class LspFoldAllCommand(LspTextCommand): def run(self, edit: sublime.Edit, kind: str | None = None, event: dict | None = None) -> None: if session := self.best_session(self.capability): params: FoldingRangeParams = {'textDocument': text_document_identifier(self.view)} - session.send_request_async( + session.send_request( Request.foldingRange(params, self.view), partial(self._handle_response_async, kind)) def _handle_response_async(self, kind: str | None, response: list[FoldingRange] | None) -> None: diff --git a/plugin/formatting.py b/plugin/formatting.py index 0ea1e8dd5..9def2523a 100644 --- a/plugin/formatting.py +++ b/plugin/formatting.py @@ -3,9 +3,9 @@ from ..protocol import TextDocumentSaveReason from ..protocol import TextEdit from .code_actions import CodeActionsOnFormatTask +from .core.aio import run_coroutine_threadsafe from .core.collections import DottedDict from .core.edit import apply_text_edits -from .core.promise import Promise from .core.protocol import Error from .core.registry import LspTextCommand from .core.registry import windows @@ -21,8 +21,6 @@ from .lsp_task import LspTextCommandWithTasks from functools import partial from typing import Any -from typing import Callable -from typing import Iterator from typing import List from typing import TYPE_CHECKING from typing import Union @@ -32,7 +30,7 @@ if TYPE_CHECKING: from .core.sessions import Session -FormatResponse = Union[List[TextEdit], None, Error] +FormatResponse = Union[List[TextEdit], None] def get_formatter(window: sublime.Window | None, base_scope: str) -> str | None: @@ -44,18 +42,18 @@ def get_formatter(window: sublime.Window | None, base_scope: str) -> str | None: isinstance(project_data, dict) else window_manager.formatters.get(base_scope) -def format_document(text_command: LspTextCommand, formatter: str | None = None) -> Promise[FormatResponse]: +async def format_document(text_command: LspTextCommand, formatter: str | None = None) -> FormatResponse: view = text_command.view if formatter: if session := text_command.session_by_name(formatter, LspFormatDocumentCommand.capability): - return session.send_request_task(text_document_formatting(view)) + return await session.request(text_document_formatting(view)) if session := text_command.best_session(LspFormatDocumentCommand.capability): # Either use the documentFormattingProvider ... - return session.send_request_task(text_document_formatting(view)) + return await session.request(text_document_formatting(view)) if session := text_command.best_session(LspFormatDocumentRangeCommand.capability): # ... or use the documentRangeFormattingProvider and format the entire range. - return session.send_request_task(text_document_range_formatting(view, entire_content_region(view))) - return Promise.resolve(None) + return await session.request(text_document_range_formatting(view, entire_content_region(view))) + return None class WillSaveWaitTask(LspTask): @@ -63,30 +61,21 @@ class WillSaveWaitTask(LspTask): def is_applicable(cls, view: sublime.View) -> bool: return bool(view.file_name()) - def __init__(self, task_runner: LspTextCommand, on_complete: Callable[[], None]) -> None: - super().__init__(task_runner, on_complete) - self._session_iterator: Iterator[Session] | None = None + def __init__(self, text_command: LspTextCommand) -> None: + super().__init__(text_command) - def run_async(self) -> None: - super().run_async() - self._session_iterator = self._task_runner.sessions('textDocumentSync.willSaveWaitUntil') - self._handle_next_session_async() - - def _handle_next_session_async(self) -> None: - session = next(self._session_iterator, None) if self._session_iterator else None - if session: + async def run(self) -> None: + await super().run() + for session in self._text_command.sessions('textDocumentSync.willSaveWaitUntil'): self._purge_changes_async() - view = self._task_runner.view - session.send_request_task(will_save_wait_until(view, reason=TextDocumentSaveReason.Manual)) \ - .then(self._on_response_async) - else: - self._on_complete() - - def _on_response_async(self, response: FormatResponse) -> None: - promise: Promise[None] = Promise.resolve(None) - if response and not isinstance(response, Error) and not self._cancelled: - promise.then(lambda _: apply_text_edits(self._task_runner.view, response, label="Format on Save")) - promise.then(lambda _: self._handle_next_session_async()) + view = self._text_command.view + try: + if text_edits := await session.request( + will_save_wait_until(view, reason=TextDocumentSaveReason.Manual) + ): + await apply_text_edits(self._text_command.view, text_edits, label="Format on Save") + except Exception as ex: + sublime.status_message(f"Failed to apply Will Save Task: {ex}") class FormatOnSaveTask(LspTask): @@ -99,27 +88,27 @@ def is_applicable(cls, view: sublime.View) -> bool: return enabled and bool(view.window()) and bool(view.file_name()) @override - def run_async(self) -> None: - super().run_async() + async def run(self) -> None: + await super().run() self._purge_changes_async() - syntax = self._task_runner.view.syntax() + syntax = self._text_command.view.syntax() if not syntax: return base_scope = syntax.scope - formatter = get_formatter(self._task_runner.view.window(), base_scope) - format_document(self._task_runner, formatter).then(self._on_response_async) - - def _on_response_async(self, response: FormatResponse) -> None: - promise: Promise[None] = Promise.resolve(None) - if response and not isinstance(response, Error) and not self._cancelled: - promise.then(lambda _: apply_text_edits(self._task_runner.view, response, label="Format on Save")) - promise.then(lambda _: self._on_complete()) + formatter = get_formatter(self._text_command.view.window(), base_scope) + try: + if text_edits := await format_document(self._text_command, formatter): + await apply_text_edits(self._text_command.view, text_edits, label="Format On Save") + except Exception as ex: + sublime.status_message(f"Failed to apply Format On Save: {ex}") class LspFormatDocumentCommand(LspTextCommandWithTasks): capability = 'documentFormattingProvider' + label = 'Format File' + @property @override def tasks(self) -> list[type[LspTask]]: @@ -132,7 +121,7 @@ def is_enabled(self, event: dict | None = None, select: bool = False) -> bool: return super().is_enabled() or bool(self.best_session(LspFormatDocumentRangeCommand.capability)) @override - def on_tasks_completed(self, *, select: bool = False, **kwargs: dict[str, Any]) -> None: + async def on_tasks_completed(self, *, select: bool = False, **kwargs: dict[str, Any]) -> None: session_names = [session.config.name for session in self.sessions(self.capability)] syntax = self.view.syntax() if not syntax: @@ -144,19 +133,22 @@ def on_tasks_completed(self, *, select: bool = False, **kwargs: dict[str, Any]) if listener := self.get_listener(): listener.purge_changes_async() if len(session_names) > 1: - formatter = get_formatter(self.view.window(), base_scope) - if formatter: - session = self.session_by_name(formatter, self.capability) - if session: - session.send_request_task(text_document_formatting(self.view)).then(self.on_result_async) + if formatter := get_formatter(self.view.window(), base_scope): + if session := self.session_by_name(formatter, self.capability): + await self._apply_text_edits( + await session.request(text_document_formatting(self.view)), label=self.label + ) return self.select_formatter(base_scope, session_names) else: - format_document(self).then(self.on_result_async) + await self._apply_text_edits(await format_document(self), label=self.label) - def on_result_async(self, result: FormatResponse) -> None: - if result and not isinstance(result, Error): - apply_text_edits(self.view, result, label="Format File") + async def _apply_text_edits(self, text_edits: list[TextEdit] | None, label: str) -> None: + try: + if text_edits: + await apply_text_edits(self.view, text_edits, label=label) + except Exception as ex: + sublime.status_message(f"Failed to {label}: {ex}") def select_formatter(self, base_scope: str, session_names: list[str]) -> None: if window := self.view.window(): @@ -182,10 +174,16 @@ def on_select_formatter(self, base_scope: str, session_names: list[str], index: window.set_project_data(project_data) else: # Save temporarily for this window window_manager.formatters[base_scope] = session_name - if session := self.session_by_name(session_name, self.capability): - if listener := self.get_listener(): - listener.purge_changes_async() - session.send_request_task(text_document_formatting(self.view)).then(self.on_result_async) + + async def do_format() -> None: + if session := self.session_by_name(session_name, self.capability): + if listener := self.get_listener(): + listener.purge_changes_async() + await self._apply_text_edits( + await session.request(text_document_formatting(self.view)), label=self.label + ) + + run_coroutine_threadsafe(do_format()) class LspFormatDocumentRangeCommand(LspTextCommand): @@ -203,25 +201,26 @@ def is_enabled(self, event: dict | None = None, point: int | None = None) -> boo return False def run(self, edit: sublime.Edit, event: dict | None = None) -> None: + run_coroutine_threadsafe(self._run()) + + async def _run(self) -> None: if listener := self.get_listener(): listener.purge_changes_async() - if has_single_nonempty_selection(self.view): - session = self.best_session(self.capability) - selection = first_selection_region(self.view) - if session and selection is not None: - request = text_document_range_formatting(self.view, selection) - session.send_request_task(request).then(self._handle_response_async) - elif self.view.has_non_empty_selection_region(): - if session := self.best_session('documentRangeFormattingProvider.rangesSupport'): - request = text_document_ranges_formatting(self.view) - session.send_request_task(request).then(self._handle_response_async) - - def _handle_response_async(self, response: FormatResponse) -> None: - if isinstance(response, Error): - sublime.status_message(f'Formatting error: {response}') - return - if response: - apply_text_edits(self.view, response, label="Format Selection") + session: Session | None = None + text_edits: list[TextEdit] | None = None + try: + if has_single_nonempty_selection(self.view): + session = self.best_session(self.capability) + selection = first_selection_region(self.view) + if session and selection is not None: + text_edits = await session.request(text_document_range_formatting(self.view, selection)) + elif self.view.has_non_empty_selection_region(): + if session := self.best_session('documentRangeFormattingProvider.rangesSupport'): + text_edits = await session.request(text_document_ranges_formatting(self.view)) + if text_edits is not None: + await apply_text_edits(self.view, text_edits) + except Error as error: + sublime.status_message(f'Formatting error: {error}') class LspFormatCommand(LspTextCommand): diff --git a/plugin/goto.py b/plugin/goto.py index f39460581..83106c37f 100644 --- a/plugin/goto.py +++ b/plugin/goto.py @@ -5,6 +5,7 @@ from ..protocol import DocumentUri from ..protocol import Location from ..protocol import LocationLink +from .core.aio import run_coroutine_threadsafe from .core.constants import DIAGNOSTIC_KINDS from .core.input_handlers import PreselectedListInputHandler from .core.paths import simple_project_path @@ -25,7 +26,7 @@ from .core.views import to_encoded_filename from .core.views import uri_from_view from .locationpicker import LocationPicker -from .locationpicker import open_location_async +from .locationpicker import open_location from collections import Counter from functools import partial from os.path import basename @@ -105,13 +106,13 @@ def _handle_response_async( ) -> None: if isinstance(response, dict): self.view.run_command("add_jump_record", {"selection": [(r.a, r.b) for r in self.view.sel()]}) - open_location_async(session, response, side_by_side, force_group, group) + run_coroutine_threadsafe(open_location(session, response, side_by_side, force_group, group)) elif isinstance(response, list): if len(response) == 0: self._handle_no_results(fallback, side_by_side) elif len(response) == 1: self.view.run_command("add_jump_record", {"selection": [(r.a, r.b) for r in self.view.sel()]}) - open_location_async(session, response[0], side_by_side, force_group, group) + run_coroutine_threadsafe(open_location(session, response[0], side_by_side, force_group, group)) else: self.view.run_command("add_jump_record", {"selection": [(r.a, r.b) for r in self.view.sel()]}) placeholder = self.placeholder_text + " " + self.view.substr(self.view.word(position)) @@ -352,7 +353,7 @@ def confirm(self, value: DiagnosticData | None) -> None: self._open_file(value) elif session := self._session(value): location: Location = {'uri': self.uri, 'range': value['diagnostic']['range']} - sublime.set_timeout_async(partial(session.open_location_async, location)) + run_coroutine_threadsafe(session.open_location(location)) def _session(self, value: DiagnosticData) -> Session | None: session_name = value['session_name'] diff --git a/plugin/hierarchy.py b/plugin/hierarchy.py index 1f5400a5f..bbaf48b20 100644 --- a/plugin/hierarchy.py +++ b/plugin/hierarchy.py @@ -167,7 +167,7 @@ def run(self, edit: sublime.Edit, event: dict | None = None, point: int | None = if position is None: return params = text_document_position_params(self.view, position) - session.send_request_async( + session.send_request( self.request(params, self.view), partial(self._handle_response_async, weakref.ref(session))) def _handle_response_async( diff --git a/plugin/hover.py b/plugin/hover.py index 5a1e1f495..e71333716 100644 --- a/plugin/hover.py +++ b/plugin/hover.py @@ -9,6 +9,8 @@ from ..protocol import Position from ..protocol import Range from .code_actions import filter_quickfix_actions +from .core.aio import call_soon_threadsafe +from .core.aio import run_coroutine_threadsafe from .core.constants import HOVER_ENABLED_KEY from .core.constants import MarkdownLangMap from .core.constants import RegionKey @@ -138,7 +140,7 @@ def run_async() -> None: ] Promise.all(code_action_promises).then(partial(self._handle_code_actions, listener, hover_point)) - sublime.set_timeout_async(run_async) + call_soon_threadsafe(run_async) def request_symbol_hover_async(self, listener: AbstractViewListener, point: int) -> None: hover_promises: list[Promise[ResolvedHover]] = [] @@ -317,11 +319,11 @@ def _on_navigate(self, uri: str) -> None: pass elif scheme == 'file': if window := self.view.window(): - open_file_uri(window, uri) + run_coroutine_threadsafe(open_file_uri(window, uri)) elif scheme == CODE_ACTION_SCHEME: session_name, version, action = decode_code_action_uri(uri) if version == self.view.change_count() and (session := self.session_by_name(session_name)): - sublime.set_timeout_async(lambda: session.run_code_action_async(action, progress=True, view=self.view)) + run_coroutine_threadsafe(session.run_code_action(action, progress=True, view=self.view)) self.view.hide_popup() elif uri == "quick-panel:DocumentLink": if window := self.view.window(): @@ -338,19 +340,20 @@ def on_select(targets: list[str], idx: int) -> None: if session := self.session_by_name(session_name): position: Position = {"line": row, "character": col_utf16} r: Range = {"start": position, "end": position} - sublime.set_timeout_async(partial(session.open_uri_async, uri, r)) + run_coroutine_threadsafe(session.open_uri(uri, r)) elif scheme.lower() in {"http", "https"} or (not scheme and uri.startswith('www.')): open_in_browser(uri) elif scheme: - sublime.set_timeout_async(partial(self.try_open_custom_uri_async, uri)) + run_coroutine_threadsafe(self.try_open_custom_uri(uri)) - def try_open_custom_uri_async(self, uri: str) -> None: + async def try_open_custom_uri(self, uri: str) -> None: uri_parts = urlsplit(uri) r = lsp_range_from_uri_fragment(uri_parts.fragment) if r: uri = urlunsplit(uri_parts._replace(fragment='')) for session in self.sessions(): - if session.try_open_uri_async(uri, r) is not None: + result = await session.try_open_uri(uri, r) + if isinstance(result, sublime.View) or result is None: return @@ -367,7 +370,7 @@ def is_checked(self) -> bool: def run(self) -> None: enable = not self.is_checked() self.window.settings().set(HOVER_ENABLED_KEY, enable) - sublime.set_timeout_async(partial(self._update_views_async, enable)) + call_soon_threadsafe(self._update_views_async, enable) def _has_hover_provider(self, view: sublime.View) -> bool: listener = windows.listener_for_view(view) diff --git a/plugin/inlay_hint.py b/plugin/inlay_hint.py index 77fc43be2..1da47e5e5 100644 --- a/plugin/inlay_hint.py +++ b/plugin/inlay_hint.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.constants import RequestFlags from .core.constants import ST_VERSION from .core.css import css @@ -61,46 +62,28 @@ class LspInlayHintClickCommand(LspTextCommand): def run(self, _edit: sublime.Edit, session_name: str, inlay_hint: InlayHint, phantom_uuid: str, event: dict | None = None, label_part: InlayHintLabelPart | None = None) -> None: + run_coroutine_threadsafe(self._run(session_name, inlay_hint, phantom_uuid, label_part)) + + async def _run(self, session_name: str, inlay_hint: InlayHint, phantom_uuid: str, + label_part: InlayHintLabelPart | None = None) -> None: # Insert textEdits for the given inlay hint. # If a InlayHintLabelPart was clicked, label_part will be passed as an argument to the LspInlayHintClickCommand # and InlayHintLabelPart.command will be executed. session = self.session_by_name(session_name, 'inlayHintProvider') if session and session.has_capability('inlayHintProvider.resolveProvider'): - request = Request.resolveInlayHint(inlay_hint, self.view) - session.send_request_async( - request, - lambda response: self.handle(session_name, response, phantom_uuid, label_part)) - return - self.handle(session_name, inlay_hint, phantom_uuid, label_part) - - def handle(self, session_name: str, inlay_hint: InlayHint, phantom_uuid: str, - label_part: InlayHintLabelPart | None = None) -> None: - self.handle_inlay_hint_text_edits(session_name, inlay_hint, phantom_uuid) - self.handle_label_part_command(session_name, label_part) - - def handle_inlay_hint_text_edits(self, session_name: str, inlay_hint: InlayHint, phantom_uuid: str) -> None: - session = self.session_by_name(session_name, 'inlayHintProvider') - if not session: - return - text_edits = inlay_hint.get('textEdits') - if not text_edits: - return - for sb in session.session_buffers_async(): - sb.remove_inlay_hint_phantom(phantom_uuid) - apply_text_edits(self.view, text_edits, label="Insert Inlay Hint") - - def handle_label_part_command(self, session_name: str, label_part: InlayHintLabelPart | None = None) -> None: - if not label_part: - return - command = label_part.get('command') - if not command: - return - args = { - "session_name": session_name, - "command_name": command["command"], - "command_args": command.get("arguments") - } - self.view.run_command("lsp_execute", args) + inlay_hint = await session.request(Request.resolveInlayHint(inlay_hint, self.view)) + + if session and (text_edits := inlay_hint.get('textEdits')): + for sb in session.session_buffers_async(): + sb.remove_inlay_hint_phantom(phantom_uuid) + await apply_text_edits(self.view, text_edits, label="Insert Inlay Hint") + + if label_part and (command := label_part.get('command')): + self.view.run_command("lsp_execute", { + "session_name": session_name, + "command_name": command["command"], + "command_args": command.get("arguments") + }) def inlay_hint_to_phantom(view: sublime.View, inlay_hint: InlayHint, session: Session) -> sublime.Phantom: diff --git a/plugin/locationpicker.py b/plugin/locationpicker.py index ae1391eac..36cc7afc6 100644 --- a/plugin/locationpicker.py +++ b/plugin/locationpicker.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.constants import ST_PACKAGES_PATH from .core.constants import SublimeKind from .core.logging import debug @@ -8,7 +9,6 @@ from .core.views import to_encoded_filename from typing import TYPE_CHECKING from urllib.request import url2pathname -import functools import sublime import weakref @@ -20,7 +20,7 @@ from .core.sessions import Session -def open_location_async( +async def open_location( session: Session, location: Location | LocationLink, side_by_side: bool, @@ -32,15 +32,12 @@ def open_location_async( flags |= sublime.NewFileFlags.FORCE_GROUP if side_by_side: flags |= sublime.NewFileFlags.ADD_TO_SELECTION | sublime.NewFileFlags.SEMI_TRANSIENT - - def check_success_async(view: sublime.View | None) -> None: - if not view: - uri = get_uri_and_position_from_location(location)[0] - msg = f"Unable to open URI {uri}" - debug(msg) - session.window.status_message(msg) - - session.open_location_async(location, flags, group).then(check_success_async) + view = await session.open_location(location, flags, group) + if not view: + uri = get_uri_and_position_from_location(location)[0] + msg = f"Unable to open URI {uri}" + debug(msg) + session.window.status_message(msg) def open_basic_file( @@ -128,9 +125,9 @@ def _select_entry(self, index: int) -> None: if not open_basic_file(session, uri, position, flags): self._window.status_message(f"Unable to open {uri}") else: - sublime.set_timeout_async( - functools.partial( - open_location_async, session, location, self._side_by_side, self._force_group, self._group)) + run_coroutine_threadsafe( + open_location(session, location, self._side_by_side, self._force_group, self._group) + ) else: self._window.focus_view(self._view) # When a group was specified close the current highlighted diff --git a/plugin/lsp_task.py b/plugin/lsp_task.py index 13f656922..bf8d6b4c5 100644 --- a/plugin/lsp_task.py +++ b/plugin/lsp_task.py @@ -1,14 +1,12 @@ from __future__ import annotations +from .core.aio import run_coroutine_threadsafe from .core.registry import LspTextCommand -from .core.settings import userprefs from abc import ABC from abc import abstractmethod -from functools import partial from typing import Any -from typing import Callable -from typing import final from typing_extensions import override +import asyncio import sublime @@ -16,7 +14,7 @@ class LspTask(ABC): """ Base class for tasks that run from `LspTextCommandWithTasks` command. - Note: The whole task runs on the async thread. + Note: The whole task runs on the asyncio thread. """ @classmethod @@ -24,86 +22,21 @@ class LspTask(ABC): def is_applicable(cls, view: sublime.View) -> bool: pass - def __init__(self, task_runner: LspTextCommand, on_done: Callable[[], None]) -> None: - self._task_runner = task_runner - self._on_done = on_done - self._completed = False - self._cancelled = False + def __init__(self, task_runner: LspTextCommand) -> None: + self._text_command = task_runner self._status_key = type(self).__name__ - def run_async(self) -> None: + async def run(self) -> None: self._erase_view_status() - sublime.set_timeout_async(self._on_timeout, userprefs().on_save_task_timeout_ms) - - def _on_timeout(self) -> None: - if not self._completed and not self._cancelled: - self._set_view_status(f'LSP: Timeout processing {self.__class__.__name__}') - self._cancelled = True - self._on_done() - - def cancel(self) -> None: - self._cancelled = True - - def _set_view_status(self, text: str) -> None: - self._task_runner.view.set_status(self._status_key, text) - sublime.set_timeout_async(self._erase_view_status, 5000) def _erase_view_status(self) -> None: - self._task_runner.view.erase_status(self._status_key) - - def _on_complete(self) -> None: - assert not self._completed - self._completed = True - if not self._cancelled: - self._on_done() + self._text_command.view.erase_status(self._status_key) def _purge_changes_async(self) -> None: - if listener := self._task_runner.get_listener(): + if listener := self._text_command.get_listener(): listener.purge_changes_async() -@final -class TasksRunner: - def __init__( - self, text_command: LspTextCommand, tasks: list[type[LspTask]], on_complete: Callable[[], None] - ) -> None: - self._text_command = text_command - self._tasks = tasks - self._on_tasks_completed = on_complete - self._pending_tasks: list[LspTask] = [] - self._canceled = False - - def run(self) -> None: - for task in self._tasks: - if task.is_applicable(self._text_command.view): - self._pending_tasks.append(task(self._text_command, self._on_task_completed_async)) - self._process_next_task() - - def cancel(self) -> None: - for task in self._pending_tasks: - task.cancel() - self._pending_tasks = [] - self._canceled = True - - def _process_next_task(self) -> None: - if self._pending_tasks: - # Even though we might be on an async thread already, we want to give ST a chance to notify us about - # potential document changes. - sublime.set_timeout_async(self._run_next_task_async) - else: - self._on_tasks_completed() - - def _run_next_task_async(self) -> None: - if self._canceled: - return - current_task = self._pending_tasks[0] - current_task.run_async() - - def _on_task_completed_async(self) -> None: - self._pending_tasks.pop(0) - self._process_next_task() - - class LspTextCommandWithTasks(LspTextCommand, ABC): @property @@ -113,22 +46,33 @@ def tasks(self) -> list[type[LspTask]]: def __init__(self, view: sublime.View) -> None: super().__init__(view) - self._tasks_runner: TasksRunner | None = None + self._tasks_runner: asyncio.Task | None = None def on_before_tasks(self) -> None: """Override this to execute code before the task handler starts.""" - def on_tasks_completed(self, **kwargs: dict[str, Any]) -> None: + async def on_tasks_completed(self, **kwargs: dict[str, Any]) -> None: """Override this to execute code when all tasks are completed.""" - def _on_tasks_completed(self, **kwargs: dict[str, Any]) -> None: - self._tasks_runner = None - self.on_tasks_completed(**kwargs) - @override def run(self, edit: sublime.Edit, **kwargs: dict[str, Any]) -> None: + run_coroutine_threadsafe(self._run(**kwargs)) + + async def _run(self, **kwargs: dict[str, Any]) -> None: if self._tasks_runner: - self._tasks_runner.cancel() + if self._tasks_runner.cancel(): + await self._tasks_runner + self._tasks_runner = None self.on_before_tasks() - self._tasks_runner = TasksRunner(self, self.tasks, partial(self._on_tasks_completed, **kwargs)) - self._tasks_runner.run() + self._tasks_runner = asyncio.create_task(run_tasks(self, self.tasks)) + try: + await asyncio.wait_for(self._tasks_runner, timeout=1) + except asyncio.exceptions.TimeoutError: + sublime.status_message('Running "on save" tasks took too long!') + await self.on_tasks_completed(**kwargs) + + +async def run_tasks(text_command: LspTextCommandWithTasks, tasks: list[type[LspTask]]) -> None: + for task in tasks: + if task.is_applicable(text_command.view): + await task(text_command).run() diff --git a/plugin/rename_file.py b/plugin/rename_file.py index 6e73b04ce..297ba72ed 100644 --- a/plugin/rename_file.py +++ b/plugin/rename_file.py @@ -1,5 +1,7 @@ from __future__ import annotations +from .core.aio import call_soon_threadsafe +from .core.aio import run_coroutine_threadsafe from .core.edit import show_summary_message from .core.logging import debug from .core.open import open_file_uri @@ -106,9 +108,14 @@ def run(self, new_name: str, paths: list[str] | None = None, prompt_workspace_ed "prompt_workspace_edits": False } label = f"Rename {Path(old_path).name} -> {new_name}" - sublime.set_timeout_async(lambda: self.prompt_rename_async(file_rename, label, rename_command_args)) + + call_soon_threadsafe(self.prompt_rename_async, file_rename, label, rename_command_args) return - self.rename_path(old_path, new_name).then(lambda success: self.on_rename_path(success, file_rename)) + + async def run() -> None: + self.on_rename_path(await self.rename_path(old_path, new_name), file_rename) + + run_coroutine_threadsafe(run()) def on_rename_path(self, success: bool, file_rename: FileRename) -> None: if success: @@ -155,7 +162,7 @@ def on_prompt_for_workspace_edits_concluded( .then(lambda _: accepted) return Promise.resolve(False) - def rename_path(self, old: str, new: str) -> Promise[bool]: + async def rename_path(self, old: str, new: str) -> bool: old_path = Path(old) new_path = Path(new) restore_files: list[tuple[str, tuple[int, int], list[sublime.Region]]] = [] @@ -173,14 +180,14 @@ def rename_path(self, old: str, new: str) -> Promise[bool]: if (new_dir := new_path.parent) and not new_dir.exists(): new_dir.mkdir(parents=True) try: - old_path.rename(new_path) + old_path.rename(new_path) # noqa: ASYNC240 except Exception as error: sublime.status_message(f"Rename error: {error}") - return Promise.resolve(False) - return Promise.all([ - open_file_uri(self.window, file_name, group=group[0]).then(partial(self.restore_view, selection, group)) - for file_name, group, selection in reversed(restore_files) - ]).then(lambda _: self.focus_view(last_active_view)).then(lambda _: True) + return False + for file_name, group, selection in reversed(restore_files): + self.restore_view(selection, group, await open_file_uri(self.window, file_name, group=group[0])) + self.focus_view(last_active_view) + return True def notify_did_rename(self, file_rename: FileRename) -> None: for session in self.sessions(): diff --git a/plugin/save_command.py b/plugin/save_command.py index fb2731ab2..030551840 100644 --- a/plugin/save_command.py +++ b/plugin/save_command.py @@ -2,6 +2,8 @@ from .code_actions import CodeActionsOnFormatOnSaveTask from .code_actions import CodeActionsOnSaveTask +from .core.aio import call_soon_threadsafe +from .core.logging import trace from .formatting import FormatOnSaveTask from .formatting import WillSaveWaitTask from .lsp_task import LspTask @@ -30,11 +32,12 @@ def tasks(self) -> list[type[LspTask]]: @override def on_before_tasks(self) -> None: - sublime.set_timeout_async(self._trigger_on_pre_save_async) + call_soon_threadsafe(self._trigger_on_pre_save_async) @override - def on_tasks_completed(self, **kwargs: dict[str, Any]) -> None: + async def on_tasks_completed(self, **kwargs: dict[str, Any]) -> None: # Triggered from set_timeout to preserve original semantics of on_pre_save handling + trace() sublime.set_timeout(lambda: self.view.run_command('save', kwargs)) def _trigger_on_pre_save_async(self) -> None: diff --git a/plugin/session_buffer.py b/plugin/session_buffer.py index ceb110966..dcd7c0b2c 100644 --- a/plugin/session_buffer.py +++ b/plugin/session_buffer.py @@ -33,6 +33,7 @@ from .api import LspPlugin from .code_lens import CodeLensCache from .code_lens import LspToggleCodeLensesCommand +from .core.aio import TaskContainer from .core.constants import AUTO_CLOSE_BRACKETS from .core.constants import ChangeEventAction from .core.constants import CODE_LENS_ANNOTATION_SCOPE @@ -84,16 +85,20 @@ from typing import Any from typing import Callable from typing import cast +from typing import Coroutine +from typing import Union from typing_extensions import Concatenate from typing_extensions import deprecated from typing_extensions import ParamSpec from typing_extensions import TypeGuard from weakref import WeakSet +import asyncio import itertools import sublime import time P = ParamSpec('P') +MaybeCoroutine = Union[None, Coroutine[None, None, None]] # If the total number of characters in the file exceeds this limit, try to send a semantic tokens request only for the # visible part first when the file was just opened @@ -145,7 +150,7 @@ def __init__(self) -> None: self.pending_response: int | None = None -class SessionBuffer: +class SessionBuffer(TaskContainer): """ Holds state per session per buffer. @@ -155,6 +160,7 @@ class SessionBuffer: """ def __init__(self, session_view: SessionViewProtocol, buffer_id: int, uri: DocumentUri) -> None: + super().__init__() view = session_view.view self.opened = False # Every SessionBuffer has its own personal capabilities due to "dynamic registration". @@ -175,7 +181,7 @@ def __init__(self, session_view: SessionViewProtocol, buffer_id: int, uri: Docum self._document_diagnostic_pending_requests: dict[DiagnosticsIdentifier, PendingDocumentDiagnosticRequest | None] = {} # noqa: E501 self._last_synced_version = 0 self._last_text_change_time = 0.0 - self._diagnostics_debouncer_async = DebouncerNonThreadSafe(async_thread=True) + self._diagnostics_debouncer_async = DebouncerNonThreadSafe(self) self._color_phantoms = sublime.PhantomSet(view, "lsp_color") self._document_links: list[DocumentLink] = [] self.semantic_tokens = SemanticTokensData() @@ -193,7 +199,11 @@ def __init__(self, session_view: SessionViewProtocol, buffer_id: int, uri: Docum self._on_type_formatting_triggers: tuple[str, ...] = () self._update_supported_commands() self._update_on_type_formatting_triggers() - self._update_color_scheme_rules(view) + try: + self._update_color_scheme_rules(view) + except KeyError: + # Happens when the view is already closed in the meantime. + pass @property def session(self) -> Session: @@ -220,7 +230,11 @@ def _check_did_open(self, view: sublime.View) -> None: if not language_id: # we're closing return - self.session.send_notification(did_open(view, language_id)) + try: + self.session.send_notification_async(did_open(view, language_id)) + except MissingUriError: + # Closed tab. Just forget about it. + return self.opened = True version = view.change_count() self._last_synced_version = version @@ -235,7 +249,7 @@ def _check_did_open(self, view: sublime.View) -> None: self.do_code_lenses_async(view) if userprefs().link_highlight_style in {"underline", "none"}: self._do_document_link_async(view, version) - self.session.notify_plugin_on_session_buffer_change(self) + self.session.notify_plugin_on_session_buffer_change_async(self) def _check_did_close(self, view: sublime.View) -> None: if self.opened and self.should_notify_did_close(): @@ -271,13 +285,14 @@ def add_session_view(self, sv: SessionViewProtocol) -> None: self.session_views.add(sv) sv.handle_code_lenses_async(self._filter_supported_code_lenses()) - def remove_session_view(self, sv: SessionViewProtocol) -> None: + async def remove_session_view(self, sv: SessionViewProtocol) -> list[Exception]: self._clear_semantic_token_regions(sv.view) self.session_views.remove(sv) if len(self.session_views) == 0: - self._on_before_destroy(sv.view) + return await self._on_before_destroy(sv.view) + return [] - def _on_before_destroy(self, view: sublime.View) -> None: + async def _on_before_destroy(self, view: sublime.View) -> list[Exception]: self.remove_all_inlay_hints() # With pull diagnostics, the client is responsible to update or clear diagnostics when appropriate. # Clear all diagnostics for this view if the file is outside of the workspace folders, so that they don't @@ -294,6 +309,7 @@ def _on_before_destroy(self, view: sublime.View) -> None: # Only send textDocument/didClose when we are the only view left (i.e. there are no other clones). self._check_did_close(view) self.session.unregister_session_buffer_async(self) + return await self.cancel_all_tasks() def register_capability_async( self, @@ -398,7 +414,7 @@ def on_text_changed_async( .then(partial(self._on_type_formatting_result_async, view, change_count)) else: debounced(lambda: self.purge_changes_async(view), FEATURES_TIMEOUT, - lambda: view.is_valid() and change_count == view.change_count(), async_thread=True) + lambda: view.is_valid() and change_count == view.change_count()) def _cancel_pending_requests_async(self) -> None: for identifier, pending_request in self._document_diagnostic_pending_requests.items(): @@ -413,7 +429,7 @@ def on_revert_async(self, view: sublime.View) -> None: self._pending_changes = None # Don't bother with pending changes version = view.change_count() self.session.send_notification(did_change(view, version, None)) - sublime.set_timeout_async(lambda: self._on_after_change_async(view, version)) + self._on_after_change_async(view, version) on_reload_async = on_revert_async @@ -430,15 +446,14 @@ def purge_changes_async(self, view: sublime.View, suppress_requests: bool = Fals changes = self._pending_changes.changes version = self._pending_changes.version try: - notification = did_change(view, version, changes) - self.session.send_notification(notification) + self.create_task(self.session.notify(did_change(view, version, changes))) self._last_synced_version = version except MissingUriError: return # we're closing finally: self._pending_changes = None - self.session.notify_plugin_on_session_buffer_change(self) - sublime.set_timeout_async(lambda: self._on_after_change_async(view, version, suppress_requests)) + self.session.notify_plugin_on_session_buffer_change_async(self) + self._on_after_change_async(view, version, suppress_requests) def _on_after_change_async(self, view: sublime.View, version: int, suppress_requests: bool = False) -> None: if self._is_saving: @@ -520,11 +535,18 @@ def _reset_pending_refresh(self, flags: RequestFlags) -> None: """Reset the refresh marker for the request type(s) given by `flags`.""" self.pending_refreshes &= ~flags - def _if_view_unchanged(self, f: Callable[Concatenate[sublime.View, P], None], version: int) -> Callable[P, None]: + def _if_view_unchanged( + self, + f: Callable[Concatenate[sublime.View, P], MaybeCoroutine | None], + version: int + ) -> Callable[P, None]: """Ensures that the view is at the same version when we were called, before calling the `f` function.""" def handler(*args: P.args, **kwargs: P.kwargs) -> None: if (view := self.some_view()) and view.change_count() == version: - f(view, *args, **kwargs) + if asyncio.iscoroutinefunction(f): + self.create_task(f(view, *args, **kwargs)) + else: + f(view, *args, **kwargs) return handler @@ -635,11 +657,12 @@ def do_document_diagnostic_async(self, view: sublime.View, version: int, *, forc # If the document content changed in the meanwhile, new diagnostic requests will automatically be triggered # from _on_after_change_async after the didChange notification. return + for identifier in self.session.diagnostics.get_identifiers(view): - self._do_document_diagnostic_async(view, identifier, version, forced_update=forced_update) + self.create_task(self._do_document_diagnostic(view, identifier, version, forced_update=forced_update)) self._reset_pending_refresh(RequestFlags.DIAGNOSTIC) - def _do_document_diagnostic_async( + async def _do_document_diagnostic( self, view: sublime.View, identifier: DiagnosticsIdentifier, version: int, *, forced_update: bool = False ) -> None: if version == self._diagnostics_versions.get(identifier, -1) and not forced_update: @@ -653,45 +676,37 @@ def _do_document_diagnostic_async( params['identifier'] = identifier if (result_id := self.session.diagnostics_result_ids.get((self._last_known_uri, identifier))) is not None: params['previousResultId'] = result_id - request_id = self.session.send_request_async( - Request.documentDiagnostic(params, view), - partial(self._on_document_diagnostic_async, identifier, version), - partial(self._on_document_diagnostic_error_async, view, identifier, version) - ) - self._document_diagnostic_pending_requests[identifier] = \ - PendingDocumentDiagnosticRequest(version, request_id) - - def _on_document_diagnostic_async( - self, identifier: DiagnosticsIdentifier, version: int, response: DocumentDiagnosticReport - ) -> None: - self._diagnostics_versions[identifier] = version - self._document_diagnostic_pending_requests[identifier] = None - self.session.diagnostics_result_ids[(self._last_known_uri, identifier)] = response.get('resultId') - if is_related_full_document_diagnostic_report(response): - self.session.handle_diagnostics_async(self._last_known_uri, identifier, version, response['items']) - if related_documents := response.get('relatedDocuments'): - for uri, diagnostic_report in related_documents.items(): - uri = normalize_uri(uri) - self.session.diagnostics_result_ids[(uri, identifier)] = diagnostic_report.get('resultId') - if is_full_document_diagnostic_report(diagnostic_report): - self.session.handle_diagnostics_async(uri, identifier, None, diagnostic_report['items']) - - def _on_document_diagnostic_error_async( - self, view: sublime.View, identifier: DiagnosticsIdentifier, version: int, error: ResponseError - ) -> None: - self._document_diagnostic_pending_requests[identifier] = None - if error['code'] == LSPErrorCodes.ServerCancelled: - data = error.get('data') - if is_diagnostic_server_cancellation_data(data) and data['retriggerRequest']: - # Retrigger the request after a short delay, but only if there are no additional changes to the buffer - # in the meanwhile, because in that case a new request will be sent automatically after the didChange - # notification. - if version != view.change_count(): - return - sublime.set_timeout_async( - lambda: self._if_view_unchanged(self._do_document_diagnostic_async, version)(identifier, version), - DOCUMENT_DIAGNOSTICS_RETRIGGER_DELAY - ) + req = self.session.request(Request.documentDiagnostic(params, view)) + self._document_diagnostic_pending_requests[identifier] = PendingDocumentDiagnosticRequest(version, req.id) + error: Error | None = None + try: + response = await req + self._diagnostics_versions[identifier] = version + self.session.diagnostics_result_ids[(self._last_known_uri, identifier)] = response.get('resultId') + if is_related_full_document_diagnostic_report(response): + self.session.handle_diagnostics_async(self._last_known_uri, identifier, version, response['items']) + if related_documents := response.get('relatedDocuments'): + for uri, diagnostic_report in related_documents.items(): + uri = normalize_uri(uri) + self.session.diagnostics_result_ids[(uri, identifier)] = diagnostic_report.get('resultId') + if is_full_document_diagnostic_report(diagnostic_report): + self.session.handle_diagnostics_async(uri, identifier, None, diagnostic_report['items']) + except Error as e: + error = e + finally: + self._document_diagnostic_pending_requests[identifier] = None + if ( + error + and error.code == LSPErrorCodes.ServerCancelled + and is_diagnostic_server_cancellation_data(error.data) + and error.data['retriggerRequest'] + ): + # Retrigger the request after a short delay, but only if there are no additional changes to the + # buffer in the meanwhile, because in that case a new request will be sent automatically after the + # didChange notification. + await asyncio.sleep(DOCUMENT_DIAGNOSTICS_RETRIGGER_DELAY) + if version == view.change_count(): + self.create_task(self._do_document_diagnostic(view, identifier, version)) # --- textDocument/publishDiagnostics ------------------------------------------------------------------------------ @@ -807,7 +822,7 @@ def _on_type_formatting_result_async( self, view: sublime.View, version: int, result: list[TextEdit] | Error | None ) -> None: if result and not isinstance(result, Error) and version == view.change_count(): - apply_text_edits(view, result) + self.create_task(apply_text_edits(view, result)) # --- textDocument/semanticTokens ---------------------------------------------------------------------------------- @@ -992,13 +1007,25 @@ def remove_all_inlay_hints(self) -> None: # --- textDocument/codeAction -------------------------------------------------------------------------------------- def request_code_actions_async( + self, + view: sublime.View, + region: sublime.Region, + diagnostics: list[Diagnostic], + kinds: list[str | CodeActionKind] | None = None, + trigger_kind: CodeActionTriggerKind = CodeActionTriggerKind.Automatic, + ) -> Promise[list[Command | CodeAction] | BaseException | None]: + return Promise.wrap_task( + self.create_task(self.request_code_actions(view, region, diagnostics, kinds, trigger_kind)) + ) + + async def request_code_actions( self, view: sublime.View, region: sublime.Region, diagnostics: list[Diagnostic], kinds: list[str | CodeActionKind] | None = None, trigger_kind: CodeActionTriggerKind = CodeActionTriggerKind.Automatic - ) -> Promise[list[Command | CodeAction] | Error | None]: + ) -> list[Command | CodeAction] | Error | None: context: CodeActionContext = { 'diagnostics': diagnostics, 'triggerKind': trigger_kind @@ -1010,8 +1037,7 @@ def request_code_actions_async( 'range': region_to_range(view, region), 'context': context } - request = Request.codeAction(params, view) - return self.session.send_request_task(request) + return await self.session.request(Request.codeAction(params, view)) # --- textDocument/codeLens ---------------------------------------------------------------------------------------- diff --git a/plugin/session_view.py b/plugin/session_view.py index 6bef5477b..5b4038ac4 100644 --- a/plugin/session_view.py +++ b/plugin/session_view.py @@ -84,7 +84,7 @@ def __init__(self, listener: AbstractViewListener, session: Session, uri: Docume self._clear_auto_complete_triggers(settings) self._setup_auto_complete_triggers(settings) - def on_before_remove(self) -> None: + async def on_before_remove(self) -> list[Exception]: settings: sublime.Settings = self.view.settings() self._clear_auto_complete_triggers(settings) self.clear_code_lenses_async() @@ -96,7 +96,7 @@ def on_before_remove(self) -> None: for request_id, data in self._active_requests.items(): if data.request.view and not data.canceled: self.session.cancel_request_async(request_id) - self.session.unregister_session_view_async(self) + await self.session.unregister_session_view(self) self.session.config.erase_view_status(self.view) for severity in reversed(DIAGNOSTIC_STYLES.keys()): self.view.erase_regions(f"{self.diagnostics_key(severity, False)}_icon") @@ -104,9 +104,10 @@ def on_before_remove(self) -> None: self.view.erase_regions(f"{self.diagnostics_key(severity, True)}_icon") self.view.erase_regions(f"{self.diagnostics_key(severity, True)}_underline") self.view.erase_regions(RegionKey.DOCUMENT_LINK) - self.session_buffer.remove_session_view(self) + exceptions = await self.session_buffer.remove_session_view(self) if listener := self.listener(): listener.on_diagnostics_updated_async(self.session_buffer, False) + return exceptions def on_initialized(self) -> None: self.session_buffer.on_session_view_initialized(self._view) @@ -289,9 +290,10 @@ def on_capability_removed_async(self, registration_id: str, discarded_capabiliti def has_capability_async(self, capability_path: str) -> bool: return self.session_buffer.has_capability(capability_path) - def shutdown_async(self) -> None: + async def shutdown(self) -> list[Exception]: if listener := self.listener(): - listener.on_session_shutdown_async(self.session) + return await listener.on_session_shutdown(self.session) + return [] def diagnostics_key(self, severity: DiagnosticSeverity, multiline: bool) -> str: return "lsp{}d{}{}".format(self.session.config.name, "m" if multiline else "s", severity) diff --git a/plugin/symbols.py b/plugin/symbols.py index 45cd7ad01..52b18f3f4 100644 --- a/plugin/symbols.py +++ b/plugin/symbols.py @@ -8,6 +8,7 @@ from ..protocol import SymbolKind from ..protocol import SymbolTag from ..protocol import WorkspaceSymbol +from .core.aio import run_coroutine_threadsafe from .core.constants import SYMBOL_KINDS from .core.input_handlers import DynamicListInputHandler from .core.input_handlers import PreselectedListInputHandler @@ -328,25 +329,23 @@ class LspWorkspaceSymbolsCommand(LspWindowCommand): capability = 'workspaceSymbolProvider' def run(self, symbol: WorkspaceSymbolValue) -> None: + run_coroutine_threadsafe(self._run(symbol)) + + async def _run(self, symbol: WorkspaceSymbolValue) -> None: session_name = symbol['session'] if session := self.session_by_name(session_name): if location := symbol.get('location'): - session.open_location_async(location, sublime.NewFileFlags.ENCODED_POSITION) + await session.open_location(location, sublime.NewFileFlags.ENCODED_POSITION) elif workspace_symbol := symbol.get('workspaceSymbol'): - session.send_request( - Request.resolveWorkspaceSymbol(workspace_symbol), - partial(self._on_resolved_symbol_async, session_name)) + workspace_symbol = await session.request(Request.resolveWorkspaceSymbol(workspace_symbol)) + location = cast('Location', workspace_symbol['location']) + await session.open_location(location, sublime.NewFileFlags.ENCODED_POSITION) def input(self, args: dict[str, Any]) -> sublime_plugin.ListInputHandler | None: if 'symbol' not in args: return WorkspaceSymbolsInputHandler(self, args) return None - def _on_resolved_symbol_async(self, session_name: str, response: WorkspaceSymbol) -> None: - if session := self.session_by_name(session_name): - location = cast('Location', response['location']) - session.open_location_async(location, sublime.NewFileFlags.ENCODED_POSITION) - class WorkspaceSymbolsInputHandler(DynamicListInputHandler): diff --git a/plugin/tooling.py b/plugin/tooling.py index 3faaeeb64..882540380 100644 --- a/plugin/tooling.py +++ b/plugin/tooling.py @@ -4,11 +4,13 @@ from .api import LspPlugin from .api import OnPreStartContext from .api import PluginStartError +from .core.aio import run_coroutine_threadsafe from .core.css import css from .core.logging import debug from .core.registry import windows from .core.transports import TransportCallbacks from .core.transports import TransportWrapper +from .core.types import ClientConfig from .core.version import __version__ from .core.views import extract_variables from .core.views import make_command_link @@ -21,18 +23,19 @@ from typing import Callable from typing import cast from typing import TYPE_CHECKING +import asyncio import json import mdpopups import os import sublime import sublime_plugin import textwrap +import traceback import urllib.parse import urllib.request if TYPE_CHECKING: from .core.types import Capabilities - from .core.types import ClientConfig from .session_buffer import SessionBuffer @@ -326,19 +329,15 @@ def on_selected(self, selected_index: int, configs: list[ClientConfig], active_v output_sheet = mdpopups.new_html_sheet( self.window, f'Server: {config.name}', '# Running server test...', css=css().sheets, wrapper_class=css().sheets_classname) - sublime.set_timeout_async(lambda: self.test_run_server_async(config, self.window, active_view, output_sheet)) - - def test_run_server_async(self, config: ClientConfig, window: sublime.Window, - active_view: sublime.View, output_sheet: sublime.HtmlSheet) -> None: - server = ServerTestRunner( - config, window, active_view, + # Store the instance so that it's not GC'ed before it's finished. + self.test_runner: ServerTestRunner | None = ServerTestRunner( + config, self.window, active_view, lambda resolved_command, output, exit_code: self.update_sheet( config, active_view, output_sheet, resolved_command, output, exit_code)) - # Store the instance so that it's not GC'ed before it's finished. - self.test_runner: ServerTestRunner | None = server + run_coroutine_threadsafe(self.test_runner.run()) def update_sheet(self, config: ClientConfig, active_view: sublime.View | None, output_sheet: sublime.HtmlSheet, - resolved_command: list[str], server_output: str, exit_code: int) -> None: + resolved_command: list[str] | None, server_output: str, exit_code: int) -> None: self.test_runner = None frontmatter = mdpopups.format_frontmatter({'allow_code_wrap': True}) contents = self.get_contents(config, active_view, resolved_command, server_output, exit_code) @@ -348,7 +347,7 @@ def update_sheet(self, config: ClientConfig, active_view: sublime.View | None, o formatted = f'{frontmatter}{copy_link}\n{contents}' mdpopups.update_html_sheet(output_sheet, formatted, css=css().sheets, wrapper_class=css().sheets_classname) - def get_contents(self, config: ClientConfig, active_view: sublime.View | None, resolved_command: list[str], + def get_contents(self, config: ClientConfig, active_view: sublime.View | None, resolved_command: list[str] | None, server_output: str, exit_code: int) -> str: lines = [] @@ -365,8 +364,9 @@ def line(s: str) -> None: line(f' - exit code: {exit_code}\n - output\n{self.code_block(server_output)}') line('## Server Configuration') - line(f' - command\n{self.json_dump(config.command)}') - line(' - shell command\n{}'.format(self.code_block(list2cmdline(resolved_command), 'sh'))) + if resolved_command: + line(f' - command\n{self.json_dump(config.command)}') + line(' - shell command\n{}'.format(self.code_block(list2cmdline(resolved_command), 'sh'))) line(f' - selector\n{self.code_block(config.selector)}') line(f' - priority_selector\n{self.code_block(config.priority_selector)}') line(' - init_options') @@ -495,45 +495,57 @@ def __init__( config: ClientConfig, window: sublime.Window, initiating_view: sublime.View, - on_close: Callable[[list[str], str, int], None] + on_close: Callable[[list[str] | None, str, int], None] ) -> None: + self._config = config + self._window = window + self._initiating_view = initiating_view self._on_close = on_close self._transport: TransportWrapper | None = None - self._resolved_command: list[str] = [] + self._resolved_command: list[str] | None = None self._stderr_lines: list[str] = [] + + async def run(self) -> None: + view = self._initiating_view + file_path = view.file_name() or '' + config = ClientConfig.from_config(self._config, {}) + loop = asyncio.get_running_loop() + try: - variables = extract_variables(window) + workspace = ProjectFolders(self._window) + workspace_folders = sorted_workspace_folders(workspace.folders, file_path) plugin_class = get_plugin(config.name) - workspace = ProjectFolders(window) - workspace_folders = sorted_workspace_folders(workspace.folders, initiating_view.file_name() or '') - cwd = None + variables = extract_variables(self._window) + cwd = workspace_folders[0].path if workspace_folders else None + context = OnPreStartContext(config, variables, view, cwd, workspace_folders) if plugin_class: - # TODO: We should share this common code with WindowManager.start_async - cwd = workspace_folders[0].path if workspace_folders else None - plugin_context = OnPreStartContext(config, variables, initiating_view, cwd, workspace_folders) + # TODO: We should share this common code with WindowManager.start if issubclass(plugin_class, LspPlugin): - plugin_class.on_pre_start_async(plugin_context) + if plugin_class.use_asyncio: + await plugin_class.on_pre_start(context) + else: + await loop.run_in_executor(None, plugin_class.on_pre_start_async, context) + cwd = context.working_directory else: if plugin_class.needs_update_or_installation(): - plugin_class.install_or_update() + await loop.run_in_executor(None, plugin_class.install_or_update) additional_variables = plugin_class.additional_variables() if isinstance(additional_variables, dict): variables.update(additional_variables) - reason = plugin_class.can_start(window, initiating_view, workspace_folders, config) + reason = plugin_class.can_start( + self._window, view, workspace_folders, config) if reason: raise PluginStartError(f'Plugin.can_start() prevented the start due to: {reason}') - if new_cwd := plugin_class.on_pre_start(window, initiating_view, workspace_folders, config): + if new_cwd := plugin_class.on_pre_start(self._window, view, workspace_folders, config): cwd = new_cwd + transport_config = config.create_transport_config() - self._transport = transport_config.start(config.command, config.env, cwd, variables, self) + self._transport = await transport_config.start(config.command, config.env, cwd, variables, self) self._resolved_command = self._transport.process_args - sublime.set_timeout_async(self.force_close_transport, self.CLOSE_TIMEOUT_SEC * 1000) + await asyncio.sleep(self.CLOSE_TIMEOUT_SEC) + await self._transport.close() except Exception as ex: - self.on_transport_close(-1, ex) - - def force_close_transport(self) -> None: - if self._transport: - self._transport.close() + await self.on_transport_close(-1, ex) def on_payload(self, payload: dict[str, Any]) -> None: pass @@ -541,9 +553,11 @@ def on_payload(self, payload: dict[str, Any]) -> None: def on_stderr_message(self, message: str) -> None: self._stderr_lines.append(message) - def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: - self._transport = None - output = str(exception) if exception else '\n'.join(self._stderr_lines).rstrip() + async def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: + if exception: + output = ''.join(traceback.format_exception(type(exception), exception, exception.__traceback__)) + else: + output = '\n'.join(self._stderr_lines).rstrip() sublime.set_timeout(lambda: self._on_close(self._resolved_command, output, exit_code)) diff --git a/stubs/sublime_aio.pyi b/stubs/sublime_aio.pyi new file mode 100644 index 000000000..62f86e266 --- /dev/null +++ b/stubs/sublime_aio.pyi @@ -0,0 +1,84 @@ +import asyncio +import concurrent +import concurrent.futures +import sublime +import sublime_plugin +from _typeshed import Incomplete +from abc import ABCMeta +from collections.abc import Coroutine +from contextvars import Context +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec + +__all__ = ['__version__', 'active_window', 'ApplicationCommand', 'call_coroutine', 'call_soon_threadsafe', 'debounced', 'EventListener', 'InputCancelledError', 'run_coroutine', 'TextChangeListener', 'View', 'ViewCommand', 'ViewEventListener', 'Window', 'WindowCommand', 'windows'] + +P = ParamSpec('P') +T = TypeVar('T') +EL = TypeVar('EL', bound='EventListener') +VEL = TypeVar('VEL', bound='ViewEventListener') +__version__: str + +class ExitEvent: + @classmethod + def aquire(cls) -> None: ... + @classmethod + def release(cls) -> None: ... + @classmethod + def wait(cls) -> None: ... + +def debounced(delay_in_ms: int): ... +def run_coroutine(coro: Coroutine[object, object, T]) -> concurrent.futures.Future[T]: ... +def call_coroutine(coro: Coroutine[object, object, None]) -> asyncio.Handle: ... +def call_soon_threadsafe(callback: Callable[..., None], *args: Any, context: Context | None = None) -> asyncio.Handle: ... +def active_window() -> Window: ... +def windows() -> list[Window]: ... + +class ApplicationCommand(sublime_plugin.ApplicationCommand): + def run_(self, edit_token: int, args: Any) -> None: ... + async def run(self, **kwargs: Any) -> None: ... + +class WindowCommand(sublime_plugin.WindowCommand): + window: Incomplete + def __init__(self, window: sublime.Window) -> None: ... + def run_(self, edit_token: int, args: Any) -> None: ... + async def run(self, **kwargs: Any) -> None: ... + +class ViewCommand(sublime_plugin.TextCommand): + def run_(self, edit_token: int, args: Any) -> None: ... + async def run(self, **kwargs: Any) -> None: ... + +class CoroutineAdapter: + coro_func: Incomplete + def __init__(self, coro_func: Callable[..., Coroutine[object, object, None]]) -> None: ... + def __call__(self, *args, **kwargs: Any) -> None: ... + def callback(self, *args, **kwargs: Any) -> None: ... + +class AsyncEventListenerType(ABCMeta): + def __new__(mcs: type[AsyncEventListenerType], name: str, bases: tuple[type, ...], attrs: dict[str, object]) -> AsyncEventListenerType: ... + +class EventListener(sublime_plugin.EventListener, metaclass=AsyncEventListenerType): ... +class ViewEventListener(sublime_plugin.ViewEventListener, metaclass=AsyncEventListenerType): ... + +class AsyncTextChangeListenerType(ABCMeta): + def __new__(mcs: type[AsyncTextChangeListenerType], name: str, bases: tuple[type, ...], attrs: dict[str, object]) -> AsyncTextChangeListenerType: ... + +class TextChangeListener(sublime_plugin.TextChangeListener, metaclass=AsyncTextChangeListenerType): ... +class InputCancelledError(Exception): ... + +class Window(sublime.Window): + def active_view(self) -> View | None: ... + def new_file(self, flags=..., syntax: str = '') -> View: ... + def open_file(self, fname: str, flags=..., group: int = -1) -> View: ... + def find_open_file(self, fname: str, group: int = -1) -> View | None: ... + def views(self, *, include_transient: bool = False) -> list[View]: ... + def active_view_in_group(self, group: int) -> View | None: ... + def views_in_group(self, group: int) -> list[View]: ... + def transient_view_in_group(self, group: int) -> View | None: ... + def create_output_panel(self, name: str, unlisted: bool = False) -> View: ... + def find_output_panel(self, name: str) -> View | None: ... + async def show_input_panel(self, caption: str, initial_text: str = '', on_change: Callable[[sublime.View, str], Coroutine[object, object, T]] | None = None) -> str: ... + async def show_quick_panel(self, items: list[str] | list[list[str]] | list[sublime.QuickPanelItem], flags: sublime.QuickPanelFlags = ..., selected_index: int = -1, on_highlight: Callable[[int], Coroutine[object, object, T]] | None = None, placeholder: str | None = None) -> int: ... + +class View(sublime.View): + def window(self) -> Window | None: ... + def clones(self) -> list[View]: ... diff --git a/tests/async_test_case.py b/tests/async_test_case.py new file mode 100644 index 000000000..cdaccf6ce --- /dev/null +++ b/tests/async_test_case.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from typing import Callable +from typing import Coroutine +from typing import Protocol +from typing_extensions import override +from unittesting import DeferrableTestCase +import asyncio +import inspect + + +class FutureLike(Protocol): + def done(self) -> bool: ... + def result(self) -> Any: ... + def exception(self) -> BaseException | None: ... + def cancelled(self) -> bool: ... + def add_done_callback(self, fn: Callable[[FutureLike], Any]) -> None: ... + + +class AsyncTestCase(DeferrableTestCase): + timeout_ms: int = 2000 + + @classmethod + def run_coroutine(cls, coro: Coroutine) -> FutureLike: + """Override this method and run the given coroutine (using sublime_aio.run_coroutine for instance).""" + raise NotImplementedError + + @classmethod + def _runCoro(cls, coro: Coroutine[Any, Any, Any]) -> Generator: + + async def withTimeout() -> None: + task = asyncio.create_task(coro) + _, pending = await asyncio.wait({task}, timeout=cls.timeout_ms / 1000, return_when=asyncio.FIRST_COMPLETED) + if task in pending: + print("\n=== BEGIN: COROUTINE STACK BEFORE CANCELLATION ===") + task.print_stack() + print("=== END: COROUTINE STACK BEFORE CANCELLATION ===") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise TimeoutError + await task + + future = cls.run_coroutine(withTimeout()) + + class Signal: + def __init__(self) -> None: + self.done = False + self.exception: BaseException | None = None + + def check(self) -> bool: + if self.exception: + raise self.exception + return self.done + + signal = Signal() + + def onDone(future: FutureLike) -> None: + if ex := future.exception(): + signal.exception = ex + elif future.done(): + signal.done = True + + future.add_done_callback(onDone) + yield {"condition": signal.check, "timeout": cls.timeout_ms} + + @classmethod + async def asyncSetUpClass(cls) -> None: + pass + + @classmethod + async def asyncTearDownClass(cls) -> None: + pass + + async def asyncDoCleanups(self) -> None: + pass + + @override + @classmethod + def setUpClass(cls) -> Generator: + print("setUpClass was called") + yield from cls._runCoro(cls.asyncSetUpClass()) + + @override + @classmethod + def tearDownClass(cls) -> Generator: + print("tearDownClass was called") + yield from cls._runCoro(cls.asyncTearDownClass()) + + @override + def doCleanups(self) -> Generator: + yield from self._runCoro(self.asyncDoCleanups()) + + @override + def _callSetUp(self) -> Generator | None: + deferred = self.setUp() + if isinstance(deferred, Generator): + yield from deferred + elif inspect.iscoroutine(deferred): + yield from self._runCoro(deferred) + + @override + def _callTestMethod(self, method: Callable[[], Coroutine | Generator | None]) -> Generator | None: + deferred = method() + if isinstance(deferred, Generator): + yield from deferred + elif inspect.iscoroutine(deferred): + yield from self._runCoro(deferred) + + @override + def _callTearDown(self) -> Generator | None: + deferred = self.tearDown() + if isinstance(deferred, Generator): + yield from deferred + elif inspect.iscoroutine(deferred): + yield from self._runCoro(deferred) + + @override + def _callCleanup( + self, function: Callable[..., Coroutine | Generator | None], *args: Any, **kwargs: Any + ) -> Generator | None: + deferred = function(*args, **kwargs) + if isinstance(deferred, Generator): + yield from deferred + elif inspect.iscoroutine(deferred): + yield from self._runCoro(deferred) diff --git a/tests/setup.py b/tests/setup.py index 71df41cea..08e4666da 100644 --- a/tests/setup.py +++ b/tests/setup.py @@ -1,27 +1,36 @@ from __future__ import annotations +from .async_test_case import AsyncTestCase +from .async_test_case import FutureLike from .test_mocks import basic_responses -from collections.abc import Generator +from LSP.plugin.core.aio import next_frame +from LSP.plugin.core.aio import run_coroutine_threadsafe from LSP.plugin.core.collections import DottedDict -from LSP.plugin.core.promise import Promise +from LSP.plugin.core.open import open_file from LSP.plugin.core.protocol import Notification from LSP.plugin.core.protocol import Request from LSP.plugin.core.registry import windows from LSP.plugin.core.settings import client_configs from LSP.plugin.core.types import ClientConfig -from LSP.plugin.core.types import ClientStates +from LSP.plugin.core.url import filename_to_uri from LSP.plugin.documents import DocumentSyncListener from os import environ from os.path import join from sublime_plugin import view_event_listeners from typing import Any +from typing import Callable +from typing import Coroutine from typing import TYPE_CHECKING -from unittesting import DeferrableTestCase +from typing_extensions import override +import asyncio import sublime if TYPE_CHECKING: - from collections.abc import Generator - from LSP.plugin.core.promise import Promise + from LSP.plugin.core.sessions import CancellableInflightRequest + from LSP.plugin.core.sessions import Session + from LSP.plugin.core.windows import WindowManager + from LSP.protocol import CodeAction + from LSP.protocol import LSPAny CI = any(key in environ for key in ("TRAVIS", "CI", "GITHUB_ACTIONS")) @@ -29,24 +38,6 @@ text_config = ClientConfig(name="textls", selector="text.plain", command=[], tcp_port=None) -class YieldPromise: - __slots__ = ("__done", "__result") - - def __init__(self) -> None: - self.__done = False - - def __call__(self) -> bool: - return self.__done - - def fulfill(self, result: Any = None) -> None: - assert not self.__done - self.__result = result - self.__done = True - - def result(self) -> Any: - return self.__result - - def make_stdio_test_config(name: str, init_options: dict[str, Any] | None = None) -> ClientConfig: """Create a config for starting the fake language server in STDIO mode.""" return ClientConfig( @@ -96,10 +87,11 @@ def remove_config(config: ClientConfig) -> None: client_configs.remove_for_testing(config) -def close_test_view(view: sublime.View | None) -> Generator: +async def close_test_view(view: sublime.View | None) -> None: if view: view.set_scratch(True) - yield {"condition": lambda: not view.is_loading(), "timeout": TIMEOUT_TIME} + while view.is_loading(): # noqa: ASYNC110 + await asyncio.sleep(0.05) view.close() @@ -107,45 +99,74 @@ def expand(s: str, w: sublime.Window) -> str: return sublime.expand_variables(s, w.extract_variables()) -class TextDocumentTestCase(DeferrableTestCase): +class SublimeAioTestCase(AsyncTestCase): + timeout_ms = TIMEOUT_TIME + + @classmethod + def run_coroutine(cls, coro: Coroutine) -> FutureLike: + return run_coroutine_threadsafe(coro) + + +class TextDocumentTestCase(SublimeAioTestCase): + + config: ClientConfig + wm: WindowManager + view: sublime.View + session: Session + @classmethod def get_stdio_test_config(cls) -> ClientConfig: return make_stdio_test_config("TEST") + @override @classmethod - def setUpClass(cls) -> Generator: - super().setUpClass() + async def asyncSetUpClass(cls) -> None: + print("asyncSetUpClass") test_name = cls.get_test_name() server_capabilities = cls.get_test_server_capabilities() window = sublime.active_window() filename = expand(join("$packages", "LSP", "tests", f"{test_name}.txt"), window) - open_view = window.find_open_file(filename) - yield from close_test_view(open_view) + await close_test_view(window.find_open_file(filename)) cls.config = cls.get_stdio_test_config() cls.config.initialization_options.set("serverResponse", server_capabilities) add_config(cls.config) - cls.wm = windows.lookup(window) - cls.view = window.open_file(filename) - yield {"condition": lambda: not cls.view.is_loading(), "timeout": TIMEOUT_TIME} - yield cls.ensure_document_listener_created - yield {"condition": lambda: cls.wm.get_session(cls.config.name, filename) is not None, "timeout": TIMEOUT_TIME} - cls.session = cls.wm.get_session(cls.config.name, filename) - yield {"condition": lambda: cls.session.state == ClientStates.READY, "timeout": TIMEOUT_TIME} - cls.initialize_params = yield from cls.await_message("initialize") - yield from cls.await_message("initialized") - yield from close_test_view(cls.view) - - def setUp(self) -> Generator: + if wm := windows.lookup(window): + cls.wm = wm + else: + raise AssertionError("unable to find WindowManager") + if view := await open_file(window, filename_to_uri(filename)): + cls.view = view + else: + raise AssertionError(f"unable to open file {filename}") + if listener := cls.ensure_document_listener_created(): + print("starting", cls.config) + if session := await cls.wm.start(cls.config, listener): + cls.session = session + else: + raise AssertionError("unable to start session") + else: + raise AssertionError(f"unable to find listener for view {cls.view.id()}") + print("awaiting initialize request") + cls.initialize_params = await cls.await_message("initialize") + print("awaiting initialized notification") + await cls.await_message("initialized") + + @override + async def setUp(self) -> None: + print("setUp") window = sublime.active_window() filename = expand(join("$packages", "LSP", "tests", f"{self.get_test_name()}.txt"), window) - open_view = window.find_open_file(filename) - if not open_view: - self.__class__.view = window.open_file(filename) - yield {"condition": lambda: not self.view.is_loading(), "timeout": TIMEOUT_TIME} - self.assertTrue(self.wm.get_config_manager().match_view(self.view, self.wm.workspace_folders)) + if view := await open_file(sublime.active_window(), filename_to_uri(filename)): + self.__class__.view = view + else: + raise AssertionError(f"unable to open file {filename}") self.init_view_settings() - yield self.ensure_document_listener_created - params = yield from self.await_message("textDocument/didOpen") + self.assertIsNotNone(self.ensure_document_listener_created()) + params = await self.await_message("textDocument/didOpen") + self.assertIsInstance(params, dict) + assert isinstance(params, dict) + self.assertIsInstance(params["textDocument"], dict) + assert isinstance(params["textDocument"], dict) self.assertEqual(params["textDocument"]["version"], 0) @classmethod @@ -156,9 +177,9 @@ def get_test_name(cls) -> str: def get_test_server_capabilities(cls) -> dict: return basic_responses["initialize"] - @classmethod - def init_view_settings(cls) -> None: - s = cls.view.settings().set + def init_view_settings(self) -> None: + assert self.view + s = self.view.settings().set s("auto_complete_selector", "text") s("ensure_newline_at_eof_on_save", False) s("rulers", []) @@ -168,7 +189,7 @@ def init_view_settings(cls) -> None: s("lsp_format_on_save", False) @classmethod - def ensure_document_listener_created(cls) -> bool: + def ensure_document_listener_created(cls) -> DocumentSyncListener | None: assert cls.view # Bug in ST3? Either that, or CI runs with ST window not in focus and that makes ST3 not trigger some # events like on_load_async, on_activated, on_deactivated. That makes things not properly initialize on @@ -176,12 +197,17 @@ def ensure_document_listener_created(cls) -> bool: # Revisit this once we're on ST4. for listener in view_event_listeners[cls.view.id()]: if isinstance(listener, DocumentSyncListener): - sublime.set_timeout_async(listener.on_activated_async) - return True - return False + return listener + return None + + @staticmethod + async def wait_until_st_state(condition: Callable[[], bool]) -> None: + """Returns when the given state has been reached.""" + while not condition(): + await next_frame() @classmethod - def await_message(cls, method: str, promise: YieldPromise | None = None) -> Generator[Any, None, Any]: + def await_message(cls, method: str) -> CancellableInflightRequest[LSPAny]: """ Awaits until server receives a request with a specified method. @@ -190,130 +216,76 @@ def await_message(cls, method: str, promise: YieldPromise | None = None) -> Gene request yet, it will wait for it and then respond. :param method: The method type that we are awaiting response for. - :param promise: The optional promise to fullfill on response. - :returns: A generator with resolved value. + :returns: resolved value. """ # cls.assertIsNotNone(cls.session) assert cls.session - if promise is None: - promise = YieldPromise() - - def handler(params: Any) -> None: - promise.fulfill(params) + return cls.session.request(Request("$test/getReceived", {"method": method})) - def error_handler(params: Any) -> None: - print("Got error:", params, "awaiting timeout :(") - - cls.session.send_request(Request("$test/getReceived", {"method": method}), handler, error_handler) - yield from cls.await_promise(promise) - return promise.result() # noqa: B901 - - def make_server_do_fake_request(self, method: str, params: Any) -> YieldPromise: - promise = YieldPromise() - - def on_result(params: Any) -> None: - promise.fulfill(params) - - def on_error(params: Any) -> None: - promise.fulfill(params) - - req = Request("$test/fakeRequest", {"method": method, "params": params}) - self.session.send_request(req, on_result, on_error) - return promise + @classmethod + def make_server_do_fake_request(cls, method: str, params: LSPAny) -> CancellableInflightRequest[LSPAny]: + """Make the fake server do an arbitrary request.""" + assert cls.session + return cls.session.request(Request("$test/fakeRequest", {"method": method, "params": params})) @classmethod - def await_promise(cls, promise: YieldPromise | Promise) -> Generator[Any, None, Any]: - if isinstance(promise, YieldPromise): - yielder = promise - else: - yielder = YieldPromise() - promise.then(yielder.fulfill) - yield {"condition": yielder, "timeout": TIMEOUT_TIME} - return yielder.result() # noqa: B901 - - def await_run_code_action(self, code_action: dict[str, Any]) -> Generator: - promise = YieldPromise() - sublime.set_timeout_async( - lambda: self.session.run_code_action_async(code_action, progress=False, view=self.view).then( - promise.fulfill - ) - ) - yield from self.await_promise(promise) - - def set_response(self, method: str, response: Any) -> None: + async def await_run_code_action(cls, code_action: CodeAction) -> LSPAny: + assert cls.session + return await cls.session.run_code_action(code_action, progress=False, view=cls.view) + + async def mock_response(self, method: str, response: LSPAny) -> None: + """Set up what the fake server should reply when it receives this method.""" self.assertIsNotNone(self.session) assert self.session - self.session.send_notification(Notification("$test/setResponse", {"method": method, "response": response})) + await self.session.notify(Notification("$test/setResponse", {"method": method, "response": response})) - def set_responses(self, responses: list[tuple[str, Any]]) -> Generator: + async def mock_responses(self, responses: list[tuple[str, LSPAny]]) -> None: + """Set up what the fake server should reply, given these request methods.""" self.assertIsNotNone(self.session) assert self.session - promise = YieldPromise() - - def handler(params: Any) -> None: - promise.fulfill(params) - - def error_handler(params: Any) -> None: - print("Got error:", params, "awaiting timeout :(") - payload = [{"method": method, "response": responses} for method, responses in responses] - self.session.send_request(Request("$test/setResponses", payload), handler, error_handler) - yield from self.await_promise(promise) + await self.session.request(Request("$test/setResponses", payload)) - def await_client_notification(self, method: str, params: Any = None) -> Generator: + async def mock_client_notification(self, method: str, params: LSPAny = None) -> LSPAny: + """Emit an arbitrary notification from the fake server.""" self.assertIsNotNone(self.session) assert self.session - promise = YieldPromise() - - def handler(params: Any) -> None: - promise.fulfill(params) + await self.session.request(Request("$test/sendNotification", {"method": method, "params": params})) + return params - def error_handler(params: Any) -> None: - print("Got error:", params, "awaiting timeout :(") - - req = Request("$test/sendNotification", {"method": method, "params": params}) - self.session.send_request(req, handler, error_handler) - yield from self.await_promise(promise) - - def await_clear_view_and_save(self) -> Generator: + async def await_clear_view_and_save(self) -> None: assert isinstance(self.view, sublime.View) self.view.run_command("select_all") self.view.run_command("left_delete") self.view.run_command("save") - yield from self.await_message("textDocument/didChange") - yield from self.await_message("textDocument/didSave") + await self.await_message("textDocument/didChange") + await self.await_message("textDocument/didSave") - def await_view_change(self, expected_change_count: int) -> Generator: + async def await_view_change(self, expected_change_count: int) -> None: assert isinstance(self.view, sublime.View) - - def condition() -> bool: - nonlocal self, expected_change_count - assert self.view - v = self.view - return v.change_count() == expected_change_count - - yield {"condition": condition, "timeout": TIMEOUT_TIME} + await self.wait_until_st_state(lambda: self.view.change_count() == expected_change_count) def insert_characters(self, characters: str) -> int: assert isinstance(self.view, sublime.View) self.view.run_command("insert", {"characters": characters}) return self.view.change_count() + @override @classmethod - def tearDownClass(cls) -> Generator: - if cls.session and cls.wm: - sublime.set_timeout_async(cls.session.end_async) - yield lambda: cls.session.state == ClientStates.STOPPING - if cls.view: - yield lambda: cls.wm.get_session(cls.config.name, cls.view.file_name()) is None - cls.session = None - cls.wm = None - # restore the user's configs - remove_config(cls.config) - super().tearDownClass() - - def doCleanups(self) -> Generator: - if self.view and self.view.is_valid(): - yield from close_test_view(self.view) - yield from super().doCleanups() + async def asyncTearDownClass(cls) -> None: + try: + if cls.session and cls.wm: + await cls.session.end() + finally: + # restore the user's configs + remove_config(cls.config) + await super().asyncTearDownClass() + + @override + async def asyncDoCleanups(self) -> None: + try: + if self.view and self.view.is_valid(): + await close_test_view(self.view) + except Exception: + pass diff --git a/tests/test_code_actions.py b/tests/test_code_actions.py index 8da252a12..734fdcc83 100644 --- a/tests/test_code_actions.py +++ b/tests/test_code_actions.py @@ -14,19 +14,23 @@ from LSP.plugin.core.views import kind_contains_other_kind from LSP.plugin.core.views import versioned_text_document_identifier from LSP.plugin.documents import DocumentSyncListener -from typing import Any -from typing import Generator from typing import TYPE_CHECKING +import asyncio import unittest if TYPE_CHECKING: + from LSP.protocol import CodeAction + from LSP.protocol import Command from LSP.protocol import Range + from LSP.protocol import TextEdit + from LSP.protocol import WorkspaceEdit + from typing import Any import sublime TEST_FILE_URI = filename_to_uri(TEST_FILE_PATH) -def edit_to_lsp(edit: tuple[str, Range]) -> dict[str, Any]: +def edit_to_lsp(edit: tuple[str, Range]) -> TextEdit: return {"newText": edit[0], "range": edit[1]} @@ -37,7 +41,7 @@ def range_from_points(start: Point, end: Point) -> Range: } -def create_code_action_edit(view: sublime.View, version: int, edits: list[tuple[str, Range]]) -> dict[str, Any]: +def create_code_action_edit(view: sublime.View, version: int, edits: list[tuple[str, Range]]) -> WorkspaceEdit: return { "documentChanges": [ { @@ -48,16 +52,16 @@ def create_code_action_edit(view: sublime.View, version: int, edits: list[tuple[ } -def create_command(command_name: str, command_args: list[Any] | None = None) -> dict[str, Any]: - result: dict[str, Any] = {"command": command_name} +def create_command(command_name: str, command_args: list[Any] | None = None) -> Command: + result: Command = {"command": command_name} if command_args is not None: result["arguments"] = command_args return result def create_test_code_action(view: sublime.View, version: int, edits: list[tuple[str, Range]], - kind: str | None = None) -> dict[str, Any]: - action = { + kind: str | None = None) -> CodeAction: + action: CodeAction = { "title": "Fix errors", "edit": create_code_action_edit(view, version, edits) } @@ -67,8 +71,8 @@ def create_test_code_action(view: sublime.View, version: int, edits: list[tuple[ def create_test_code_action2(command_name: str, command_args: list[Any] | None = None, - kind: str | None = None) -> dict[str, Any]: - action = { + kind: str | None = None) -> CodeAction: + action: CodeAction = { "title": "Fix errors", "command": create_command(command_name, command_args) } @@ -101,12 +105,11 @@ def diagnostic_to_lsp(diagnostic: tuple[str, Range]) -> dict: class CodeActionsTestCaseBase(TextDocumentTestCase): - @classmethod - def init_view_settings(cls) -> None: + def init_view_settings(self) -> None: super().init_view_settings() # "quickfix" is not supported but its here for testing purposes - cls.view.settings().set('lsp_code_actions_on_save', {'source.fixAll': True, 'quickfix': True}) - cls.view.settings().set("lsp_format_on_save", False) + self.view.settings().set('lsp_code_actions_on_save', {'source.fixAll': True, 'quickfix': True}) + self.view.settings().set("lsp_format_on_save", False) @classmethod def get_test_server_capabilities(cls) -> dict: @@ -114,18 +117,17 @@ def get_test_server_capabilities(cls) -> dict: capabilities['capabilities']['codeActionProvider'] = {'codeActionKinds': ['quickfix', 'source.fixAll']} return capabilities - def doCleanups(self) -> Generator: - yield from self.await_clear_view_and_save() - yield from super().doCleanups() + async def asyncDoCleanups(self) -> None: + await self.await_clear_view_and_save() + await super().asyncDoCleanups() class CodeActionsOnSaveTaskTestCase(TextDocumentTestCase): - @classmethod - def init_view_settings(cls) -> None: + def init_view_settings(self) -> None: super().init_view_settings() - cls.view.settings().set('lsp_code_actions_on_save', {"source.fixAll": True}) - cls.view.settings().set('lsp_code_actions_on_format', {"source.fixAll.eslint": True}) - cls.view.settings().set('lsp_format_on_save', False) + self.view.settings().set('lsp_code_actions_on_save', {"source.fixAll": True}) + self.view.settings().set('lsp_code_actions_on_format', {"source.fixAll.eslint": True}) + self.view.settings().set('lsp_format_on_save', False) def test_applicable_when_format_on_save_disabled(self) -> None: self.assertTrue(CodeActionsOnSaveTask.is_applicable(self.view)) @@ -136,8 +138,10 @@ def test_applicable_when_format_on_save_enabled(self) -> None: class CodeActionsOnSaveTestCase(CodeActionsTestCaseBase): - def test_applies_matching_kind(self) -> Generator: - yield from self._setup_document_with_missing_semicolon() + async def test_applies_matching_kind(self) -> None: + + # Set up the mock. + await self._setup_document_with_missing_semicolon() code_action_kind = 'source.fixAll' code_action = create_test_code_action( self.view, @@ -145,15 +149,28 @@ def test_applies_matching_kind(self) -> Generator: [(';', range_from_points(Point(0, 11), Point(0, 11)))], code_action_kind ) - self.set_response('textDocument/codeAction', [code_action]) + + await self.mock_response('textDocument/codeAction', [code_action]) + await self.await_message('textDocument/codeAction') + await self.mock_response('textDocument/codeAction', [code_action]) + + # Save the file. self.view.run_command('lsp_save', {'async': True}) - yield from self.await_message('textDocument/codeAction') - yield from self.await_message('textDocument/didSave') - self.assertEqual(entire_content(self.view), 'const x = 1;') + + # The save should have caused a request for code actions. + await self.await_message('textDocument/codeAction') + + # And it should have caused a didSave notification. + await self.await_message('textDocument/didSave') + + # After the didSave, the view should not be dirty (clean?) self.assertEqual(self.view.is_dirty(), False) - def test_requests_with_diagnostics(self) -> Generator: - yield from self._setup_document_with_missing_semicolon() + # The mocked code action should have been applied. + self.assertEqual(entire_content(self.view), 'const x = 1;') + + async def test_requests_with_diagnostics(self) -> None: + await self._setup_document_with_missing_semicolon() code_action_kind = 'source.fixAll' code_action = create_test_code_action( self.view, @@ -161,26 +178,32 @@ def test_requests_with_diagnostics(self) -> Generator: [(';', range_from_points(Point(0, 11), Point(0, 11)))], code_action_kind ) - self.set_response('textDocument/codeAction', [code_action]) + + await self.mock_response('textDocument/codeAction', [code_action]) + await self.await_message('textDocument/codeAction') + await self.mock_response('textDocument/codeAction', [code_action]) + self.view.run_command('lsp_save', {'async': True}) - code_action_request = yield from self.await_message('textDocument/codeAction') + code_action_request = await self.await_message('textDocument/codeAction') + self.assertIsInstance(code_action_request, dict) + assert isinstance(code_action_request, dict) self.assertEqual(len(code_action_request['context']['diagnostics']), 1) self.assertEqual(code_action_request['context']['diagnostics'][0]['message'], 'Missing semicolon') - yield from self.await_message('textDocument/didSave') + await self.await_message('textDocument/didSave') self.assertEqual(entire_content(self.view), 'const x = 1;') self.assertEqual(self.view.is_dirty(), False) - def test_applies_only_one_pass(self) -> Generator: + async def test_applies_only_one_pass(self) -> None: self.insert_characters('const x = 1') initial_change_count = self.view.change_count() - yield from self.await_client_notification( + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([ ('Missing semicolon', range_from_points(Point(0, 11), Point(0, 11))), ]) ) code_action_kind = 'source.fixAll' - yield from self.set_responses([ + await self.mock_responses([ ( 'textDocument/codeAction', [ @@ -206,10 +229,10 @@ def test_applies_only_one_pass(self) -> Generator: ]) self.view.run_command('lsp_save', {'async': True}) # Wait for the view to be saved - yield lambda: not self.view.is_dirty() + await self.wait_until_st_state(lambda: not self.view.is_dirty()) self.assertEqual(entire_content(self.view), 'const x = 1;') - def test_applies_immediately_after_text_change(self) -> Generator: + async def test_applies_immediately_after_text_change(self) -> None: self.insert_characters('const x = 1') code_action_kind = 'source.fixAll' code_action = create_test_code_action( @@ -218,23 +241,23 @@ def test_applies_immediately_after_text_change(self) -> Generator: [(';', range_from_points(Point(0, 11), Point(0, 11)))], code_action_kind ) - self.set_response('textDocument/codeAction', [code_action]) + await self.mock_response('textDocument/codeAction', [code_action]) self.view.run_command('lsp_save', {'async': True}) - yield from self.await_message('textDocument/codeAction') - yield from self.await_message('textDocument/didSave') + await self.await_message('textDocument/codeAction') + await self.await_message('textDocument/didSave') self.assertEqual(entire_content(self.view), 'const x = 1;') self.assertEqual(self.view.is_dirty(), False) - def test_no_fix_on_non_matching_kind(self) -> Generator: - yield from self._setup_document_with_missing_semicolon() + async def test_no_fix_on_non_matching_kind(self) -> None: + await self._setup_document_with_missing_semicolon() initial_content = 'const x = 1' self.view.run_command('lsp_save', {'async': True}) - yield from self.await_message('textDocument/didSave') + await self.await_message('textDocument/didSave') self.assertEqual(entire_content(self.view), initial_content) self.assertEqual(self.view.is_dirty(), False) - def test_does_not_apply_unsupported_kind(self) -> Generator: - yield from self._setup_document_with_missing_semicolon() + async def test_does_not_apply_unsupported_kind(self) -> None: + await self._setup_document_with_missing_semicolon() code_action_kind = 'quickfix' code_action = create_test_code_action( self.view, @@ -242,15 +265,15 @@ def test_does_not_apply_unsupported_kind(self) -> Generator: [(';', range_from_points(Point(0, 11), Point(0, 11)))], code_action_kind ) - self.set_response('textDocument/codeAction', [code_action]) + await self.mock_response('textDocument/codeAction', [code_action]) self.view.run_command('lsp_save', {'async': True}) - yield from self.await_message('textDocument/didSave') + await self.await_message('textDocument/didSave') self.assertEqual(entire_content(self.view), 'const x = 1') - def _setup_document_with_missing_semicolon(self) -> Generator: + async def _setup_document_with_missing_semicolon(self) -> None: self.insert_characters('const x = 1') - yield from self.await_message("textDocument/didChange") - yield from self.await_client_notification( + await self.await_message("textDocument/didChange") + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([ ('Missing semicolon', range_from_points(Point(0, 11), Point(0, 11))), @@ -259,14 +282,13 @@ def _setup_document_with_missing_semicolon(self) -> Generator: class CodeActionsOnFormatTestCase(CodeActionsTestCaseBase): - @classmethod - def init_view_settings(cls) -> None: + def init_view_settings(self) -> None: super().init_view_settings() - cls.view.settings().set('lsp_code_actions_on_format', {'source.fixAll': True, 'quickfix': True}) + self.view.settings().set('lsp_code_actions_on_format', {'source.fixAll': True, 'quickfix': True}) - def test_format_document_with_code_actions_on_format(self) -> Generator: + async def test_format_document_with_code_actions_on_format(self) -> None: self.insert_characters(' const x = 1') - yield from self.await_message('textDocument/didChange') + await self.await_message('textDocument/didChange') code_action_kind = 'source.fixAll' code_action = create_test_code_action( @@ -275,9 +297,9 @@ def test_format_document_with_code_actions_on_format(self) -> Generator: [(';', range_from_points(Point(0, 12), Point(0, 12)))], code_action_kind ) - self.set_response('textDocument/codeAction', [code_action]) + await self.mock_response('textDocument/codeAction', [code_action]) - self.set_response('textDocument/formatting', [{ + await self.mock_response('textDocument/formatting', [{ 'newText': "", 'range': { 'start': {'line': 0, 'character': 0}, @@ -286,18 +308,18 @@ def test_format_document_with_code_actions_on_format(self) -> Generator: }]) self.view.run_command('lsp_format_document', {'async': True}) - yield from self.await_message('textDocument/codeAction') - yield from self.await_message('textDocument/formatting') - yield from self.await_message('textDocument/didChange') + await self.await_message('textDocument/codeAction') + await self.await_message('textDocument/formatting') + await self.await_message('textDocument/didChange') # Response is fixed (fixAll added ";") and formatted (removed leading space) self.assertEqual(entire_content(self.view), 'const x = 1;') # Formatting does not save the document self.assertEqual(self.view.is_dirty(), True) - def test_format_on_save_with_code_actions_on_format(self) -> Generator: + async def test_format_on_save_with_code_actions_on_format(self) -> None: self.view.settings().set("lsp_format_on_save", True) self.insert_characters(' const x = 1') - yield from self.await_message("textDocument/didChange") + await self.await_message("textDocument/didChange") code_action_kind = 'source.fixAll' code_action = create_test_code_action( @@ -306,9 +328,9 @@ def test_format_on_save_with_code_actions_on_format(self) -> Generator: [(';', range_from_points(Point(0, 12), Point(0, 12)))], code_action_kind ) - self.set_response('textDocument/codeAction', [code_action]) + await self.mock_response('textDocument/codeAction', [code_action]) - self.set_response('textDocument/formatting', [{ + await self.mock_response('textDocument/formatting', [{ 'newText': "", 'range': { 'start': {'line': 0, 'character': 0}, @@ -317,10 +339,10 @@ def test_format_on_save_with_code_actions_on_format(self) -> Generator: }]) self.view.run_command("lsp_save", {'async': True}) - yield from self.await_message('textDocument/codeAction') - yield from self.await_message('textDocument/formatting') - yield from self.await_message('textDocument/didChange') - yield from self.await_message('textDocument/didSave') + await self.await_message('textDocument/codeAction') + await self.await_message('textDocument/formatting') + await self.await_message('textDocument/didChange') + await self.await_message('textDocument/didSave') # Response is fixed (fixAll added ";") and formatted (removed leading space) self.assertEqual(entire_content(self.view), 'const x = 1;') # Document should be saved @@ -328,12 +350,11 @@ def test_format_on_save_with_code_actions_on_format(self) -> Generator: class CodeActionsOnFormatOnSaveTaskTestCase(TextDocumentTestCase): - @classmethod - def init_view_settings(cls) -> None: + def init_view_settings(self) -> None: super().init_view_settings() - cls.view.settings().set('lsp_code_actions_on_save', {'source.fixAll': True, 'quickfix': True}) - cls.view.settings().set('lsp_code_actions_on_format', {}) - cls.view.settings().set("lsp_format_on_save", False) + self.view.settings().set('lsp_code_actions_on_save', {'source.fixAll': True, 'quickfix': True}) + self.view.settings().set('lsp_code_actions_on_format', {}) + self.view.settings().set("lsp_format_on_save", False) userprefs().lsp_format_on_save = False userprefs().lsp_code_actions_on_save = {} userprefs().lsp_code_actions_on_format = {} @@ -423,12 +444,12 @@ def test_kind_matching(self) -> None: class CodeActionsListenerTestCase(TextDocumentTestCase): - def setUp(self) -> Generator: - yield from super().setUp() + async def setUp(self) -> None: + await super().setUp() self.original_debounce_time = DocumentSyncListener.debounce_time DocumentSyncListener.debounce_time = 0 - def tearDown(self) -> None: + async def tearDown(self) -> None: DocumentSyncListener.debounce_time = self.original_debounce_time super().tearDown() @@ -438,23 +459,23 @@ def get_test_server_capabilities(cls) -> dict: capabilities['capabilities']['codeActionProvider'] = {} return capabilities - def test_requests_with_diagnostics(self) -> Generator: + async def test_requests_with_diagnostics(self) -> None: initial_content = 'a\nb\nc' self.insert_characters(initial_content) - yield from self.await_message('textDocument/didChange') + await self.await_message('textDocument/didChange') range_a = range_from_points(Point(0, 0), Point(0, 1)) range_b = range_from_points(Point(1, 0), Point(1, 1)) range_c = range_from_points(Point(2, 0), Point(2, 1)) - yield from self.await_client_notification( + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([('issue a', range_a), ('issue b', range_b), ('issue c', range_c)]) ) code_action_a = create_test_code_action(self.view, self.view.change_count(), [("A", range_a)]) code_action_b = create_test_code_action(self.view, self.view.change_count(), [("B", range_b)]) - self.set_response('textDocument/codeAction', [code_action_a, code_action_b]) + await self.mock_response('textDocument/codeAction', [code_action_a, code_action_b]) self.view.run_command('lsp_selection_set', {"regions": [(0, 3)]}) # Select a and b. - yield 100 - params = yield from self.await_message('textDocument/codeAction') + await asyncio.sleep(0.1) + params = await self.await_message('textDocument/codeAction') self.assertEqual(params['range']['start']['line'], 0) self.assertEqual(params['range']['start']['character'], 0) self.assertEqual(params['range']['end']['line'], 1) @@ -465,12 +486,12 @@ def test_requests_with_diagnostics(self) -> Generator: self.assertEqual(annotations_range[0].a, 3) self.assertEqual(annotations_range[0].b, 0) - def test_excludes_disabled_code_actions(self) -> Generator: + async def test_excludes_disabled_code_actions(self) -> None: initial_content = 'a\n' self.insert_characters(initial_content) - yield from self.await_message("textDocument/didChange") + await self.await_message("textDocument/didChange") range_a = range_from_points(Point(0, 0), Point(0, 1)) - yield from self.await_client_notification( + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([('issue a', range_a)]) ) @@ -479,10 +500,9 @@ def test_excludes_disabled_code_actions(self) -> Generator: self.view.change_count(), [(';', range_a)] ) - self.set_response('textDocument/codeAction', [code_action]) + await self.mock_response('textDocument/codeAction', [code_action]) self.view.run_command('lsp_selection_set', {"regions": [(0, 1)]}) # Select a - yield 100 - yield from self.await_message('textDocument/codeAction') + await self.await_message('textDocument/codeAction') code_action_ranges = self.view.get_regions(RegionKey.CODE_ACTION) self.assertEqual(len(code_action_ranges), 0) @@ -495,70 +515,72 @@ def get_test_server_capabilities(cls) -> dict: capabilities['capabilities']['codeActionProvider'] = {"resolveProvider": True} return capabilities - def test_requests_code_actions_on_newly_published_diagnostics(self) -> Generator: + async def test_requests_code_actions_on_newly_published_diagnostics(self) -> None: self.insert_characters('a\nb') - yield from self.await_message("textDocument/didChange") - yield from self.await_client_notification( + await self.await_message("textDocument/didChange") + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([ ('issue a', range_from_points(Point(0, 0), Point(0, 1))), ('issue b', range_from_points(Point(1, 0), Point(1, 1))) ]) ) - params = yield from self.await_message('textDocument/codeAction') + params = await self.await_message('textDocument/codeAction') + self.assertIsInstance(params, dict) + assert isinstance(params, dict) self.assertEqual(params['range']['start']['line'], 1) self.assertEqual(params['range']['start']['character'], 1) self.assertEqual(params['range']['end']['line'], 1) self.assertEqual(params['range']['end']['character'], 1) self.assertEqual(len(params['context']['diagnostics']), 1) - def test_applies_code_action_with_matching_document_version(self) -> Generator: + async def test_applies_code_action_with_matching_document_version(self) -> None: code_action = create_test_code_action(self.view, 3, [ ("c", range_from_points(Point(0, 0), Point(0, 1))), ("d", range_from_points(Point(1, 0), Point(1, 1))), ]) self.insert_characters('a\nb') - yield from self.await_message("textDocument/didChange") + await self.await_message("textDocument/didChange") self.assertEqual(self.view.change_count(), 3) - yield from self.await_run_code_action(code_action) - # yield from self.await_message('codeAction/resolve') + await self.await_run_code_action(code_action) + # await self.await_message('codeAction/resolve') self.assertEqual(entire_content(self.view), 'c\nd') - def test_does_not_apply_with_nonmatching_document_version(self) -> Generator: + async def test_does_not_apply_with_nonmatching_document_version(self) -> None: initial_content = 'a\nb' code_action = create_test_code_action(self.view, 0, [ ("c", range_from_points(Point(0, 0), Point(0, 1))), ("d", range_from_points(Point(1, 0), Point(1, 1))), ]) self.insert_characters(initial_content) - yield from self.await_message("textDocument/didChange") - yield from self.await_run_code_action(code_action) + await self.await_message("textDocument/didChange") + await self.await_run_code_action(code_action) self.assertEqual(entire_content(self.view), initial_content) - def test_runs_command_in_resolved_code_action(self) -> Generator: + async def test_runs_command_in_resolved_code_action(self) -> None: code_action = create_test_code_action2("dosomethinguseful", ["1", 0, {"hello": "there"}]) resolved_code_action = deepcopy(code_action) resolved_code_action["edit"] = create_code_action_edit(self.view, 3, [ ("c", range_from_points(Point(0, 0), Point(0, 1))), ("d", range_from_points(Point(1, 0), Point(1, 1))), ]) - self.set_response('codeAction/resolve', resolved_code_action) - self.set_response('workspace/executeCommand', {"reply": "OK done"}) + await self.mock_response('codeAction/resolve', resolved_code_action) + await self.mock_response('workspace/executeCommand', {"reply": "OK done"}) self.insert_characters('a\nb') - yield from self.await_message("textDocument/didChange") + await self.await_message("textDocument/didChange") self.assertEqual(self.view.change_count(), 3) - yield from self.await_run_code_action(code_action) - yield from self.await_message('codeAction/resolve') - params = yield from self.await_message('workspace/executeCommand') + await self.await_run_code_action(code_action) + await self.await_message('codeAction/resolve') + params = await self.await_message('workspace/executeCommand') self.assertEqual(params, {"command": "dosomethinguseful", "arguments": ["1", 0, {"hello": "there"}]}) self.assertEqual(entire_content(self.view), 'c\nd') # Keep this test last as it breaks pyls! - def test_applies_correctly_after_emoji(self) -> Generator: + async def test_applies_correctly_after_emoji(self) -> None: self.insert_characters('🕵️hi') - yield from self.await_message("textDocument/didChange") + await self.await_message("textDocument/didChange") code_action = create_test_code_action(self.view, self.view.change_count(), [ ("bye", range_from_points(Point(0, 3), Point(0, 5))), ]) - yield from self.await_run_code_action(code_action) + await self.await_run_code_action(code_action) self.assertEqual(entire_content(self.view), '🕵️bye') diff --git a/tests/test_completion.py b/tests/test_completion.py index 06ead3681..a7c05463e 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -11,9 +11,8 @@ from LSP.protocol import CompletionItemTag from LSP.protocol import InsertTextFormat from typing import Any -from typing import Callable -from typing import Generator from unittest import TestCase +import asyncio import sublime additional_edits = { @@ -37,11 +36,10 @@ class CompletionsTestsBase(TextDocumentTestCase): - @classmethod - def init_view_settings(cls) -> None: + def init_view_settings(self) -> None: super().init_view_settings() - assert cls.view - cls.view.settings().set("auto_complete_selector", "text.plain") + assert self.view + self.view.settings().set("auto_complete_selector", "text.plain") def type(self, text: str) -> None: self.view.run_command('append', {'characters': text}) @@ -54,64 +52,56 @@ def move_cursor(self, row: int, col: int) -> None: s.clear() s.add(point) - def create_commit_completion_closure( - self, commit_completion_command: str = "commit_completion" - ) -> Callable[[], bool]: - committed = False - current_change_count = self.view.change_count() - - def commit_completion() -> bool: - if not self.view.is_auto_complete_visible(): - return False - nonlocal committed, current_change_count - if not committed: - self.view.run_command(commit_completion_command) - committed = True - return self.view.change_count() > current_change_count + async def wait_until_auto_complete_is_visible(self) -> None: + await self.wait_until_st_state(self.view.is_auto_complete_visible) - return commit_completion + async def commit_completion(self, commit_completion_command: str = "commit_completion") -> None: + current_change_count = self.view.change_count() + await self.wait_until_auto_complete_is_visible() + self.view.run_command(commit_completion_command) + await self.wait_until_st_state(lambda: self.view.change_count() > current_change_count) - def select_completion(self) -> Generator: + async def select_completion(self) -> None: self.view.run_command('auto_complete') - yield self.create_commit_completion_closure() + await self.commit_completion() - def shift_select_completion(self) -> Generator: + async def shift_select_completion(self) -> None: self.view.run_command('auto_complete') - yield self.create_commit_completion_closure("lsp_commit_completion_with_opposite_insert_mode") + await self.commit_completion("lsp_commit_completion_with_opposite_insert_mode") def read_file(self) -> str: return self.view.substr(sublime.Region(0, self.view.size())) - def verify(self, *, completion_items: list[dict[str, Any]], insert_text: str, expected_text: str) -> Generator: + async def verify(self, *, completion_items: list[dict[str, Any]], insert_text: str, expected_text: str) -> None: if insert_text: self.type(insert_text) - self.set_response("textDocument/completion", completion_items) - yield from self.select_completion() - yield from self.await_message("textDocument/completion") - yield from self.await_message("textDocument/didChange") + await self.mock_response("textDocument/completion", completion_items) + await self.select_completion() + await self.await_message("textDocument/completion") + await self.await_message("textDocument/didChange") self.assertEqual(self.read_file(), expected_text) class QueryCompletionsTests(CompletionsTestsBase): - def test_none(self) -> Generator: - self.set_response("textDocument/completion", None) + async def test_none(self) -> None: + await self.mock_response("textDocument/completion", None) self.view.run_command('auto_complete') - yield lambda: self.view.is_auto_complete_visible() is False + await self.wait_until_auto_complete_is_visible() - def test_simple_label(self) -> Generator: - yield from self.verify( + async def test_simple_label(self) -> None: + await self.verify( completion_items=[{'label': 'asdf'}, {'label': 'efcgh'}], insert_text='', expected_text='asdf') - def test_prefer_insert_text_over_label(self) -> Generator: - yield from self.verify( + async def test_prefer_insert_text_over_label(self) -> None: + await self.verify( completion_items=[{"label": "Label text", "insertText": "Insert text"}], insert_text='', expected_text='Insert text') - def test_prefer_text_edit_over_insert_text(self) -> Generator: - yield from self.verify( + async def test_prefer_text_edit_over_insert_text(self) -> None: + await self.verify( completion_items=[{ "label": "Label text", "insertText": "Insert text", @@ -132,22 +122,22 @@ def test_prefer_text_edit_over_insert_text(self) -> Generator: insert_text='', expected_text='Text edit') - def test_simple_insert_text(self) -> Generator: - yield from self.verify( + async def test_simple_insert_text(self) -> None: + await self.verify( completion_items=[{'label': 'asdf', 'insertText': 'asdf()'}], insert_text="a", expected_text='asdf()') - def test_var_prefix_using_label(self) -> Generator: - yield from self.verify(completion_items=[{'label': '$what'}], insert_text="$", expected_text="$what") + async def test_var_prefix_using_label(self) -> None: + await self.verify(completion_items=[{'label': '$what'}], insert_text="$", expected_text="$what") - def test_var_prefix_added_in_insertText(self) -> Generator: + async def test_var_prefix_added_in_insertText(self) -> None: """ https://github.com/sublimelsp/LSP/issues/294. User types '$env:U', server replaces '$env:U' with '$env:USERPROFILE' """ - yield from self.verify( + await self.verify( completion_items=[{ 'filterText': '$env:USERPROFILE', 'insertText': '$env:USERPROFILE', @@ -171,7 +161,7 @@ def test_var_prefix_added_in_insertText(self) -> Generator: insert_text="$env:U", expected_text="$env:USERPROFILE") - def test_pure_insertion_text_edit(self) -> Generator: + async def test_pure_insertion_text_edit(self) -> None: """ https://github.com/sublimelsp/LSP/issues/368. @@ -179,7 +169,7 @@ def test_pure_insertion_text_edit(self) -> Generator: THIS TEST FAILS """ - yield from self.verify( + await self.verify( completion_items=[{ 'textEdit': { 'newText': 'meParam', @@ -201,9 +191,9 @@ def test_pure_insertion_text_edit(self) -> Generator: insert_text="$so", expected_text="$someParam") - def test_space_added_in_label(self) -> Generator: + async def test_space_added_in_label(self) -> None: """Clangd: label=" const", insertText="const" (https://github.com/sublimelsp/LSP/issues/368).""" - yield from self.verify( + await self.verify( completion_items=[{ "label": " const", "sortText": "3f400000const", @@ -229,13 +219,13 @@ def test_space_added_in_label(self) -> Generator: insert_text=' co', expected_text=" const") # NOT 'const' - def test_dash_missing_from_label(self) -> Generator: + async def test_dash_missing_from_label(self) -> None: """ Powershell: label="UniqueId", trigger="-UniqueIdd, text to be inserted = "-UniqueId". (https://github.com/sublimelsp/LSP/issues/572) """ - yield from self.verify( + await self.verify( completion_items=[{ "filterText": "-UniqueId", "documentation": None, @@ -261,9 +251,9 @@ def test_dash_missing_from_label(self) -> Generator: insert_text="u", expected_text="-UniqueId") - def test_edit_before_cursor(self) -> Generator: + async def test_edit_before_cursor(self) -> None: """https://github.com/sublimelsp/LSP/issues/536.""" - yield from self.verify( + await self.verify( completion_items=[{ 'insertTextFormat': 2, 'data': { @@ -294,9 +284,9 @@ def test_edit_before_cursor(self) -> Generator: insert_text='def myF', expected_text='override def myFunction(): Unit = ???') - def test_edit_after_nonword(self) -> Generator: + async def test_edit_after_nonword(self) -> None: """https://github.com/sublimelsp/LSP/issues/645.""" - yield from self.verify( + await self.verify( completion_items=[{ "textEdit": { "newText": "apply($0)", @@ -325,7 +315,7 @@ def test_edit_after_nonword(self) -> Generator: insert_text="List.", expected_text='List.apply()') - def test_filter_text_is_not_a_prefix_of_label(self) -> Generator: + async def test_filter_text_is_not_a_prefix_of_label(self) -> None: """ Metals: "Implement all members". @@ -341,7 +331,7 @@ def test_filter_text_is_not_a_prefix_of_label(self) -> Generator: https://github.com/sublimelsp/LSP/issues/771 """ - yield from self.verify( + await self.verify( completion_items=[{ "label": "Implement all members", "kind": 12, @@ -363,11 +353,11 @@ def test_filter_text_is_not_a_prefix_of_label(self) -> Generator: insert_text='e', expected_text='def foo: Int \u003d ???\n def boo: Int \u003d ???') - def test_additional_edits_if_session_has_the_resolve_capability(self) -> Generator: + async def test_additional_edits_if_session_has_the_resolve_capability(self) -> None: completion_item = { 'label': 'asdf' } - self.set_response("completionItem/resolve", { + await self.mock_response("completionItem/resolve", { 'label': 'asdf', 'additionalTextEdits': [ { @@ -385,13 +375,13 @@ def test_additional_edits_if_session_has_the_resolve_capability(self) -> Generat } ] }) - yield from self.verify( + await self.verify( completion_items=[completion_item], insert_text='', expected_text='import asdf;\nasdf') - def test_prefix_should_include_the_dollar_sign(self) -> Generator: - self.set_response( + async def test_prefix_should_include_the_dollar_sign(self) -> None: + await self.mock_response( 'textDocument/completion', { "items": @@ -415,13 +405,13 @@ def test_prefix_should_include_the_dollar_sign(self) -> Generator: self.type('\n') # move cursor after `$he|` self.move_cursor(2, 3) - yield from self.select_completion() - yield from self.await_message('textDocument/completion') + await self.select_completion() + await self.await_message('textDocument/completion') self.assertEqual(self.read_file(), '\n') - def test_fuzzy_match_plaintext_insert_text(self) -> Generator: - yield from self.verify( + async def test_fuzzy_match_plaintext_insert_text(self) -> None: + await self.verify( completion_items=[{ 'insertTextFormat': 1, 'label': 'aaba', @@ -430,8 +420,8 @@ def test_fuzzy_match_plaintext_insert_text(self) -> Generator: insert_text='aa', expected_text='aaca') - def test_fuzzy_match_plaintext_text_edit(self) -> Generator: - yield from self.verify( + async def test_fuzzy_match_plaintext_text_edit(self) -> None: + await self.verify( completion_items=[{ 'insertTextFormat': 1, 'label': 'aaba', @@ -442,8 +432,8 @@ def test_fuzzy_match_plaintext_text_edit(self) -> Generator: insert_text='aab', expected_text='aaca') - def test_fuzzy_match_snippet_insert_text(self) -> Generator: - yield from self.verify( + async def test_fuzzy_match_snippet_insert_text(self) -> None: + await self.verify( completion_items=[{ 'insertTextFormat': 2, 'label': 'aaba', @@ -452,8 +442,8 @@ def test_fuzzy_match_snippet_insert_text(self) -> Generator: insert_text='aab', expected_text='aaca') - def test_fuzzy_match_snippet_text_edit(self) -> Generator: - yield from self.verify( + async def test_fuzzy_match_snippet_text_edit(self) -> None: + await self.verify( completion_items=[{ 'insertTextFormat': 2, 'label': 'aaba', @@ -464,7 +454,7 @@ def test_fuzzy_match_snippet_text_edit(self) -> Generator: insert_text='aab', expected_text='aaca') - def verify_multi_cursor(self, completion: dict[str, Any]) -> Generator: + async def verify_multi_cursor(self, completion: dict[str, Any]) -> None: """ Check whether `fd` gets replaced by `fmod` when the cursor is at `fd|`. Turning the `d` into an `m` is an important part of the test. @@ -478,20 +468,20 @@ def verify_multi_cursor(self, completion: dict[str, Any]) -> Generator: self.assertEqual(len(selection), 3) for region in selection: self.assertEqual(self.view.substr(self.view.line(region)), "fd") - self.set_response("textDocument/completion", [completion]) - yield from self.select_completion() - yield from self.await_message("textDocument/completion") + await self.mock_response("textDocument/completion", [completion]) + await self.select_completion() + await self.await_message("textDocument/completion") self.assertEqual(self.read_file(), 'fmod()\nfmod()\nfmod()') - def test_multi_cursor_plaintext_insert_text(self) -> Generator: - yield from self.verify_multi_cursor({ + async def test_multi_cursor_plaintext_insert_text(self) -> None: + await self.verify_multi_cursor({ 'insertTextFormat': 1, 'label': 'fmod(a, b)', 'insertText': 'fmod()' }) - def test_multi_cursor_plaintext_text_edit(self) -> Generator: - yield from self.verify_multi_cursor({ + async def test_multi_cursor_plaintext_text_edit(self) -> None: + await self.verify_multi_cursor({ 'insertTextFormat': 1, 'label': 'fmod(a, b)', 'textEdit': { @@ -500,15 +490,15 @@ def test_multi_cursor_plaintext_text_edit(self) -> Generator: } }) - def test_multi_cursor_snippet_insert_text(self) -> Generator: - yield from self.verify_multi_cursor({ + async def test_multi_cursor_snippet_insert_text(self) -> None: + await self.verify_multi_cursor({ 'insertTextFormat': 2, 'label': 'fmod(a, b)', 'insertText': 'fmod($0)' }) - def test_multi_cursor_snippet_text_edit(self) -> Generator: - yield from self.verify_multi_cursor({ + async def test_multi_cursor_snippet_text_edit(self) -> None: + await self.verify_multi_cursor({ 'insertTextFormat': 2, 'label': 'fmod(a, b)', 'textEdit': { @@ -517,10 +507,10 @@ def test_multi_cursor_snippet_text_edit(self) -> Generator: } }) - def test_nontrivial_text_edit_removal(self) -> Generator: + async def test_nontrivial_text_edit_removal(self) -> None: self.type('#include ') self.move_cursor(0, 11) # Put the cursor inbetween 'u' and '>' - self.set_response("textDocument/completion", [{ + await self.mock_response("textDocument/completion", [{ 'filterText': 'uchar.h>', 'label': ' uchar.h>', 'textEdit': { @@ -532,14 +522,14 @@ def test_nontrivial_text_edit_removal(self) -> Generator: 'kind': 17, 'insertTextFormat': 2 }]) - yield from self.select_completion() - yield from self.await_message("textDocument/completion") + await self.select_completion() + await self.await_message("textDocument/completion") self.assertEqual(self.read_file(), '#include ') - def test_nontrivial_text_edit_removal_with_buffer_modifications_clangd(self) -> Generator: + async def test_nontrivial_text_edit_removal_with_buffer_modifications_clangd(self) -> None: self.type('#include ') self.move_cursor(0, 11) # Put the cursor inbetween 'u' and '>' - self.set_response("textDocument/completion", [{ + await self.mock_response("textDocument/completion", [{ 'filterText': 'uchar.h>', 'label': ' uchar.h>', 'textEdit': { @@ -552,23 +542,23 @@ def test_nontrivial_text_edit_removal_with_buffer_modifications_clangd(self) -> 'insertTextFormat': 2 }]) self.view.run_command('auto_complete') # show the AC widget - yield from self.await_message("textDocument/completion") - yield 100 + await self.await_message("textDocument/completion") + await asyncio.sleep(0.1) self.view.run_command('insert', {'characters': 'c'}) # type characters - yield 100 + await asyncio.sleep(0.1) self.view.run_command('insert', {'characters': 'h'}) # while the AC widget - yield 100 + await asyncio.sleep(0.1) self.view.run_command('insert', {'characters': 'a'}) # is visible - yield 100 + await asyncio.sleep(0.1) # Commit the completion. The buffer has been modified in the meantime, so the old text edit that says to # remove "u>" is invalid. The code in completion.py must be able to handle this. - yield self.create_commit_completion_closure() + await self.commit_completion() self.assertEqual(self.read_file(), '#include ') - def test_nontrivial_text_edit_removal_with_buffer_modifications_json(self) -> Generator: + async def test_nontrivial_text_edit_removal_with_buffer_modifications_json(self) -> None: self.type('{"k"}') self.move_cursor(0, 3) # Put the cursor inbetween 'k' and '"' - self.set_response("textDocument/completion", [{ + await self.mock_response("textDocument/completion", [{ 'kind': 10, 'documentation': 'Array of single or multiple keys', 'insertTextFormat': 2, @@ -582,21 +572,21 @@ def test_nontrivial_text_edit_removal_with_buffer_modifications_json(self) -> Ge "insertText": 'keys": [$1]' }]) self.view.run_command('auto_complete') # show the AC widget - yield from self.await_message("textDocument/completion") - yield 100 + await self.await_message("textDocument/completion") + await asyncio.sleep(0.1) self.view.run_command('insert', {'characters': 'e'}) # type characters - yield 100 + await asyncio.sleep(0.1) self.view.run_command('insert', {'characters': 'y'}) # while the AC widget is open - yield 100 + await asyncio.sleep(0.1) # Commit the completion. The buffer has been modified in the meantime, so the old text edit that says to # remove '"k"' is invalid. The code in completion.py must be able to handle this. - yield self.create_commit_completion_closure() + await self.commit_completion() self.assertEqual(self.read_file(), '{"keys": []}') - def test_text_edit_plaintext_with_multiple_lines_indented(self) -> Generator[None, None, None]: + async def test_text_edit_plaintext_with_multiple_lines_indented(self) -> None: self.type("\t\n\t") self.move_cursor(1, 2) - self.set_response("textDocument/completion", [{ + await self.mock_response("textDocument/completion", [{ 'label': 'a', 'textEdit': { 'range': {'start': {'line': 1, 'character': 4}, 'end': {'line': 1, 'character': 4}}, @@ -604,15 +594,15 @@ def test_text_edit_plaintext_with_multiple_lines_indented(self) -> Generator[Non }, 'insertTextFormat': InsertTextFormat.PlainText }]) - yield from self.select_completion() - yield from self.await_message("textDocument/completion") + await self.select_completion() + await self.await_message("textDocument/completion") # the "b" should be intended one level deeper self.assertEqual(self.read_file(), '\t\n\ta\n\t\tb') - def test_insert_insert_mode(self) -> Generator: + async def test_insert_insert_mode(self) -> None: self.type('{{ title }}') self.move_cursor(0, 5) # Put the cursor inbetween 'i' and 't' - self.set_response("textDocument/completion", [{ + await self.mock_response("textDocument/completion", [{ 'label': 'title', 'textEdit': { 'newText': 'title', @@ -620,14 +610,14 @@ def test_insert_insert_mode(self) -> Generator: 'replace': {'start': {'line': 0, 'character': 3}, 'end': {'line': 0, 'character': 8}} } }]) - yield from self.select_completion() - yield from self.await_message("textDocument/completion") + await self.select_completion() + await self.await_message("textDocument/completion") self.assertEqual(self.read_file(), '{{ titletle }}') - def test_replace_insert_mode(self) -> Generator: + async def test_replace_insert_mode(self) -> None: self.type('{{ title }}') self.move_cursor(0, 4) # Put the cursor inbetween 't' and 'i' - self.set_response("textDocument/completion", [{ + await self.mock_response("textDocument/completion", [{ 'label': 'turtle', 'textEdit': { 'newText': 'turtle', @@ -635,8 +625,8 @@ def test_replace_insert_mode(self) -> Generator: 'replace': {'start': {'line': 0, 'character': 3}, 'end': {'line': 0, 'character': 8}} } }]) - yield from self.shift_select_completion() # commit the opposite insert mode - yield from self.await_message("textDocument/completion") + await self.shift_select_completion() # commit the opposite insert mode + await self.await_message("textDocument/completion") self.assertEqual(self.read_file(), '{{ turtle }}') def test_show_deprecated_flag(self) -> None: @@ -657,8 +647,8 @@ def test_show_deprecated_tag(self) -> None: formatted_completion_item = format_completion(item_with_deprecated_tags, 0, False, "", {}, self.view.id()) self.assertIn("DEPRECATED", formatted_completion_item.annotation) - def test_strips_carriage_return_in_insert_text(self) -> Generator: - yield from self.verify( + async def test_strips_carriage_return_in_insert_text(self) -> None: + await self.verify( completion_items=[{ 'label': 'greeting', 'insertText': 'hello\r\nworld' @@ -666,8 +656,8 @@ def test_strips_carriage_return_in_insert_text(self) -> Generator: insert_text='', expected_text='hello\nworld') - def test_strips_carriage_return_in_text_edit(self) -> Generator: - yield from self.verify( + async def test_strips_carriage_return_in_text_edit(self) -> None: + await self.verify( completion_items=[{ 'label': 'greeting', 'textEdit': { @@ -776,7 +766,7 @@ def get_test_server_capabilities(cls) -> dict: capabilities['capabilities']['completionProvider']['resolveProvider'] = False return capabilities - def test_additional_edits_if_session_does_not_have_the_resolve_capability(self) -> Generator: + async def test_additional_edits_if_session_does_not_have_the_resolve_capability(self) -> None: completion_item = { 'label': 'ghjk', 'additionalTextEdits': [ @@ -795,7 +785,7 @@ def test_additional_edits_if_session_does_not_have_the_resolve_capability(self) } ] } - yield from self.verify( + await self.verify( completion_items=[completion_item], insert_text='', expected_text='import ghjk;\nghjk') diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index 48e57b410..4d27f0902 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -5,11 +5,8 @@ from LSP.plugin.core.protocol import Point from LSP.plugin.core.url import filename_to_uri from typing import TYPE_CHECKING -from unittesting import AWAIT_WORKER -import sublime if TYPE_CHECKING: - from collections.abc import Generator from LSP.protocol import Diagnostic from LSP.protocol import PublishDiagnosticsParams from LSP.protocol import Range @@ -46,7 +43,7 @@ def range_from_points(start: Point, end: Point) -> Range: class DiagnosticsTestCase(TextDocumentTestCase): - def test_clear_diagnostics_immediately_after_change(self) -> Generator: + async def test_clear_diagnostics_immediately_after_change(self) -> None: # Trigger specific sequence of events: # 1. document has diagnostic issue # 2. (async) view is modified @@ -54,44 +51,41 @@ def test_clear_diagnostics_immediately_after_change(self) -> Generator: # 4. (async) session gets notified about view changes # # Verify that the diagnostics are properly cleared. - - def insert_text_and_clear_diagnostics_async() -> None: - self.insert_characters('// anything') - next(self.await_client_notification("textDocument/publishDiagnostics", create_test_diagnostics([]))) - self.insert_characters('const x = 1') - yield from self.await_message("textDocument/didChange") - yield from self.await_client_notification( + await self.await_message("textDocument/didChange") + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([('error', Point(0, 0), Point(0, 11))]) ) session_buffer = self.session.get_session_buffer_for_uri_async(TEST_FILE_URI) self.assertEqual(len(session_buffer.diagnostics), 1) - sublime.set_timeout_async(insert_text_and_clear_diagnostics_async) - yield AWAIT_WORKER + # Insert characters and clear diagnostics. + self.insert_characters('// anything') + await self.mock_client_notification("textDocument/publishDiagnostics", create_test_diagnostics([])) + # Just a dummy wait to ensure that the `textDocument/publishDiagnostics` triggered from async thread # is processed since we can't await it there. - yield from self.await_client_notification('$/dummy', []) + await self.mock_client_notification('$/dummy', []) self.assertEqual(len(session_buffer.diagnostics), 0) - def test_ignores_publish_diagnostics_version(self) -> Generator: + async def test_ignores_publish_diagnostics_version(self) -> None: self.insert_characters('const x = 1') - yield from self.await_message("textDocument/didChange") - yield from self.await_client_notification( + await self.await_message("textDocument/didChange") + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([('error', Point(0, 0), Point(0, 11))]) ) session_buffer = self.session.get_session_buffer_for_uri_async(TEST_FILE_URI) self.assertEqual(len(session_buffer.diagnostics), 1) - yield from self.await_client_notification( + await self.mock_client_notification( "textDocument/publishDiagnostics", create_test_diagnostics([], version=1000) ) self.assertEqual(len(session_buffer.diagnostics), 0) - def test_handles_unknown_tag_gracefully(self) -> Generator: + async def test_handles_unknown_tag_gracefully(self) -> None: self.insert_characters('const x = 1') - yield from self.await_message("textDocument/didChange") - yield from self.await_client_notification( + await self.await_message("textDocument/didChange") + await self.mock_client_notification( "textDocument/publishDiagnostics", { "uri": TEST_FILE_URI, @@ -107,10 +101,10 @@ def test_handles_unknown_tag_gracefully(self) -> Generator: session_buffer = self.session.get_session_buffer_for_uri_async(TEST_FILE_URI) self.assertEqual(len(session_buffer.diagnostics), 1) - def test_handles_multiple_tags(self) -> Generator: + async def test_handles_multiple_tags(self) -> None: self.insert_characters('const x = 1') - yield from self.await_message("textDocument/didChange") - yield from self.await_client_notification( + await self.await_message("textDocument/didChange") + await self.mock_client_notification( "textDocument/publishDiagnostics", { "uri": TEST_FILE_URI, diff --git a/tests/test_documents.py b/tests/test_documents.py index 14f78d516..7ccb12a7e 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -7,24 +7,22 @@ from .setup import make_tcp_client_test_config from .setup import make_tcp_server_test_config from .setup import remove_config -from .setup import TIMEOUT_TIME -from .setup import YieldPromise -from LSP.plugin.core.logging import debug +from .setup import SublimeAioTestCase +from LSP.plugin.core.open import open_file from LSP.plugin.core.protocol import Request from LSP.plugin.core.registry import windows -from LSP.plugin.core.types import ClientStates +from LSP.plugin.core.url import filename_to_uri from LSP.plugin.documents import DocumentSyncListener from os.path import join from sublime_plugin import view_event_listeners -from typing import Any -from typing import Generator -from unittesting import DeferrableTestCase +from typing_extensions import override +import asyncio import sublime -class WindowDocumentHandlerTests(DeferrableTestCase): +class WindowDocumentHandlerTests(SublimeAioTestCase): - def ensure_document_listener_created(self) -> bool: + def ensure_document_listener_created(self) -> DocumentSyncListener | None: assert self.view # Bug in ST3? Either that, or CI runs with ST window not in focus and that makes ST3 not trigger some # events like on_load_async, on_activated, on_deactivated. That makes things not properly initialize on @@ -32,11 +30,11 @@ def ensure_document_listener_created(self) -> bool: # Revisit this once we're on ST4. for listener in view_event_listeners[self.view.id()]: if isinstance(listener, DocumentSyncListener): - sublime.set_timeout_async(listener.on_activated_async) - return True - return False + return listener + return None - def setUp(self) -> Generator: + @override + async def setUp(self) -> None: initialization_options = { "serverResponse": { "capabilities": { @@ -57,6 +55,8 @@ def setUp(self) -> Generator: self.config2 = make_tcp_client_test_config("TEST-2", initialization_options) self.config3 = make_tcp_server_test_config("TEST-3", initialization_options) self.wm = windows.lookup(self.window) + self.assertIsNotNone(self.wm) + assert self.wm add_config(self.config1) add_config(self.config2) add_config(self.config3) @@ -64,58 +64,50 @@ def setUp(self) -> Generator: self.wm.get_config_manager().all[self.config2.name] = self.config2 self.wm.get_config_manager().all[self.config3.name] = self.config3 - def test_sends_did_open_to_multiple_sessions(self) -> Generator: + async def test_sends_did_open_to_multiple_sessions(self) -> None: filename = expand(join("$packages", "LSP", "tests", "testfile.txt"), self.window) - open_view = self.window.find_open_file(filename) - yield from close_test_view(open_view) - self.view = self.window.open_file(filename) - yield {"condition": lambda: not self.view.is_loading(), "timeout": TIMEOUT_TIME} + await close_test_view(self.window.find_open_file(filename)) + self.view = await open_file(self.window, filename_to_uri(filename)) + self.assertIsNotNone(self.wm) + assert self.wm + assert self.view self.assertTrue(self.wm.get_config_manager().match_view(self.view, self.wm.workspace_folders)) # self.init_view_settings() - yield {"condition": self.ensure_document_listener_created, "timeout": TIMEOUT_TIME} - yield { - "condition": lambda: self.wm.get_session(self.config1.name, self.view.file_name()) is not None, - "timeout": TIMEOUT_TIME} - yield { - "condition": lambda: self.wm.get_session(self.config2.name, self.view.file_name()) is not None, - "timeout": TIMEOUT_TIME} - yield { - "condition": lambda: self.wm.get_session(self.config3.name, self.view.file_name()) is not None, - "timeout": TIMEOUT_TIME} - self.session1 = self.wm.get_session(self.config1.name, self.view.file_name()) - self.session2 = self.wm.get_session(self.config2.name, self.view.file_name()) - self.session3 = self.wm.get_session(self.config3.name, self.view.file_name()) + listener = self.ensure_document_listener_created() + self.assertIsNotNone(listener) + assert listener + self.session1 = await self.wm.start(self.config1, listener) + self.session2 = await self.wm.start(self.config2, listener) + self.session3 = await self.wm.start(self.config3, listener) self.assertIsNotNone(self.session1) self.assertIsNotNone(self.session2) self.assertIsNotNone(self.session3) + assert self.session1 + assert self.session2 + assert self.session3 self.assertEqual(self.session1.config.name, self.config1.name) self.assertEqual(self.session2.config.name, self.config2.name) self.assertEqual(self.session3.config.name, self.config3.name) - yield {"condition": lambda: self.session1.state == ClientStates.READY, "timeout": TIMEOUT_TIME} - yield {"condition": lambda: self.session2.state == ClientStates.READY, "timeout": TIMEOUT_TIME} - yield {"condition": lambda: self.session3.state == ClientStates.READY, "timeout": TIMEOUT_TIME} - yield from self.await_message("initialize") - yield from self.await_message("initialized") - yield from self.await_message("textDocument/didOpen") + await self.assert_rpc_message("initialize") + await self.assert_rpc_message("initialized") + await self.assert_rpc_message("textDocument/didOpen") self.view.run_command("insert", {"characters": "a"}) - yield from self.await_message("textDocument/didChange") - yield from close_test_view(self.view) - yield from self.await_message("textDocument/didClose") + await self.assert_rpc_message("textDocument/didChange") + await close_test_view(self.view) + await self.assert_rpc_message("textDocument/didClose") - def doCleanups(self) -> Generator: + @override + async def tearDown(self) -> None: try: - yield from close_test_view(self.view) + await close_test_view(self.view) except Exception: pass if self.session1: - sublime.set_timeout_async(self.session1.end_async) - yield lambda: self.session1.state == ClientStates.STOPPING + await self.session1.end() if self.session2: - sublime.set_timeout_async(self.session2.end_async) - yield lambda: self.session2.state == ClientStates.STOPPING + await self.session2.end() if self.session3: - sublime.set_timeout_async(self.session3.end_async) - yield lambda: self.session3.state == ClientStates.STOPPING + await self.session3.end() try: remove_config(self.config3) except ValueError: @@ -128,31 +120,16 @@ def doCleanups(self) -> Generator: remove_config(self.config1) except ValueError: pass + assert self.wm self.wm.get_config_manager().all.pop(self.config3.name, None) self.wm.get_config_manager().all.pop(self.config2.name, None) self.wm.get_config_manager().all.pop(self.config1.name, None) - yield from super().doCleanups() - - def await_message(self, method: str) -> Generator: - promise1 = YieldPromise() - promise2 = YieldPromise() - promise3 = YieldPromise() - - def handler1(params: Any) -> None: - promise1.fulfill(params) - - def handler2(params: Any) -> None: - promise2.fulfill(params) - - def handler3(params: Any) -> None: - promise3.fulfill(params) - - def error_handler(params: Any) -> None: - debug("Got error:", params, "awaiting timeout :(") - self.session1.send_request(Request("$test/getReceived", {"method": method}), handler1, error_handler) - self.session2.send_request(Request("$test/getReceived", {"method": method}), handler2, error_handler) - self.session3.send_request(Request("$test/getReceived", {"method": method}), handler3, error_handler) - yield {"condition": promise1, "timeout": TIMEOUT_TIME} - yield {"condition": promise2, "timeout": TIMEOUT_TIME} - yield {"condition": promise3, "timeout": TIMEOUT_TIME} + async def assert_rpc_message(self, method: str) -> None: + assert self.session1 + assert self.session2 + assert self.session3 + timeout = 5 + await asyncio.wait_for(self.session1.request(Request("$test/getReceived", {"method": method})), timeout=timeout) + await asyncio.wait_for(self.session2.request(Request("$test/getReceived", {"method": method})), timeout=timeout) + await asyncio.wait_for(self.session3.request(Request("$test/getReceived", {"method": method})), timeout=timeout) diff --git a/tests/test_edit.py b/tests/test_edit.py index 80346a7bc..773a13659 100644 --- a/tests/test_edit.py +++ b/tests/test_edit.py @@ -234,7 +234,7 @@ def test_sorts_in_application_order2(self) -> None: class ApplyDocumentEditTestCase(TextDocumentTestCase): - def test_applies_text_edit(self) -> None: + async def test_applies_text_edit(self) -> None: self.insert_characters('abc') edits: list[TextEdit] = [{ 'newText': 'x$0y', @@ -249,10 +249,10 @@ def test_applies_text_edit(self) -> None: } } }] - apply_text_edits(self.view, edits) + await apply_text_edits(self.view, edits) self.assertEqual(entire_content(self.view), 'ax$0yc') - def test_applies_text_edit_with_placeholder(self) -> None: + async def test_applies_text_edit_with_placeholder(self) -> None: self.insert_characters('abc') edits: list[TextEdit] = [{ 'newText': 'x$0y', @@ -267,12 +267,12 @@ def test_applies_text_edit_with_placeholder(self) -> None: } } }] - apply_text_edits(self.view, edits, process_placeholders=True) + await apply_text_edits(self.view, edits, process_placeholders=True) self.assertEqual(entire_content(self.view), 'axyc') self.assertEqual(len(self.view.sel()), 1) self.assertEqual(self.view.sel()[0], sublime.Region(2, 2)) - def test_applies_multiple_text_edits_with_placeholders(self) -> None: + async def test_applies_multiple_text_edits_with_placeholders(self) -> None: self.insert_characters('ab') newline_edit: TextEdit = { 'newText': '\n$0', @@ -288,7 +288,7 @@ def test_applies_multiple_text_edits_with_placeholders(self) -> None: } } edits: list[TextEdit] = [newline_edit, newline_edit] - apply_text_edits(self.view, edits, process_placeholders=True) + await apply_text_edits(self.view, edits, process_placeholders=True) self.assertEqual(entire_content(self.view), 'a\n\nb') self.assertEqual(len(self.view.sel()), 2) self.assertEqual(self.view.sel()[0], sublime.Region(2, 2)) diff --git a/tests/test_file_watcher.py b/tests/test_file_watcher.py index 78ac92862..8c4f1662e 100644 --- a/tests/test_file_watcher.py +++ b/tests/test_file_watcher.py @@ -13,7 +13,6 @@ from LSP.plugin.core.types import ClientConfig from LSP.protocol import WatchKind from os.path import join -from typing import Generator from typing import TYPE_CHECKING import sublime @@ -87,24 +86,24 @@ class FileWatcherDocumentTestCase(TextDocumentTestCase): """ @classmethod - def setUpClass(cls) -> None: + async def asyncSetUpClass(cls) -> None: # Don't call the superclass. register_file_watcher_implementation(TestFileWatcher) @classmethod - def tearDownClass(cls) -> None: + async def asyncTearDownClass(cls) -> None: # Don't call the superclass. pass - def setUp(self) -> Generator: + async def setUp(self) -> None: self.assertEqual(len(TestFileWatcher.active_watchers), 0) # Watchers are only registered when there are workspace folders so add a folder. self.folder_root_path = setup_workspace_folder() - yield from super().setUpClass() - yield from super().setUp() + await super().asyncSetUpClass() + await super().setUp() - def tearDown(self) -> Generator: - yield from super().tearDownClass() + async def tearDown(self) -> None: + await super().asyncTearDownClass() self.assertEqual(len(TestFileWatcher.active_watchers), 0) # Restore original project data. window = sublime.active_window() @@ -138,11 +137,12 @@ def test_creates_static_watcher(self) -> None: self.assertEqual(watcher.ignores, ['.git']) self.assertEqual(watcher.root_path, self.folder_root_path) - def test_handles_file_event(self) -> Generator: + async def test_handles_file_event(self) -> None: watcher = TestFileWatcher.active_watchers[0] filepath = join(self.folder_root_path, 'file.js') watcher.trigger_event([('change', filepath)]) - sent_notification = yield from self.await_message('workspace/didChangeWatchedFiles') + sent_notification = await self.await_message('workspace/didChangeWatchedFiles') + assert isinstance(sent_notification, dict) self.assertIs(type(sent_notification['changes']), list) self.assertEqual(len(sent_notification['changes']), 1) change = sent_notification['changes'][0] @@ -152,7 +152,7 @@ def test_handles_file_event(self) -> Generator: class FileWatcherDynamicTests(FileWatcherDocumentTestCase): - def test_handles_dynamic_watcher_registration(self) -> Generator: + async def test_handles_dynamic_watcher_registration(self) -> None: registration_params = { 'registrations': [ { @@ -169,7 +169,7 @@ def test_handles_dynamic_watcher_registration(self) -> Generator: } ] } - yield self.make_server_do_fake_request('client/registerCapability', registration_params) + await self.make_server_do_fake_request('client/registerCapability', registration_params) self.assertEqual(len(TestFileWatcher.active_watchers), 1) watcher = TestFileWatcher.active_watchers[0] self.assertEqual(watcher.patterns, ['*.py']) @@ -178,7 +178,8 @@ def test_handles_dynamic_watcher_registration(self) -> Generator: # Trigger the file event filepath = join(self.folder_root_path, 'file.py') watcher.trigger_event([('create', filepath), ('change', filepath)]) - sent_notification = yield from self.await_message('workspace/didChangeWatchedFiles') + sent_notification = await self.await_message('workspace/didChangeWatchedFiles') + assert isinstance(sent_notification, dict) self.assertIs(type(sent_notification['changes']), list) self.assertEqual(len(sent_notification['changes']), 2) change1 = sent_notification['changes'][0] @@ -188,7 +189,7 @@ def test_handles_dynamic_watcher_registration(self) -> Generator: self.assertEqual(change2['type'], file_watcher_event_type_to_lsp_file_change_type('change')) self.assertTrue(change2['uri'].endswith('file.py')) - def test_aggregates_multiple_registrations_with_common_kind_and_base(self) -> Generator: + async def test_aggregates_multiple_registrations_with_common_kind_and_base(self) -> None: register_options: DidChangeWatchedFilesRegistrationOptions = { 'watchers': [ { @@ -229,7 +230,7 @@ def test_aggregates_multiple_registrations_with_common_kind_and_base(self) -> Ge } ] } - yield self.make_server_do_fake_request('client/registerCapability', registration_params) + await self.make_server_do_fake_request('client/registerCapability', registration_params) self.assertEqual(len(TestFileWatcher.active_watchers), 2) watcher = TestFileWatcher.active_watchers[0] self.assertEqual(watcher.patterns, ['*.py', '*.json', '*.js']) @@ -240,7 +241,7 @@ def test_aggregates_multiple_registrations_with_common_kind_and_base(self) -> Ge self.assertEqual(watcher.events, ['create', 'delete']) self.assertEqual(watcher.root_path, self.folder_root_path) - def test_does_not_aggregate_non_matching_base(self) -> Generator: + async def test_does_not_aggregate_non_matching_base(self) -> None: base_uri_1 = filename_to_uri('/a/b') base_uri_2 = filename_to_uri('/a/c') register_options: DidChangeWatchedFilesRegistrationOptions = { @@ -270,7 +271,7 @@ def test_does_not_aggregate_non_matching_base(self) -> Generator: } ] } - yield self.make_server_do_fake_request('client/registerCapability', registration_params) + await self.make_server_do_fake_request('client/registerCapability', registration_params) self.assertEqual(len(TestFileWatcher.active_watchers), 2) watcher = TestFileWatcher.active_watchers[0] self.assertEqual(watcher.patterns, ['*.py']) diff --git a/tests/test_server_notifications.py b/tests/test_server_notifications.py index 7802e7a03..3e3cfcbcb 100644 --- a/tests/test_server_notifications.py +++ b/tests/test_server_notifications.py @@ -5,15 +5,15 @@ from LSP.protocol import DiagnosticSeverity from LSP.protocol import DiagnosticTag from LSP.protocol import PublishDiagnosticsParams -from typing import Generator +import asyncio import sublime class ServerNotifications(TextDocumentTestCase): - def test_publish_diagnostics(self) -> Generator: + async def test_publish_diagnostics(self) -> None: self.insert_characters("a b c\n") - yield from self.await_message('textDocument/didChange') + await self.await_message('textDocument/didChange') params: PublishDiagnosticsParams = { 'uri': filename_to_uri(self.view.file_name() or ''), 'diagnostics': [ @@ -38,17 +38,20 @@ def test_publish_diagnostics(self) -> Generator: } ] } - yield from self.await_client_notification("textDocument/publishDiagnostics", params) + await self.mock_client_notification("textDocument/publishDiagnostics", params) errors_icon_regions = self.view.get_regions("lspTESTds1_icon") errors_underline_regions = self.view.get_regions("lspTESTds1_underline") warnings_icon_regions = self.view.get_regions("lspTESTds2_icon") warnings_underline_regions = self.view.get_regions("lspTESTds2_underline") info_icon_regions = self.view.get_regions("lspTESTds3_icon") info_underline_regions = self.view.get_regions("lspTESTds3_underline") - yield lambda: len(errors_icon_regions) == len(errors_underline_regions) == 1 - yield lambda: len(warnings_icon_regions) == len(warnings_underline_regions) == 1 - yield lambda: len(info_icon_regions) == len(info_underline_regions) == 1 - yield lambda: len(self.view.get_regions("lspTESTds3_tags")) == 0 + while not ( # noqa: ASYNC110 + len(errors_icon_regions) == len(errors_underline_regions) == 1 + and len(warnings_icon_regions) == len(warnings_underline_regions) == 1 + and len(info_icon_regions) == len(info_underline_regions) == 1 + and len(self.view.get_regions("lspTESTds3_tags")) == 0 + ): + await asyncio.sleep(0.05) self.assertEqual(errors_underline_regions[0], sublime.Region(0, 1)) self.assertEqual(warnings_underline_regions[0], sublime.Region(2, 3)) self.assertEqual(info_underline_regions[0], sublime.Region(4, 5)) diff --git a/tests/test_server_requests.py b/tests/test_server_requests.py index 63711ae2a..13ee044c4 100644 --- a/tests/test_server_requests.py +++ b/tests/test_server_requests.py @@ -1,14 +1,15 @@ from __future__ import annotations from .setup import TextDocumentTestCase +from LSP.plugin import Error from LSP.plugin.core.types import ClientConfig from LSP.plugin.core.url import filename_to_uri from LSP.protocol import ErrorCodes from LSP.protocol import TextDocumentSyncKind from pathlib import Path from typing import Any -from typing import Generator from typing import TYPE_CHECKING +import asyncio import os import sublime import tempfile @@ -26,24 +27,41 @@ def get_auto_complete_trigger(sb: SessionBufferProtocol) -> list[dict[str, str]] return None -def verify(testcase: TextDocumentTestCase, method: str, input_params: Any, expected_output_params: Any) -> Generator: - promise = testcase.make_server_do_fake_request(method, input_params) - yield from testcase.await_promise(promise) - testcase.assertEqual(promise.result(), expected_output_params) +async def verify( + testcase: TextDocumentTestCase, + method: str, + input_params: Any, + expected_output_params: Any, + expected_error_code: ErrorCodes | None = None, +) -> None: + try: + result = await testcase.make_server_do_fake_request(method, input_params) + testcase.assertEqual(result, expected_output_params) + except Error as error: + if expected_error_code is not None: + testcase.assertEqual(error.code, expected_error_code) + else: + testcase.fail(f"method {method} returned error {error}") class ServerRequests(TextDocumentTestCase): + async def test_unknown_method(self) -> None: + await verify( + self, + "foobar/qux", + {}, + {"code": ErrorCodes.MethodNotFound, "message": "foobar/qux"}, + ErrorCodes.MethodNotFound, + ) - def test_unknown_method(self) -> Generator: - yield from verify(self, "foobar/qux", {}, {"code": ErrorCodes.MethodNotFound, "message": "foobar/qux"}) - - def test_m_workspace_workspaceFolders(self) -> Generator: + async def test_m_workspace_workspaceFolders(self) -> None: expected_output = [{"name": os.path.basename(f), "uri": filename_to_uri(f)} for f in sublime.active_window().folders()] self.maxDiff = None - yield from verify(self, "workspace/workspaceFolders", {}, expected_output) + await verify(self, "workspace/workspaceFolders", {}, expected_output) - def test_m_workspace_configuration(self) -> Generator: + async def test_m_workspace_configuration(self) -> None: + assert self.session self.session.config.settings.set("foo.bar", "$hello") self.session.config.settings.set("foo.baz", "$world") self.session.config.settings.set("foo.a", 1) @@ -53,27 +71,28 @@ def test_m_workspace_configuration(self) -> Generator: method = "workspace/configuration" params = {"items": [{"section": "foo"}]} expected_output = [{"bar": "X", "baz": "Y", "a": 1, "b": None, "c": ["asdf X Y"]}] - yield from verify(self, method, params, expected_output) + await verify(self, method, params, expected_output) self.session.config.settings.clear() - def test_m_workspace_applyEdit(self) -> Generator: + async def test_m_workspace_applyEdit(self) -> None: old_change_count = self.insert_characters("hello\nworld\n") edit = { "newText": "there", "range": {"start": {"line": 1, "character": 0}, "end": {"line": 1, "character": 5}}} params = {"edit": {"changes": {filename_to_uri(self.view.file_name()): [edit]}}} - yield from verify(self, "workspace/applyEdit", params, {"applied": True}) - yield lambda: self.view.change_count() > old_change_count + await verify(self, "workspace/applyEdit", params, {"applied": True}) + while self.view.change_count() <= old_change_count: # noqa: ASYNC110 + await asyncio.sleep(0.05) self.assertEqual(self.view.substr(sublime.Region(0, self.view.size())), "hello\nthere\n") - def test_m_workspace_applyEdit_with_nontrivial_promises(self) -> Generator: + async def test_m_workspace_applyEdit_with_nontrivial_promises(self) -> None: with tempfile.TemporaryDirectory() as dirpath: initial_text = ["a b", "c d"] file_paths = [] for i in range(2): file_paths.append(os.path.join(dirpath, f"file{i}.txt")) - Path(file_paths[-1]).write_text(initial_text[i], encoding="utf-8") - yield from verify( + Path(file_paths[-1]).write_text(initial_text[i], encoding="utf-8") # noqa: ASYNC240 + await verify( self, "workspace/applyEdit", { @@ -117,9 +136,9 @@ def test_m_workspace_applyEdit_with_nontrivial_promises(self) -> Generator: self.assertEqual(view.substr(sublime.Region(0, view.size())), expected[i]) view.close() - def test_m_workspace_applyEdit_with_wrong_uri(self) -> Generator: + async def test_m_workspace_applyEdit_with_wrong_uri(self) -> None: uri = "file:///C:/wrong/uri.txt" - yield from verify( + await verify( self, "workspace/applyEdit", { @@ -157,13 +176,13 @@ def test_m_workspace_applyEdit_with_wrong_uri(self) -> Generator: } ) - def test_m_workspace_applyEdit_with_wrong_document_version(self) -> Generator: + async def test_m_workspace_applyEdit_with_wrong_document_version(self) -> None: with tempfile.TemporaryDirectory() as dirpath: file_name = os.path.join(dirpath, "file3.txt") uri = filename_to_uri(file_name) version = 123 - Path(file_name).write_text("a b", encoding="utf-8") - yield from verify( + Path(file_name).write_text("a b", encoding="utf-8") # noqa: ASYNC240 + await verify( self, "workspace/applyEdit", { @@ -201,8 +220,8 @@ def test_m_workspace_applyEdit_with_wrong_document_version(self) -> Generator: } ) - def test_m_client_registerCapability(self) -> Generator: - yield from verify( + async def test_m_client_registerCapability(self) -> None: + await verify( self, "client/registerCapability", { @@ -248,14 +267,14 @@ def test_m_client_registerCapability(self) -> Generator: self.assertTrue(trigger) self.assertEqual(trigger.get("characters"), "!@#") - def test_m_client_unregisterCapability(self) -> Generator: - yield from verify( + async def test_m_client_unregisterCapability(self) -> None: + await verify( self, "client/registerCapability", {"registrations": [{"method": "foo/bar", "id": "hello"}]}, None) self.assertIn("barProvider", self.session.capabilities) - yield from verify( + await verify( self, "client/unregisterCapability", {"unregisterations": [{"method": "foo/bar", "id": "hello"}]}, @@ -279,8 +298,8 @@ def get_stdio_test_config(cls) -> ClientConfig: } ) - def test_m_client_registerCapability(self) -> Generator: - yield from verify( + async def test_m_client_registerCapability(self) -> None: + await verify( self, "client/registerCapability", { diff --git a/tests/test_session.py b/tests/test_session.py index 1277d166b..c3489a000 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -2,7 +2,6 @@ from .test_mocks import TEST_CONFIG from LSP.plugin.core.collections import DottedDict -from LSP.plugin.core.promise import Promise from LSP.plugin.core.sessions import get_initialize_params from LSP.plugin.core.sessions import Logger from LSP.plugin.core.sessions import Manager @@ -47,24 +46,22 @@ def get_project_path(self, file_name: str) -> str | None: def should_ignore_diagnostics(self, uri: DocumentUri, configuration: ClientConfig) -> str | None: return None - def start_async(self, configuration: ClientConfig, initiating_view: sublime.View) -> None: + async def start(self, configuration: ClientConfig, initiating_view: sublime.View) -> Session | None: pass - def on_post_exit_async(self, session: Session, exit_code: int, exception: Exception | None) -> None: + async def on_post_exit(self, session: Session, exit_code: int, exception: Exception | None) -> None: pass def on_diagnostics_updated(self) -> None: pass - def handle_message_request( + async def handle_message_request( self, config_name: str, params: ShowMessageRequestParams - ) -> Promise[MessageActionItem | None]: - return Promise.resolve(None) + ) -> MessageActionItem | None: + return None - def handle_show_message( - self, config_name: str, params: ShowMessageParams - ) -> Promise[MessageActionItem | None]: - return Promise.resolve(None) + def handle_show_message(self, config_name: str, params: ShowMessageParams) -> None: + return None def handle_log_message(self, config_name: str, params: LogMessageParams) -> None: ... diff --git a/tests/test_single_document.py b/tests/test_single_document.py index 0c025d535..8ba36c7bd 100644 --- a/tests/test_single_document.py +++ b/tests/test_single_document.py @@ -1,20 +1,21 @@ from __future__ import annotations from .setup import TextDocumentTestCase -from .setup import TIMEOUT_TIME -from .setup import YieldPromise from copy import deepcopy from LSP.plugin import apply_text_edits from LSP.plugin import Request from LSP.plugin.core.protocol import UINT_MAX from LSP.plugin.core.url import filename_to_uri from LSP.plugin.core.views import entire_content -from typing import Generator from typing import Iterable +from typing import TYPE_CHECKING from unittest import skip import os import sublime +if TYPE_CHECKING: + from LSP.protocol import Command + SELFDIR = os.path.dirname(__file__) TEST_FILE_PATH = os.path.join(SELFDIR, 'testfile.txt') GOTO_RESPONSE = [ @@ -57,9 +58,9 @@ def test_did_open(self) -> None: # -> "shutdown" -> client shut down pass - def test_out_of_bounds_column_for_text_document_edit(self) -> None: + async def test_out_of_bounds_column_for_text_document_edit(self) -> None: self.insert_characters("a\nb\nc\n") - apply_text_edits(self.view, [ + await apply_text_edits(self.view, [ { 'newText': 'hello there', 'range': { @@ -76,27 +77,27 @@ def test_out_of_bounds_column_for_text_document_edit(self) -> None: ]) self.assertEqual(entire_content(self.view), "a\nhello there\nc\n") - def test_did_close(self) -> Generator: + async def test_did_close(self) -> None: self.assertTrue(self.view) self.assertTrue(self.view.is_valid()) self.view.close() - yield from self.await_message("textDocument/didClose") + await self.await_message("textDocument/didClose") - def test_sends_save_with_purge(self) -> Generator: + async def test_sends_save_with_purge(self) -> None: assert self.view self.view.settings().set("lsp_format_on_save", False) self.insert_characters("A") self.view.run_command("lsp_save", {'async': True}) - yield from self.await_message("textDocument/didChange") - yield from self.await_message("textDocument/didSave") - yield from self.await_clear_view_and_save() + await self.await_message("textDocument/didChange") + await self.await_message("textDocument/didSave") + await self.await_clear_view_and_save() - def test_formats_on_save(self) -> Generator: + async def test_formats_on_save(self) -> None: assert self.view self.view.settings().set("lsp_format_on_save", True) self.insert_characters("A") - yield from self.await_message("textDocument/didChange") - self.set_response('textDocument/formatting', [{ + await self.await_message("textDocument/didChange") + await self.mock_response('textDocument/formatting', [{ 'newText': "BBB", 'range': { 'start': {'line': 0, 'character': 0}, @@ -104,22 +105,22 @@ def test_formats_on_save(self) -> Generator: } }]) self.view.run_command("lsp_save", {'async': True}) - yield from self.await_message("textDocument/formatting") - yield from self.await_message("textDocument/didChange") - yield from self.await_message("textDocument/didSave") + await self.await_message("textDocument/formatting") + await self.await_message("textDocument/didChange") + await self.await_message("textDocument/didSave") text = self.view.substr(sublime.Region(0, self.view.size())) self.assertEqual("BBB", text) - yield from self.await_clear_view_and_save() + await self.await_clear_view_and_save() - def test_hover_popup_visible(self) -> Generator: + async def test_hover_popup_visible(self) -> None: assert self.view - self.set_response('textDocument/hover', {"contents": "greeting"}) + await self.mock_response('textDocument/hover', {"contents": "greeting"}) self.view.run_command('insert', {"characters": "Hello Wrld"}) self.assertFalse(self.view.is_popup_visible()) self.view.run_command('lsp_hover', {'point': 3}) - yield self.view.is_popup_visible + await self.wait_until_st_state(self.view.is_popup_visible) - def test_remove_line_and_then_insert_at_that_line_at_end(self) -> Generator: + async def test_remove_line_and_then_insert_at_that_line_at_end(self) -> None: original = ( 'a\n' 'b\n' @@ -140,9 +141,9 @@ def test_remove_line_and_then_insert_at_that_line_at_end(self) -> Generator: # New behavior: # 1) line index 3 is "created" ('a\n', 'b\n', 'c\n', c\n')) # 2) deletes line index 2. - yield from self.__run_formatting_test(original, expected, file_changes) + await self.__run_formatting_test(original, expected, file_changes) - def test_apply_formatting(self) -> Generator: + async def test_apply_formatting(self) -> None: original = ( '\n' '\n' @@ -162,9 +163,9 @@ def test_apply_formatting(self) -> Generator: '\n' '\n' ) - yield from self.__run_formatting_test(original, expected, file_changes) + await self.__run_formatting_test(original, expected, file_changes) - def test_apply_formatting_and_preserve_order(self) -> Generator: + async def test_apply_formatting_and_preserve_order(self) -> None: original = ( 'abcde\n' 'fghij\n' @@ -182,48 +183,48 @@ def test_apply_formatting_and_preserve_order(self) -> Generator: 'a123bcde\n' 'fg456ij\n' ) - yield from self.__run_formatting_test(original, expected, file_changes) + await self.__run_formatting_test(original, expected, file_changes) - def test_tabs_are_respected_even_when_translate_tabs_to_spaces_is_set_to_true(self) -> Generator: + async def test_tabs_are_respected_even_when_translate_tabs_to_spaces_is_set_to_true(self) -> None: original = ' ' * 4 file_changes = [((0, 0), (0, 4), '\t')] expected = '\t' assert self.view self.view.settings().set("translate_tabs_to_spaces", True) - yield from self.__run_formatting_test(original, expected, file_changes) + await self.__run_formatting_test(original, expected, file_changes) # Make sure the user's settings haven't changed self.assertTrue(self.view.settings().get("translate_tabs_to_spaces")) - def __run_formatting_test( + async def __run_formatting_test( self, original: Iterable[str], expected: Iterable[str], file_changes: list[tuple[tuple[int, int], tuple[int, int], str]] - ) -> Generator: + ) -> None: assert self.view original_change_count = self.insert_characters(''.join(original)) # self.assertEqual(original_change_count, 1) - self.set_response('textDocument/formatting', [{ + await self.mock_response('textDocument/formatting', [{ 'newText': new_text, 'range': { 'start': {'line': start[0], 'character': start[1]}, 'end': {'line': end[0], 'character': end[1]}}} for start, end, new_text in file_changes]) self.view.run_command('lsp_format_document') - yield from self.await_message('textDocument/formatting') - yield from self.await_view_change(original_change_count + len(file_changes)) + await self.await_message('textDocument/formatting') + await self.await_view_change(original_change_count + len(file_changes)) edited_content = self.view.substr(sublime.Region(0, self.view.size())) self.assertEqual(edited_content, ''.join(expected)) - def __run_goto_test(self, response: list, text_document_request: str, subl_command_suffix: str) -> Generator: + async def __run_goto_test(self, response: list, text_document_request: str, subl_command_suffix: str) -> None: assert self.view self.insert_characters(GOTO_CONTENT) # Put the cursor back at the start of the buffer, otherwise is_at_word fails in goto.py. self.view.sel().clear() self.view.sel().add(sublime.Region(0, 0)) method = f'textDocument/{text_document_request}' - self.set_response(method, response) + await self.mock_response(method, response) self.view.run_command(f'lsp_symbol_{subl_command_suffix}') - yield from self.await_message(method) + await self.await_message(method) def condition() -> bool: nonlocal self @@ -233,35 +234,35 @@ def condition() -> bool: return False return s[0].begin() > 0 - yield {"condition": condition, "timeout": TIMEOUT_TIME} + await self.wait_until_st_state(condition) first = self.view.sel()[0].begin() self.assertEqual(self.view.substr(sublime.Region(first, first + 1)), "F") - def test_definition(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE, 'definition', 'definition') + async def test_definition(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE, 'definition', 'definition') - def test_definition_location_link(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'definition', 'definition') + async def test_definition_location_link(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'definition', 'definition') - def test_type_definition(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE, 'typeDefinition', 'type_definition') + async def test_type_definition(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE, 'typeDefinition', 'type_definition') - def test_type_definition_location_link(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'typeDefinition', 'type_definition') + async def test_type_definition_location_link(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'typeDefinition', 'type_definition') - def test_declaration(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE, 'declaration', 'declaration') + async def test_declaration(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE, 'declaration', 'declaration') - def test_declaration_location_link(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'declaration', 'declaration') + async def test_declaration_location_link(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'declaration', 'declaration') - def test_implementation(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE, 'implementation', 'implementation') + async def test_implementation(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE, 'implementation', 'implementation') - def test_implementation_location_link(self) -> Generator: - yield from self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'implementation', 'implementation') + async def test_implementation_location_link(self) -> None: + await self.__run_goto_test(GOTO_RESPONSE_LOCATION_LINK, 'implementation', 'implementation') - def test_expand_selection(self) -> Generator: + async def test_expand_selection(self) -> None: self.insert_characters("abcba\nabcba\nabcba\n") self.view.run_command("lsp_selection_set", {"regions": [(2, 2)]}) self.assertEqual(len(self.view.sel()), 1) @@ -277,19 +278,19 @@ def test_expand_selection(self) -> Generator: "range": {"start": {"line": 0, "character": 2}, "end": {"line": 0, "character": 3}} }] - def expand_and_check(a: int, b: int) -> Generator: - self.set_response("textDocument/selectionRange", response) + async def expand_and_check(a: int, b: int) -> None: + await self.mock_response("textDocument/selectionRange", response) self.view.run_command("lsp_expand_selection") - yield from self.await_message("textDocument/selectionRange") - yield lambda: self.view.sel()[0] == sublime.Region(a, b) + await self.await_message("textDocument/selectionRange") + await self.wait_until_st_state(lambda: self.view.sel()[0] == sublime.Region(a, b)) - yield from expand_and_check(2, 3) - yield from expand_and_check(1, 3) - yield from expand_and_check(0, 5) + await expand_and_check(2, 3) + await expand_and_check(1, 3) + await expand_and_check(0, 5) - def test_rename(self) -> Generator: + async def test_rename(self) -> None: self.insert_characters("foo\nfoo\nfoo\n") - self.set_response("textDocument/rename", { + await self.mock_response("textDocument/rename", { 'changes': { filename_to_uri(TEST_FILE_PATH): [ { @@ -315,47 +316,39 @@ def test_rename(self) -> Generator: ) self.view.run_command("lsp_selection_set", {"regions": [(0, 0)]}) self.view.run_command("lsp_symbol_rename", {"new_name": "bar"}) - yield from self.await_message("textDocument/rename") - yield from self.await_view_change(9) + await self.await_message("textDocument/rename") + await self.await_view_change(9) self.assertEqual(self.view.substr(sublime.Region(0, self.view.size())), "bar\nbar\nbar\n") - def test_run_command(self) -> Generator: - self.set_response("workspace/executeCommand", {"canReturnAnythingHere": "asdf"}) - promise = YieldPromise() - sublime.set_timeout_async( - lambda: self.session.execute_command( - {"command": "foo", "arguments": ["hello", "there", "general", "kenobi"]}, - progress=False, - view=self.view, - ).then(promise.fulfill) - ) - yield from self.await_promise(promise) - yield from self.await_message("workspace/executeCommand") - self.assertEqual(promise.result(), {"canReturnAnythingHere": "asdf"}) + async def test_run_command(self) -> None: + await self.mock_response("workspace/executeCommand", {"canReturnAnythingHere": "asdf"}) + command: Command = {"command": "foo", "arguments": ["hello", "there", "general", "kenobi"]} + assert self.session + result = await self.session.execute_command(command, progress=False) + await self.await_message("workspace/executeCommand") + self.assertEqual(result, {"canReturnAnythingHere": "asdf"}) - def test_progress(self) -> Generator: - request = Request("foobar", {"hello": "world"}, self.view, progress=True) - self.set_response("foobar", {"general": "kenobi"}) - promise = self.session.send_request_task(request) - yield lambda: "workDoneToken" in request.params - result = yield from self.await_promise(promise) - self.assertEqual(result, {"general": "kenobi"}) + async def test_progress(self) -> None: + # not sure how this tests $/progress ? + await self.mock_response("foobar", {"general": "kenobi"}) + assert self.session + result = self.session.request(Request("foobar", {"hello": "world"}, self.view, progress=True)) + self.assertEqual(await result, {"general": "kenobi"}) class SingleDocumentTestCase2(TextDocumentTestCase): - def test_did_change(self) -> Generator: + async def test_did_change(self) -> None: assert self.view self.maxDiff = None self.insert_characters("A") - yield from self.await_message("textDocument/didChange") + await self.await_message("textDocument/didChange") # multiple changes are batched into one didChange notification self.insert_characters("B\n") self.insert_characters("🙂\n") self.insert_characters("D") - promise = YieldPromise() - yield from self.await_message("textDocument/didChange", promise) - self.assertEqual(promise.result(), { + result = await self.await_message("textDocument/didChange") + self.assertEqual(result, { 'contentChanges': [ {'rangeLength': 0, 'range': {'start': {'line': 0, 'character': 1}, 'end': {'line': 0, 'character': 1}}, 'text': 'B'}, # noqa {'rangeLength': 0, 'range': {'start': {'line': 0, 'character': 2}, 'end': {'line': 0, 'character': 2}}, 'text': '\n'}, # noqa @@ -377,7 +370,7 @@ def get_test_name(cls) -> str: return "testfile2" @skip('Flaky on Windows and Mac') - def test_did_change_before_did_close(self) -> Generator: + async def test_did_change_before_did_close(self) -> None: assert self.view self.view.window().run_command("chain", { "commands": [ @@ -386,9 +379,9 @@ def test_did_change_before_did_close(self) -> Generator: ["close", {}] ] }) - yield from self.await_message('textDocument/didChange') - yield from self.await_message('textDocument/didSave') - yield from self.await_message('textDocument/didClose') + await self.await_message('textDocument/didChange') + await self.await_message('textDocument/didSave') + await self.await_message('textDocument/didClose') class WillSaveWaitUntilTestCase(TextDocumentTestCase): @@ -399,11 +392,11 @@ def get_test_server_capabilities(cls) -> dict: capabilities['capabilities']['textDocumentSync']['willSaveWaitUntil'] = True return capabilities - def test_will_save_wait_until(self) -> Generator: + async def test_will_save_wait_until(self) -> None: assert self.view self.insert_characters("A") - yield from self.await_message("textDocument/didChange") - self.set_response('textDocument/willSaveWaitUntil', [{ + await self.await_message("textDocument/didChange") + await self.mock_response('textDocument/willSaveWaitUntil', [{ 'newText': "BBB", 'range': { 'start': {'line': 0, 'character': 0}, @@ -412,9 +405,9 @@ def test_will_save_wait_until(self) -> Generator: }]) self.view.settings().set("lsp_format_on_save", False) self.view.run_command("lsp_save", {'async': True}) - yield from self.await_message("textDocument/willSaveWaitUntil") - yield from self.await_message("textDocument/didChange") - yield from self.await_message("textDocument/didSave") + await self.await_message("textDocument/willSaveWaitUntil") + await self.await_message("textDocument/didChange") + await self.await_message("textDocument/didSave") text = self.view.substr(sublime.Region(0, self.view.size())) self.assertEqual("BBB", text) - yield from self.await_clear_view_and_save() + await self.await_clear_view_and_save() diff --git a/tests/test_views.py b/tests/test_views.py index 0b47ecdfe..2dfc0ba9b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -36,12 +36,12 @@ from LSP.protocol import MarkupKind from typing import Any from unittest.mock import MagicMock -from unittesting import DeferrableTestCase import re import sublime +import unittest -class ViewsTest(DeferrableTestCase): +class ViewsTest(unittest.TestCase): def setUp(self) -> None: super().setUp()