Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
cancel-in-progress: true
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macos-latest]
fail-fast: false
env:
Expand Down
39 changes: 33 additions & 6 deletions nonebot/adapters/telegram/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
import anyio
from pydantic.main import BaseModel
from pydantic.json import pydantic_encoder
from nonebot.utils import escape_tag, logger_wrapper
from nonebot.drivers import URL, Driver, Request, Response, HTTPServerSetup
from nonebot.utils import UNSET, escape_tag, logger_wrapper
from nonebot.drivers import (
URL,
DEFAULT_TIMEOUT,
Driver,
Request,
Timeout,
Response,
HTTPServerSetup,
)

from nonebot.adapters import Adapter as BaseAdapter

Expand All @@ -31,7 +39,7 @@ class Adapter(BaseAdapter):
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
self.adapter_config = AdapterConfig(**self.config.model_dump())
self.tasks: list[asyncio.Task] = []
self.tasks: set[asyncio.Task] = set()
self.setup()

@classmethod
Expand Down Expand Up @@ -90,11 +98,13 @@ async def poll(self, bot: Bot):
if update_offset is not None:
for update in updates:
update_offset = update.update_id + 1
asyncio.create_task(
task = asyncio.create_task(
self.__handle_update(
bot, update.model_dump(by_alias=True, exclude_none=True)
)
)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
elif updates:
update_offset = updates[0].update_id
except Exception as e:
Expand All @@ -104,7 +114,9 @@ async def poll(self, bot: Bot):
def setup_polling(self, bot: Bot):
@self.on_ready
async def _():
self.tasks.append(asyncio.create_task(self.poll(bot)))
task = asyncio.create_task(self.poll(bot))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)

@self.driver.on_shutdown
async def _():
Expand All @@ -120,7 +132,9 @@ async def handle_http(self, request: Request) -> Response:
if bot.secret_token == token:
if request.content:
update: dict = json.loads(request.content)
asyncio.create_task(self.__handle_update(bot, update))
task = asyncio.create_task(self.__handle_update(bot, update))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
return Response(204)
return Response(401)

Expand Down Expand Up @@ -152,6 +166,18 @@ async def _call_api(self, bot: Bot, api: str, **data) -> Any:
s.capitalize() for s in api.split("_")[1:]
)
data = _escape_none(data)
request_timeout = UNSET
if api == "getUpdates":
timeout = data.get("timeout")
if not isinstance(timeout, bool) and isinstance(timeout, (int, float)):
# Telegram timeout is server-side long polling; the HTTP read
# timeout must be slightly longer.
request_timeout = Timeout(
total=DEFAULT_TIMEOUT.total,
connect=DEFAULT_TIMEOUT.connect,
read=float(timeout) + 5,
close=DEFAULT_TIMEOUT.close,
)

# 分离文件到 files
files: dict[str, tuple[str, bytes]] = {}
Expand Down Expand Up @@ -233,6 +259,7 @@ async def process_input_file(file: Union[InputFile, str]) -> Optional[str]:
data=data if files else None,
json=data if not files else None,
files=files, # type: ignore
timeout=request_timeout,
proxy=self.adapter_config.proxy,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion nonebot/adapters/telegram/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def call_api(self, api: str, *args: Any, **kargs: Any) -> Any:
)
return await super().call_api(api, **kargs)

def __getattribute__(self, __name: str) -> Any:
def __getattribute__(self, __name: str, /) -> Any:
if not __name.startswith("__") and hasattr(API, __name):
return partial(self.call_api, __name)
return object.__getattribute__(self, __name)
Expand Down
15 changes: 11 additions & 4 deletions nonebot/adapters/telegram/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class EventWithChat(Protocol):


class Event(BaseEvent):
telegram_model: Update = Field(default=None)
telegram_model: Optional[Update] = Field(default=None)

@classmethod
def __parse_event(cls, obj: dict) -> "Event":
Expand Down Expand Up @@ -362,7 +362,7 @@ def get_event_description(self) -> str:


class GroupEditedMessageEvent(EditedMessageEvent):
from_: User = Field(default=None, alias="from")
from_: Optional[User] = Field(default=None, alias="from")
sender_chat: Optional[Chat] = None

@classmethod
Expand All @@ -380,14 +380,17 @@ def get_event_name(self) -> str:

@override
def get_user_id(self) -> str:
assert self.from_ is not None
return str(self.from_.id)

@override
def get_session_id(self) -> str:
assert self.from_ is not None
return f"group_{self.chat.id}_{self.from_.id}"

@override
def get_event_description(self) -> str:
assert self.from_ is not None
return (
f"EditedMessage {self.message_id} from {self.from_.id}"
f"@[Chat {self.chat.id}]: {self.get_message_description()}"
Expand All @@ -403,10 +406,12 @@ def get_event_name(self) -> str:

@override
def get_session_id(self) -> str:
assert self.from_ is not None
return f"group_{self.chat.id}_thread{self.message_thread_id}_{self.from_.id}"

@override
def get_event_description(self) -> str:
assert self.from_ is not None
return (
f"EditedMessage {self.message_id} from {self.from_.id}@[Chat {self.chat.id}"
f" Thread {self.message_thread_id}]: {self.get_message_description()}"
Expand Down Expand Up @@ -470,7 +475,7 @@ class PinnedMessageEvent(NoticeEvent):
sender_chat: Optional[Chat] = None
chat: Chat
date: int
pinned_message: MessageEvent = Field(default=None)
pinned_message: Optional[MessageEvent] = Field(default=None)

@classmethod
def __parse_event(cls, obj: dict):
Expand All @@ -485,10 +490,12 @@ def get_event_name(self) -> str:

@override
def get_message(self) -> Message:
assert self.pinned_message is not None
return self.pinned_message.get_message()

@override
def get_event_description(self) -> str:
assert self.pinned_message is not None
return (
f"PinnedMessage {self.pinned_message.message_id} "
f"@[Chat {self.pinned_message.chat.id}]: {self.get_message_description()}"
Expand Down Expand Up @@ -705,7 +712,7 @@ def get_event_description(self) -> str:


class CallbackQueryEvent(InlineEvent, CallbackQuery):
chat: Chat = Field(default=None)
chat: Optional[Chat] = Field(default=None)

@override
def get_event_name(self) -> str:
Expand Down
Loading
Loading