forked from apache/doris
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.py
91 lines (82 loc) · 3.15 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from distutils.command.config import config
from config import Config
from index_calculator import IndexCalculator
from sql_executor import SQLExecutor
import matplotlib.pyplot as plt
class Evaluator:
def __init__(self, config: Config, query: str) -> None:
self.config = config
self.query = query.lower()
self.setup_queries = [
"set enable_nereids_planner=true;",
"set enable_fallback_to_original_planner=false;",
"set enable_profile=true;"
]
self.sql_executor = SQLExecutor(
config.user,
config.password,
config.host,
config.port,
config.database)
def cold_run(self):
for _ in range(self.config.cold_run):
self.sql_executor.execute_query(self.query, None)
def evaluate(self):
self.setup()
self.cold_run()
plans = self.extract_all_plans()
res: list[tuple[float, float]] = []
for n, (plan, cost) in plans.items():
time = self.sql_executor.get_execute_time(plan)
res.append((cost, time))
if self.config.plot:
self.plot(res)
print(res)
index_calculator = IndexCalculator(res)
return index_calculator.calculate()
def plot(self, data):
x_values = [t[0] for t in data]
y_values = [t[1] for t in data]
fig, ax = plt.subplots()
ax.scatter(x_values, y_values)
ax.set_xlabel('Cost')
ax.set_ylabel('Time')
plt.show()
def setup(self):
for q in self.setup_queries:
self.sql_executor.execute_query(q, None)
def extract_all_plans(self):
plan_set = set()
plan_map: dict[int, tuple[str, float]] = {}
for n in range(1, self.config.plan_number):
query = self.inject_nth_optimized_hint(n)
plan, cost = self.sql_executor.get_plan_with_cost(query)
if plan in plan_set:
break
plan_set.add(plan)
plan_map[n] = (query, cost)
return plan_map
def inject_nth_optimized_hint(self, n: int):
if ("set_var(" in self.query):
query = self.query.replace(
"/*+set_var(", f"/*+set_var(nth_optimized_plan={n}, ")
else:
query = self.query.replace(
"select", f"select /*+set_var(nth_optimized_plan={n})*/")
return query