由于dataset是只支持同步操作,在性能方面有很大瓶颈,这里提供一个异步协程的插件。
pip install aiomysql
pip install nest_asyncio
import asyncio
import datetime
from types import NoneType
import aiomysql
import logging
import nest_asyncio
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 定义类型映射(这个类型映射,是为了防止在进行操作的时候,没有建表,用于进行自动建表)
type_mapping = {
int: 'INT',
float: 'FLOAT',
# str: 'VARCHAR(255)',
str: 'TEXT',
bool: 'BOOLEAN',
datetime.datetime: 'DATETIME',
datetime.date: 'DATE',
None: 'TEXT',
NoneType: 'TEXT'
}
async def create_pool():
"""
创建数据库连接池
"""
pool = await aiomysql.create_pool(
host='localhost',
port=3306,
user='root',
password='123456',
db='zhihu',
autocommit=True, # 自动提交事务
minsize=1, # 连接池最小连接数
maxsize=10 # 连接池最大连接数
)
return pool
async def close_pool(pool):
"""
关闭连接池
:param pool: 数据库连接池
"""
pool.close()
await pool.wait_closed()
async def exists_table(pool, table_name: str):
"""
检测表是否存在
:param pool: 数据库连接池
:param table_name: 数据库表名
:return: 表是否存在,True 表示存在,False 表示不存在
"""
# 构建 SQL 检测表是否存在语句
sql = f'''
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
'''
logging.info(sql % (table_name,))
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 检测表是否存在语句
await cur.execute(sql, (table_name,))
# 获取查询结果
result = await cur.fetchone()
return result[0] > 0 # 返回表是否存在
async def create_table(pool, table_name: str, data: dict):
"""
创建表
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data: 数据字典,键为字段名,值为字段类型
"""
# 自动添加主键 id
columns = ['id INT AUTO_INCREMENT PRIMARY KEY']
# 根据 data 字典中的数据类型生成字段定义
if data:
for key, value in data.items():
field_type = type_mapping.get(type(value))
if field_type:
columns.append(f"`{key}` {field_type}")
else:
raise ValueError(f"Unsupported data type for field '{key}': {type(value)}")
# 自动添加创建时间和更新时间字段
columns.append('created_at DATETIME DEFAULT CURRENT_TIMESTAMP')
columns.append('updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')
# 构建 SQL 创建表语句
sql = f'''
CREATE TABLE `{table_name}` (
{', '.join(columns)}
)
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 创建表语句
await cur.execute(sql)
async def drop_table(pool, table_name: str):
"""
删除表
:param pool: 数据库连接池
:param table_name: 数据库表名
"""
# 构建 SQL 删除表语句
sql = f'''
DROP TABLE IF EXISTS `{table_name}`
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 删除表语句
await cur.execute(sql)
async def truncate_table(pool, table_name: str):
"""
截断表
:param pool: 数据库连接池
:param table_name: 数据库表名
"""
# 构建 SQL 截断表语句
sql = f'''
TRUNCATE TABLE `{table_name}`
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 截断表语句
await cur.execute(sql)
async def update_table(pool, table_name: str, data: dict):
"""
更新表结构
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data: 数据字典,键为字段名,值为字段类型
"""
# 保留的字段
reserved_fields = {'id', 'created_at', 'updated_at'}
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 获取当前表的字段信息
await cur.execute(f"DESCRIBE `{table_name}`")
current_fields = await cur.fetchall()
current_fields_dict = {field[0]: field[1] for field in current_fields}
# 构建 ALTER TABLE 语句
alter_statements = []
if data:
for key, value in data.items():
field_type = type_mapping.get(type(value))
if field_type:
if key not in current_fields_dict and key not in reserved_fields:
alter_statements.append(f"ADD COLUMN `{key}` {field_type}")
elif key in current_fields_dict