Skip to content

Commit

Permalink
Minor fixes.
Browse files Browse the repository at this point in the history
Signed-off-by: Aliwoto <[email protected]>
  • Loading branch information
ALiwoto committed Dec 15, 2024
1 parent 319b22e commit b21585a
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 236 deletions.
2 changes: 1 addition & 1 deletion pyrogram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

__version__ = "2.0.149"
__version__ = "2.0.150"
__license__ = "GNU Lesser General Public License v3.0 (LGPL-3.0)"
__copyright__ = "Copyright (C) 2017-present Dan <https://github.com/delivrance>"

Expand Down
131 changes: 114 additions & 17 deletions pyrogram/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple

import pyrogram
from pyrogram import __version__, __license__
Expand All @@ -45,14 +45,17 @@
SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate,
BadRequest, AuthBytesInvalid,
FloodWait, FloodPremiumWait
FloodWait, FloodPremiumWait,
ChannelInvalid, PersistentTimestampInvalid, PersistentTimestampOutdated
)
from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods
from pyrogram.session import Auth, Session
from pyrogram.storage import Storage, FileStorage, MemoryStorage
from pyrogram.types import User, TermsOfService
from pyrogram.utils import ainput
from .connection import Connection
from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
Expand Down Expand Up @@ -264,7 +267,9 @@ def __init__(
max_message_cache_size: int = MAX_MESSAGE_CACHE_SIZE,
storage_engine: Optional[Storage] = None,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
init_connection_params: Optional["raw.base.JSONValue"] = None
init_connection_params: Optional["raw.base.JSONValue"] = None,
connection_factory: Type[Connection] = Connection,
protocol_factory: Type[TCP] = TCPAbridged
):
super().__init__()

Expand Down Expand Up @@ -299,6 +304,8 @@ def __init__(
self.max_message_cache_size = max_message_cache_size
self.client_platform = client_platform
self.init_connection_params = init_connection_params
self.connection_factory = connection_factory
self.protocol_factory = protocol_factory

self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")

Expand Down Expand Up @@ -541,48 +548,51 @@ def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, raw.types.Channel]]) -> bool:
is_min = False
parsed_peers = []
parsed_usernames = []

for peer in peers:
if getattr(peer, "min", False):
is_min = True
continue

usernames = None
usernames = []
phone_number = None

if isinstance(peer, raw.types.User):
peer_id = peer.id
access_hash = peer.access_hash
usernames = (
[peer.username.lower()] if peer.username
else [username.username.lower() for username in peer.usernames] if peer.usernames
else None
)
phone_number = peer.phone
peer_type = "bot" if peer.bot else "user"

if peer.username:
usernames.append(peer.username.lower())
elif peer.usernames:
usernames.extend(username.username.lower() for username in peer.usernames)
elif isinstance(peer, (raw.types.Chat, raw.types.ChatForbidden)):
peer_id = -peer.id
access_hash = 0
peer_type = "group"
elif isinstance(peer, raw.types.Channel):
peer_id = utils.get_channel_id(peer.id)
access_hash = peer.access_hash
usernames = (
[peer.username.lower()] if peer.username
else [username.username.lower() for username in peer.usernames] if peer.usernames
else None
)
peer_type = "channel" if peer.broadcast else "supergroup"

if peer.username:
usernames.append(peer.username.lower())
elif peer.usernames:
usernames.extend(username.username.lower() for username in peer.usernames)
elif isinstance(peer, raw.types.ChannelForbidden):
peer_id = utils.get_channel_id(peer.id)
access_hash = peer.access_hash
peer_type = "channel" if peer.broadcast else "supergroup"
else:
continue

parsed_peers.append((peer_id, access_hash, peer_type, usernames, phone_number))
parsed_peers.append((peer_id, access_hash, peer_type, phone_number))
parsed_usernames.append((peer_id, usernames))

await self.storage.update_peers(parsed_peers)
await self.storage.update_usernames(parsed_usernames)

return is_min

Expand Down Expand Up @@ -639,10 +649,11 @@ async def handle_updates(self, updates):
)]
),
pts=pts - pts_count,
limit=pts
limit=pts,
force=False
)
)
except ChannelPrivate:
except (ChannelPrivate, PersistentTimestampOutdated, PersistentTimestampInvalid):
pass
else:
if not isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
Expand Down Expand Up @@ -688,6 +699,92 @@ async def handle_updates(self, updates):
elif isinstance(updates, raw.types.UpdatesTooLong):
log.info(updates)

async def recover_gaps(self) -> Tuple[int, int]:
states = await self.storage.update_state()

message_updates_counter = 0
other_updates_counter = 0

if not states:
log.info("No states found, skipping recovery.")
return (message_updates_counter, other_updates_counter)

for state in states:
id, local_pts, _, local_date, _ = state

prev_pts = 0

while True:
try:
diff = await self.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000,
force=False
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (ChannelPrivate, ChannelInvalid, PersistentTimestampOutdated, PersistentTimestampInvalid):
break

if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date

if prev_pts == local_pts:
break

prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts

users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}

for message in diff.new_messages:
message_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)

for update in diff.other_updates:
other_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(update, users, chats)
)

if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break

await self.storage.update_state(id)

log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
return (message_updates_counter, other_updates_counter)

async def load_session(self):
await self.storage.open()

Expand Down
Loading

0 comments on commit b21585a

Please sign in to comment.