diff --git a/.gitignore b/.gitignore index f9bc356..52c8f85 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +.ruff_cache/ sessions/ # Byte-compiled / optimized / DLL files diff --git a/.vscode/settings.json b/.vscode/settings.json index c0dc805..7f3f159 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,11 @@ { + "files.exclude": { + ".ruff_cache": true, + ".pytest_cache": true + }, "[python]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": ["source.organizeImports"] }, - "python.formatting.provider": "black", + "python.formatting.provider": "black" } diff --git a/twapi/models.py b/twapi/models.py new file mode 100644 index 0000000..64cf3cb --- /dev/null +++ b/twapi/models.py @@ -0,0 +1,186 @@ +import email.utils +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +from .utils import get_or, int_or_none + + +@dataclass +class Coordinates: + longitude: float + latitude: float + + +@dataclass +class Place: + id: str + fullName: str + name: str + type: str + country: str + countryCode: str + + @staticmethod + def parse(obj: dict): + return Place( + id=obj["id"], + fullName=obj["full_name"], + name=obj["name"], + type=obj["place_type"], + country=obj["country"], + countryCode=obj["country_code"], + ) + + +@dataclass +class TextLink: + url: str + text: str | None + tcourl: str | None + indices: tuple[int, int] + + @staticmethod + def parse(obj: dict): + return TextLink( + url=obj["expanded_url"], + text=obj["display_url"], + tcourl=obj["url"], + indices=tuple(obj["indices"]), + ) + + +@dataclass +class UserRef: + id: int + username: str + displayname: str + + @staticmethod + def parse(obj: dict): + return UserRef( + id=obj["id"], + username=obj["screen_name"], + displayname=obj["name"], + ) + + +@dataclass +class User: + id: int + username: str + displayname: str + rawDescription: str + created: datetime + followersCount: int + friendsCount: int + statusesCount: int + favouritesCount: int + listedCount: int + mediaCount: int + location: str + profileImageUrl: str + profileBannerUrl: str | None = None + protected: bool | None = None + verified: bool | None = None + + # descriptionLinks: typing.Optional[typing.List[TextLink]] = None + # link: typing.Optional[TextLink] = None + # label: typing.Optional["UserLabel"] = None + + @property + def url(self) -> str: + return f"https://twitter.com/{self.username}" + + @staticmethod + def parse(obj: dict): + return User( + id=int(obj["id_str"]), + username=obj["screen_name"], + displayname=obj["name"], + rawDescription=obj["description"], + created=email.utils.parsedate_to_datetime(obj["created_at"]), + followersCount=obj["followers_count"], + friendsCount=obj["friends_count"], + statusesCount=obj["statuses_count"], + favouritesCount=obj["favourites_count"], + listedCount=obj["listed_count"], + mediaCount=obj["media_count"], + location=obj["location"], + profileImageUrl=obj["profile_image_url_https"], + profileBannerUrl=obj.get("profile_banner_url"), + verified=obj.get("verified"), + protected=obj.get("protected"), + ) + + +@dataclass +class Tweet: + id: int + date: datetime + user: User + lang: str + rawContent: str + replyCount: int + retweetCount: int + likeCount: int + quoteCount: int + conversationId: int + hashtags: list[str] + cashtags: list[str] + mentionedUsers: list[UserRef] + links: list[TextLink] + viewCount: int | None = None + retweetedTweet: Optional["Tweet"] = None + quotedTweet: Optional["Tweet"] = None + place: Optional[Place] = None + coordinates: Optional[Coordinates] = None + + @property + def url(self): + return f"https://twitter.com/{self.user.username}/status/{self.id}" + + # renderedContent: str + # source: str | None = None + # sourceUrl: str | None = None + # sourceLabel: str | None = None + # media: typing.Optional[typing.List["Medium"]] = None + # inReplyToTweetId: typing.Optional[int] = None + # inReplyToUser: typing.Optional["User"] = None + # card: typing.Optional["Card"] = None + # vibe: typing.Optional["Vibe"] = None + + @staticmethod + def parse(obj: dict, res: dict): + rt_obj = get_or(res, f"globalObjects.tweets.{obj.get('retweeted_status_id_str')}") + qt_obj = get_or(res, f"globalObjects.tweets.{obj.get('quoted_status_id_str')}") + + coordinates: Coordinates | None = None + if obj.get("coordinates"): + coords = obj["coordinates"]["coordinates"] + coordinates = Coordinates(coords[0], coords[1]) + elif obj.get("geo"): + coords = obj["geo"]["coordinates"] + coordinates = Coordinates(coords[1], coords[0]) + + return Tweet( + id=obj["id"], + date=email.utils.parsedate_to_datetime(obj["created_at"]), + user=User.parse(res["globalObjects"]["users"][obj["user_id_str"]]), + lang=obj["lang"], + rawContent=obj["full_text"], + replyCount=obj["reply_count"], + retweetCount=obj["retweet_count"], + likeCount=obj["favorite_count"], + quoteCount=obj["quote_count"], + conversationId=int(obj["conversation_id_str"]), + hashtags=[x["text"] for x in get_or(obj, "entities.hashtags", [])], + cashtags=[x["text"] for x in get_or(obj, "entities.symbols", [])], + mentionedUsers=[UserRef.parse(x) for x in get_or(obj, "entities.user_mentions", [])], + links=[TextLink.parse(x) for x in get_or(obj, "entities.urls", [])], + viewCount=int_or_none(obj, "ext_views.count"), + retweetedTweet=Tweet.parse(rt_obj, res) if rt_obj else None, + quotedTweet=Tweet.parse(qt_obj, res) if qt_obj else None, + place=Place.parse(obj["place"]) if obj.get("place") else None, + coordinates=coordinates, + ) diff --git a/twapi/pool.py b/twapi/pool.py index f37e3bf..82f133d 100644 --- a/twapi/pool.py +++ b/twapi/pool.py @@ -1,7 +1,5 @@ import asyncio -from typing import AsyncGenerator, Callable, Tuple -from httpx import AsyncClient, HTTPStatusError, Response from loguru import logger from .client import UserClient @@ -30,30 +28,3 @@ class AccountsPool: else: logger.debug(f"No accounts available for queue '{queue}' (sleeping for 5 sec)") await asyncio.sleep(5) - - async def execute( - self, - queue: str, - cb: Callable[ - [AsyncClient, str | None], AsyncGenerator[Tuple[Response, dict, str | None], None] - ], - cursor: str | None = None, - ): - while True: - account = await self.get_account_or_wait(queue) - - try: - client = account.make_client() - async for x in cb(client, cursor): - rep, data, cursor = x - yield rep, data, cursor - return # exit if no more results - except HTTPStatusError as e: - if e.response.status_code == 429: - account.update_limit(queue, e.response) - logger.debug(f"Rate limit reached for account {account.username}") - continue - else: - raise e - finally: - account.unlock(queue) diff --git a/twapi/search.py b/twapi/search.py index 646e194..141b1db 100644 --- a/twapi/search.py +++ b/twapi/search.py @@ -1,9 +1,12 @@ import json +from typing import Awaitable, Callable -from httpx import AsyncClient, Response +from httpx import AsyncClient, HTTPStatusError, Response from loguru import logger +from .models import Tweet from .pool import AccountsPool +from .utils import find_item BASIC_SEARCH_PARAMS = """ include_profile_interstitial_type=1 @@ -70,34 +73,64 @@ SEARCH_URL = "https://api.twitter.com/2/search/adaptive.json" SEARCH_PARAMS = dict(x.split("=") for x in BASIC_SEARCH_PARAMS.splitlines() if x) -def json_params(params: dict): - return {k: json.dumps(v, separators=(",", ":")) for k, v in params.items()} +def filter_null(obj: dict): + try: + return {k: v for k, v in obj.items() if v is not None} + except AttributeError: + return obj + + +def json_params(obj: dict): + return {k: json.dumps(filter_null(v), separators=(",", ":")) for k, v in obj.items()} def get_ql_entries(obj: dict) -> list[dict]: - try: - key = list(obj["data"].keys())[0] - return obj["data"][key]["timeline"]["instructions"][0]["entries"] - except Exception: - return [] - - -def get_ql_cursor(obj: dict) -> str | None: - for entry in get_ql_entries(obj): - if entry["entryId"].startswith("cursor-bottom-"): - return entry["content"]["value"] - return None - - -def rep_info(rep: Response) -> str: - return f"[{rep.status_code} ~ {rep.headers['x-rate-limit-remaining']}/{rep.headers['x-rate-limit-limit']}]" + entries = find_item(obj, "entries") + return entries or [] class Search: def __init__(self, pool: AccountsPool): self.pool = pool - def get_next_cursor(self, res: dict) -> str | None: + async def _inf_req(self, queue: str, cb: Callable[[AsyncClient], Awaitable[Response]]): + while True: + account = await self.pool.get_account_or_wait(queue) + client = account.make_client() + + try: + while True: + rep = await cb(client) + rep.raise_for_status() + yield rep + except HTTPStatusError as e: + if e.response.status_code == 429: + logger.debug(f"Rate limit for account={account.username} on queue={queue}") + account.update_limit(queue, e.response) + continue + else: + logger.error(f"[{e.response.status_code}] {e.request.url}\n{e.response.text}") + raise e + finally: + account.unlock(queue) + + def _check_stop(self, rep: Response, txt: str, cnt: int, res: list, cur: str | None, lim: int): + els = len(res) + is_res, is_cur, is_lim = els > 0, cur is not None, lim > 0 and cnt >= lim + + msg = [ + f"{txt} {cnt:,d} (+{els:,d}) res={int(is_res)} cur={int(is_cur)} lim={int(is_lim)}", + f"[{rep.headers['x-rate-limit-remaining']}/{rep.headers['x-rate-limit-limit']}]", + ] + logger.debug(" ".join(msg)) + + end_before = not is_res + end_after = not is_cur or is_lim + return cnt + els, end_before, end_after + + # search + + def get_search_cursor(self, res: dict) -> str | None: try: for x in res["timeline"]["instructions"]: entry = x.get("replaceEntry", None) @@ -111,86 +144,112 @@ class Search: logger.debug(e) return None - async def get_items(self, client: AsyncClient, q: str, cursor: str | None): - while True: + async def search_raw(self, q: str, limit=-1): + queue, cursor, all_count = "search", None, 0 + + async def _get(client: AsyncClient): params = {**SEARCH_PARAMS, "q": q, "count": 20} params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch" + return await client.get(SEARCH_URL, params=params) - rep = await client.get(SEARCH_URL, params=params) - rep.raise_for_status() - + async for rep in self._inf_req(queue, _get): data = rep.json() - cursor = self.get_next_cursor(data) + + cursor = self.get_search_cursor(data) tweets = data.get("globalObjects", {}).get("tweets", []) - if not tweets or not cursor: - is_result = len(tweets) > 0 - is_cursor = cursor is not None - logger.debug(f"{q} - no more items [res={is_result} cur={is_cursor}]") + + check = self._check_stop(rep, q, all_count, tweets, cursor, limit) + all_count, end_before, end_after = check + + if end_before: return - yield rep, data, cursor - - async def search(self, q: str): - total_count = 0 - async for x in self.pool.execute("search", lambda c, cur: self.get_items(c, q, cur)): - rep, data, cursor = x - - tweets = data.get("globalObjects", {}).get("tweets", []) - total_count += len(tweets) - logger.debug(f"{q} - {total_count:,d} (+{len(tweets):,d}) {rep_info(rep)}") - yield rep + if end_after: + return + + async def search(self, q: str, limit=-1): + async for rep in self.search_raw(q, limit=limit): + data = rep.json() + items = list(data.get("globalObjects", {}).get("tweets", {}).values()) + for x in items: + yield Tweet.parse(x, data) + + # graphql + + def get_ql_cursor(self, obj: dict) -> str | None: + try: + for entry in get_ql_entries(obj): + if entry["entryId"].startswith("cursor-bottom-"): + return entry["content"]["value"] + return None + except Exception: + return None + async def graphql_items(self, op: str, variables: dict, features: dict = {}, limit=-1): url = f"https://twitter.com/i/api/graphql/{op}" features = {**BASE_FEATURES, **features} - cursor, all_count, queue = None, 0, op.split("/")[-1] - while True: - account = await self.pool.get_account_or_wait(queue) - client = account.make_client() + queue, cursor, all_count = op.split("/")[-1], None, 0 - try: - params = {"variables": {**variables, "cursor": cursor}, "features": features} - rep = await client.get(url, params=json_params(params)) - logger.debug(f"{url} {rep_info(rep)}") - rep.raise_for_status() + async def _get(client: AsyncClient): + params = {"variables": {**variables, "cursor": cursor}, "features": features} + return await client.get(url, params=json_params(params)) - data = rep.json() - entries, cursor = get_ql_entries(data), get_ql_cursor(data) + async for rep in self._inf_req(queue, _get): + data = rep.json() + entries, cursor = get_ql_entries(data), self.get_ql_cursor(data) - # cursor-top / cursor-bottom always present - now_count = len([x for x in entries if not x["entryId"].startswith("cursor-")]) - all_count += now_count + # cursor-top / cursor-bottom always present + items = [x for x in entries if not x["entryId"].startswith("cursor-")] + check = self._check_stop(rep, queue, all_count, items, cursor, limit) + all_count, end_before, end_after = check - yield rep + if end_before: + return - if not cursor or not now_count or (limit > 0 and all_count >= limit): - return - finally: - account.unlock(queue) + yield rep + + if end_after: + return async def graphql_item(self, op: str, variables: dict, features: dict = {}): - res: list[Response] = [] - async for x in self.graphql_items(op, variables, features): - res.append(x) - break - return res[0] + url = f"https://twitter.com/i/api/graphql/{op}" + features = {**BASE_FEATURES, **features} + + async def _get(client: AsyncClient): + params = {"variables": {**variables}, "features": features} + return await client.get(url, params=json_params(params)) + + queue = op.split("/")[-1] + async for rep in self._inf_req(queue, _get): + msg = [ + f"{queue}", + f"[{rep.headers['x-rate-limit-remaining']}/{rep.headers['x-rate-limit-limit']}]", + ] + logger.debug(" ".join(msg)) + + return rep async def user_by_login(self, login: str): - v = {"screen_name": login, "withSafetyModeUserFields": True} - return await self.graphql_item("sLVLhk0bGj3MVFEKTdax1w/UserByScreenName", v) + op = "sLVLhk0bGj3MVFEKTdax1w/UserByScreenName" + kv = {"screen_name": login, "withSafetyModeUserFields": True} + return await self.graphql_item(op, kv) async def user_by_id(self, uid: int): - v = {"userId": str(uid), "withSafetyModeUserFields": True} - return await self.graphql_item("GazOglcBvgLigl3ywt6b3Q/UserByRestId", v) + op = "GazOglcBvgLigl3ywt6b3Q/UserByRestId" + kv = {"userId": str(uid), "withSafetyModeUserFields": True} + return await self.graphql_item(op, kv) async def retweeters(self, twid: int, limit=-1): - v = {"tweetId": str(twid), "count": 20, "includePromotedContent": True} - async for x in self.graphql_items("U5f_jm0CiLmSfI1d4rGleQ/Retweeters", v, limit=limit): + op = "U5f_jm0CiLmSfI1d4rGleQ/Retweeters" + kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True} + async for x in self.graphql_items(op, kv, limit=limit): yield x async def favoriters(self, twid: int, limit=-1): - v = {"tweetId": str(twid), "count": 20, "includePromotedContent": True} - async for x in self.graphql_items("vcTrPlh9ovFDQejz22q9vg/Favoriters", v, limit=limit): + op = "vcTrPlh9ovFDQejz22q9vg/Favoriters" + kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True} + async for x in self.graphql_items(op, kv, limit=limit): yield x diff --git a/twapi/utils.py b/twapi/utils.py new file mode 100644 index 0000000..ed06771 --- /dev/null +++ b/twapi/utils.py @@ -0,0 +1,37 @@ +from typing import Any, TypeVar + +T = TypeVar("T") + + +# https://stackoverflow.com/a/43184871 +def find_item(obj: dict, key: str, default=None): + stack = [iter(obj.items())] + while stack: + for k, v in stack[-1]: + if k == key: + return v + elif isinstance(v, dict): + stack.append(iter(v.items())) + break + elif isinstance(v, list): + stack.append(iter(enumerate(v))) + break + else: + stack.pop() + return default + + +def get_or(obj: dict, key: str, default_value: T = None) -> Any | T: + for part in key.split("."): + if part not in obj: + return default_value + obj = obj[part] + return obj + + +def int_or_none(obj: dict, key: str): + try: + val = get_or(obj, key) + return int(val) if val is not None else None + except Exception: + return None