move http communication from api class; handle network fail cases; add tests

Этот коммит содержится в:
Vlad Pronsky 2023-05-25 05:08:20 +03:00 коммит произвёл vladkens
родитель 1ad12528d2
Коммит c0204aa492
7 изменённых файлов: 299 добавлений и 162 удалений

Просмотреть файл

@ -3,6 +3,9 @@
all:
@echo "hi"
install:
pip install -e .[dev]
build:
python -m build

Просмотреть файл

@ -29,6 +29,7 @@ dev = [
"pylint==2.17.3",
"pytest-asyncio==0.21.0",
"pytest-cov==4.0.0",
"pytest-httpx==0.22.0",
"pytest==7.3.1",
"ruff==0.0.263",
]

0
tests/__init__.py Обычный файл
Просмотреть файл

94
tests/test_queue_client.py Обычный файл
Просмотреть файл

@ -0,0 +1,94 @@
import httpx
from pytest_httpx import HTTPXMock
from twscrape.logger import set_log_level
from twscrape.queue_client import QueueClient
from .utils import get_pool
DB_FILE = "/tmp/test_queue_client.db"
URL = "https://example.com/api"
set_log_level("ERROR")
async def get_client():
pool = get_pool(DB_FILE)
await pool.add_account("user1", "pass1", "email1", "email_pass1")
await pool.add_account("user2", "pass2", "email2", "email_pass2")
await pool.set_active("user1", True)
await pool.set_active("user2", True)
client = QueueClient(pool, "search")
return pool, client
async def test_should_lock_account_on_queue(httpx_mock: HTTPXMock):
pool, client = await get_client()
assert (await pool.stats())["locked_search"] == 0
await client.__aenter__()
assert (await pool.stats())["locked_search"] == 1
httpx_mock.add_response(url=URL, json={"foo": "bar"}, status_code=200)
assert (await client.get(URL)).json() == {"foo": "bar"}
await client.__aexit__(None, None, None)
assert (await pool.stats())["locked_search"] == 0
async def test_should_not_switch_account_on_200(httpx_mock: HTTPXMock):
pool, client = await get_client()
assert (await pool.stats())["locked_search"] == 0
await client.__aenter__()
httpx_mock.add_response(url=URL, json={"foo": "1"}, status_code=200)
httpx_mock.add_response(url=URL, json={"foo": "2"}, status_code=200)
rep = await client.get(URL)
assert rep.json() == {"foo": "1"}
rep = await client.get(URL)
assert rep.json() == {"foo": "2"}
assert (await pool.stats())["locked_search"] == 1
await client.__aexit__(None, None, None)
async def test_should_switch_account_on_http_error(httpx_mock: HTTPXMock):
pool, client = await get_client()
assert (await pool.stats())["locked_search"] == 0
await client.__aenter__()
httpx_mock.add_response(url=URL, json={"foo": "1"}, status_code=403)
httpx_mock.add_response(url=URL, json={"foo": "2"}, status_code=200)
rep = await client.get(URL)
assert rep.json() == {"foo": "2"}
assert (await pool.stats())["locked_search"] == 1 # user1 unlocked, user2 locked
await client.__aexit__(None, None, None)
async def test_should_retry_with_same_account_on_network_error(httpx_mock: HTTPXMock):
pool, client = await get_client()
await client.__aenter__()
httpx_mock.add_exception(httpx.ReadTimeout("Unable to read within timeout"))
httpx_mock.add_response(url=URL, json={"foo": "2"}, status_code=200)
rep = await client.get(URL)
assert rep.json() == {"foo": "2"}
assert (await pool.stats())["locked_search"] == 1
username = getattr(rep, "__username", None)
assert username is not None
acc1 = await pool.get(username)
assert len(acc1.locks) > 0
acc2 = await pool.get("user2" if username == "user1" else "user1")
assert len(acc2.locks) == 0

13
tests/utils.py Обычный файл
Просмотреть файл

