This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new e3d7f7c8d8 [feature](Nereids) add test framework for cost model
(#17071)
e3d7f7c8d8 is described below
commit e3d7f7c8d8bf7e665b1479cfc64691141f0ccc0c
Author: 谢健 <[email protected]>
AuthorDate: Tue Feb 28 20:59:07 2023 +0800
[feature](Nereids) add test framework for cost model (#17071)
add test-frame-work for cost model according paper Testing the Accuracy of
Query Optimizers
---
tools/cost_model_evaluate/README.MD | 23 +++++++
tools/cost_model_evaluate/config.py | 40 ++++++++++++
tools/cost_model_evaluate/evaluator.py | 91 +++++++++++++++++++++++++++
tools/cost_model_evaluate/index_calculator.py | 69 ++++++++++++++++++++
tools/cost_model_evaluate/main.py | 61 ++++++++++++++++++
tools/cost_model_evaluate/requirements.txt | 17 +++++
tools/cost_model_evaluate/sql_executor.py | 69 ++++++++++++++++++++
7 files changed, 370 insertions(+)
diff --git a/tools/cost_model_evaluate/README.MD
b/tools/cost_model_evaluate/README.MD
new file mode 100644
index 0000000000..2cc9daf45d
--- /dev/null
+++ b/tools/cost_model_evaluate/README.MD
@@ -0,0 +1,23 @@
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+This code is used to evaluate the cost model in doris.
+You can config the query in config of main.py
+
+Before running, you should install the libraries in requirements.txt
\ No newline at end of file
diff --git a/tools/cost_model_evaluate/config.py
b/tools/cost_model_evaluate/config.py
new file mode 100644
index 0000000000..bfe378a8d2
--- /dev/null
+++ b/tools/cost_model_evaluate/config.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from dataclasses import dataclass
+
+@dataclass
+class Config:
+ # user for mysql client
+ user: str
+ # password for mysql client
+ password: str
+ # host of mysql client
+ host: str
+ # post of mysql client
+ port: int
+ # database of query that used to evaluated
+ database: str
+ # execute times for one plan of the query. Note a query can generate
multiple plans
+ execute_times: int
+ # the number of generate plans for one query. Note if the number > the
possible plans,
+ # we will only use the valid plans.
+ plan_number: int
+ # Does plot the relation of cost and time
+ plot: bool
+ # run the query before really evaluate, just for avoiding cold running
+ cold_run: int
\ No newline at end of file
diff --git a/tools/cost_model_evaluate/evaluator.py
b/tools/cost_model_evaluate/evaluator.py
new file mode 100644
index 0000000000..e963a183ed
--- /dev/null
+++ b/tools/cost_model_evaluate/evaluator.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.command.config import config
+from config import Config
+from index_calculator import IndexCalculator
+from sql_executor import SQLExecutor
+import matplotlib.pyplot as plt
+
+
+class Evaluator:
+ def __init__(self, config: Config, query: str) -> None:
+ self.config = config
+ self.query = query.lower()
+ self.setup_queries = [
+ "set enable_nereids_planner=true;",
+ "set enable_fallback_to_original_planner=false;",
+ "set enable_profile=true;"
+ ]
+ self.sql_executor = SQLExecutor(
+ config.user,
+ config.password,
+ config.host,
+ config.port,
+ config.database)
+
+ def cold_run(self):
+ for _ in range(self.config.cold_run):
+ self.sql_executor.execute_query(self.query, None)
+
+ def evaluate(self):
+ self.setup()
+ self.cold_run()
+ plans = self.extract_all_plans()
+ res: list[tuple[float, float]] = []
+ for n, (plan, cost) in plans.items():
+ time = self.sql_executor.get_execute_time(plan)
+ res.append((cost, time))
+ if self.config.plot:
+ self.plot(res)
+ print(res)
+ index_calculator = IndexCalculator(res)
+ return index_calculator.calculate()
+
+ def plot(self, data):
+ x_values = [t[0] for t in data]
+ y_values = [t[1] for t in data]
+ fig, ax = plt.subplots()
+ ax.scatter(x_values, y_values)
+ ax.set_xlabel('Cost')
+ ax.set_ylabel('Time')
+ plt.show()
+
+ def setup(self):
+ for q in self.setup_queries:
+ self.sql_executor.execute_query(q, None)
+
+ def extract_all_plans(self):
+ plan_set = set()
+ plan_map: dict[int, tuple[str, float]] = {}
+ for n in range(1, self.config.plan_number):
+ query = self.inject_nth_optimized_hint(n)
+ plan, cost = self.sql_executor.get_plan_with_cost(query)
+ if plan in plan_set:
+ break
+ plan_set.add(plan)
+ plan_map[n] = (query, cost)
+ return plan_map
+
+ def inject_nth_optimized_hint(self, n: int):
+ if ("set_var(" in self.query):
+ query = self.query.replace(
+ "/*+set_var(", f"/*+set_var(nth_optimized_plan={n}, ")
+ else:
+ query = self.query.replace(
+ "select", f"select /*+set_var(nth_optimized_plan={n})*/")
+ return query
diff --git a/tools/cost_model_evaluate/index_calculator.py
b/tools/cost_model_evaluate/index_calculator.py
new file mode 100644
index 0000000000..8422146c9b
--- /dev/null
+++ b/tools/cost_model_evaluate/index_calculator.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import math
+from typing import List, Tuple
+import unittest
+
+# The index is motivated by Testing the Accuracy of Query Optimizers
+
+
+class IndexCalculator:
+ def __init__(self, cost_time_list: List[Tuple[float, float]]) -> None:
+ self.cost_time_list = cost_time_list
+ sorted(self.cost_time_list, key=lambda t: t[0])
+ self.max_c = max(self.cost_time_list, key=lambda ct: ct[0])[0]
+ self.min_c = min(self.cost_time_list, key=lambda ct: ct[0])[0]
+ self.max_t = max(self.cost_time_list, key=lambda ct: ct[1])[1]
+ self.min_t = min(self.cost_time_list, key=lambda ct: ct[1])[1]
+
+ def calculate(self) -> float:
+
+ l = len(self.cost_time_list)
+ score = 0.0
+ for j in range(0, l):
+ for i in range(0, j):
+ score += self.weight(i)*self.weight(j) * \
+ self.distance(i, j)*self.sgn(i, j)
+ return score
+
+ def weight(self, i: int) -> float:
+ return self.cost_time_list[0][0]/self.cost_time_list[i][0]
+
+ def distance(self, i: int, j: int) -> float:
+ d0 = (self.cost_time_list[i][0] - self.cost_time_list[j]
+ [0])/(self.max_c - self.min_c + 0.00001)
+ d1 = (self.cost_time_list[i][1] - self.cost_time_list[j]
+ [1])/(self.max_t - self.min_t + 0.00001)
+
+ return math.sqrt(d0*d0 + d1*d1)
+
+ def sgn(self, i: int, j: int) -> float:
+ if self.cost_time_list[j][1] - self.cost_time_list[i][1] >= 0:
+ return 1
+ else:
+ return -1
+
+
+class Test(unittest.TestCase):
+ def test(self):
+ idx_cal = IndexCalculator([(1, 2), (2, 3)])
+ self.assertEqual(round(idx_cal.calculate(), 2), 0.71)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/cost_model_evaluate/main.py
b/tools/cost_model_evaluate/main.py
new file mode 100644
index 0000000000..3103fb2316
--- /dev/null
+++ b/tools/cost_model_evaluate/main.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from config import Config
+from evaluator import Evaluator
+
+
+config = Config(
+ "root",
+ "",
+ "127.0.0.1",
+ 9030,
+ "regression_test_nereids_tpch_p0",
+ 2,
+ 50,
+ True,
+ 3
+)
+
+sql = """
+select
+ n_name,
+ sum(l_extendedprice * (1 - l_discount)) as revenue
+from
+ customer,
+ orders,
+ lineitem,
+ supplier,
+ nation,
+ region
+where
+ c_custkey = o_custkey
+ and l_orderkey = o_orderkey
+ and l_suppkey = s_suppkey
+ and c_nationkey = s_nationkey
+ and s_nationkey = n_nationkey
+ and n_regionkey = r_regionkey
+ and r_name = 'ASIA'
+ and o_orderdate >= date '1994-01-01'
+ and o_orderdate < date '1994-01-01' + interval '1' year
+group by
+ n_name
+order by
+ revenue desc;
+"""
+
+print(Evaluator(config, sql).evaluate())
diff --git a/tools/cost_model_evaluate/requirements.txt
b/tools/cost_model_evaluate/requirements.txt
new file mode 100644
index 0000000000..ffbc650c26
--- /dev/null
+++ b/tools/cost_model_evaluate/requirements.txt
@@ -0,0 +1,17 @@
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+matplotlib==3.7.0
+mysql_connector_repackaged==0.3.1
diff --git a/tools/cost_model_evaluate/sql_executor.py
b/tools/cost_model_evaluate/sql_executor.py
new file mode 100644
index 0000000000..511e12c8ad
--- /dev/null
+++ b/tools/cost_model_evaluate/sql_executor.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest import result
+import mysql.connector
+from typing import List, Tuple
+
+
+class SQLExecutor:
+ def __init__(self, user: str, password: str, host: str, port: int,
database: str) -> None:
+ self.connection = mysql.connector.connect(
+ user=user,
+ password=password,
+ host=host,
+ port=port,
+ database=database
+ )
+ self.cursor = self.connection.cursor()
+ self.wait_fetch_time_index = 16
+
+ def execute_query(self, query: str, parameters: Tuple | None) ->
List[Tuple]:
+ if parameters:
+ self.cursor.execute(query, parameters)
+ else:
+ self.cursor.execute(query)
+ results = self.cursor.fetchall()
+ return results
+
+ def get_execute_time(self, query: str) -> float:
+ self.execute_query(query, None)
+ profile = self.execute_query("show query profile\"\"", None)
+ return float(profile[0][self.wait_fetch_time_index].replace("ms", ""))
+
+ def execute_many_queries(self, queries: List[Tuple[str, Tuple]]) ->
List[List[Tuple]]:
+ results = []
+ for query, parameters in queries:
+ result = self.execute_query(query, parameters)
+ results.append(result)
+ return results
+
+ def get_plan_with_cost(self, query: str):
+ result = self.execute_query(f"explain optimized plan {query}", None)
+ cost = float(result[0][0].replace("cost = ", ""))
+ plan = "".join([s[0] for s in result[1:]])
+ return plan, cost
+
+ def commit(self) -> None:
+ self.connection.commit()
+
+ def rollback(self) -> None:
+ self.connection.rollback()
+
+ def close(self) -> None:
+ self.cursor.close()
+ self.connection.close()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]