插件-数据库

由于dataset是只支持同步操作,在性能方面有很大瓶颈,这里提供一个异步协程的插件。

pip install aiomysql
pip install nest_asyncio
import asyncio
import datetime
from types import NoneType

import aiomysql
import logging

import nest_asyncio

# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 定义类型映射(这个类型映射,是为了防止在进行操作的时候,没有建表,用于进行自动建表)
type_mapping = {
    int: 'INT',
    float: 'FLOAT',
    # str: 'VARCHAR(255)',
    str: 'TEXT',
    bool: 'BOOLEAN',
    datetime.datetime: 'DATETIME',
    datetime.date: 'DATE',
    None: 'TEXT',
    NoneType: 'TEXT'
}


async def create_pool():
    """
    创建数据库连接池
    """
    pool = await aiomysql.create_pool(
        host='localhost',
        port=3306,
        user='root',
        password='123456',
        db='zhihu',
        autocommit=True,  # 自动提交事务
        minsize=1,  # 连接池最小连接数
        maxsize=10  # 连接池最大连接数
    )
    return pool


async def close_pool(pool):
    """
    关闭连接池
    :param pool: 数据库连接池
    """
    pool.close()
    await pool.wait_closed()


async def exists_table(pool, table_name: str):
    """
    检测表是否存在
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :return: 表是否存在,True 表示存在,False 表示不存在
    """
    # 构建 SQL 检测表是否存在语句
    sql = f'''
        SELECT COUNT(*)
        FROM information_schema.tables
        WHERE table_schema = DATABASE()
        AND table_name = %s
    '''
    logging.info(sql % (table_name,))
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 检测表是否存在语句
            await cur.execute(sql, (table_name,))
            # 获取查询结果
            result = await cur.fetchone()
            return result[0] > 0  # 返回表是否存在


async def create_table(pool, table_name: str, data: dict):
    """
    创建表
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data: 数据字典,键为字段名,值为字段类型
    """

    # 自动添加主键 id
    columns = ['id INT AUTO_INCREMENT PRIMARY KEY']

    # 根据 data 字典中的数据类型生成字段定义
    if data:
        for key, value in data.items():
            field_type = type_mapping.get(type(value))
            if field_type:
                columns.append(f"`{key}` {field_type}")
            else:
                raise ValueError(f"Unsupported data type for field '{key}': {type(value)}")
    # 自动添加创建时间和更新时间字段
    columns.append('created_at DATETIME DEFAULT CURRENT_TIMESTAMP')
    columns.append('updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')

    # 构建 SQL 创建表语句
    sql = f'''
            CREATE TABLE `{table_name}` (
                {', '.join(columns)}
            )
        '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 创建表语句
            await cur.execute(sql)


async def drop_table(pool, table_name: str):
    """
    删除表
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    """
    # 构建 SQL 删除表语句
    sql = f'''
        DROP TABLE IF EXISTS `{table_name}`
    '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 删除表语句
            await cur.execute(sql)


async def truncate_table(pool, table_name: str):
    """
    截断表
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    """
    # 构建 SQL 截断表语句
    sql = f'''
        TRUNCATE TABLE `{table_name}`
    '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 截断表语句
            await cur.execute(sql)


async def update_table(pool, table_name: str, data: dict):
    """
    更新表结构
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data: 数据字典,键为字段名,值为字段类型
    """
    # 保留的字段
    reserved_fields = {'id', 'created_at', 'updated_at'}
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 获取当前表的字段信息
            await cur.execute(f"DESCRIBE `{table_name}`")
            current_fields = await cur.fetchall()
            current_fields_dict = {field[0]: field[1] for field in current_fields}

            # 构建 ALTER TABLE 语句
            alter_statements = []
            if data:
                for key, value in data.items():
                    field_type = type_mapping.get(type(value))
                    if field_type:
                        if key not in current_fields_dict and key not in reserved_fields:
                            alter_statements.append(f"ADD COLUMN `{key}` {field_type}")
                        elif key in current_fields_dict
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文子阳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值