@ -0,0 +1,13 @@
import os
from twscrape.accounts_pool import AccountsPool
from twscrape.db import DB
def get_pool(db_path: str):
DB._init_once[db_path] = False
if os.path.exists(db_path):
os.remove(db_path)
pool = AccountsPool(db_path)
return pool

Просмотреть файл

@ -1,31 +1,20 @@
import json
from datetime import datetime
from typing import Awaitable, Callable
from httpx import AsyncClient, HTTPStatusError, Response
from httpx import Response
from .accounts_pool import AccountsPool
from .constants import GQL_FEATURES, GQL_URL, SEARCH_PARAMS, SEARCH_URL
from .logger import logger
from .models import Tweet, User
from .utils import encode_params, find_obj, get_by_path, to_old_obj, to_old_rep, utc_ts
from .queue_client import QueueClient, req_id
from .utils import encode_params, find_obj, get_by_path, to_old_obj, to_old_rep
class API:
def __init__(self, pool: AccountsPool, debug=False):
self.pool = pool
self.debug = debug
self._history: list[Response] = []
# http helpers
def _limit_msg(self, rep: Response):
lr = rep.headers.get("x-rate-limit-remaining", -1)
ll = rep.headers.get("x-rate-limit-limit", -1)
username = getattr(rep, "__username", "<UNKNOWN>")
return f"{username} {lr}/{ll}"
def _is_end(self, rep: Response, q: str, res: list, cur: str | None, cnt: int, lim: int):
new_count = len(res)
new_total = cnt + new_count
@ -36,163 +25,37 @@ class API:
stats = f"{q} {new_total:,d} (+{new_count:,d})"
flags = f"res={int(is_res)} cur={int(is_cur)} lim={int(is_lim)}"
logger.debug(" ".join([stats, flags, self._limit_msg(rep)]))
logger.debug(" ".join([stats, flags, req_id(rep)]))
return new_total, not is_res, not is_cur or is_lim
def _push_history(self, rep: Response):
self._history.append(rep)
if len(self._history) > 3:
self._history.pop(0)
def _dump_history(self, extra: str = ""):
if not self.debug:
return
ts = str(datetime.now()).replace(":", "-").replace(" ", "_")
filename = f"/tmp/api_dump_{ts}.txt"
with open(filename, "w") as fp:
txt = f"{extra}\n"
for rep in self._history:
res = json.dumps(rep.json(), indent=2)
hdr = "\n".join([str(x) for x in list(rep.request.headers.items())])
div = "-" * 20
msg = f"{div}\n{self._limit_msg(rep)}"
msg = f"{msg}\n{rep.request.method} {rep.request.url}"
msg = f"{msg}\n{rep.status_code}\n{div}"
msg = f"{msg}\n{hdr}\n{div}\n{res}\n\n"
txt += msg
fp.write(txt)
print(f"API dump ({len(self._history)}) dumped to {filename}")
async def _inf_req(self, queue: str, cb: Callable[[AsyncClient], Awaitable[Response]]):
while True:
acc = await self.pool.get_for_queue_or_wait(queue)
client = acc.make_client()
try:
while True:
rep = await cb(client)
setattr(rep, "__username", acc.username)
self._push_history(rep)
rep.raise_for_status()
yield rep
except HTTPStatusError as e:
rep = e.response
log_id = f"{self._limit_msg(rep)} on queue={queue}"
# rate limit
if rep.status_code == 429:
logger.debug(f"Rate limit for {log_id}")
reset_ts = int(rep.headers.get("x-rate-limit-reset", 0))
await self.pool.lock_until(acc.username, queue, reset_ts)
continue
# possible account banned
if rep.status_code == 403:
logger.warning(f"403 for {log_id}")
reset_ts = utc_ts() + 60 * 60 # + 1 hour
await self.pool.lock_until(acc.username, queue, reset_ts)
continue
# twitter can return different types of cursors that not transfers between accounts
# just take the next account, the current cursor can work in it
if rep.status_code == 400:
logger.debug(f"Cursor not valid for {log_id}")
continue
logger.error(f"[{rep.status_code}] {e.request.url}\n{rep.text}")
raise e
finally:
await self.pool.unlock(acc.username, queue)
await client.aclose()
return rep if is_res else None, new_total, is_cur and not is_lim
def _get_cursor(self, obj: dict):
if cur := find_obj(obj, lambda x: x.get("cursorType") == "Bottom"):
return cur.get("value")
return None
def _get_ql_entries(self, obj: dict) -> list[dict]:
entries = get_by_path(obj, "entries")
return entries or []
async def _ql_items(self, op: str, kv: dict, limit=-1):
queue, cursor, count = op.split("/")[-1], None, 0
async def _get(client: AsyncClient):
params = {"variables": {**kv, "cursor": cursor}, "features": GQL_FEATURES}
return await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
async for rep in self._inf_req(queue, _get):
obj = rep.json()
# cursor-top / cursor-bottom always present
entries = self._get_ql_entries(obj)
entries = [x for x in entries if not x["entryId"].startswith("cursor-")]
cursor = self._get_cursor(obj)
check = self._is_end(rep, queue, entries, cursor, count, limit)
count, end_before, end_after = check
if end_before:
return
yield rep
if end_after:
return
async def _ql_item(self, op: str, kv: dict, ft: dict = {}):
async def _get(client: AsyncClient):
params = {"variables": {**kv}, "features": {**GQL_FEATURES, **ft}}
return await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
queue = op.split("/")[-1]
async for rep in self._inf_req(queue, _get):
return rep
raise Exception("No response") # todo
# search
async def search_raw(self, q: str, limit=-1):
queue, cursor, count = "search", None, 0
queue, cursor, count, active = "search", None, 0, True
async def _get(client: AsyncClient):
params = {**SEARCH_PARAMS, "q": q, "count": 20}
params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch"
try:
return await client.get(SEARCH_URL, params=params)
except Exception as e:
logger.error(f"Error requesting {q}: {e}")
logger.error(f"Request: {SEARCH_URL}, {params}")
raise e
async with QueueClient(self.pool, queue, self.debug) as client:
while active:
params = {**SEARCH_PARAMS, "q": q, "count": 20}
params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch"
try:
async for rep in self._inf_req(queue, _get):
data = rep.json()
rep = await client.get(SEARCH_URL, params=params)
obj = rep.json()
tweets = data.get("globalObjects", {}).get("tweets", [])
cursor = self._get_cursor(data)
tweets = obj.get("globalObjects", {}).get("tweets", [])
cursor = self._get_cursor(obj)
check = self._is_end(rep, q, tweets, cursor, count, limit)
count, end_before, end_after = check
if end_before:
rep, count, active = self._is_end(rep, q, tweets, cursor, count, limit)
if rep is None:
return
yield rep
if end_after:
return
except HTTPStatusError as e:
self._dump_history(f"q={q}\ncount={count}\nwas_cur={cursor}\nnew_cur=None")
raise e
async def search(self, q: str, limit=-1):
twids = set()
async for rep in self.search_raw(q, limit=limit):
@ -203,12 +66,41 @@ class API:
twids.add(x["id_str"])
yield Tweet.parse(x, obj)
# gql helpers
async def _gql_items(self, op: str, kv: dict, limit=-1):
queue, cursor, count, active = op.split("/")[-1], None, 0, True
async with QueueClient(self.pool, queue, self.debug) as client:
while active:
params = {"variables": {**kv, "cursor": cursor}, "features": GQL_FEATURES}
rep = await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
obj = rep.json()
entries = get_by_path(obj, "entries") or []
entries = [x for x in entries if not x["entryId"].startswith("cursor-")]
cursor = self._get_cursor(obj)
rep, count, active = self._is_end(rep, queue, entries, cursor, count, limit)
if rep is None:
return
yield rep
async def _gql_item(self, op: str, kv: dict, ft: dict | None = None):
ft = ft or {}
queue = op.split("/")[-1]
async with QueueClient(self.pool, queue, self.debug) as client:
params = {"variables": {**kv}, "features": {**GQL_FEATURES, **ft}}
return await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
# user_by_id
async def user_by_id_raw(self, uid: int):
op = "GazOglcBvgLigl3ywt6b3Q/UserByRestId"
kv = {"userId": str(uid), "withSafetyModeUserFields": True}
return await self._ql_item(op, kv)
return await self._gql_item(op, kv)
async def user_by_id(self, uid: int):
rep = await self.user_by_id_raw(uid)
@ -220,7 +112,7 @@ class API:
async def user_by_login_raw(self, login: str):
op = "sLVLhk0bGj3MVFEKTdax1w/UserByScreenName"
kv = {"screen_name": login, "withSafetyModeUserFields": True}
return await self._ql_item(op, kv)
return await self._gql_item(op, kv)
async def user_by_login(self, login: str):
rep = await self.user_by_login_raw(login)
@ -251,7 +143,7 @@ class API:
"responsive_web_twitter_blue_verified_badge_is_enabled": True,
"longform_notetweets_richtext_consumption_enabled": True,
}
return await self._ql_item(op, kv, ft)
return await self._gql_item(op, kv, ft)
async def tweet_details(self, twid: int):
rep = await self.tweet_details_raw(twid)
@ -263,7 +155,7 @@ class API:
async def followers_raw(self, uid: int, limit=-1):
op = "djdTXDIk2qhd4OStqlUFeQ/Followers"
kv = {"userId": str(uid), "count": 20, "includePromotedContent": False}
async for x in self._ql_items(op, kv, limit=limit):
async for x in self._gql_items(op, kv, limit=limit):
yield x
async def followers(self, uid: int, limit=-1):
@ -277,7 +169,7 @@ class API:
async def following_raw(self, uid: int, limit=-1):
op = "IWP6Zt14sARO29lJT35bBw/Following"
kv = {"userId": str(uid), "count": 20, "includePromotedContent": False}
async for x in self._ql_items(op, kv, limit=limit):
async for x in self._gql_items(op, kv, limit=limit):
yield x
async def following(self, uid: int, limit=-1):
@ -291,7 +183,7 @@ class API:
async def retweeters_raw(self, twid: int, limit=-1):
op = "U5f_jm0CiLmSfI1d4rGleQ/Retweeters"
kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True}
async for x in self._ql_items(op, kv, limit=limit):
async for x in self._gql_items(op, kv, limit=limit):
yield x
async def retweeters(self, twid: int, limit=-1):
@ -305,7 +197,7 @@ class API:
async def favoriters_raw(self, twid: int, limit=-1):
op = "vcTrPlh9ovFDQejz22q9vg/Favoriters"
kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True}
async for x in self._ql_items(op, kv, limit=limit):
async for x in self._gql_items(op, kv, limit=limit):
yield x
async def favoriters(self, twid: int, limit=-1):
@ -326,7 +218,7 @@ class API:
"withVoice": True,
"withV2Timeline": True,
}
async for x in self._ql_items(op, kv, limit=limit):
async for x in self._gql_items(op, kv, limit=limit):
yield x
async def user_tweets(self, uid: int, limit=-1):
@ -347,7 +239,7 @@ class API:
"withVoice": True,
"withV2Timeline": True,
}
async for x in self._ql_items(op, kv, limit=limit):
async for x in self._gql_items(op, kv, limit=limit):
yield x
async def user_tweets_and_replies(self, uid: int, limit=-1):

