add setuptools; add more fields to models; accounts_poll refactoring

Этот коммит содержится в:
Vlad Pronsky 2023-04-30 23:54:17 +03:00
родитель 0c1377d3c6
Коммит 9509378441
13 изменённых файлов: 302 добавлений и 133 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -2,7 +2,7 @@
.ruff_cache/ .ruff_cache/
accounts/ accounts/
results-raw/ results-raw/
results/ results-parsed/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

Просмотреть файл

@ -7,6 +7,9 @@ lint:
lint-fix: lint-fix:
ruff check --fix . ruff check --fix .
pylint:
pylint --errors-only twapi
test: test:
pytest --cov=twapi tests/ pytest --cov=twapi tests/

Просмотреть файл

@ -1,3 +1,47 @@
[build-system]
requires = ['setuptools>=61', 'setuptools_scm>=6.2']
build-backend = "setuptools.build_meta"
[project]
name = "tw-api"
version = "0.1.0"
authors = [{name = "vladkens"}]
description = "Twitter GraphQL and Search API implementation with SNScrape data models"
readme = "readme.md"
requires-python = ">=3.10"
keywords = ["twitter", "api", "scrape", "snscrape", "tw-api", "twapi"]
license = {text = "MIT"}
classifiers = [
'Development Status :: 4 - Beta',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
]
dependencies = [
"fake-useragent==1.1.3",
"httpx==0.24.0",
"loguru==0.7.0"
]
[project.optional-dependencies]
dev = [
"pylint==2.17.3",
"pytest-asyncio==0.21.0",
"pytest-cov==4.0.0",
"pytest==7.3.1",
"ruff==0.0.263",
]
[project.urls]
repository = "https://github.com/vladkens/tw-api"
[tool.setuptools]
packages = ['twapi']
[tool.pytest.ini_options]
pythonpath = ["."]
asyncio_mode = "auto"
[tool.isort] [tool.isort]
profile = "black" profile = "black"
@ -6,7 +50,3 @@ line-length = 99
[tool.ruff] [tool.ruff]
line-length = 99 line-length = 99
[tool.pytest.ini_options]
pythonpath = ["."]
asyncio_mode = "auto"

Просмотреть файл

@ -1,6 +1,12 @@
Twitter GraphQL and Search API implementation with [SNScrape](https://github.com/JustAnotherArchivist/snscrape) data models. Twitter GraphQL and Search API implementation with [SNScrape](https://github.com/JustAnotherArchivist/snscrape) data models.
### Usage ## Install
```bash
pip install https://github.com/vladkens/tw-api
```
## Usage
```python ```python
import asyncio import asyncio
@ -48,6 +54,13 @@ async def main():
# change log level, default info # change log level, default info
set_log_level("DEBUG") set_log_level("DEBUG")
# Tweet & User model can be converted to regular dict or json, e.g.:
doc = await api.user_by_id(user_id) # User
doc.dict() # -> python dict
doc.json() # -> json string
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())
``` ```

Просмотреть файл

@ -1,4 +0,0 @@
ruff==0.0.263
pytest==7.3.1
pytest-asyncio==0.21.0
pytest-cov==4.0.0

Просмотреть файл

@ -1,3 +0,0 @@
httpx==0.24.0
fake-useragent==1.1.3
loguru==0.7.0

Просмотреть файл

