首页 > 代码库 > DBUtils 和 pymysql 结合的简单封装

DBUtils 和 pymysql 结合的简单封装

1.使用 xml.dom.minidom 解析 xml

  话不多说,直接上代码。

 1 import sys
 2 import re
 3 import pymysql
 4 import xml.dom.minidom
 5 
 6 from xml.dom.minidom import parse
 7 
 8 class ConfigurationParser(object):
 9     """
10     解析xml
11     - @return configDict = {"jdbcConnectionDictList":jdbcConnectionDictList,"tableList":tableList}
12     """
13     def __init__(self, configFilePath=None):
14         if configFilePath:
15             self.__configFilePath = configFilePath
16         else:
17             self.__configFilePath = sys.path[0] + "/config/config.xml"
18         pass
19 
20     def parseConfiguration(self):
21         """
22         解析xml,返回jdbc配置信息以及需要生成python对象的表集合
23         """
24         # 解析xml文件,获取Document对象
25         DOMTree = xml.dom.minidom.parse(self.__configFilePath)
26         # 获取 generatorConfiguration 节点的NodeList对象
27         configDOM = DOMTree.getElementsByTagName("generatorConfiguration")[0]
28 
29         # 获取 jdbcConnection 节点的 property 节点集合
30         jdbcConnectionPropertyList = configDOM.getElementsByTagName("jdbcConnection")[0].getElementsByTagName("property")
31         # 循环 jdbcConnection 节点的 property 节点集合,获取属性名称和属性值
32         jdbcConnectionDict = {}
33         for property in jdbcConnectionPropertyList:
34             name = property.getAttributeNode("name").nodeValue.strip().lower()
35             if property.hasAttribute("value"):
36                 value = http://www.mamicode.com/property.getAttributeNode("value").nodeValue
37                 if re.match("[0-9]",value) and name != "password" and name != "host":
38                     value =http://www.mamicode.com/ int(value)
39             else:
40                 value =http://www.mamicode.com/ property.childNodes[0].data
41                 if re.match("[0-9]",value) and name != "password" and name != "host":
42                     value =http://www.mamicode.com/ int(value)
43             if name == "charset":
44                 if re.match("utf-8|UTF8", value, re.I):
45                     continue
46             elif name == "port":
47                 value =http://www.mamicode.com/ int(value)
48             elif name == "creator":
49                 if value =http://www.mamicode.com/= "pymysql":
50                     value =http://www.mamicode.com/ pymysql
51             jdbcConnectionDict[name] = value
52         # print(jdbcConnectionDict)
53         return jdbcConnectionDict
54 
55 if __name__ == "__main__":
56     print(ConfigurationParser().parseConfiguration())

  config.xml

 1 <?xml version="1.0" encoding="utf-8"?>
 2 <generatorConfiguration>
 3     <jdbcConnection>
 4         <property name="creator">pymysql</property>
 5         <property name="host">127.0.0.1</property>
 6         <property name="database">rcddup</property>
 7         <property name="port">3306</property>
 8         <property name="user">root</property>
 9         <property name="password">root</property>
10         <property name="charset">Utf-8</property>
11         <property name="mincached">0</property>
12         <property name="maxcached">10</property>
13         <property name="maxshared">0</property>
14         <property name="maxconnections">20</property>
15     </jdbcConnection>
16 </generatorConfiguration>

  通过调用Python内置的 xml.dom.minidom 对 xml 文件进行解析,获取xml内容。

