From c0204aa49239568a3bf69593927994461b18228d Mon Sep 17 00:00:00 2001 From: Vlad Pronsky Date: Thu, 25 May 2023 05:08:20 +0300 Subject: [PATCH] move http communication from api class; handle network fail cases; add tests --- Makefile | 3 + pyproject.toml | 1 + tests/__init__.py | 0 tests/test_queue_client.py | 94 ++++++++++++++++ tests/utils.py | 13 +++ twscrape/api.py | 216 ++++++++++--------------------------- twscrape/queue_client.py | 134 +++++++++++++++++++++++ 7 files changed, 299 insertions(+), 162 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_queue_client.py create mode 100644 tests/utils.py create mode 100644 twscrape/queue_client.py diff --git a/Makefile b/Makefile index 5129e03..0c4b1a7 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,9 @@ all: @echo "hi" +install: + pip install -e .[dev] + build: python -m build diff --git a/pyproject.toml b/pyproject.toml index f3af44a..c86b145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dev = [ "pylint==2.17.3", "pytest-asyncio==0.21.0", "pytest-cov==4.0.0", + "pytest-httpx==0.22.0", "pytest==7.3.1", "ruff==0.0.263", ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_queue_client.py b/tests/test_queue_client.py new file mode 100644 index 0000000..d41b221 --- /dev/null +++ b/tests/test_queue_client.py @@ -0,0 +1,94 @@ +import httpx +from pytest_httpx import HTTPXMock + +from twscrape.logger import set_log_level +from twscrape.queue_client import QueueClient + +from .utils import get_pool + +DB_FILE = "/tmp/test_queue_client.db" +URL = "https://example.com/api" + +set_log_level("ERROR") + + +async def get_client(): + pool = get_pool(DB_FILE) + await pool.add_account("user1", "pass1", "email1", "email_pass1") + await pool.add_account("user2", "pass2", "email2", "email_pass2") + await pool.set_active("user1", True) + await pool.set_active("user2", True) + + client = QueueClient(pool, "search") + return pool, client + + +async def test_should_lock_account_on_queue(httpx_mock: HTTPXMock): + pool, client = await get_client() + assert (await pool.stats())["locked_search"] == 0 + + await client.__aenter__() + assert (await pool.stats())["locked_search"] == 1 + + httpx_mock.add_response(url=URL, json={"foo": "bar"}, status_code=200) + assert (await client.get(URL)).json() == {"foo": "bar"} + + await client.__aexit__(None, None, None) + assert (await pool.stats())["locked_search"] == 0 + + +async def test_should_not_switch_account_on_200(httpx_mock: HTTPXMock): + pool, client = await get_client() + + assert (await pool.stats())["locked_search"] == 0 + await client.__aenter__() + + httpx_mock.add_response(url=URL, json={"foo": "1"}, status_code=200) + httpx_mock.add_response(url=URL, json={"foo": "2"}, status_code=200) + + rep = await client.get(URL) + assert rep.json() == {"foo": "1"} + + rep = await client.get(URL) + assert rep.json() == {"foo": "2"} + + assert (await pool.stats())["locked_search"] == 1 + await client.__aexit__(None, None, None) + + +async def test_should_switch_account_on_http_error(httpx_mock: HTTPXMock): + pool, client = await get_client() + + assert (await pool.stats())["locked_search"] == 0 + await client.__aenter__() + + httpx_mock.add_response(url=URL, json={"foo": "1"}, status_code=403) + httpx_mock.add_response(url=URL, json={"foo": "2"}, status_code=200) + + rep = await client.get(URL) + assert rep.json() == {"foo": "2"} + + assert (await pool.stats())["locked_search"] == 1 # user1 unlocked, user2 locked + await client.__aexit__(None, None, None) + + +async def test_should_retry_with_same_account_on_network_error(httpx_mock: HTTPXMock): + pool, client = await get_client() + await client.__aenter__() + + httpx_mock.add_exception(httpx.ReadTimeout("Unable to read within timeout")) + httpx_mock.add_response(url=URL, json={"foo": "2"}, status_code=200) + + rep = await client.get(URL) + assert rep.json() == {"foo": "2"} + + assert (await pool.stats())["locked_search"] == 1 + + username = getattr(rep, "__username", None) + assert username is not None + + acc1 = await pool.get(username) + assert len(acc1.locks) > 0 + + acc2 = await pool.get("user2" if username == "user1" else "user1") + assert len(acc2.locks) == 0 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..ad3ff32 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,13 @@ +import os + +from twscrape.accounts_pool import AccountsPool +from twscrape.db import DB + + +def get_pool(db_path: str): + DB._init_once[db_path] = False + if os.path.exists(db_path): + os.remove(db_path) + + pool = AccountsPool(db_path) + return pool diff --git a/twscrape/api.py b/twscrape/api.py index ef5d26c..0652b9b 100644 --- a/twscrape/api.py +++ b/twscrape/api.py @@ -1,31 +1,20 @@ -import json -from datetime import datetime -from typing import Awaitable, Callable - -from httpx import AsyncClient, HTTPStatusError, Response +from httpx import Response from .accounts_pool import AccountsPool from .constants import GQL_FEATURES, GQL_URL, SEARCH_PARAMS, SEARCH_URL from .logger import logger from .models import Tweet, User -from .utils import encode_params, find_obj, get_by_path, to_old_obj, to_old_rep, utc_ts +from .queue_client import QueueClient, req_id +from .utils import encode_params, find_obj, get_by_path, to_old_obj, to_old_rep class API: def __init__(self, pool: AccountsPool, debug=False): self.pool = pool self.debug = debug - self._history: list[Response] = [] # http helpers - def _limit_msg(self, rep: Response): - lr = rep.headers.get("x-rate-limit-remaining", -1) - ll = rep.headers.get("x-rate-limit-limit", -1) - - username = getattr(rep, "__username", "") - return f"{username} {lr}/{ll}" - def _is_end(self, rep: Response, q: str, res: list, cur: str | None, cnt: int, lim: int): new_count = len(res) new_total = cnt + new_count @@ -36,163 +25,37 @@ class API: stats = f"{q} {new_total:,d} (+{new_count:,d})" flags = f"res={int(is_res)} cur={int(is_cur)} lim={int(is_lim)}" - logger.debug(" ".join([stats, flags, self._limit_msg(rep)])) + logger.debug(" ".join([stats, flags, req_id(rep)])) - return new_total, not is_res, not is_cur or is_lim - - def _push_history(self, rep: Response): - self._history.append(rep) - if len(self._history) > 3: - self._history.pop(0) - - def _dump_history(self, extra: str = ""): - if not self.debug: - return - - ts = str(datetime.now()).replace(":", "-").replace(" ", "_") - filename = f"/tmp/api_dump_{ts}.txt" - with open(filename, "w") as fp: - txt = f"{extra}\n" - for rep in self._history: - res = json.dumps(rep.json(), indent=2) - hdr = "\n".join([str(x) for x in list(rep.request.headers.items())]) - div = "-" * 20 - - msg = f"{div}\n{self._limit_msg(rep)}" - msg = f"{msg}\n{rep.request.method} {rep.request.url}" - msg = f"{msg}\n{rep.status_code}\n{div}" - msg = f"{msg}\n{hdr}\n{div}\n{res}\n\n" - txt += msg - - fp.write(txt) - - print(f"API dump ({len(self._history)}) dumped to {filename}") - - async def _inf_req(self, queue: str, cb: Callable[[AsyncClient], Awaitable[Response]]): - while True: - acc = await self.pool.get_for_queue_or_wait(queue) - client = acc.make_client() - - try: - while True: - rep = await cb(client) - setattr(rep, "__username", acc.username) - self._push_history(rep) - rep.raise_for_status() - - yield rep - except HTTPStatusError as e: - rep = e.response - log_id = f"{self._limit_msg(rep)} on queue={queue}" - - # rate limit - if rep.status_code == 429: - logger.debug(f"Rate limit for {log_id}") - reset_ts = int(rep.headers.get("x-rate-limit-reset", 0)) - await self.pool.lock_until(acc.username, queue, reset_ts) - continue - - # possible account banned - if rep.status_code == 403: - logger.warning(f"403 for {log_id}") - reset_ts = utc_ts() + 60 * 60 # + 1 hour - await self.pool.lock_until(acc.username, queue, reset_ts) - continue - - # twitter can return different types of cursors that not transfers between accounts - # just take the next account, the current cursor can work in it - if rep.status_code == 400: - logger.debug(f"Cursor not valid for {log_id}") - continue - - logger.error(f"[{rep.status_code}] {e.request.url}\n{rep.text}") - raise e - finally: - await self.pool.unlock(acc.username, queue) - await client.aclose() + return rep if is_res else None, new_total, is_cur and not is_lim def _get_cursor(self, obj: dict): if cur := find_obj(obj, lambda x: x.get("cursorType") == "Bottom"): return cur.get("value") return None - def _get_ql_entries(self, obj: dict) -> list[dict]: - entries = get_by_path(obj, "entries") - return entries or [] - - async def _ql_items(self, op: str, kv: dict, limit=-1): - queue, cursor, count = op.split("/")[-1], None, 0 - - async def _get(client: AsyncClient): - params = {"variables": {**kv, "cursor": cursor}, "features": GQL_FEATURES} - return await client.get(f"{GQL_URL}/{op}", params=encode_params(params)) - - async for rep in self._inf_req(queue, _get): - obj = rep.json() - - # cursor-top / cursor-bottom always present - entries = self._get_ql_entries(obj) - entries = [x for x in entries if not x["entryId"].startswith("cursor-")] - cursor = self._get_cursor(obj) - - check = self._is_end(rep, queue, entries, cursor, count, limit) - count, end_before, end_after = check - - if end_before: - return - - yield rep - - if end_after: - return - - async def _ql_item(self, op: str, kv: dict, ft: dict = {}): - async def _get(client: AsyncClient): - params = {"variables": {**kv}, "features": {**GQL_FEATURES, **ft}} - return await client.get(f"{GQL_URL}/{op}", params=encode_params(params)) - - queue = op.split("/")[-1] - async for rep in self._inf_req(queue, _get): - return rep - - raise Exception("No response") # todo - # search async def search_raw(self, q: str, limit=-1): - queue, cursor, count = "search", None, 0 + queue, cursor, count, active = "search", None, 0, True - async def _get(client: AsyncClient): - params = {**SEARCH_PARAMS, "q": q, "count": 20} - params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch" - try: - return await client.get(SEARCH_URL, params=params) - except Exception as e: - logger.error(f"Error requesting {q}: {e}") - logger.error(f"Request: {SEARCH_URL}, {params}") - raise e + async with QueueClient(self.pool, queue, self.debug) as client: + while active: + params = {**SEARCH_PARAMS, "q": q, "count": 20} + params["cursor" if cursor else "requestContext"] = cursor if cursor else "launch" - try: - async for rep in self._inf_req(queue, _get): - data = rep.json() + rep = await client.get(SEARCH_URL, params=params) + obj = rep.json() - tweets = data.get("globalObjects", {}).get("tweets", []) - cursor = self._get_cursor(data) + tweets = obj.get("globalObjects", {}).get("tweets", []) + cursor = self._get_cursor(obj) - check = self._is_end(rep, q, tweets, cursor, count, limit) - count, end_before, end_after = check - - if end_before: + rep, count, active = self._is_end(rep, q, tweets, cursor, count, limit) + if rep is None: return yield rep - if end_after: - return - except HTTPStatusError as e: - self._dump_history(f"q={q}\ncount={count}\nwas_cur={cursor}\nnew_cur=None") - raise e - async def search(self, q: str, limit=-1): twids = set() async for rep in self.search_raw(q, limit=limit): @@ -203,12 +66,41 @@ class API: twids.add(x["id_str"]) yield Tweet.parse(x, obj) + # gql helpers + + async def _gql_items(self, op: str, kv: dict, limit=-1): + queue, cursor, count, active = op.split("/")[-1], None, 0, True + + async with QueueClient(self.pool, queue, self.debug) as client: + while active: + params = {"variables": {**kv, "cursor": cursor}, "features": GQL_FEATURES} + + rep = await client.get(f"{GQL_URL}/{op}", params=encode_params(params)) + obj = rep.json() + + entries = get_by_path(obj, "entries") or [] + entries = [x for x in entries if not x["entryId"].startswith("cursor-")] + cursor = self._get_cursor(obj) + + rep, count, active = self._is_end(rep, queue, entries, cursor, count, limit) + if rep is None: + return + + yield rep + + async def _gql_item(self, op: str, kv: dict, ft: dict | None = None): + ft = ft or {} + queue = op.split("/")[-1] + async with QueueClient(self.pool, queue, self.debug) as client: + params = {"variables": {**kv}, "features": {**GQL_FEATURES, **ft}} + return await client.get(f"{GQL_URL}/{op}", params=encode_params(params)) + # user_by_id async def user_by_id_raw(self, uid: int): op = "GazOglcBvgLigl3ywt6b3Q/UserByRestId" kv = {"userId": str(uid), "withSafetyModeUserFields": True} - return await self._ql_item(op, kv) + return await self._gql_item(op, kv) async def user_by_id(self, uid: int): rep = await self.user_by_id_raw(uid) @@ -220,7 +112,7 @@ class API: async def user_by_login_raw(self, login: str): op = "sLVLhk0bGj3MVFEKTdax1w/UserByScreenName" kv = {"screen_name": login, "withSafetyModeUserFields": True} - return await self._ql_item(op, kv) + return await self._gql_item(op, kv) async def user_by_login(self, login: str): rep = await self.user_by_login_raw(login) @@ -251,7 +143,7 @@ class API: "responsive_web_twitter_blue_verified_badge_is_enabled": True, "longform_notetweets_richtext_consumption_enabled": True, } - return await self._ql_item(op, kv, ft) + return await self._gql_item(op, kv, ft) async def tweet_details(self, twid: int): rep = await self.tweet_details_raw(twid) @@ -263,7 +155,7 @@ class API: async def followers_raw(self, uid: int, limit=-1): op = "djdTXDIk2qhd4OStqlUFeQ/Followers" kv = {"userId": str(uid), "count": 20, "includePromotedContent": False} - async for x in self._ql_items(op, kv, limit=limit): + async for x in self._gql_items(op, kv, limit=limit): yield x async def followers(self, uid: int, limit=-1): @@ -277,7 +169,7 @@ class API: async def following_raw(self, uid: int, limit=-1): op = "IWP6Zt14sARO29lJT35bBw/Following" kv = {"userId": str(uid), "count": 20, "includePromotedContent": False} - async for x in self._ql_items(op, kv, limit=limit): + async for x in self._gql_items(op, kv, limit=limit): yield x async def following(self, uid: int, limit=-1): @@ -291,7 +183,7 @@ class API: async def retweeters_raw(self, twid: int, limit=-1): op = "U5f_jm0CiLmSfI1d4rGleQ/Retweeters" kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True} - async for x in self._ql_items(op, kv, limit=limit): + async for x in self._gql_items(op, kv, limit=limit): yield x async def retweeters(self, twid: int, limit=-1): @@ -305,7 +197,7 @@ class API: async def favoriters_raw(self, twid: int, limit=-1): op = "vcTrPlh9ovFDQejz22q9vg/Favoriters" kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True} - async for x in self._ql_items(op, kv, limit=limit): + async for x in self._gql_items(op, kv, limit=limit): yield x async def favoriters(self, twid: int, limit=-1): @@ -326,7 +218,7 @@ class API: "withVoice": True, "withV2Timeline": True, } - async for x in self._ql_items(op, kv, limit=limit): + async for x in self._gql_items(op, kv, limit=limit): yield x async def user_tweets(self, uid: int, limit=-1): @@ -347,7 +239,7 @@ class API: "withVoice": True, "withV2Timeline": True, } - async for x in self._ql_items(op, kv, limit=limit): + async for x in self._gql_items(op, kv, limit=limit): yield x async def user_tweets_and_replies(self, uid: int, limit=-1): diff --git a/twscrape/queue_client.py b/twscrape/queue_client.py new file mode 100644 index 0000000..dac5776 --- /dev/null +++ b/twscrape/queue_client.py @@ -0,0 +1,134 @@ +import json +from datetime import datetime + +import httpx + +from .accounts_pool import Account, AccountsPool +from .logger import logger +from .utils import utc_ts + +ReqParams = dict[str, str | int] | None + + +def req_id(rep: httpx.Response): + lr = rep.headers.get("x-rate-limit-remaining", -1) + ll = rep.headers.get("x-rate-limit-limit", -1) + + username = getattr(rep, "__username", "") + return f"{username} {lr}/{ll}" + + +class Ctx: + def __init__(self, acc: Account, clt: httpx.AsyncClient): + self.acc = acc + self.clt = clt + + +class QueueClient: + def __init__(self, pool: AccountsPool, queue: str, debug=False): + self.pool = pool + self.queue = queue + self.debug = debug + self.history: list[httpx.Response] = [] + self.ctx: Ctx | None = None + + async def __aenter__(self): + await self._get_ctx() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._close_ctx() + return self + + async def _close_ctx(self): + if self.ctx is not None: + await self.ctx.clt.aclose() + await self.pool.unlock(self.ctx.acc.username, self.queue) + + async def _get_ctx(self, fresh=False) -> Ctx: + if self.ctx and not fresh: + return self.ctx + + if self.ctx is not None: + await self._close_ctx() + + acc = await self.pool.get_for_queue_or_wait(self.queue) + clt = acc.make_client() + self.ctx = Ctx(acc, clt) + return self.ctx + + def _push_history(self, rep: httpx.Response): + self.history.append(rep) + if len(self.history) > 3: + self.history.pop(0) + + def _dump_history(self, extra: str = ""): + if not self.debug: + return + + ts = str(datetime.now()).replace(":", "-").replace(" ", "_") + filename = f"/tmp/api_dump_{ts}.txt" + with open(filename, "w", encoding="utf-8") as fp: + txt = f"{extra}\n" + for rep in self.history: + res = json.dumps(rep.json(), indent=2) + hdr = "\n".join([str(x) for x in list(rep.request.headers.items())]) + div = "-" * 20 + + msg = f"{div}\n{req_id(rep)}" + msg = f"{msg}\n{rep.request.method} {rep.request.url}" + msg = f"{msg}\n{rep.status_code}\n{div}" + msg = f"{msg}\n{hdr}\n{div}\n{res}\n\n" + txt += msg + + fp.write(txt) + + print(f"API dump ({len(self.history)}) dumped to {filename}") + + async def req(self, method: str, url: str, params: ReqParams = None): + fresh = False # do not get new account on first try + while True: + ctx = await self._get_ctx(fresh=fresh) + fresh = True + + try: + rep = await ctx.clt.request(method, url, params=params) + setattr(rep, "__username", ctx.acc.username) + self._push_history(rep) + rep.raise_for_status() + return rep + except httpx.HTTPStatusError as e: + rep = e.response + log_id = f"{req_id(rep)} on queue={self.queue}" + + # rate limit + if rep.status_code == 429: + logger.debug(f"Rate limit for {log_id}") + reset_ts = int(rep.headers.get("x-rate-limit-reset", 0)) + await self.pool.lock_until(ctx.acc.username, self.queue, reset_ts) + continue + + # possible account banned + if rep.status_code == 403: + logger.warning(f"403 for {log_id}") + reset_ts = utc_ts() + 60 * 60 # + 1 hour + await self.pool.lock_until(ctx.acc.username, self.queue, reset_ts) + continue + + # twitter can return different types of cursors that not transfers between accounts + # just take the next account, the current cursor can work in it + if rep.status_code == 400: + logger.debug(f"Cursor not valid for {log_id}") + continue + + logger.error(f"[{rep.status_code}] {e.request.url}\n{rep.text}") + raise e + except Exception as e: + logger.warning(f"Unknown error, retrying. Err: {e}") + + async def get(self, url: str, params: ReqParams = None): + try: + return await self.req("GET", url, params=params) + except httpx.HTTPStatusError as e: + self._dump_history(f"GET {url} {json.dumps(params)}") + raise e