@ -58,52 +58,70 @@ async def test_search():
items = await gather(api.search("elon musk lang:en", limit=20)) items = await gather(api.search("elon musk lang:en", limit=20))
assert len(items) > 0 assert len(items) > 0
for x in items: for doc in items:
assert x.id is not None assert doc.id is not None
assert x.user is not None assert doc.user is not None
tw_dict = x.json() obj = doc.dict()
assert x.id == tw_dict["id"] assert doc.id == obj["id"]
assert x.user.id == tw_dict["user"]["id"] assert doc.user.id == obj["user"]["id"]
assert "url" in obj
assert "url" in obj["user"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_user_by_id(): async def test_user_by_id():
api = API(AccountsPool()) api = API(AccountsPool())
mock_rep(api, "user_by_id_raw") mock_rep(api, "user_by_id_raw")
rep = await api.user_by_id(2244994945) doc = await api.user_by_id(2244994945)
assert rep.id == 2244994945 assert doc.id == 2244994945
assert rep.username == "TwitterDev" assert doc.username == "TwitterDev"
obj = rep.json() obj = doc.dict()
assert rep.id == obj["id"] assert doc.id == obj["id"]
assert rep.username == obj["username"] assert doc.username == obj["username"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_user_by_login(): async def test_user_by_login():
api = API(AccountsPool()) api = API(AccountsPool())
mock_rep(api, "user_by_login_raw") mock_rep(api, "user_by_login_raw")
rep = await api.user_by_login("twitterdev") doc = await api.user_by_login("twitterdev")
assert rep.id == 2244994945 assert doc.id == 2244994945
assert rep.username == "TwitterDev" assert doc.username == "TwitterDev"
obj = rep.json() obj = doc.dict()
assert rep.id == obj["id"] assert doc.id == obj["id"]
assert rep.username == obj["username"] assert doc.username == obj["username"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_tweet_details(): async def test_tweet_details():
api = API(AccountsPool()) api = API(AccountsPool())
mock_rep(api, "tweet_details_raw") mock_rep(api, "tweet_details_raw")
rep = await api.tweet_details(1649191520250245121) doc = await api.tweet_details(1649191520250245121)
assert rep.id == 1649191520250245121 assert doc.id == 1649191520250245121
assert rep.user is not None assert doc.user is not None
obj = rep.json() obj = doc.dict()
assert rep.id == obj["id"] assert doc.id == obj["id"]
assert rep.user.id == obj["user"]["id"] assert doc.user.id == obj["user"]["id"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_followers(): async def test_followers():
@ -113,13 +131,17 @@ async def test_followers():
users = await gather(api.followers(2244994945)) users = await gather(api.followers(2244994945))
assert len(users) > 0 assert len(users) > 0
for user in users: for doc in users:
assert user.id is not None assert doc.id is not None
assert user.username is not None assert doc.username is not None
obj = user.json() obj = doc.dict()
assert user.id == obj["id"] assert doc.id == obj["id"]
assert user.username == obj["username"] assert doc.username == obj["username"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_following(): async def test_following():
@ -129,13 +151,17 @@ async def test_following():
users = await gather(api.following(2244994945)) users = await gather(api.following(2244994945))
assert len(users) > 0 assert len(users) > 0
for user in users: for doc in users:
assert user.id is not None assert doc.id is not None
assert user.username is not None assert doc.username is not None
obj = user.json() obj = doc.dict()
assert user.id == obj["id"] assert doc.id == obj["id"]
assert user.username == obj["username"] assert doc.username == obj["username"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_retweters(): async def test_retweters():
@ -145,13 +171,17 @@ async def test_retweters():
users = await gather(api.retweeters(1649191520250245121)) users = await gather(api.retweeters(1649191520250245121))
assert len(users) > 0 assert len(users) > 0
for user in users: for doc in users:
assert user.id is not None assert doc.id is not None
assert user.username is not None assert doc.username is not None
obj = user.json() obj = doc.dict()
assert user.id == obj["id"] assert doc.id == obj["id"]
assert user.username == obj["username"] assert doc.username == obj["username"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_favoriters(): async def test_favoriters():
@ -161,13 +191,17 @@ async def test_favoriters():
users = await gather(api.favoriters(1649191520250245121)) users = await gather(api.favoriters(1649191520250245121))
assert len(users) > 0 assert len(users) > 0
for user in users: for doc in users:
assert user.id is not None assert doc.id is not None
assert user.username is not None assert doc.username is not None
obj = user.json() obj = doc.dict()
assert user.id == obj["id"] assert doc.id == obj["id"]
assert user.username == obj["username"] assert doc.username == obj["username"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_user_tweets(): async def test_user_tweets():
@ -177,13 +211,17 @@ async def test_user_tweets():
tweets = await gather(api.user_tweets(2244994945)) tweets = await gather(api.user_tweets(2244994945))
assert len(tweets) > 0 assert len(tweets) > 0
for tweet in tweets: for doc in tweets:
assert tweet.id is not None assert doc.id is not None
assert tweet.user is not None assert doc.user is not None
obj = tweet.json() obj = doc.dict()
assert tweet.id == obj["id"] assert doc.id == obj["id"]
assert tweet.user.id == obj["user"]["id"] assert doc.user.id == obj["user"]["id"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def test_user_tweets_and_replies(): async def test_user_tweets_and_replies():
@ -193,13 +231,17 @@ async def test_user_tweets_and_replies():
tweets = await gather(api.user_tweets_and_replies(2244994945)) tweets = await gather(api.user_tweets_and_replies(2244994945))
assert len(tweets) > 0 assert len(tweets) > 0
for tweet in tweets: for doc in tweets:
assert tweet.id is not None assert doc.id is not None
assert tweet.user is not None assert doc.user is not None
obj = tweet.json() obj = doc.dict()
assert tweet.id == obj["id"] assert doc.id == obj["id"]
assert tweet.user.id == obj["user"]["id"] assert doc.user.id == obj["user"]["id"]
txt = doc.json()
assert isinstance(txt, str)
assert str(doc.id) in txt
async def main(): async def main():
@ -207,7 +249,7 @@ async def main():
# you need to have some account to perform this # you need to have some account to perform this
pool = AccountsPool() pool = AccountsPool()
pool.load_from_dir() pool.restore()
api = API(pool) api = API(pool)

Просмотреть файл

@ -2,4 +2,5 @@
from .account import Account from .account import Account
from .accounts_pool import AccountsPool from .accounts_pool import AccountsPool
from .api import API from .api import API
from .models import *
from .utils import gather from .utils import gather

Просмотреть файл

@ -1,5 +1,4 @@
import json import json
import os
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
@ -19,10 +18,8 @@ class Status(str, Enum):
class Account: class Account:
BASE_DIR = "accounts"
@classmethod @classmethod
def load(cls, filepath: str): def load_from_file(cls, filepath: str):
try: try:
with open(filepath) as f: with open(filepath) as f:
data = json.load(f) data = json.load(f)
@ -81,15 +78,8 @@ class Account:
"status": self.status, "status": self.status,
} }
def save(self):
os.makedirs(self.BASE_DIR, exist_ok=True)
data = self.dump()
with open(f"{self.BASE_DIR}/{self.username}.json", "w") as f:
json.dump(data, f, indent=2)
def update_limit(self, queue: str, reset_ts: int): def update_limit(self, queue: str, reset_ts: int):
self.limits[queue] = datetime.fromtimestamp(reset_ts, tz=timezone.utc) self.limits[queue] = datetime.fromtimestamp(reset_ts, tz=timezone.utc)
self.save()
def can_use(self, queue: str): def can_use(self, queue: str):
if self.locked.get(queue, False) or self.status != Status.ACTIVE: if self.locked.get(queue, False) or self.status != Status.ACTIVE:
@ -126,7 +116,6 @@ class Account:
if e.response.status_code == 403: if e.response.status_code == 403:
logger.error(f"403 error {log_id}") logger.error(f"403 error {log_id}")
self.status = Status.LOGIN_ERROR self.status = Status.LOGIN_ERROR
self.save()
return return
self.client.headers["x-csrf-token"] = self.client.cookies["ct0"] self.client.headers["x-csrf-token"] = self.client.cookies["ct0"]
@ -134,7 +123,6 @@ class Account:
logger.info(f"logged in success {log_id}") logger.info(f"logged in success {log_id}")
self.status = Status.ACTIVE self.status = Status.ACTIVE
self.save()
async def get_guest_token(self): async def get_guest_token(self):
rep = await self.client.post("https://api.twitter.com/1.1/guest/activate.json") rep = await self.client.post("https://api.twitter.com/1.1/guest/activate.json")

Просмотреть файл

@ -1,27 +1,34 @@
import asyncio import asyncio
import json
import os import os
from .account import Account, Status from .account import Account, Status
from .logger import logger from .logger import logger
from .utils import shuffle
class AccountsPool: class AccountsPool:
BASE_DIR = "accounts" def __init__(self, base_dir: str | None = None):
def __init__(self):
self.accounts: list[Account] = [] self.accounts: list[Account] = []
self.base_dir = base_dir or "accounts"
def load_from_dir(self, folder: str | None = None): def restore(self):
folder = folder or self.BASE_DIR files = [os.path.join(self.base_dir, x) for x in os.listdir(self.base_dir)]
files = os.listdir(folder)
files = [x for x in files if x.endswith(".json")] files = [x for x in files if x.endswith(".json")]
files = [os.path.join(folder, x) for x in files]
for file in files: for file in files:
account = Account.load(file) self._load_account_from_file(file)
if account:
self.accounts.append(account) def _load_account_from_file(self, filepath: str):
account = Account.load_from_file(filepath)
if account:
username = set(x.username for x in self.accounts)
if account.username in username:
raise ValueError(f"Duplicate username {account.username}")
self.accounts.append(account)
return account
def _get_filename(self, username: str):
return f"{self.base_dir}/{username}.json"
def add_account( def add_account(
self, self,
@ -32,14 +39,20 @@ class AccountsPool:
proxy: str | None = None, proxy: str | None = None,
user_agent: str | None = None, user_agent: str | None = None,
): ):
filepath = os.path.join(self.BASE_DIR, f"{login}.json") account = self._load_account_from_file(self._get_filename(login))
account = Account.load(filepath)
if account: if account:
self.accounts.append(account)
return return
account = Account(login, password, email, email_password, user_agent, proxy) account = Account(
self.accounts.append(account) login,
password,
email,
email_password,
proxy=proxy,
user_agent=user_agent,
)
self.save_account(account)
self._load_account_from_file(self._get_filename(login))
async def login(self): async def login(self):
for x in self.accounts: for x in self.accounts:
@ -48,7 +61,8 @@ class AccountsPool:
await x.login() await x.login()
except Exception as e: except Exception as e:
logger.error(f"Error logging in to {x.username}: {e}") logger.error(f"Error logging in to {x.username}: {e}")
pass finally:
self.save_account(x)
def get_username_by_token(self, auth_token: str) -> str: def get_username_by_token(self, auth_token: str) -> str:
for x in self.accounts: for x in self.accounts:
@ -57,7 +71,8 @@ class AccountsPool:
return "UNKNOWN" return "UNKNOWN"
def get_account(self, queue: str) -> Account | None: def get_account(self, queue: str) -> Account | None:
for x in self.accounts: accounts = shuffle(self.accounts) # make random order each time
for x in accounts:
if x.can_use(queue): if x.can_use(queue):
return x return x
return None return None
@ -72,3 +87,15 @@ class AccountsPool:
else: else:
logger.debug(f"No accounts available for queue '{queue}' (sleeping for 5 sec)") logger.debug(f"No accounts available for queue '{queue}' (sleeping for 5 sec)")
await asyncio.sleep(5) await asyncio.sleep(5)
def save_account(self, account: Account):
filename = self._get_filename(account.username)
data = account.dump()
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(data, f, indent=2)
def update_limit(self, account: Account, queue: str, reset_ts: int):
account.update_limit(queue, reset_ts)
self.save_account(account)

Просмотреть файл

@ -7,7 +7,7 @@ from .accounts_pool import AccountsPool
from .constants import GQL_FEATURES, GQL_URL, SEARCH_PARAMS, SEARCH_URL from .constants import GQL_FEATURES, GQL_URL, SEARCH_PARAMS, SEARCH_URL
from .logger import logger from .logger import logger
from .models import Tweet, User from .models import Tweet, User
from .utils import encode_params, find_item, to_old_obj, to_search_like from .utils import encode_params, get_by_path, to_old_obj, to_search_like
class API: class API:
@ -52,13 +52,13 @@ class API:
if e.response.status_code == 429: if e.response.status_code == 429:
logger.debug(f"Rate limit for account={account.username} on queue={queue}") logger.debug(f"Rate limit for account={account.username} on queue={queue}")
reset_ts = int(e.response.headers.get("x-rate-limit-reset", 0)) reset_ts = int(e.response.headers.get("x-rate-limit-reset", 0))
account.update_limit(queue, reset_ts) self.pool.update_limit(account, queue, reset_ts)
continue continue
if e.response.status_code == 403: if e.response.status_code == 403:
logger.debug(f"Account={account.username} is banned on queue={queue}") logger.debug(f"Account={account.username} is banned on queue={queue}")
reset_ts = int(time.time() + 60 * 60) # 1 hour reset_ts = int(time.time() + 60 * 60) # 1 hour
account.update_limit(queue, reset_ts) self.pool.update_limit(account, queue, reset_ts)
continue continue
logger.error(f"[{e.response.status_code}] {e.request.url}\n{e.response.text}") logger.error(f"[{e.response.status_code}] {e.request.url}\n{e.response.text}")
@ -80,13 +80,13 @@ class API:
logger.debug(e) logger.debug(e)
return None return None
def get_ql_entries(self, obj: dict) -> list[dict]: def _get_ql_entries(self, obj: dict) -> list[dict]:
entries = find_item(obj, "entries") entries = get_by_path(obj, "entries")
return entries or [] return entries or []
def _get_ql_cursor(self, obj: dict) -> str | None: def _get_ql_cursor(self, obj: dict) -> str | None:
try: try:
for entry in self.get_ql_entries(obj): for entry in self._get_ql_entries(obj):
if entry["entryId"].startswith("cursor-bottom-"): if entry["entryId"].startswith("cursor-bottom-"):
return entry["content"]["value"] return entry["content"]["value"]
return None return None
@ -104,7 +104,7 @@ class API:
obj = rep.json() obj = rep.json()
# cursor-top / cursor-bottom always present # cursor-top / cursor-bottom always present
entries = self.get_ql_entries(obj) entries = self._get_ql_entries(obj)
entries = [x for x in entries if not x["entryId"].startswith("cursor-")] entries = [x for x in entries if not x["entryId"].startswith("cursor-")]
cursor = self._get_ql_cursor(obj) cursor = self._get_ql_cursor(obj)
@ -141,11 +141,18 @@ class API:
params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch" params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch"
return await client.get(SEARCH_URL, params=params) return await client.get(SEARCH_URL, params=params)
retries = 0
async for rep in self._inf_req(queue, _get): async for rep in self._inf_req(queue, _get):
data = rep.json() data = rep.json()
cursor = self._get_search_cursor(data)
tweets = data.get("globalObjects", {}).get("tweets", []) tweets = data.get("globalObjects", {}).get("tweets", [])
if not tweets and retries < 3:
retries += 1
continue
else:
retries = 0
cursor = self._get_search_cursor(data)
check = self._is_end(rep, q, tweets, cursor, count, limit) check = self._is_end(rep, q, tweets, cursor, count, limit)
count, end_before, end_after = check count, end_before, end_after = check

Просмотреть файл

@ -1,16 +1,24 @@
import email.utils import email.utils
from dataclasses import asdict, dataclass import json
import re
from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from .utils import get_or, int_or_none from snscrape.modules import twitter
from .logger import logger
from .utils import find_item, get_or, int_or_none
@dataclass @dataclass
class JSONTrait: class JSONTrait:
def json(self): def dict(self):
return asdict(self) return asdict(self)
def json(self):
return json.dumps(self.dict(), default=str)
@dataclass @dataclass
class Coordinates(JSONTrait): class Coordinates(JSONTrait):
@ -80,6 +88,7 @@ class UserRef(JSONTrait):
@dataclass @dataclass
class User(JSONTrait): class User(JSONTrait):
id: int id: int
url: str
username: str username: str
displayname: str displayname: str
rawDescription: str rawDescription: str
@ -100,14 +109,11 @@ class User(JSONTrait):
# link: typing.Optional[TextLink] = None # link: typing.Optional[TextLink] = None
# label: typing.Optional["UserLabel"] = None # label: typing.Optional["UserLabel"] = None
@property
def url(self) -> str:
return f"https://twitter.com/{self.username}"
@staticmethod @staticmethod
def parse(obj: dict): def parse(obj: dict):
return User( return User(
id=int(obj["id_str"]), id=int(obj["id_str"]),
url=f'https://twitter.com/{obj["screen_name"]}',
username=obj["screen_name"], username=obj["screen_name"],
displayname=obj["name"], displayname=obj["name"],
rawDescription=obj["description"], rawDescription=obj["description"],
@ -129,6 +135,7 @@ class User(JSONTrait):
@dataclass @dataclass
class Tweet(JSONTrait): class Tweet(JSONTrait):
id: int id: int
url: str
date: datetime date: datetime
user: User user: User
lang: str lang: str
@ -147,30 +154,28 @@ class Tweet(JSONTrait):
quotedTweet: Optional["Tweet"] = None quotedTweet: Optional["Tweet"] = None
place: Optional[Place] = None place: Optional[Place] = None
coordinates: Optional[Coordinates] = None coordinates: Optional[Coordinates] = None
inReplyToTweetId: int | None = None
inReplyToUser: UserRef | None = None
source: str | None = None
sourceUrl: str | None = None
sourceLabel: str | None = None
# renderedContent: str # renderedContent: str
# source: str | None = None
# sourceUrl: str | None = None
# sourceLabel: str | None = None
# media: typing.Optional[typing.List["Medium"]] = None # media: typing.Optional[typing.List["Medium"]] = None
# inReplyToTweetId: typing.Optional[int] = None
# inReplyToUser: typing.Optional["User"] = None
# card: typing.Optional["Card"] = None # card: typing.Optional["Card"] = None
# vibe: typing.Optional["Vibe"] = None # vibe: typing.Optional["Vibe"] = None
@property
def url(self):
return f"https://twitter.com/{self.user.username}/status/{self.id}"
@staticmethod @staticmethod
def parse(obj: dict, res: dict): def parse(obj: dict, res: dict):
tw_usr = User.parse(res["users"][obj["user_id_str"]])
rt_obj = get_or(res, f"tweets.{obj.get('retweeted_status_id_str')}") rt_obj = get_or(res, f"tweets.{obj.get('retweeted_status_id_str')}")
qt_obj = get_or(res, f"tweets.{obj.get('quoted_status_id_str')}") qt_obj = get_or(res, f"tweets.{obj.get('quoted_status_id_str')}")
return Tweet( return Tweet(
id=int(obj["id_str"]), id=int(obj["id_str"]),
url=f'https://twitter.com/{tw_usr.username}/status/{obj["id_str"]}',
date=email.utils.parsedate_to_datetime(obj["created_at"]), date=email.utils.parsedate_to_datetime(obj["created_at"]),
user=User.parse(res["users"][obj["user_id_str"]]), user=tw_usr,
lang=obj["lang"], lang=obj["lang"],
rawContent=obj["full_text"], rawContent=obj["full_text"],
replyCount=obj["reply_count"], replyCount=obj["reply_count"],
@ -187,4 +192,40 @@ class Tweet(JSONTrait):
quotedTweet=Tweet.parse(qt_obj, res) if qt_obj else None, quotedTweet=Tweet.parse(qt_obj, res) if qt_obj else None,
place=Place.parse(obj["place"]) if obj.get("place") else None, place=Place.parse(obj["place"]) if obj.get("place") else None,
coordinates=Coordinates.parse(obj), coordinates=Coordinates.parse(obj),
inReplyToTweetId=int_or_none(obj, "in_reply_to_status_id_str"),
inReplyToUser=_get_reply_user(obj, res),
source=obj.get("source", None),
sourceUrl=_get_source_url(obj),
sourceLabel=_get_source_label(obj),
) )
def _get_reply_user(tw_obj: dict, res: dict):
user_id = tw_obj.get("in_reply_to_user_id_str", None)
if user_id is None:
return None
if user_id in res["users"]:
return UserRef.parse(res["users"][user_id])
mentions = get_or(tw_obj, "entities.user_mentions", [])
mention = find_item(mentions, lambda x: x["id_str"] == tw_obj["in_reply_to_user_id_str"])
if mention:
return UserRef.parse(mention)
logger.debug(f'{tw_obj["in_reply_to_user_id_str"]}\n{json.dumps(res)}')
return None
def _get_source_url(tw_obj: dict):
source = tw_obj.get("source", None)
if source and (match := re.search(r'href=[\'"]?([^\'" >]+)', source)):
return str(match.group(1))
return None
def _get_source_label(tw_obj: dict):
source = tw_obj.get("source", None)
if source and (match := re.search(r">([^<]*)<", source)):
return str(match.group(1))
return None

Просмотреть файл

@ -1,6 +1,7 @@
import json import json
import random
from collections import defaultdict from collections import defaultdict
from typing import Any, AsyncGenerator, TypeVar from typing import Any, AsyncGenerator, Callable, TypeVar
from httpx import HTTPStatusError, Response from httpx import HTTPStatusError, Response
@ -53,7 +54,7 @@ def int_or_none(obj: dict, key: str):
# https://stackoverflow.com/a/43184871 # https://stackoverflow.com/a/43184871
def find_item(obj: dict, key: str, default=None): def get_by_path(obj: dict, key: str, default=None):
stack = [iter(obj.items())] stack = [iter(obj.items())]
while stack: while stack:
for k, v in stack[-1]: for k, v in stack[-1]:
@ -70,6 +71,13 @@ def find_item(obj: dict, key: str, default=None):
return default return default
def find_item(lst: list[T], fn: Callable[[T], bool]) -> T | None:
for item in lst:
if fn(item):
return item
return None
def get_typed_object(obj: dict, res: defaultdict[str, list]): def get_typed_object(obj: dict, res: defaultdict[str, list]):
obj_type = obj.get("__typename", None) obj_type = obj.get("__typename", None)
if obj_type is not None: if obj_type is not None:
@ -100,3 +108,9 @@ def to_search_like(obj: dict):
users = {str(x["rest_id"]): to_old_obj(x) for x in users} users = {str(x["rest_id"]): to_old_obj(x) for x in users}
return {"tweets": tweets, "users": users} return {"tweets": tweets, "users": users}
def shuffle(lst: list):
lst = lst.copy()
random.shuffle(lst)
return lst