2.BaseDao

  BaseDao是在 DBUtils 的基础上对 pymysql 操作数据库进行了一些简单的封装。

  其中 queryUtil 用于拼接SQL语句,log4py用于控制台输出信息,page 分页对象。

  1 import pymysql
  2 import time
  3 import json
  4 
  5 from DBUtils.PooledDB import PooledDB
  6 from configParser import ConfigurationParser
  7 from queryUtil import QueryUtil
  8 from log4py import Logger
  9 from page import Page
 10 
 11 
 12 global PRIMARY_KEY_DICT_LIST
 13 PRIMARY_KEY_DICT_LIST = []
 14 
 15 class BaseDao(object):
 16     """
 17     Python 操作数据库基类方法
 18     - @Author RuanCheng
 19     - @UpdateDate 2017/5/17
 20     """
 21     __logger = None
 22     __parser = None                 # 获取 xml 文件信息对象
 23     __poolConfigDict = None         # 从 xml 中获取的数据库连接信息的字典对象
 24     __pool = None                   # 数据库连接池
 25     __obj = None                    # 实体类
 26     __className = None              # 实体类类名
 27     __tableName = None              # 实体类对应的数据库名
 28     __primaryKeyDict = {}           # 数据库表的主键字典对象
 29     __columnList = []
 30 
 31     def __init__(self, obj=None):
 32         """
 33         初始化方法:
 34         - 1.初始化配置信息
 35         - 2.初始化 className
 36         - 3.初始化数据库表的主键
 37         """
 38         if not obj:
 39             raise Exception("BaseDao is missing a required parameter --> obj(class object).\nFor example [super().__init__(User)].")
 40         else:
 41             self.__logger = Logger(self.__class__)                                      # 初始化日志对象
 42             self.__logger.start()                                                       # 开启日志
 43             if not self.__parser:                                                       # 解析 xml
 44                 self.__parser = ConfigurationParser()
 45                 self.__poolConfigDict = self.__parser.parseConfiguration()
 46                 self.__pool = PooledDB(**self.__poolConfigDict)
 47             # 初始化参数
 48             if (self.__obj == None) or ( self.__obj != obj):
 49                 global PRIMARY_KEY_DICT_LIST
 50                 if (not PRIMARY_KEY_DICT_LIST) or (PRIMARY_KEY_DICT_LIST.count == 0):
 51                     self.__init_primary_key_dict_list()                                 # 初始化主键字典列表
 52                 self.__init_params(obj)                                                 # 初始化参数
 53                 self.__init_columns()                                                   # 初始化字段列表
 54                 self.__logger.end()                                                     # 结束日志
 55         pass
 56     ################################################# 外部调用方法 #################################################
 57     def selectAll(self):
 58         """
 59         查询所有
 60         """
 61         sql = QueryUtil.queryAll(self.__tableName, self.__columnList)
 62         return self.__executeQuery(sql)
 63 
 64     def selectByPrimaryKey(self, value):
 65         """
 66         按主键查询
 67         - @Param: value 主键
 68         """
 69         if (not value) or (value =http://www.mamicode.com/= ""):
 70             raise Exception("selectByPrimaryKey() is missing a required paramter ‘value‘.")
 71         sql = QueryUtil.queryByPrimaryKey(self.__primaryKeyDict, value, self.__columnList)
 72         return self.__executeQuery(sql)
 73 
 74     def selectCount(self):
 75         """
 76         查询总记录数
 77         """
 78         sql = QueryUtil.queryCount(self.__tableName);
 79         return self.__execute(sql)[0][0]
 80 
 81     def selectAllByPage(self, page=None):
 82         """
 83         分页查询
 84         """
 85         if (not page) or (not isinstance(page,Page)):
 86             raise Exception("Paramter [page] is not correct. Parameter [page] must a Page object instance. ")
 87         sql = QueryUtil.queryAllByPage(self.__tableName, self.__columnList, page)
 88         return self.__executeQuery(sql, logEnable=True)
 89 
 90     def insert(self, obj):
 91         """
 92         新增
 93         - @Param: obj 实体对象
 94         """
 95         if (not obj) or (obj == ""):
 96             raise Exception("insert() is missing a required paramter ‘obj‘.")
 97         sql = QueryUtil.queryInsert(self.__primaryKeyDict, json.loads(str(obj)))
 98         return self.__executeUpdate(sql)
 99     
100     def delete(self, obj=None):
101         """
102         根据实体删除
103         - @Param: obj 实体对象
104         """
105         if (not obj) or (obj == ""):
106             raise Exception("delete() is missing a required paramter ‘obj‘.")
107         sql = QueryUtil.queryDelete(self.__primaryKeyDict, json.loads(str(obj)))
108         return self.__executeUpdate(sql)
109 
110     def deleteByPrimaryKey(self, value=http://www.mamicode.com/None):
111         """
112         根据主键删除
113         - @Param: value 主键
114         """
115         if (not value) or (value =http://www.mamicode.com/= ""):
116             raise Exception("deleteByPrimaryKey() is missing a required paramter ‘value‘.")
117         sql = QueryUtil.queryDeleteByPrimaryKey(self.__primaryKeyDict, value)
118         return self.__executeUpdate(sql)
119     
120     def updateByPrimaryKey(self, obj=None):
121         """
122         根据主键更新
123         - @Param: obj 实体对象
124         """
125         if (not obj) or (obj == ""):
126             raise Exception("updateByPrimaryKey() is missing a required paramter ‘obj‘.")
127         sql = QueryUtil.queryUpdateByPrimaryKey(self.__primaryKeyDict, json.loads(str(obj)))
128         return self.__executeUpdate(sql)
129 
130     ################################################# 内部调用方法 #################################################
131     def __execute(self, sql="", logEnable=True):
132         """
133         执行 SQL 语句(用于内部初始化参数使用):
134         - @Param: sql 执行sql
135         - @Param: logEnable 是否开启输出日志
136         - @return 查询结果
137         """
138         if not sql:
139             raise Exception("Execute method is missing a required parameter --> sql.")
140         try:
141             self.__logger.outSQL(sql, enable=logEnable)
142             conn = self.__pool.connection()
143             cur = conn.cursor()
144             cur.execute(sql)
145             result = cur.fetchall()
146             resultList = []
147             for r in result:
148                 resultList.append(r)
149             return resultList
150         except Exception as e:
151             conn.rollback()
152             raise Exception(e)
153         finally:
154             cur.close()
155             conn.close()
156             pass
157 
158     def __executeQuery(self, sql="", logEnable=True):
159         """
160         执行查询 SQL 语句:
161         - @Param: sql 执行sql
162         - @Param: logEnable 是否开启输出日志
163         - @return 查询结果
164         """
165         if not sql:
166             raise Exception("Execute method is missing a required parameter --> sql.")
167         try:
168             self.__logger.outSQL(sql, enable=logEnable)
169             conn = self.__pool.connection()
170             cur = conn.cursor()
171             cur.execute(sql)
172             resultList = list(cur.fetchall())
173             objList = []
174             for result in resultList:
175                 i = 0
176                 obj = self.__obj()
177                 for col in self.__columnList:
178                     obj.__setattr__(col, result[i])
179                 objList.append(obj)
180             if not objList:
181                 return None
182             elif objList and objList.__len__ == 1:
183                 return objList[0]
184             else:
185                 return objList
186         except Exception as e:
187             conn.rollback()
188             raise Exception(e)
189         finally:
190             cur.close()
191             conn.close()
192             pass
193     
194     def __executeUpdate(self, sql=None, logEnable=True):
195         """
196         执行修改 SQL 语句:
197         - @Param: sql 执行sql
198         - @Param: logEnable 是否开启输出日志
199         - @return 影响行数
200         """
201         try:
202             self.__logger.outSQL(sql, enable=logEnable)
203             conn = self.__pool.connection()
204             cur = conn.cursor()
205             return cur.execute(sql)
206             pass
207         except Exception as e:
208             conn.rollback()
209             raise Exception(e)
210             pass
211         finally:
212             conn.commit()
213             cur.close()
214             conn.close()
215             pass
216 
217     def __init_params(self, obj):
218         """
219         初始化参数
220         - @Param:obj class 对象
221         """
222         self.__obj = obj
223         self.__className = obj.__name__
224         for i in PRIMARY_KEY_DICT_LIST:
225             if i.get("className") == self.__className:
226                 self.__primaryKeyDict = i
227                 self.__className = i["className"]
228                 self.__tableName = i["tableName"]
229                 break
230 
231     def __init_primary_key_dict_list(self):
232         """
233         初始化数据库主键集合:
234         - pk_dict = {"className": {"tableName":tableName,"primaryKey":primaryKey,"auto_increment":auto_increment}}
235         """
236         global PRIMARY_KEY_DICT_LIST
237         sql = """
238             SELECT
239                 t.TABLE_NAME,
240                 c.COLUMN_NAME,
241                 c.ORDINAL_POSITION
242             FROM
243                 INFORMATION_SCHEMA.TABLE_CONSTRAINTS as t,
244                 INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS c
245             WHERE t.TABLE_NAME = c.TABLE_NAME
246                 AND t.TABLE_SCHEMA = "%s"
247                 AND c.CONSTRAINT_SCHEMA = "%s"
248         """%(self.__poolConfigDict.get("database"),self.__poolConfigDict.get("database"))
249         resultList = self.__execute(sql, logEnable=False)
250         for result in resultList:
251             pk_dict = dict()
252             pk_dict["tableName"] = result[0]
253             pk_dict["primaryKey"] = result[1]
254             pk_dict["ordinalPosition"] = result[2]
255             pk_dict["className"] = self.__convertToClassName(result[0])
256             PRIMARY_KEY_DICT_LIST.append(pk_dict)
257         self.__logger.outMsg("initPrimaryKey is done.")
258 
259     def __init_columns(self):
260         """
261         初始化表字段
262         """
263         sql = "SELECT column_name FROM  Information_schema.columns WHERE table_Name = ‘%s‘"%(self.__tableName)
264         resultList = self.__execute(sql, logEnable=False)
265         for result in resultList:
266             self.__columnList.append(result)
267         self.__logger.outMsg("init_columns is done.")
268         # print(self.__columnList)
269         pass
270 
271     def __convertToClassName(self, tableName):
272         """
273         表名转换方法(配置自己特定的数据库表明前缀):
274         - @Param: tableName 表名
275         - @return 转换后的类名
276         """
277         result = None
278         if tableName.startswith("t_md_"):
279             result = tableName.replace("t_md_", "").replace("_","").lower()
280         elif tableName.startswith("t_ac_"):
281             result = tableName.replace("t_ac_","").replace("_","").lower()
282         elif tableName.startswith("t_"):
283             result = tableName.replace("t_","").replace("_","").lower()
284         else:
285             result = tableName
286         return result.capitalize()

 3.简单应用 UserDao

  创建以个 UserDao,继承BaseDao之后调用父类初始化方法,传递一个 User 对象给父类,我们就可以很方便的对 User 进行CRUD了。

 1 import random
 2 import math
 3 
 4 from baseDao import BaseDao
 5 from user import User
 6 from page import Page
 7 
 8 
 9 class UserDao(BaseDao):
10 
11     def __init__(self):
12         super().__init__(User)
13         pass
14         
15 userDao = UserDao()
16 ######################################## CRUD
17 
18 # print(userDao.selectAll())
19 # user = userDao.selectByPrimaryKey(1)
20 
21 # print(userDao.insert(user))
22 
23 # print(userDao.delete(user))
24 # print(userDao.deleteByPrimaryKey(4))
25 
26 # user = userDao.selectByPrimaryKey(1)
27 # print(userDao.updateByPrimaryKey())
28 # print(userDao.update())
29 
30 ######################################## 根据主键更新
31 
32 # strList = list("赵倩顺利王五张三赵丽历史李四八哥大哈萨克龙卷风很快乐节哀顺变风山东矿机感觉到付款了合法更不能")
33 # for index in range(1000):
34 #     user = User()
35 #     user.set_id(index+1)
36 #     name = ""
37 #     for i in range(random.randint(3,8)):
38 #         r = random.randint(0, len(strList)-1)
39 #         name += strList[r]
40 #     user.set_name(name)
41 #     user.set_status(1)
42 #     i += 1
43 #     userDao.updateByPrimaryKey(user)
44 
45 ######################################## 更新
46 
47 # user = User()
48 # user.set_id(2)
49 # user.set_name("测试更新")
50 # userDao.updateByPrimaryKey(user)
51 
52 ######################################## 分页查询
53 
54 # page = Page()
55 # pageNum = 1
56 # limit = 10
57 # page.set_page(pageNum)
58 # page.set_limit(limit)
59 # total_count = userDao.selectCount()
60 # page.set_total_count(total_count)
61 # if total_count % limit == 0:
62 #     total_page = total_count / limit
63 # else:
64 #     total_page = math.ceil(total_count / limit)
65 # page.set_total_page(total_page)
66 # begin = (pageNum - 1) * limit
67 
68 # print(userDao.selectAllByPage(page))

4. User

  User 对象属性设置为私有,通过 get/set 方法访问,最后重写 __str__() 方法,用于 BaseDao 返回 User 对象,而不是一个字典对象或者字符串什么的。

 1 import json
 2 
 3 class User(object):
 4 
 5     def __init__(self):
 6         self.__id = None
 7         self.__name = None
 8         self.__status = None
 9         pass
10 
11     def get_id(self):
12         return self.__id
13 
14     def set_id(self, id):
15         self.__id = id
16 
17     def get_name(self):
18         return self.__name
19 
20     def set_name(self, name):
21         self.__name = name
22 
23     def get_status(self):
24         return self.__status
25 
26     def set_status(self, status):
27         self.__status = status
28 
29 
30     def __str__(self):
31         userDict = {id:self.__id,name:self.__name,status:self.__status}
32         return json.dumps(userDict)

 

5.QueryUtil

  拼接 SQL 语句的工具类。

  1 from page import Page
  2 
  3 class QueryUtil(object):
  4 
  5     def __init__(self):
  6         pass
  7     
  8     @staticmethod
  9     def queryColumns(columnList):
 10         i = 1
 11         s = ""
 12         for col in columnList:
 13             if i != 1:
 14                 s += ", `%s`"%(col)
 15             else:
 16                 s += "`%s`"%(col)
 17             i += 1
 18         return s
 19     @staticmethod    
 20     def queryByPrimaryKey(primaryKeyDict, value, columnList):
 21         """
 22         拼接主键查询
 23         """
 24         sql = SELECT %s FROM `%s` WHERE `%s`="%s"%(QueryUtil.queryColumns(columnList), primaryKeyDict["tableName"], primaryKeyDict["primaryKey"], str(value))
 25         return sql
 26 
 27     @staticmethod
 28     def queryAll(tableName, columnList):
 29         """
 30         拼接查询所有
 31         """
 32         return SELECT %s FROM %s%(QueryUtil.queryColumns(columnList), tableName)
 33 
 34     @staticmethod
 35     def queryCount(tableName):
 36         """
 37         拼接查询记录数
 38         """
 39         return SELECT COUNT(*) FROM %s%(tableName)
 40 
 41     @staticmethod
 42     def queryAllByPage(tableName, columnList, page=None):
 43         """
 44         拼接分页查询
 45         """
 46         if not page:
 47             page = Page()
 48         return SELECT %s FROM %s LIMIT %d,%d%(QueryUtil.queryColumns(columnList), tableName, page.get_begin(), page.get_limit())
 49 
 50 
 51     @staticmethod
 52     def queryInsert(primaryKeyDict, objDict):
 53         """
 54         拼接新增
 55         """
 56         tableName = primaryKeyDict["tableName"]
 57         key = primaryKeyDict["primaryKey"]
 58         columns = list(objDict.keys())
 59         values = list(objDict.values())
 60 
 61         sql = "INSERT INTO `%s`("%(tableName)
 62         for i in range(0, columns.__len__()):
 63             if i == 0:
 64                 sql += `%s`%(columns[i])
 65             else:
 66                 sql += ,`%s`%(columns[i])
 67         sql += ) VALUES(
 68         for i in range(0, values.__len__()):
 69             if values[i] == None or values[i] == "None":
 70                 value = http://www.mamicode.com/"null"
 71             else:
 72                 value = http://www.mamicode.com/"%s"%(values[i])
 73             if i == 0:
 74                 sql += value
 75             else:
 76                 sql += ,%s%(value);
 77         sql += )
 78         return sql
 79     
 80     @staticmethod
 81     def queryDelete(primaryKeyDict, objDict):
 82         """
 83         拼接删除
 84         """
 85         # DELETE FROM `t_user` WHERE `id` = ‘5‘
 86         tableName = primaryKeyDict["tableName"]
 87         key = primaryKeyDict["primaryKey"]
 88         columns = list(objDict.keys())
 89         values = list(objDict.values())
 90 
 91         sql = "DELETE FROM `%s` WHERE 1=1 "%(tableName)
 92         for i in range(0, values.__len__()):
 93             if values[i] != None and values[i] != "None":
 94                 sql += and `%s`="%s"%(columns[i], values[i])
 95         return sql
 96 
 97     @staticmethod
 98     def queryDeleteByPrimaryKey(primaryKeyDict, value=http://www.mamicode.com/None):
 99         """
100         拼接根据主键删除
101         """
102         # DELETE FROM `t_user` WHERE `id` = ‘5‘
103         sql = DELETE FROM `%s` WHERE `%s`="%s"%(primaryKeyDict["tableName"], primaryKeyDict["primaryKey"], value)
104         return sql
105     
106     @staticmethod
107     def queryUpdateByPrimaryKey(primaryKeyDict, objDict):
108         """
109         拼接根据主键更新
110         UPDATE t_user SET name=‘test‘ WHERE id = 1007
111         """
112         tableName = primaryKeyDict["tableName"]
113         key = primaryKeyDict["primaryKey"]
114         columns = list(objDict.keys())
115         values = list(objDict.values())
116         keyValue =http://www.mamicode.com/ None
117         sql = "UPDATE `%s` SET"%(tableName)
118         for i in range(0, columns.__len__()):
119             if (values[i] != None) and (values[i] != "None"):
120                 if columns[i] != key:
121                     sql +=  `%s`="%s", %(columns[i], values[i])
122                 else:
123                     keyValue =http://www.mamicode.com/ values[i]
124         sql = sql[0:len(sql)-2] +  WHERE `%s`="%s"%(key, keyValue)
125         return sql

 

6. Page

  分页对象

import json
import math

class Page(object):

    def __init__(self):
        self.__page = 1
        self.__total_page = 1
        self.__total_count = 0
        self.__begin = 0
        self.__limit = 10
        self.__result = []
        pass

    def get_page(self):
        return self.__page

    def set_page(self, page):
        if page > 1:
            self.__page = page

    def get_total_page(self):
        return self.__total_page

    def set_total_page(self, total_page):
        if total_page > 1:
            self.__total_page = total_page

    def get_total_count(self):
        return self.__total_count

    def set_total_count(self, total_count):
        if total_count > 0:
            self.__total_count = total_count

    def get_begin(self):
        return self.__begin

    def set_begin(self, begin):
        if begin > 0:
            self.__begin = begin

    def get_limit(self):
        return self.__limit

    def set_limit(self, limit):
        if limit > 0:
            self.__limit = limit

    def get_result(self):
        return self.__result

    def set_result(self, result):
        self.__result = result


    def __str__(self):
        pageDict = {page:self.__page,total_page:self.__total_page,total_count:self.__total_count,begin:self.__begin,limit:self.__limit,result:self.__result}
        return json.dumps(pageDict)

 

7.Logger

  简单的用于输出信息。

 1 import time
 2 
 3 class Logger(object):
 4 
 5     def __init__(self, obj):
 6         self.__obj = obj
 7         self.__start = None
 8         pass
 9     
10     def start(self):
11         self.__start = time.time()
12         pass
13 
14     def end(self):
15         print("%s >>> [%s] Finished [Time consuming %dms]"%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), self.__obj.__name__, time.time()-self.__start))
16         pass
17 
18     def outSQL(self, msg, enable=True):
19         """
20         输出 SQL 日志:
21         - @Param: msg SQL语句
22         - @Param: enable 日志开关
23         """
24         if enable:
25             print("%s >>> [%s] [SQL] %s"%(str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())), self.__obj.__name__, msg))
26         pass
27     
28     def outMsg(self, msg, enable=True):
29         """
30         输出消息日志:
31         - @Param: msg 日志信息
32         - @Param: enable 日志开关
33         """
34         if enable:
35             print("%s >>> [%s] [Msg] %s"%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), self.__obj.__name__, msg))
36         pass
37     
38         

