From 3333692a3253a5dc67669212961705ae8c8a19e0 Mon Sep 17 00:00:00 2001 From: Vlad Pronsky Date: Thu, 25 May 2023 05:44:18 +0300 Subject: [PATCH] add error on unsupported sqlite version --- .github/workflows/test.yml | 2 -- tests/test_pool.py | 2 +- tests/test_queue_client.py | 2 +- twscrape/__init__.py | 1 + twscrape/db.py | 20 +++++++++++++++++--- 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 479c323..2f5287c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,8 +2,6 @@ name: test on: push: - tags-ignore: - - '**' env: PIP_ROOT_USER_ACTION: ignore diff --git a/tests/test_pool.py b/tests/test_pool.py index 729dde5..c0d3bfd 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -4,7 +4,7 @@ from twscrape.accounts_pool import AccountsPool from twscrape.db import DB from twscrape.utils import utc_ts -DB_FILE = "/tmp/twscrape_test.db" +DB_FILE = "/tmp/twscrape_test_pool.db" def remove_db(): diff --git a/tests/test_queue_client.py b/tests/test_queue_client.py index d41b221..fbca4d0 100644 --- a/tests/test_queue_client.py +++ b/tests/test_queue_client.py @@ -6,7 +6,7 @@ from twscrape.queue_client import QueueClient from .utils import get_pool -DB_FILE = "/tmp/test_queue_client.db" +DB_FILE = "/tmp/twscrape_test_queue_client.db" URL = "https://example.com/api" set_log_level("ERROR") diff --git a/twscrape/__init__.py b/twscrape/__init__.py index e00efb4..3b4c14a 100644 --- a/twscrape/__init__.py +++ b/twscrape/__init__.py @@ -2,5 +2,6 @@ from .account import Account from .accounts_pool import AccountsPool from .api import API +from .logger import set_log_level from .models import * # noqa: F403 from .utils import gather diff --git a/twscrape/db.py b/twscrape/db.py index 066b864..d306dbb 100644 --- a/twscrape/db.py +++ b/twscrape/db.py @@ -4,6 +4,8 @@ from collections import defaultdict import aiosqlite +MIN_SQLITE_VERSION = "3.34" + def lock_retry(max_retries=5, delay=1): def decorator(func): @@ -23,6 +25,17 @@ def lock_retry(max_retries=5, delay=1): return decorator +async def check_version(db: aiosqlite.Connection): + async with db.execute("SELECT SQLITE_VERSION()") as cur: + rs = await cur.fetchone() + rs = rs[0] if rs else "3.0.0" + rs = ".".join(rs.split(".")[:2]) + + if rs < MIN_SQLITE_VERSION: + msg = f"SQLite version '{rs}' is too old, please upgrade to {MIN_SQLITE_VERSION}+" + raise SystemError(msg) + + class DB: _init_queries: defaultdict[str, list[str]] = defaultdict(list) _init_once: defaultdict[str, bool] = defaultdict(bool) @@ -34,6 +47,7 @@ class DB: async def __aenter__(self): db = await aiosqlite.connect(self.db_path) db.row_factory = aiosqlite.Row + await check_version(db) if not self._init_once[self.db_path]: for qs in self._init_queries[self.db_path]: @@ -56,13 +70,13 @@ def add_init_query(db_path: str, qs: str): @lock_retry() -async def execute(db_path: str, qs: str, params: dict = {}): +async def execute(db_path: str, qs: str, params: dict | None = None): async with DB(db_path) as db: await db.execute(qs, params) @lock_retry() -async def fetchone(db_path: str, qs: str, params: dict = {}): +async def fetchone(db_path: str, qs: str, params: dict | None = None): async with DB(db_path) as db: async with db.execute(qs, params) as cur: row = await cur.fetchone() @@ -70,7 +84,7 @@ async def fetchone(db_path: str, qs: str, params: dict = {}): @lock_retry() -async def fetchall(db_path: str, qs: str, params: dict = {}): +async def fetchall(db_path: str, qs: str, params: dict | None = None): async with DB(db_path) as db: async with db.execute(qs, params) as cur: rows = await cur.fetchall()