diff --git a/tests/conftest.py b/tests/conftest.py
index 83983ee..bb03ab3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,21 +1,31 @@
import pytest
from twscrape.accounts_pool import AccountsPool
+from twscrape.api import API
from twscrape.queue_client import QueueClient
@pytest.fixture
-def poolm(tmp_path) -> AccountsPool: # type: ignore
+def pool_mock(tmp_path) -> AccountsPool: # type: ignore
db_path = tmp_path / "test.db"
yield AccountsPool(db_path)
@pytest.fixture
-async def client_fixture(poolm: AccountsPool):
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- await poolm.add_account("user2", "pass2", "email2", "email_pass2")
- await poolm.set_active("user1", True)
- await poolm.set_active("user2", True)
+async def client_fixture(pool_mock: AccountsPool):
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ await pool_mock.add_account("user2", "pass2", "email2", "email_pass2")
+ await pool_mock.set_active("user1", True)
+ await pool_mock.set_active("user2", True)
- client = QueueClient(poolm, "search")
- yield poolm, client
+ client = QueueClient(pool_mock, "search")
+ yield pool_mock, client
+
+
+@pytest.fixture
+async def api_mock(pool_mock: AccountsPool):
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ await pool_mock.set_active("user1", True)
+
+ api = API(pool_mock)
+ yield api
diff --git a/tests/test_api.py b/tests/test_api.py
new file mode 100644
index 0000000..4c071e0
--- /dev/null
+++ b/tests/test_api.py
@@ -0,0 +1,38 @@
+from twscrape.api import API
+from twscrape.logger import set_log_level
+from twscrape.utils import gather
+
+set_log_level("DEBUG")
+
+
+class MockedError(Exception):
+ pass
+
+
+GQL_GEN = [
+ "followers",
+ "following",
+ "retweeters",
+ "favoriters",
+ "user_tweets",
+ "user_tweets_and_replies",
+]
+
+
+async def test_gql_params(api_mock: API, monkeypatch):
+ for func in GQL_GEN:
+ args = []
+
+ def mock_gql_items(*a, **kw):
+ args.append((a, kw))
+ raise MockedError()
+
+ try:
+ monkeypatch.setattr(api_mock, "_gql_items", mock_gql_items)
+ await gather(getattr(api_mock, func)("user1", limit=100, kv={"count": 100}))
+ except MockedError:
+ pass
+
+ assert len(args) == 1, f"{func} not called once"
+ assert args[0][1]["limit"] == 100, f"limit not changed in {func}"
+ assert args[0][0][1]["count"] == 100, f"count not changed in {func}"
diff --git a/tests/test_pool.py b/tests/test_pool.py
index 9282709..7baed32 100644
--- a/tests/test_pool.py
+++ b/tests/test_pool.py
@@ -2,78 +2,78 @@ from twscrape.accounts_pool import AccountsPool
from twscrape.utils import utc_ts
-async def test_add_accounts(poolm: AccountsPool):
+async def test_add_accounts(pool_mock: AccountsPool):
# should add account
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- acc = await poolm.get("user1")
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ acc = await pool_mock.get("user1")
assert acc.username == "user1"
assert acc.password == "pass1"
assert acc.email == "email1"
assert acc.email_password == "email_pass1"
# should not add account with same username
- await poolm.add_account("user1", "pass2", "email2", "email_pass2")
- acc = await poolm.get("user1")
+ await pool_mock.add_account("user1", "pass2", "email2", "email_pass2")
+ acc = await pool_mock.get("user1")
assert acc.username == "user1"
assert acc.password == "pass1"
assert acc.email == "email1"
assert acc.email_password == "email_pass1"
# should not add account with different username case
- await poolm.add_account("USER1", "pass2", "email2", "email_pass2")
- acc = await poolm.get("user1")
+ await pool_mock.add_account("USER1", "pass2", "email2", "email_pass2")
+ acc = await pool_mock.get("user1")
assert acc.username == "user1"
assert acc.password == "pass1"
assert acc.email == "email1"
assert acc.email_password == "email_pass1"
# should add account with different username
- await poolm.add_account("user2", "pass2", "email2", "email_pass2")
- acc = await poolm.get("user2")
+ await pool_mock.add_account("user2", "pass2", "email2", "email_pass2")
+ acc = await pool_mock.get("user2")
assert acc.username == "user2"
assert acc.password == "pass2"
assert acc.email == "email2"
assert acc.email_password == "email_pass2"
-async def test_get_all(poolm: AccountsPool):
+async def test_get_all(pool_mock: AccountsPool):
# should return empty list
- accs = await poolm.get_all()
+ accs = await pool_mock.get_all()
assert len(accs) == 0
# should return all accounts
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- await poolm.add_account("user2", "pass2", "email2", "email_pass2")
- accs = await poolm.get_all()
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ await pool_mock.add_account("user2", "pass2", "email2", "email_pass2")
+ accs = await pool_mock.get_all()
assert len(accs) == 2
assert accs[0].username == "user1"
assert accs[1].username == "user2"
-async def test_save(poolm: AccountsPool):
+async def test_save(pool_mock: AccountsPool):
# should save account
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- acc = await poolm.get("user1")
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ acc = await pool_mock.get("user1")
acc.password = "pass2"
- await poolm.save(acc)
- acc = await poolm.get("user1")
+ await pool_mock.save(acc)
+ acc = await pool_mock.get("user1")
assert acc.password == "pass2"
# should not save account
- acc = await poolm.get("user1")
+ acc = await pool_mock.get("user1")
acc.username = "user2"
- await poolm.save(acc)
- acc = await poolm.get("user1")
+ await pool_mock.save(acc)
+ acc = await pool_mock.get("user1")
assert acc.username == "user1"
-async def test_get_for_queue(poolm: AccountsPool):
+async def test_get_for_queue(pool_mock: AccountsPool):
Q = "test_queue"
# should return account
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- await poolm.set_active("user1", True)
- acc = await poolm.get_for_queue(Q)
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ await pool_mock.set_active("user1", True)
+ acc = await pool_mock.get_for_queue(Q)
assert acc is not None
assert acc.username == "user1"
assert acc.active is True
@@ -82,56 +82,56 @@ async def test_get_for_queue(poolm: AccountsPool):
assert acc.locks[Q] is not None
# should return None
- acc = await poolm.get_for_queue(Q)
+ acc = await pool_mock.get_for_queue(Q)
assert acc is None
-async def test_account_unlock(poolm: AccountsPool):
+async def test_account_unlock(pool_mock: AccountsPool):
Q = "test_queue"
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- await poolm.set_active("user1", True)
- acc = await poolm.get_for_queue(Q)
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ await pool_mock.set_active("user1", True)
+ acc = await pool_mock.get_for_queue(Q)
assert acc is not None
assert acc.locks[Q] is not None
# should unlock account and make available for queue
- await poolm.unlock(acc.username, Q)
- acc = await poolm.get_for_queue(Q)
+ await pool_mock.unlock(acc.username, Q)
+ acc = await pool_mock.get_for_queue(Q)
assert acc is not None
assert acc.locks[Q] is not None
# should update lock time
end_time = utc_ts() + 60 # + 1 minute
- await poolm.lock_until(acc.username, Q, end_time)
+ await pool_mock.lock_until(acc.username, Q, end_time)
- acc = await poolm.get(acc.username)
+ acc = await pool_mock.get(acc.username)
assert int(acc.locks[Q].timestamp()) == end_time
-async def test_get_stats(poolm: AccountsPool):
+async def test_get_stats(pool_mock: AccountsPool):
Q = "search"
# should return empty stats
- stats = await poolm.stats()
+ stats = await pool_mock.stats()
for k, v in stats.items():
assert v == 0, f"{k} should be 0"
# should increate total
- await poolm.add_account("user1", "pass1", "email1", "email_pass1")
- stats = await poolm.stats()
+ await pool_mock.add_account("user1", "pass1", "email1", "email_pass1")
+ stats = await pool_mock.stats()
assert stats["total"] == 1
assert stats["active"] == 0
# should increate active
- await poolm.set_active("user1", True)
- stats = await poolm.stats()
+ await pool_mock.set_active("user1", True)
+ stats = await pool_mock.stats()
assert stats["total"] == 1
assert stats["active"] == 1
# should update queue stats
- await poolm.get_for_queue(Q)
- stats = await poolm.stats()
+ await pool_mock.get_for_queue(Q)
+ stats = await pool_mock.stats()
assert stats["total"] == 1
assert stats["active"] == 1
assert stats["locked_search"] == 1
diff --git a/twscrape/api.py b/twscrape/api.py
index 0652b9b..7db63ea 100644
--- a/twscrape/api.py
+++ b/twscrape/api.py
@@ -13,7 +13,7 @@ class API:
self.pool = pool
self.debug = debug
- # http helpers
+ # general helpers
def _is_end(self, rep: Response, q: str, res: list, cur: str | None, cnt: int, lim: int):
new_count = len(res)
@@ -34,6 +34,35 @@ class API:
return cur.get("value")
return None
+ # gql helpers
+
+ async def _gql_items(self, op: str, kv: dict, limit=-1):
+ queue, cursor, count, active = op.split("/")[-1], None, 0, True
+
+ async with QueueClient(self.pool, queue, self.debug) as client:
+ while active:
+ params = {"variables": {**kv, "cursor": cursor}, "features": GQL_FEATURES}
+
+ rep = await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
+ obj = rep.json()
+
+ entries = get_by_path(obj, "entries") or []
+ entries = [x for x in entries if not x["entryId"].startswith("cursor-")]
+ cursor = self._get_cursor(obj)
+
+ rep, count, active = self._is_end(rep, queue, entries, cursor, count, limit)
+ if rep is None:
+ return
+
+ yield rep
+
+ async def _gql_item(self, op: str, kv: dict, ft: dict | None = None):
+ ft = ft or {}
+ queue = op.split("/")[-1]
+ async with QueueClient(self.pool, queue, self.debug) as client:
+ params = {"variables": {**kv}, "features": {**GQL_FEATURES, **ft}}
+ return await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
+
# search
async def search_raw(self, q: str, limit=-1):
@@ -66,62 +95,33 @@ class API:
twids.add(x["id_str"])
yield Tweet.parse(x, obj)
- # gql helpers
-
- async def _gql_items(self, op: str, kv: dict, limit=-1):
- queue, cursor, count, active = op.split("/")[-1], None, 0, True
-
- async with QueueClient(self.pool, queue, self.debug) as client:
- while active:
- params = {"variables": {**kv, "cursor": cursor}, "features": GQL_FEATURES}
-
- rep = await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
- obj = rep.json()
-
- entries = get_by_path(obj, "entries") or []
- entries = [x for x in entries if not x["entryId"].startswith("cursor-")]
- cursor = self._get_cursor(obj)
-
- rep, count, active = self._is_end(rep, queue, entries, cursor, count, limit)
- if rep is None:
- return
-
- yield rep
-
- async def _gql_item(self, op: str, kv: dict, ft: dict | None = None):
- ft = ft or {}
- queue = op.split("/")[-1]
- async with QueueClient(self.pool, queue, self.debug) as client:
- params = {"variables": {**kv}, "features": {**GQL_FEATURES, **ft}}
- return await client.get(f"{GQL_URL}/{op}", params=encode_params(params))
-
# user_by_id
- async def user_by_id_raw(self, uid: int):
+ async def user_by_id_raw(self, uid: int, kv=None):
op = "GazOglcBvgLigl3ywt6b3Q/UserByRestId"
- kv = {"userId": str(uid), "withSafetyModeUserFields": True}
+ kv = {"userId": str(uid), "withSafetyModeUserFields": True, **(kv or {})}
return await self._gql_item(op, kv)
- async def user_by_id(self, uid: int):
- rep = await self.user_by_id_raw(uid)
+ async def user_by_id(self, uid: int, kv=None):
+ rep = await self.user_by_id_raw(uid, kv=kv)
res = rep.json()
return User.parse(to_old_obj(res["data"]["user"]["result"]))
# user_by_login
- async def user_by_login_raw(self, login: str):
+ async def user_by_login_raw(self, login: str, kv=None):
op = "sLVLhk0bGj3MVFEKTdax1w/UserByScreenName"
- kv = {"screen_name": login, "withSafetyModeUserFields": True}
+ kv = {"screen_name": login, "withSafetyModeUserFields": True, **(kv or {})}
return await self._gql_item(op, kv)
- async def user_by_login(self, login: str):
- rep = await self.user_by_login_raw(login)
+ async def user_by_login(self, login: str, kv=None):
+ rep = await self.user_by_login_raw(login, kv=kv)
res = rep.json()
return User.parse(to_old_obj(res["data"]["user"]["result"]))
# tweet_details
- async def tweet_details_raw(self, twid: int):
+ async def tweet_details_raw(self, twid: int, kv=None):
op = "zXaXQgfyR4GxE21uwYQSyA/TweetDetail"
kv = {
"focalTweetId": str(twid),
@@ -138,6 +138,7 @@ class API:
"withReactionsPerspective": False,
"withSuperFollowsTweetFields": False,
"withSuperFollowsUserFields": False,
+ **(kv or {}),
}
ft = {
"responsive_web_twitter_blue_verified_badge_is_enabled": True,
@@ -145,70 +146,70 @@ class API:
}
return await self._gql_item(op, kv, ft)
- async def tweet_details(self, twid: int):
- rep = await self.tweet_details_raw(twid)
+ async def tweet_details(self, twid: int, kv=None):
+ rep = await self.tweet_details_raw(twid, kv=kv)
obj = to_old_rep(rep.json())
return Tweet.parse(obj["tweets"][str(twid)], obj)
# followers
- async def followers_raw(self, uid: int, limit=-1):
+ async def followers_raw(self, uid: int, limit=-1, kv=None):
op = "djdTXDIk2qhd4OStqlUFeQ/Followers"
- kv = {"userId": str(uid), "count": 20, "includePromotedContent": False}
+ kv = {"userId": str(uid), "count": 20, "includePromotedContent": False, **(kv or {})}
async for x in self._gql_items(op, kv, limit=limit):
yield x
- async def followers(self, uid: int, limit=-1):
- async for rep in self.followers_raw(uid, limit=limit):
+ async def followers(self, uid: int, limit=-1, kv=None):
+ async for rep in self.followers_raw(uid, limit=limit, kv=kv):
obj = to_old_rep(rep.json())
for _, v in obj["users"].items():
yield User.parse(v)
# following
- async def following_raw(self, uid: int, limit=-1):
+ async def following_raw(self, uid: int, limit=-1, kv=None):
op = "IWP6Zt14sARO29lJT35bBw/Following"
- kv = {"userId": str(uid), "count": 20, "includePromotedContent": False}
+ kv = {"userId": str(uid), "count": 20, "includePromotedContent": False, **(kv or {})}
async for x in self._gql_items(op, kv, limit=limit):
yield x
- async def following(self, uid: int, limit=-1):
- async for rep in self.following_raw(uid, limit=limit):
+ async def following(self, uid: int, limit=-1, kv=None):
+ async for rep in self.following_raw(uid, limit=limit, kv=kv):
obj = to_old_rep(rep.json())
for _, v in obj["users"].items():
yield User.parse(v)
# retweeters
- async def retweeters_raw(self, twid: int, limit=-1):
+ async def retweeters_raw(self, twid: int, limit=-1, kv=None):
op = "U5f_jm0CiLmSfI1d4rGleQ/Retweeters"
- kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True}
+ kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True, **(kv or {})}
async for x in self._gql_items(op, kv, limit=limit):
yield x
- async def retweeters(self, twid: int, limit=-1):
- async for rep in self.retweeters_raw(twid, limit=limit):
+ async def retweeters(self, twid: int, limit=-1, kv=None):
+ async for rep in self.retweeters_raw(twid, limit=limit, kv=kv):
obj = to_old_rep(rep.json())
for _, v in obj["users"].items():
yield User.parse(v)
# favoriters
- async def favoriters_raw(self, twid: int, limit=-1):
+ async def favoriters_raw(self, twid: int, limit=-1, kv=None):
op = "vcTrPlh9ovFDQejz22q9vg/Favoriters"
- kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True}
+ kv = {"tweetId": str(twid), "count": 20, "includePromotedContent": True, **(kv or {})}
async for x in self._gql_items(op, kv, limit=limit):
yield x
- async def favoriters(self, twid: int, limit=-1):
- async for rep in self.favoriters_raw(twid, limit=limit):
+ async def favoriters(self, twid: int, limit=-1, kv=None):
+ async for rep in self.favoriters_raw(twid, limit=limit, kv=kv):
obj = to_old_rep(rep.json())
for _, v in obj["users"].items():
yield User.parse(v)
# user_tweets
- async def user_tweets_raw(self, uid: int, limit=-1):
+ async def user_tweets_raw(self, uid: int, limit=-1, kv=None):
op = "CdG2Vuc1v6F5JyEngGpxVw/UserTweets"
kv = {
"userId": str(uid),
@@ -217,19 +218,20 @@ class API:
"withQuickPromoteEligibilityTweetFields": True,
"withVoice": True,
"withV2Timeline": True,
+ **(kv or {}),
}
async for x in self._gql_items(op, kv, limit=limit):
yield x
- async def user_tweets(self, uid: int, limit=-1):
- async for rep in self.user_tweets_raw(uid, limit=limit):
+ async def user_tweets(self, uid: int, limit=-1, kv=None):
+ async for rep in self.user_tweets_raw(uid, limit=limit, kv=kv):
obj = to_old_rep(rep.json())
for _, v in obj["tweets"].items():
yield Tweet.parse(v, obj)
# user_tweets_and_replies
- async def user_tweets_and_replies_raw(self, uid: int, limit=-1):
+ async def user_tweets_and_replies_raw(self, uid: int, limit=-1, kv=None):
op = "zQxfEr5IFxQ2QZ-XMJlKew/UserTweetsAndReplies"
kv = {
"userId": str(uid),
@@ -238,12 +240,13 @@ class API:
"withCommunity": True,
"withVoice": True,
"withV2Timeline": True,
+ **(kv or {}),
}
async for x in self._gql_items(op, kv, limit=limit):
yield x
- async def user_tweets_and_replies(self, uid: int, limit=-1):
- async for rep in self.user_tweets_and_replies_raw(uid, limit=limit):
+ async def user_tweets_and_replies(self, uid: int, limit=-1, kv=None):
+ async for rep in self.user_tweets_and_replies_raw(uid, limit=limit, kv=kv):
obj = to_old_rep(rep.json())
for _, v in obj["tweets"].items():
yield Tweet.parse(v, obj)