134
twscrape/queue_client.py Обычный файл
Просмотреть файл

@ -0,0 +1,134 @@
import json
from datetime import datetime
import httpx
from .accounts_pool import Account, AccountsPool
from .logger import logger
from .utils import utc_ts
ReqParams = dict[str, str | int] | None
def req_id(rep: httpx.Response):
lr = rep.headers.get("x-rate-limit-remaining", -1)
ll = rep.headers.get("x-rate-limit-limit", -1)
username = getattr(rep, "__username", "<UNKNOWN>")
return f"{username} {lr}/{ll}"
class Ctx:
def __init__(self, acc: Account, clt: httpx.AsyncClient):
self.acc = acc
self.clt = clt
class QueueClient:
def __init__(self, pool: AccountsPool, queue: str, debug=False):
self.pool = pool
self.queue = queue
self.debug = debug
self.history: list[httpx.Response] = []
self.ctx: Ctx | None = None
async def __aenter__(self):
await self._get_ctx()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._close_ctx()
return self
async def _close_ctx(self):
if self.ctx is not None:
await self.ctx.clt.aclose()
await self.pool.unlock(self.ctx.acc.username, self.queue)
async def _get_ctx(self, fresh=False) -> Ctx:
if self.ctx and not fresh:
return self.ctx
if self.ctx is not None:
await self._close_ctx()
acc = await self.pool.get_for_queue_or_wait(self.queue)
clt = acc.make_client()
self.ctx = Ctx(acc, clt)
return self.ctx
def _push_history(self, rep: httpx.Response):
self.history.append(rep)
if len(self.history) > 3:
self.history.pop(0)
def _dump_history(self, extra: str = ""):
if not self.debug:
return
ts = str(datetime.now()).replace(":", "-").replace(" ", "_")
filename = f"/tmp/api_dump_{ts}.txt"
with open(filename, "w", encoding="utf-8") as fp:
txt = f"{extra}\n"
for rep in self.history:
res = json.dumps(rep.json(), indent=2)
hdr = "\n".join([str(x) for x in list(rep.request.headers.items())])
div = "-" * 20
msg = f"{div}\n{req_id(rep)}"
msg = f"{msg}\n{rep.request.method} {rep.request.url}"
msg = f"{msg}\n{rep.status_code}\n{div}"
msg = f"{msg}\n{hdr}\n{div}\n{res}\n\n"
txt += msg
fp.write(txt)
print(f"API dump ({len(self.history)}) dumped to {filename}")
async def req(self, method: str, url: str, params: ReqParams = None):
fresh = False # do not get new account on first try
while True:
ctx = await self._get_ctx(fresh=fresh)
fresh = True
try:
rep = await ctx.clt.request(method, url, params=params)
setattr(rep, "__username", ctx.acc.username)
self._push_history(rep)
rep.raise_for_status()
return rep
except httpx.HTTPStatusError as e:
rep = e.response
log_id = f"{req_id(rep)} on queue={self.queue}"
# rate limit
if rep.status_code == 429:
logger.debug(f"Rate limit for {log_id}")
reset_ts = int(rep.headers.get("x-rate-limit-reset", 0))
await self.pool.lock_until(ctx.acc.username, self.queue, reset_ts)
continue
# possible account banned
if rep.status_code == 403:
logger.warning(f"403 for {log_id}")
reset_ts = utc_ts() + 60 * 60 # + 1 hour
await self.pool.lock_until(ctx.acc.username, self.queue, reset_ts)
continue
# twitter can return different types of cursors that not transfers between accounts
# just take the next account, the current cursor can work in it
if rep.status_code == 400:
logger.debug(f"Cursor not valid for {log_id}")
continue
logger.error(f"[{rep.status_code}] {e.request.url}\n{rep.text}")
raise e
except Exception as e:
logger.warning(f"Unknown error, retrying. Err: {e}")
async def get(self, url: str, params: ReqParams = None):
try:
return await self.req("GET", url, params=params)
except httpx.HTTPStatusError as e:
self._dump_history(f"GET {url} {json.dumps(params)}")
raise e