更加优雅地在Python测试代码中查询数据库

作者: 潘峰 / 2022-08-30 / 分类: Work

MySQL, Python

更加优雅地在 Python 测试代码中查询数据库

with connection:
    with connection.cursor() as cursor:
        sql = "SELECT `id`, `password` FROM `users` WHERE `email`=%s"
        cursor.execute(sql, ('webmaster@python.org',))
        result = cursor.fetchone()

->

db.from_(_ := Table("users")).select(_.id, _.password).where(_.email.eq("test")).execute().fetchone()

的过程。

一、基于 PyMySQL 的原始写法

来自官网的样例(经过简化)

import pymysql.cursors

connection = pymysql.connect(host='localhost',
                             user='user',
                             password='passwd',
                             database='db',
                             cursorclass=pymysql.cursors.DictCursor)

with connection:
    with connection.cursor() as cursor:
        sql = "SELECT `id`, `password` FROM `users` WHERE `email`=%s"
        cursor.execute(sql, ('webmaster@python.org',))
        result = cursor.fetchone()

从上面的样例来看,不难发现其存在两个小问题:

  1. 仅查询部分来看,对于测试代码的面临的场景来说,代码有点冗长,需要两个 with 来关闭 cursorconnection
  2. 需要手写裸 SQL,会使测试代码看起来不那么优雅

二、对 PyMySQL 进行二次封装的入手点

  1. 通过单行语句完成 DB 查询操作(考虑使用场景可适当忽略语句执行性能)
  2. 可使用类 ORM 语法的 SQL 构建器构建 SQL 语句来替代手写裸 SQL

三、具体实现

1. 对 PyMySQL 进行二次封装

预计达成效果

db = DBEngine(host="localhost", username="user", password="passwd", database="db")

db.execute(sql=f"SELECT `id`, `password` FROM `users` WHERE `email`='test'").fetchone()

代码实现

class DBError(Exception):
    pass


def _raise_db_error(msg):
    raise DBError(msg)


class _Engine:
    def __init__(self, **db_options):
        self._db_options = db_options
        self._db_conn: Optional[PersistentDB] = None

    def create_conn(self, db_options, debug: bool = False):
        host = db_options.get("ip") or db_options.get("host") or _raise_db_error("ip/host 必传")
        username = db_options.get("username") or db_options.get("user") or _raise_db_error("username 必传")
        password = db_options.get("password") or _raise_db_error("password 必传")
        port = db_options.get("port") or 3306
        database = db_options.get("database")
        debug and logger.debug(f"数据库连接信息:{username}:{password[:2]}??????@{host}:{port}")
        if not self._db_conn:
            self._db_conn = PersistentDB(
                creator=pymysql, setsession=[], host=host, port=int(port),
                user=username, password=password, database=database
            )
        return self._db_conn.connection(shareable=False)

    def execute(self, sql, commit=True, debug=False):
        conn = self.create_conn(self._db_options, debug)
        debug and logger.debug(f"开始执行 SQL:{sql}")
        exe_err = None
        try:
            with conn.cursor(cursor=pymysql.cursors.DictCursor) as cursor:
                cursor.execute(sql)
                query = cursor.fetchall()
            commit and conn.commit()
            return query
        except Exception as e:
            exe_err = e
            conn.rollback()
        finally:
            if self._db_conn and (exe_err or commit):
                conn.close()
        raise exe_err

    def commit(self):
        self._db_conn and self._db_conn.dedicated_connection().commit()

    def close_conn(self):
        self._db_conn and self._db_conn.dedicated_connection().close()
        self._db_conn = None


class _Query:
    def __init__(self, res: dict):
        self.res = res

    def fetchone(self):
        return self.res[0] if self.res else {}

    def fetchmany(self, num):
        return self.res[:num]

    def fetchall(self):
        return self.res


class DBEngine(_Engine):

    def execute(self, sql, commit=True, debug=False):
        query = super().execute(sql, commit, debug)
        return _Query(query)

2. 借助开源第三方库 pypika 进行 SQL 语句构建

from pypika import MySQLQuery, Table

sql = MySQLQuery.from_(_ := Table("users")).select(_.id, _.password).where(_.email.eq("test"))
db.execute(sql=str(sql)).fetchone()

3. 代码缝合

预计达成效果

db.from_(_ := Table("users")).select(_.id, _.password).where(_.email.eq("test")).execute().fetchone()

代码实现

在上述封装的基础上,通过对类实例的创建过程进行特殊的定制(基于 __getattr__),增加 DBEngine 链式实例调用的能力。

class InstanceGen:
    """ 生成链式实例调用 """

    def __init__(self, trigger_method, method_name, instance):
        self.__trigger_method = trigger_method
        self.__method_name = method_name
        self.__instance = instance

    def __getattr__(self, method_name):
        self.__method_name = method_name
        return self

    def __call__(self, *args, **kwargs):
        if self.__method_name == "execute":
            return self.__trigger_method("execute", self.__instance, **kwargs)
        else:
            self.__instance = getattr(self.__instance, self.__method_name)(*args, **kwargs)
            return self


class _Method:

    def __getattr__(self, item):
        return InstanceGen(self.__trigger, item, instance=MySQLQuery)

    def __trigger(self, method, instance, **kwargs):
        return getattr(self, method)(str(instance), **kwargs)


class DBEngine(_Engine, _Method):

    def execute(self, sql, commit=True, debug=False):
        query = super().execute(sql, commit, debug)
        return _Query(query)

四、基于 test-x 进行使用

pip install test-x
from test-x.database import DBEngine, Table

db = DBEngine(host="localhost", username="user", password="passwd", database="db")
db.from_(_ := Table("users")).select(_.id, _.password).where(_.name.eq("test")).execute().fetchone()