''' 执行数据库操作的线程 ''' import threading import queue import time from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError import pymysql from DBUtils.PooledDB import PooledDB from concurrent.futures import ThreadPoolExecutor import json import shutil import common.sys_comm as sys_comm from common.sys_comm import LOGDBG, LOGINFO, LOGWARN, LOGERR from common.sys_comm import get_bj_time_ms # ssh配置 ssh_conf = { "ssh_host": "119.45.12.173", "ssh_port": 22, "ssh_user": "root", "ssh_pwd": "Hfln@667788", } # 数据库配置 db_config = { # 数据库相关参数 "host": "localhost", "user": "root", "password": "Hfln@1024", "database": "lnxx_dev" } # ===================== 全局对象 ===================== # 请求队列 db_req_que = queue.Queue() # 记录 SSH 隧道和数据库连接 ssh_server = None # 连接池对象 db_pool = None # 数据库线程是否运行标记 db_worker_running = False # 数据库请求类 class DBRequest_Async: def __init__(self, sql:str, params=None, callback=None, userdata=None): self.sql = sql self.params = params if params else () self.callback = callback self.userdata = userdata class DBRequest_Sync(DBRequest_Async): def __init__(self, sql:str, params=None, callback=None, userdata=None): super().__init__(sql, params, callback, userdata) self._done_event = threading.Event() self._result = None self._exception = None def wait(self, timeout=None): """阻塞等待执行完成""" finished = self._done_event.wait(timeout) if not finished: raise TimeoutError("DBRequest_Sync timed out") if self._exception: raise self._exception return self._result def set_result(self, result): self._result = result self._done_event.set() def set_exception(self, e): self._exception = e self._done_event.set() # ========== 初始化配置 ========== def db_pro_init(): global ssh_conf, db_config with sys_comm.g_sys_conf_mtx: ssh_conf = { "ssh_host": sys_comm.g_sys_conf["ssh_host"], "ssh_port": sys_comm.g_sys_conf["ssh_port"], "ssh_user": sys_comm.g_sys_conf["ssh_username"], "ssh_pwd": sys_comm.g_sys_conf["ssh_password"], } db_config = { "host": sys_comm.g_sys_conf["db_host"], "user": sys_comm.g_sys_conf["db_username"], "password": sys_comm.g_sys_conf["db_password"], "database": "lnxx_dev" } # ========== 初始化 SSH ========== def initialize_ssh_connection(): global ssh_server if ssh_server is None or not ssh_server.is_active: ssh_server = SSHTunnelForwarder( (ssh_conf["ssh_host"], ssh_conf["ssh_port"]), ssh_username=ssh_conf["ssh_user"], ssh_password=ssh_conf["ssh_pwd"], remote_bind_address=('127.0.0.1', 3306) ) ssh_server.start() LOGINFO("SSH connected") # ========== 初始化连接池 ========== def initialize_connection_pool(): global db_pool, ssh_server if sys_comm.g_sys_conf["platform"] == 0: initialize_ssh_connection() port = ssh_server.local_bind_port host = "127.0.0.1" else: port = 3306 host = db_config["host"] db_pool = PooledDB( creator=pymysql, maxconnections=10, mincached=2, maxcached=5, blocking=True, host=host, port=port, user=db_config['user'], password=db_config['password'], database=db_config['database'], charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) LOGINFO("DB connection pool initialized") # ========== 执行数据库请求 ========== def handle_db_request(db_request): conn = None try: conn = db_pool.connection() with conn.cursor() as cursor: cursor.execute(db_request.sql, db_request.params) sql_lower = db_request.sql.strip().lower() if sql_lower.startswith("select"): result = cursor.fetchall() elif sql_lower.startswith("insert"): result = {"lastrowid": cursor.lastrowid} else: result = {"rowcount": cursor.rowcount} # 执行回调 if db_request.callback: try: db_request.callback(result, db_request.userdata) except Exception as e: LOGERR(f"[DB ERROR] 回调执行失败: {e}, sql: {db_request.sql}") if isinstance(db_request, DBRequest_Sync): db_request.set_result(result) # LOGINFO(f"[DB SUCCESS] SQL executed successfully: {db_request.sql}") conn.commit() except Exception as e: LOGERR(f"[DB ERROR] SQL执行失败: {e}, sql: {db_request.sql}") if isinstance(db_request, DBRequest_Sync): db_request.set_exception(e) finally: if conn: conn.close() # ========== 封装接口 ========== # 同步执行 def db_execute_sync(sql: str, params=None, callback=None, userdata=None, timeout=5): """ 如果传了 callback,会先执行 callback,再返回结果。 如果不传 callback,直接返回查询结果。 若timeout传入None会无限期等待(不建议) """ if not db_worker_running: LOGERR("DB worker is not running, cannot execute sync request") return -1 req = DBRequest_Sync(sql=sql, params=params, callback=callback, userdata=userdata) db_req_que.put(req) return req.wait(timeout=timeout) # 异步执行 def db_execute_async(sql: str, params=None, callback=None, userdata=None): """ callback: 可选,数据库操作完成后调用 userdata: 可选,回调附带的用户数据 """ if not db_worker_running: LOGERR("DB worker is not running, cannot execute async request") return None req = DBRequest_Async(sql=sql, params=params, callback=callback, userdata=userdata) db_req_que.put(req) return req # ========== 主数据库线程 ========== def db_process(): global db_worker_running db_worker_running = True db_pro_init() initialize_connection_pool() # 单线程执行器 sync_executor = ThreadPoolExecutor(max_workers=1) # 多线程执行器 async_executor = ThreadPoolExecutor(max_workers=8) # 限制线程并发数 try: while True: db_request = db_req_que.get() if db_request is None: break try: if isinstance(db_request, DBRequest_Sync): # 同步操作 handle_db_request(db_request) # sync_executor.submit(handle_db_request, db_request) else: # 异步操作 async_executor.submit(handle_db_request, db_request) except Exception as e: LOGERR(f"[DB Thread Error] {e}, sql: {db_request.sql}") finally: db_req_que.task_done() finally: # 收到退出信号后,关闭执行器 sync_executor.shutdown(wait=True) async_executor.shutdown(wait=True) db_worker_running = False LOGERR("DB process exit gracefully") # 创建数据库线程 def create_db_process(): global db_thread db_thread = threading.Thread(target=db_process, daemon=True) return db_thread # 停止数据库线程 def stop_db_process(): if db_worker_running: db_req_que.put(None) db_thread.join() LOGINFO("DB worker stopped") # ========== 示例 ========== # 处理数据库返回的结果 def cb_handle_device_data(results, userdata): LOGDBG("Received results: {results}") # 示例请求生成器 def request_generator(): while True: sql_query = "SELECT * FROM dev_info" # 示例查询 db_req_que.put(DBRequest_Async(sql=sql_query, callback=cb_handle_device_data)) time.sleep(1) # 每秒生成一个请求 def test_main(): # 启动数据库线程 db_thread = threading.Thread(target=db_process, daemon=True) db_thread.start() # 启动请求生成器 request_gen_thread = threading.Thread(target=request_generator) request_gen_thread.start()