8.Generator

  为了便于创建 user.py文件,此出提供了自动生成方法,只需要在配置文件中简单的配置数据库连接信息以及要生成的表即可生成对象的py类文件。

  1 import sys
  2 import re
  3 import pymysql
  4 import time
  5 import os
  6 import xml.dom.minidom
  7 
  8 from xml.dom.minidom import parse
  9 
 10 global _pythonPath
 11 global _daoPath
 12 global _servicePath
 13 global _controllerPath
 14 
 15 class Generator(object):
 16     """
 17     # python类生成器
 18     @param configDict 配置文件信息的字典对象
 19     """
 20     def __init__(self, configFilePath=None):
 21         if not configFilePath:
 22             self.__configDict = ConfigurationParser().parseConfiguration()
 23         else:
 24             if os.path.isabs(configFilePath):
 25                 self.__configDict = ConfigurationParser(configFilePath).parseConfiguration()
 26             else:
 27                 configFilePath = configFilePath.replace(".", sys.path[0])
 28             pass
 29     
 30     def run(self):
 31         """
 32         # 生成器执行方法
 33         """
 34         fieldDict = DBOperator(self.__configDict).queryFieldDict()
 35         PythonGenarator(self.__configDict, fieldDict).run()
 36         # DaoGenarator(self.__configDict).run()
 37         # ServiceGenarator(self.__configDict).run()
 38         # ControllerGenarator(self.__configDict).run()
 39         
 40 
 41 class PythonGenarator(object):
 42     """
 43     # pyEntity文件生成类
 44     @param configDict 配置文件信息的字典对象
 45     """
 46     def __init__(self, configDict, fieldDict):
 47         self.__configDict = configDict
 48         self.__fieldDict = fieldDict
 49         self.__content = ""
 50         pass
 51     
 52     def run(self):
 53         """
 54         执行 py 类生成方法
 55         """
 56         for filePath in self.__configDict["pythonPathList"]:
 57             if not os.path.exists(filePath):
 58                 os.makedirs(os.path.dirname(filePath), exist_ok=True)
 59             # 获取表名
 60             fileName = os.path.basename(filePath).split(".py")[0]
 61             # 表名(首字母大写)
 62             ClassName = fileName.capitalize()
 63             # 打开新建文件
 64             file = open(filePath, "w", encoding="utf-8")
 65             self.writeImport(file)                                  # 生成 import 内容
 66             self.writeHeader(file, ClassName)                       # 生成 class 头部内容
 67             self.writeInit(file, fileName, ClassName)               # 生成 class 的 init 方法
 68             tableDictString = self.writeGetSet(file, fileName)      # 生成 get/set 方法,并返回一个类属性的字典对象
 69             self.writeStr(file, fileName, tableDictString)          # 重写 class 的 str 方法
 70             file.write(self.__content)
 71             file.close()
 72             print("Generator --> %s"%(filePath))
 73         pass
 74 
 75     def writeImport(self,file ,importList = None):
 76         """
 77         # 写import部分
 78         """
 79         self.__content += "import json\r\n"
 80         pass
 81     
 82     def writeHeader(self, file, className, superClass = None):
 83         """
 84         # 写类头部(class ClassName(object):)
 85         """
 86         if not superClass:
 87             self.__content += "class %s(object):\r\n"%(className)
 88         else:
 89             self.__content += "class %s(%s):\r\n"%(className, superClass)
 90         pass
 91         
 92     def writeInit(self, file, fileName, className):
 93         """
 94         # 写类初始化方法
 95         """
 96         self.__content += "\tdef __init__(self):\n\t\t"
 97         for field in self.__fieldDict[fileName]:
 98             self.__content += "self.__%s = None\n\t\t"%(field)
 99         self.__content += "pass\r\n"
