|
- '''
- 执行数据库操作的线程
- '''
- import threading
- import queue
- import time
- from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
- import pymysql
- from DBUtils.PooledDB import PooledDB
- from concurrent.futures import ThreadPoolExecutor
- import json
- import shutil
- import g_config
- import common.sys_comm as sys_comm
- from common.sys_comm import (
- LOGDBG, LOGINFO, LOGWARN, LOGERR, EC
- )
- from common.sys_comm import get_bj_time_ms
- # ssh配置
- ssh_conf = {
- "ssh_host": "119.45.12.173",
- "ssh_port": 22,
- "ssh_user": "root",
- "ssh_pwd": "Hfln@147888",
- }
- service = {}
- db ={}
- # ===================== 全局对象 =====================
- # 请求队列
- db_req_que = queue.Queue()
- # 记录 SSH 隧道和数据库连接
- ssh_server = None
- # 连接池对象
- db_pool = None
- # 数据库线程是否运行标记
- db_worker_running = False
- # 数据库请求类
- class DBRequest_Async:
- def __init__(self, sql:str, params=None, callback=None, userdata=None):
- self.sql = sql
- self.params = params if params else ()
- self.callback = callback
- self.userdata = userdata
- class DBRequest_Sync(DBRequest_Async):
- def __init__(self, sql:str, params=None, callback=None, userdata=None):
- super().__init__(sql, params, callback, userdata)
- self._done_event = threading.Event()
- self._result = None
- self._exception = None
- def wait(self, timeout=None):
- """阻塞等待执行完成"""
- finished = self._done_event.wait(timeout)
- if not finished:
- raise TimeoutError("DBRequest_Sync timed out")
- if self._exception:
- raise self._exception
- return self._result
- def set_result(self, result):
- self._result = result
- self._done_event.set()
- def set_exception(self, e):
- self._exception = e
- self._done_event.set()
- # ========== 初始化配置 ==========
- def db_pro_init():
- global service, db
- with g_config.g_sys_conf_mtx:
- service = g_config.g_sys_conf["service"]
- db = g_config.g_sys_conf["db"]
- # ========== 初始化 SSH ==========
- def initialize_ssh_connection():
- global ssh_server
- if ssh_server is None or not ssh_server.is_active:
- with g_config.g_sys_conf_mtx:
- service = g_config.g_sys_conf["service"]
- db = g_config.g_sys_conf["db"]
- ssh_server = SSHTunnelForwarder(
- (service["ip"], 22),
- ssh_username=service["username"],
- ssh_password=service["password"],
- remote_bind_address=('localhost', 3306)
- )
- ssh_server.start()
- LOGINFO("SSH connected")
- # ========== 初始化连接池 ==========
- def initialize_connection_pool():
- global db_pool, ssh_server
- if g_config.g_sys_conf["platform"] == 0:
- initialize_ssh_connection()
- port = ssh_server.local_bind_port
- host = "localhost"
- else:
- port = 3306
- host = db["host"]
- db_pool = PooledDB(
- creator=pymysql,
- maxconnections=10,
- mincached=2,
- maxcached=5,
- blocking=True,
- host=host,
- port=port,
- user=db['username'],
- password=db['password'],
- database=db['database'],
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor
- )
- LOGINFO("DB connection pool initialized")
- # ========== 执行数据库请求 ==========
- def handle_db_request(db_request):
- conn = None
- try:
- conn = db_pool.connection()
- with conn.cursor() as cursor:
- cursor.execute(db_request.sql, db_request.params)
- sql_lower = db_request.sql.strip().lower()
- if sql_lower.startswith("select"):
- result = cursor.fetchall()
- elif sql_lower.startswith("insert"):
- result = {"lastrowid": cursor.lastrowid}
- else:
- result = {"rowcount": cursor.rowcount}
- # 执行回调
- if db_request.callback:
- try:
- db_request.callback(result, db_request.userdata)
- except Exception as e:
- LOGERR(f"[DB ERROR] 回调执行失败: {e}, sql: {db_request.sql}")
- if isinstance(db_request, DBRequest_Sync):
- db_request.set_result(result)
- # LOGINFO(f"[DB SUCCESS] SQL executed successfully: {db_request.sql}")
- conn.commit()
- except Exception as e:
- LOGERR(f"[DB ERROR] SQL执行失败: {e}, sql: {db_request.sql}")
- if isinstance(db_request, DBRequest_Sync):
- db_request.set_exception(e)
- finally:
- if conn:
- conn.close()
- # ========== 封装接口 ==========
- # 同步执行
- def db_execute_sync(sql: str, params=None, callback=None, userdata=None, timeout=5):
- """
- 如果传了 callback,会先执行 callback,再返回结果。
- 如果不传 callback,直接返回查询结果。
- 若timeout传入None会无限期等待(不建议)
- """
- if not db_worker_running:
- LOGERR("DB worker is not running, cannot execute sync request")
- return EC.EC_FAILED
- req = DBRequest_Sync(sql=sql, params=params, callback=callback, userdata=userdata)
- db_req_que.put(req)
- return req.wait(timeout=timeout)
- # 异步执行
- def db_execute_async(sql: str, params=None, callback=None, userdata=None):
- """
- callback: 可选,数据库操作完成后调用
- userdata: 可选,回调附带的用户数据
- """
- if not db_worker_running:
- LOGERR("DB worker is not running, cannot execute async request")
- return None
- req = DBRequest_Async(sql=sql, params=params, callback=callback, userdata=userdata)
- db_req_que.put(req)
- return req
- # ========== 主数据库线程 ==========
- def db_process():
- global db_worker_running
- db_worker_running = True
- db_pro_init()
- initialize_connection_pool()
- # 多线程执行器
- async_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="AsyncDBWorker")
- try:
- while True:
- db_request = db_req_que.get()
- if db_request is None:
- break
- try:
- if isinstance(db_request, DBRequest_Sync):
- # 同步操作
- handle_db_request(db_request)
- else:
- # 异步操作
- async_executor.submit(handle_db_request, db_request)
- except Exception as e:
- LOGERR(f"[DB Thread Error] {e}, sql: {db_request.sql}")
- finally:
- db_req_que.task_done()
- finally:
- # 收到退出信号后,关闭执行器
- async_executor.shutdown(wait=True)
- db_worker_running = False
- LOGERR("DB process exit gracefully")
- # 创建数据库线程
- def create_db_process():
- global db_thread
- db_thread = threading.Thread(target=db_process, daemon=True, name="DBWorkerThread")
- return db_thread
- # 停止数据库线程
- def stop_db_process():
- if db_worker_running:
- db_req_que.put(None)
- db_thread.join()
- LOGINFO("DB worker stopped")
- # ========== 示例 ==========
- # 处理数据库返回的结果
- def cb_handle_device_data(results, userdata):
- LOGDBG("Received results: {results}")
- # 示例请求生成器
- def request_generator():
- while True:
- sql_query = "SELECT * FROM dev_info" # 示例查询
- db_req_que.put(DBRequest_Async(sql=sql_query, callback=cb_handle_device_data))
- time.sleep(1) # 每秒生成一个请求
- def test_main():
- # 启动数据库线程
- db_thread = threading.Thread(target=db_process, daemon=True)
- db_thread.start()
- # 启动请求生成器
- request_gen_thread = threading.Thread(target=request_generator)
- request_gen_thread.start()
|