diff --git a/sql/engines/mongo.py b/sql/engines/mongo.py index 08bcef05df..7a0f480a42 100644 --- a/sql/engines/mongo.py +++ b/sql/engines/mongo.py @@ -13,7 +13,6 @@ from pymongo.errors import OperationFailure from dateutil.parser import parse from bson.objectid import ObjectId -from datetime import datetime from . import EngineBase from .models import ResultSet, ReviewSet, ReviewResult @@ -429,7 +428,7 @@ def execute(self, db_name=None, sql=""): line += 1 logger.debug("执行结果:" + r) # 如果执行中有错误 - rz = r.replace(" ", "").replace('"', "").lower() + rz = r.replace(" ", "").replace('"', "") tr = 1 if ( r.lower().find("syntaxerror") >= 0 @@ -438,7 +437,7 @@ def execute(self, db_name=None, sql=""): or rz.find("ReferenceError") >= 0 or rz.find("getErrorWithCode") >= 0 or rz.find("failedtoconnect") >= 0 - or rz.find("Error: field") >= 0 + or rz.find("Error:") >= 0 ): tr = 0 if (rz.find("errmsg") >= 0 or tr == 0) and ( @@ -454,14 +453,36 @@ def execute(self, db_name=None, sql=""): sql=exec_sql, ) else: + try: + r = json.loads(r) + except Exception as e: + logger.info(str(e)) + finally: + methodStr = exec_sql.split(").")[-1].split("(")[0].strip() + if "." in methodStr: + methodStr = methodStr.split(".")[-1] + if methodStr == "insert": + actual_affected_rows = r.get("nInserted", 0) + elif methodStr in ("insertOne", "insertMany"): + actual_affected_rows = r.count("ObjectId") + elif methodStr == "update": + actual_affected_rows = r.get("nModified", 0) + elif methodStr in ("updateOne", "updateMany"): + actual_affected_rows = r.get("modifiedCount", 0) + elif methodStr in ("deleteOne", "deleteMany"): + actual_affected_rows = r.get("deletedCount", 0) + elif methodStr == "remove": + actual_affected_rows = r.get("nRemoved", 0) + else: + actual_affected_rows = 0 # 把结果转换为ReviewSet result = ReviewResult( id=line, errlevel=0, stagestatus="执行结束", - errormessage=r, + errormessage=str(r), execute_time=round(end - start, 6), - actual_affected_rows=0, # todo============这个值需要优化 + affected_rows=actual_affected_rows, sql=exec_sql, ) execute_result.rows += [result] @@ -571,9 +592,9 @@ def execute_check(self, db_name=None, sql=""): check_result.rows += [result] continue else: - methodStr = ( - sql_str.split(".")[-1].split("(")[0].strip() - ) # 最后一个.和括号(之间的字符串作为方法 + methodStr = sql_str.split(").")[-1].split("(")[0].strip() + if "." in methodStr: + methodStr = methodStr.split(".")[-1] if methodStr in is_exist_premise_method and not is_in: check_result.error = "文档不存在" result = ReviewResult( @@ -651,6 +672,75 @@ def execute_check(self, db_name=None, sql=""): sql=check_sql, execute_time=0, ) + if methodStr == "insertOne": + count = 1 + elif methodStr in ("insert", "insertMany"): + insert_str = re.search( + rf"{methodStr}\((.*)\)", sql_str, re.S + ).group(1) + first_char = insert_str.replace(" ", "").replace( + "\n", "" + )[0] + if first_char == "{": + count = 1 + elif first_char == "[": + insert_values = re.search( + r"\[(.*?)\]", insert_str, re.S + ).group(0) + de = JsonDecoder() + insert_values = de.decode(insert_values) + count = len(insert_values) + else: + count = 0 + elif methodStr in ( + "update", + "updateOne", + "updateMany", + "deleteOne", + "deleteMany", + "remove", + ): + if sql_str.find("find(") > 0: + count_sql = sql_str.replace(methodStr, "count") + else: + count_sql = ( + sql_str.replace(methodStr, "find") + ".count()" + ) + query_dict = self.parse_query_sentence(count_sql) + count_sql = f"""db.getCollection("{query_dict["collection"]}").find({query_dict["condition"]}).count()""" + query_result = self.query(db_name, count_sql) + count = json.loads(query_result.rows[0][0]).get( + "count", 0 + ) + if ( + methodStr == "update" + and "multi:true" + not in sql_str.replace(" ", "") + .replace('"', "") + .replace("'", "") + .replace("\n", "") + ) or methodStr in ("deleteOne", "updateOne"): + count = 1 if count > 0 else 0 + if methodStr in ( + "insertOne", + "insert", + "insertMany", + "update", + "updateOne", + "updateMany", + "deleteOne", + "deleteMany", + "remove", + ): + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="检测通过", + affected_rows=count, + sql=check_sql, + execute_time=0, + ) else: result = ReviewResult( id=line, @@ -1061,7 +1151,7 @@ def parse_tuple(self, cursor, db_name, tb_name, projection=None): dd = re.findall(re_date, str(value)) for d in dd: t = int(d.split(":")[1].strip()[:-1]) - e = datetime.fromtimestamp(t / 1000) + e = datetime.datetime.fromtimestamp(t / 1000) value = str(value).replace( d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] ) diff --git a/sql/engines/tests.py b/sql/engines/tests.py index 2955624109..0948d4e13c 100644 --- a/sql/engines/tests.py +++ b/sql/engines/tests.py @@ -1925,6 +1925,28 @@ def test_execute_check(self, mock_get_all_tables, mock_get_table_conut): check_result.rows[0].__dict__["errormessage"], row.__dict__["errormessage"] ) + @patch("sql.engines.mongo.MongoEngine.get_all_tables") + def test_execute_check_include_dot(self, mock_get_all_tables): + sql = """db.job.insert({ + fileName: "现金明细20230103075728.xls", + contentType: ".xls", + createdTime: ISODate("2023-01-03T12:05:27.402Z"), + reportDate: ISODate("2023-01-03T12:05:27.402Z"), + updatedTime: ISODate("2023-01-03T12:09:30.88Z") + });;""" + mock_get_all_tables.return_value.rows = "job" + check_result = self.engine.execute_check("some_db", sql) + self.assertEqual( + check_result.rows[0].__dict__["stagestatus"], "Audit completed" + ) + + @patch("sql.engines.mongo.MongoEngine.get_all_tables") + def test_execute_check_on_dml(self, mock_get_all_tables): + sql = """db.job.insert([{"orderCode":1001},{"orderCode":1002}]);""" + mock_get_all_tables.return_value.rows = "job" + check_result = self.engine.execute_check("some_db", sql) + self.assertEqual(check_result.rows[0].__dict__["affected_rows"], 2) + @patch("sql.engines.mongo.MongoEngine.exec_cmd") @patch("sql.engines.mongo.MongoEngine.get_master") def test_execute(self, mock_get_master, mock_exec_cmd): @@ -1940,6 +1962,34 @@ def test_execute(self, mock_get_master, mock_exec_cmd): mock_get_master.assert_called_once() self.assertEqual(check_result.rows[0].__dict__["errlevel"], 0) + @patch("sql.engines.mongo.MongoEngine.exec_cmd") + @patch("sql.engines.mongo.MongoEngine.get_master") + def test_execute_on_dml(self, mock_get_master, mock_exec_cmd): + sql = """db.job.insertMany([{"title":"test1"},{"title":test2"},{"title":test3"}]);""" + mock_exec_cmd.return_value = """{ + "acknowledged" : true, + "insertedIds" : [ + ObjectId("63b77b53afab4917dfd48a20"), + ObjectId("63b77b53afab4917dfd48a21"), + ObjectId("63b77b53afab4917dfd48a22") + ] + }""" + + check_result = self.engine.execute("some_db", sql) + mock_get_master.assert_called_once() + self.assertEqual(check_result.rows[0].__dict__["affected_rows"], 3) + + @patch("sql.engines.mongo.MongoEngine.exec_cmd") + @patch("sql.engines.mongo.MongoEngine.get_master") + def test_execute_return_error(self, mock_get_master, mock_exec_cmd): + sql = """db.job.insertMany({"title":"test1"},{"title":test2"},{"title":test3"});""" + mock_exec_cmd.return_value = ( + """uncaught exception: TypeError: documents.map is not a function""" + ) + check_result = self.engine.execute("some_db", sql) + mock_get_master.assert_called_once() + self.assertEqual(check_result.rows[0].__dict__["stagestatus"], "异常终止") + def test_fill_query_columns(self): columns = ["_id", "title", "tags", "likes"] cursor = [