add error on unsupported sqlite version

Этот коммит содержится в:
Vlad Pronsky 2023-05-25 05:44:18 +03:00 коммит произвёл vladkens
родитель c0204aa492
Коммит 3333692a32
5 изменённых файлов: 20 добавлений и 7 удалений

2
.github/workflows/test.yml поставляемый
Просмотреть файл

@ -2,8 +2,6 @@ name: test
on:
push:
tags-ignore:
- '**'
env:
PIP_ROOT_USER_ACTION: ignore

Просмотреть файл

@ -4,7 +4,7 @@ from twscrape.accounts_pool import AccountsPool
from twscrape.db import DB
from twscrape.utils import utc_ts
DB_FILE = "/tmp/twscrape_test.db"
DB_FILE = "/tmp/twscrape_test_pool.db"
def remove_db():

Просмотреть файл

@ -6,7 +6,7 @@ from twscrape.queue_client import QueueClient
from .utils import get_pool
DB_FILE = "/tmp/test_queue_client.db"
DB_FILE = "/tmp/twscrape_test_queue_client.db"
URL = "https://example.com/api"
set_log_level("ERROR")

Просмотреть файл

@ -2,5 +2,6 @@
from .account import Account
from .accounts_pool import AccountsPool
from .api import API
from .logger import set_log_level
from .models import * # noqa: F403
from .utils import gather

Просмотреть файл

@ -4,6 +4,8 @@ from collections import defaultdict
import aiosqlite
MIN_SQLITE_VERSION = "3.34"
def lock_retry(max_retries=5, delay=1):
def decorator(func):
@ -23,6 +25,17 @@ def lock_retry(max_retries=5, delay=1):
return decorator
async def check_version(db: aiosqlite.Connection):
async with db.execute("SELECT SQLITE_VERSION()") as cur:
rs = await cur.fetchone()
rs = rs[0] if rs else "3.0.0"
rs = ".".join(rs.split(".")[:2])
if rs < MIN_SQLITE_VERSION:
msg = f"SQLite version '{rs}' is too old, please upgrade to {MIN_SQLITE_VERSION}+"
raise SystemError(msg)
class DB:
_init_queries: defaultdict[str, list[str]] = defaultdict(list)
_init_once: defaultdict[str, bool] = defaultdict(bool)
@ -34,6 +47,7 @@ class DB:
async def __aenter__(self):
db = await aiosqlite.connect(self.db_path)
db.row_factory = aiosqlite.Row
await check_version(db)
if not self._init_once[self.db_path]:
for qs in self._init_queries[self.db_path]:
@ -56,13 +70,13 @@ def add_init_query(db_path: str, qs: str):
@lock_retry()
async def execute(db_path: str, qs: str, params: dict = {}):
async def execute(db_path: str, qs: str, params: dict | None = None):
async with DB(db_path) as db:
await db.execute(qs, params)
@lock_retry()
async def fetchone(db_path: str, qs: str, params: dict = {}):
async def fetchone(db_path: str, qs: str, params: dict | None = None):
async with DB(db_path) as db:
async with db.execute(qs, params) as cur:
row = await cur.fetchone()
@ -70,7 +84,7 @@ async def fetchone(db_path: str, qs: str, params: dict = {}):
@lock_retry()
async def fetchall(db_path: str, qs: str, params: dict = {}):
async def fetchall(db_path: str, qs: str, params: dict | None = None):
async with DB(db_path) as db:
async with db.execute(qs, params) as cur:
rows = await cur.fetchall()