100         pass
101     
102     def writeGetSet(self, file, fileName):
103         """
104         # 写类getXXX(),setXXX()方法
105         @return tableDictString 表属性字典的字符串对象,用于写__str__()方法
106         """
107         tableDictString = ""
108         i = 1
109         for field in self.__fieldDict[fileName]:
110             if i != len(self.__fieldDict[fileName]):
111                 tableDictString += "‘%s‘:self.__%s,"%(field,field)
112             else:
113                 tableDictString += "‘%s‘:self.__%s"%(field,field)
114             Field = field.capitalize()
115             self.__content += "\tdef get_%(field)s(self):\n\t\treturn self.__%(field)s\n\n\tdef set_%(field)s(self, %(field)s):\n\t\tself.__%(field)s = %(field)s\n\n"%({"field":field})
116             i += 1
117         return tableDictString
118     
119     def writeStr(self, file, fileName, tableDictString):
120         """
121         # 重写__str__()方法
122         """
123         tableDictString = "{" + tableDictString + "}"
124         self.__content += "\n\tdef __str__(self):\n\t\t%sDict = %s\r\t\treturn json.dumps(%sDict)\n"%(fileName, tableDictString, fileName)
125         pass
126 
127 class DaoGenarator(object):
128     """
129     # pyDao文件生成类
130     @param configDict 配置文件信息的字典对象
131     """
132     def __init__(self, configDict):
133         self.__configDict = configDict
134         pass
135     
136     def run(self):
137         pass
138 
139 class ServiceGenarator(object):
140     """
141     # pyService文件生成类
142     @param configDict 配置文件信息的字典对象
143     """
144     def __init__(self, configDict):
145         self.__configDict = configDict
146         pass
147     
148     def run(self):
149         pass
150 
151 class ControllerGenarator(object):
152     """
153     # pyControlelr生成类
154     @param configDict 配置文件信息的字典对象
155     """
156     def __init__(self, configDict):
157         self.__configDict = configDict
158         pass
159     
160     def run(self):
161         pass
162 
163 class ConfigurationParser(object):
164     """
165     解析xml\n
166     @return configDict = {"jdbcConnectionDictList":jdbcConnectionDictList,"tableList":tableList}
167     """
168     def __init__(self, configFilePath=None):
169         if configFilePath:
170             self.__configFilePath = configFilePath
171         else:
172             self.__configFilePath = sys.path[0] + "/config/generatorConfig.xml"
173         self.__generatorBasePath = sys.path[0] + "/src/"
174         pass
175 
176     def parseConfiguration(self):
177         """
178         解析xml,返回jdbc配置信息以及需要生成python对象的表集合
179         """
180         # 解析xml文件,获取Document对象
181         DOMTree = xml.dom.minidom.parse(self.__configFilePath)    # <class ‘xml.dom.minidom.Document‘>
182         # 获取 generatorConfiguration 节点的NodeList对象
183         configDOM = DOMTree.getElementsByTagName("generatorConfiguration")[0]  #<class ‘xml.dom.minicompat.NodeList‘>
184 
185         # jdbcConnection 节点的 property 节点集合
186         jdbcConnectionPropertyList = configDOM.getElementsByTagName("jdbcConnection")[0].getElementsByTagName("property")
187 
188         # pythonGenerator节点对象
189         pythonDOM = configDOM.getElementsByTagName("pythonGenerator")[0]
190         _pythonPath = self.__getGeneratorPath(pythonDOM.getAttributeNode("targetPath").nodeValue)
191 
192         # serviceGenerator 节点对象
193         serviceDOM = configDOM.getElementsByTagName("serviceGenerator")[0]
194         _servicePath = self.__getGeneratorPath(serviceDOM.getAttributeNode("targetPath").nodeValue)
195         
196 
197         # pythonGenerator节点对象
198         daoDOM = configDOM.getElementsByTagName("daoGenerator")[0]
199         _daoPath = self.__getGeneratorPath(daoDOM.getAttributeNode("targetPath").nodeValue)
200 
201         # controllerGenerator 节点对象
202         controllerDOM = configDOM.getElementsByTagName("controllerGenerator")[0]
203         _controllerPath = self.__getGeneratorPath(controllerDOM.getAttributeNode("targetPath").nodeValue)
204         
205         # 循环 jdbcConnection 节点的 property 节点集合,获取属性名称和属性值
206         jdbcConnectionDict = {"host":None,"user":None,"password":None,"port":3306,"database":None,"charset":"utf8"}
207         for property in jdbcConnectionPropertyList:
208             name = property.getAttributeNode("name").nodeValue.strip().lower()
209             if property.hasAttribute("value"):
210                 value = http://www.mamicode.com/property.getAttributeNode("value").nodeValue
211             else:
212                 value =http://www.mamicode.com/ property.childNodes[0].data
213             if name == "charset":
214                 if re.match("utf-8|UTF8", value, re.I):
215                     continue
216             elif name == "port":
217                 value =http://www.mamicode.com/ int(value)
218             jdbcConnectionDict[name] = value
219         # print(jdbcConnectionDict)
220 
221         
222         pythonPathList = []
223         daoPathList = []
224         servicePathList = []
225         controllerPathList = []
226 
227         # 获取 table 节点的集合
228         tableList = []
229         tableDOMList = configDOM.getElementsByTagName("table")
230         for tableDOM in tableDOMList:
231             table = {}
232             name = tableDOM.getAttributeNode("name").nodeValue.strip().lower()
233             alias = tableDOM.getAttributeNode("alias").nodeValue.strip().lower()
234             if (not alias) or alias == ‘‘ :
235                 prefix = name
236             else:
237                 prefix = alias
238             table["tableName"] = name
239             table["alias"] = alias
240             tableList.append(table)
241 
242 
243             pythonPath = "%s/%s.py" %(_pythonPath, prefix)
244             pythonPathList.append(pythonPath)
245             daoPath = "%s/%sDao.py" %(_daoPath, prefix)
246             daoPathList.append(daoPath)
247             servicePath = "%s/%sService.py" %(_servicePath, prefix)
248             servicePathList.append(servicePath)
249             controllerPath = "%s/%sController.py" %(_controllerPath, prefix)
250             controllerPathList.append(controllerPath)
251 
252         configDict = {
253                         "jdbcConnectionDict":jdbcConnectionDict,
254                         "tableList":tableList,
255                         "pythonPathList":pythonPathList,
256                         "daoPathList":daoPathList,
257                         "servicePathList":servicePathList,
258                         "controllerPathList":controllerPathList
259                     }
260         # print(configDict)
261         return configDict
262     
263     def __getGeneratorPath(self, targetPath):
264         return self.__generatorBasePath + targetPath.replace(".","/")
265 
266 class DBOperator(object):
267 
268     def __init__(self, configDict=None):
269         if configDict == None:
270             raise Exception("Error in DBOperator >>> jdbcConnectionDict is None")
271         self.__configDict = configDict
272         pass
273     
274     def queryFieldDict(self):
275         """
276         * 获取数据库表中的所有字段名
277         * @ return tableDict = {"className": "User", "value":{"tableName":""}}
278         """
279         fieldDict = {}
280         jdbcConnectionDict = self.__configDict["jdbcConnectionDict"]
281         conn = pymysql.Connect(**jdbcConnectionDict)
282         # 循环数据表
283         for table in self.__configDict["tableList"]:
284             tableName = table["tableName"]
285             alias = table["alias"]
286             fieldList = []
287             # 获取游标
288             cursor = conn.cursor()
289             # 查询表的字段名称和类型
290             sql = """SELECT COLUMN_NAME as name, DATA_TYPE as type
291                      FROM information_schema.columns
292                      WHERE table_schema = ‘%s‘ AND table_name = ‘%s‘
293                   """%(self.__configDict["jdbcConnectionDict"]["database"], tableName)
294             # print("sql >>> "+sql)
295             # 执行sql
296             cursor.execute(sql)
297             # 返回所有查询结果
298             results = cursor.fetchall()
299             # 关闭游标
300             cursor.close()
301             # 将表所有字段添加到 fieldList 中
302             for result in results:
303                 field = result[0].lower()
304                 # re.search(".*-(.*{1})",field)
305                 # if re.sub("",)
306                 fieldList.append(field)
307             # print(results)
308             fieldDict[alias] = fieldList
309         # 关闭数据库连接
310         conn.close()
311         # print(tableDict)
312         return fieldDict
313 
314 if __name__ == "__main__":
315     Generator().run()

  generatorConfig.xml

 1 <?xml version="1.0" encoding="utf-8"?>
 2 <generatorConfiguration>
 3     <jdbcConnection>
 4         <property name="host">127.0.0.1</property>
 5         <property name="database">rcddup</property>
 6         <property name="port">3306</property>
 7         <property name="user">root</property>
 8         <property name="password">root</property>
 9         <property name="charset">UTF-8</property>
10     </jdbcConnection>
11     <!-- targetPath 文件生成路径 -->
12     <pythonGenerator targetPath="cn.rcddup.entity"></pythonGenerator>
13     <daoGenerator targetPath="cn.rcddup.dao"></daoGenerator>
14     <serviceGenerator targetPath="cn.rcddup.service"></serviceGenerator>
15     <controllerGenerator targetPath="cn.rcddup.controller"> </controllerGenerator>
16 
17     <!-- name:数据库表明,alias:生成的 class 类名 -->
18     <table name="t_user" alias="User" ></table>
19 </generatorConfiguration>

  到这最近一段时间的 python 学习成果就完了,用兴趣的可以加群:626787819。如果你是小白你可以来这询问,如果你是大牛希望不要嫌弃我们小白,一起交流学习。

  本程序代码在 github 上可以下载,下载地址:https://github.com/ruancheng77/baseDao
  

  2017-05-20

 

DBUtils 和 pymysql 结合的简单封装