db_process.py 7.7 KB


  1. '''
  2. 执行数据库操作的线程
  3. '''
  4. import threading
  5. import queue
  6. import time
  7. from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
  8. import pymysql
  9. from DBUtils.PooledDB import PooledDB
  10. from concurrent.futures import ThreadPoolExecutor
  11. import json
  12. import shutil
  13. import g_config
  14. import common.sys_comm as sys_comm
  15. from common.sys_comm import (
  16. LOGDBG, LOGINFO, LOGWARN, LOGERR, EC
  17. )
  18. from common.sys_comm import get_bj_time_ms
  19. # ssh配置
  20. ssh_conf = {
  21. "ssh_host": "119.45.12.173",
  22. "ssh_port": 22,
  23. "ssh_user": "root",
  24. "ssh_pwd": "Hfln@147888",
  25. }
  26. service = {}
  27. db ={}
  28. # ===================== 全局对象 =====================
  29. # 请求队列
  30. db_req_que = queue.Queue()
  31. # 记录 SSH 隧道和数据库连接
  32. ssh_server = None
  33. # 连接池对象
  34. db_pool = None
  35. # 数据库线程是否运行标记
  36. db_worker_running = False
  37. # 数据库请求类
  38. class DBRequest_Async:
  39. def __init__(self, sql:str, params=None, callback=None, userdata=None):
  40. self.sql = sql
  41. self.params = params if params else ()
  42. self.callback = callback
  43. self.userdata = userdata
  44. class DBRequest_Sync(DBRequest_Async):
  45. def __init__(self, sql:str, params=None, callback=None, userdata=None):
  46. super().__init__(sql, params, callback, userdata)
  47. self._done_event = threading.Event()
  48. self._result = None
  49. self._exception = None
  50. def wait(self, timeout=None):
  51. """阻塞等待执行完成"""
  52. finished = self._done_event.wait(timeout)
  53. if not finished:
  54. raise TimeoutError("DBRequest_Sync timed out")
  55. if self._exception:
  56. raise self._exception
  57. return self._result
  58. def set_result(self, result):
  59. self._result = result
  60. self._done_event.set()
  61. def set_exception(self, e):
  62. self._exception = e
  63. self._done_event.set()
  64. # ========== 初始化配置 ==========
  65. def db_pro_init():
  66. global service, db
  67. with g_config.g_sys_conf_mtx:
  68. service = g_config.g_sys_conf["service"]
  69. db = g_config.g_sys_conf["db"]
  70. # ========== 初始化 SSH ==========
  71. def initialize_ssh_connection():
  72. global ssh_server
  73. if ssh_server is None or not ssh_server.is_active:
  74. with g_config.g_sys_conf_mtx:
  75. service = g_config.g_sys_conf["service"]
  76. db = g_config.g_sys_conf["db"]
  77. ssh_server = SSHTunnelForwarder(
  78. (service["ip"], 22),
  79. ssh_username=service["username"],
  80. ssh_password=service["password"],
  81. remote_bind_address=('localhost', 3306)
  82. )
  83. ssh_server.start()
  84. LOGINFO("SSH connected")
  85. # ========== 初始化连接池 ==========
  86. def initialize_connection_pool():
  87. global db_pool, ssh_server
  88. if g_config.g_sys_conf["platform"] == 0:
  89. initialize_ssh_connection()
  90. port = ssh_server.local_bind_port
  91. host = "localhost"
  92. else:
  93. port = 3306
  94. host = db["host"]
  95. db_pool = PooledDB(
  96. creator=pymysql,
  97. maxconnections=10,
  98. mincached=2,
  99. maxcached=5,
  100. blocking=True,
  101. host=host,
  102. port=port,
  103. user=db['username'],
  104. password=db['password'],
  105. database=db['database'],
  106. charset='utf8mb4',
  107. cursorclass=pymysql.cursors.DictCursor
  108. )
  109. LOGINFO("DB connection pool initialized")
  110. # ========== 执行数据库请求 ==========
  111. def handle_db_request(db_request):
  112. conn = None
  113. try:
  114. conn = db_pool.connection()
  115. with conn.cursor() as cursor:
  116. cursor.execute(db_request.sql, db_request.params)
  117. sql_lower = db_request.sql.strip().lower()
  118. if sql_lower.startswith("select"):
  119. result = cursor.fetchall()
  120. elif sql_lower.startswith("insert"):
  121. result = {"lastrowid": cursor.lastrowid}
  122. else:
  123. result = {"rowcount": cursor.rowcount}
  124. # 执行回调
  125. if db_request.callback:
  126. try:
  127. db_request.callback(result, db_request.userdata)
  128. except Exception as e:
  129. LOGERR(f"[DB ERROR] 回调执行失败: {e}, sql: {db_request.sql}")
  130. if isinstance(db_request, DBRequest_Sync):
  131. db_request.set_result(result)
  132. # LOGINFO(f"[DB SUCCESS] SQL executed successfully: {db_request.sql}")
  133. conn.commit()
  134. except Exception as e:
  135. LOGERR(f"[DB ERROR] SQL执行失败: {e}, sql: {db_request.sql}")
  136. if isinstance(db_request, DBRequest_Sync):
  137. db_request.set_exception(e)
  138. finally:
  139. if conn:
  140. conn.close()
  141. # ========== 封装接口 ==========
  142. # 同步执行
  143. def db_execute_sync(sql: str, params=None, callback=None, userdata=None, timeout=5):
  144. """
  145. 如果传了 callback,会先执行 callback,再返回结果。
  146. 如果不传 callback,直接返回查询结果。
  147. 若timeout传入None会无限期等待(不建议)
  148. """
  149. if not db_worker_running:
  150. LOGERR("DB worker is not running, cannot execute sync request")
  151. return EC.EC_FAILED
  152. req = DBRequest_Sync(sql=sql, params=params, callback=callback, userdata=userdata)
  153. db_req_que.put(req)
  154. return req.wait(timeout=timeout)
  155. # 异步执行
  156. def db_execute_async(sql: str, params=None, callback=None, userdata=None):
  157. """
  158. callback: 可选,数据库操作完成后调用
  159. userdata: 可选,回调附带的用户数据
  160. """
  161. if not db_worker_running:
  162. LOGERR("DB worker is not running, cannot execute async request")
  163. return None
  164. req = DBRequest_Async(sql=sql, params=params, callback=callback, userdata=userdata)
  165. db_req_que.put(req)
  166. return req
  167. # ========== 主数据库线程 ==========
  168. def db_process():
  169. global db_worker_running
  170. db_worker_running = True
  171. db_pro_init()
  172. initialize_connection_pool()
  173. # 多线程执行器
  174. async_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="AsyncDBWorker")
  175. try:
  176. while True:
  177. db_request = db_req_que.get()
  178. if db_request is None:
  179. break
  180. try:
  181. if isinstance(db_request, DBRequest_Sync):
  182. # 同步操作
  183. handle_db_request(db_request)
  184. else:
  185. # 异步操作
  186. async_executor.submit(handle_db_request, db_request)
  187. except Exception as e:
  188. LOGERR(f"[DB Thread Error] {e}, sql: {db_request.sql}")
  189. finally:
  190. db_req_que.task_done()
  191. finally:
  192. # 收到退出信号后,关闭执行器
  193. async_executor.shutdown(wait=True)
  194. db_worker_running = False
  195. LOGERR("DB process exit gracefully")
  196. # 创建数据库线程
  197. def create_db_process():
  198. global db_thread
  199. db_thread = threading.Thread(target=db_process, daemon=True, name="DBWorkerThread")
  200. return db_thread
  201. # 停止数据库线程
  202. def stop_db_process():
  203. if db_worker_running:
  204. db_req_que.put(None)
  205. db_thread.join()
  206. LOGINFO("DB worker stopped")
  207. # ========== 示例 ==========
  208. # 处理数据库返回的结果
  209. def cb_handle_device_data(results, userdata):
  210. LOGDBG("Received results: {results}")
  211. # 示例请求生成器
  212. def request_generator():
  213. while True:
  214. sql_query = "SELECT * FROM dev_info" # 示例查询
  215. db_req_que.put(DBRequest_Async(sql=sql_query, callback=cb_handle_device_data))
  216. time.sleep(1) # 每秒生成一个请求
  217. def test_main():
  218. # 启动数据库线程
  219. db_thread = threading.Thread(target=db_process, daemon=True)
  220. db_thread.start()
  221. # 启动请求生成器
  222. request_gen_thread = threading.Thread(target=request_generator)
  223. request_gen_thread.start()