diff --git a/sqle/driver/mysql/audit_test.go b/sqle/driver/mysql/audit_test.go index 39efc08fe..502557b53 100644 --- a/sqle/driver/mysql/audit_test.go +++ b/sqle/driver/mysql/audit_test.go @@ -6391,6 +6391,10 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect2 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_1")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex)) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_1`")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_1", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using index")) handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_1`")). WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100")) runSingleRuleInspectCase(rule, t, "", inspect2, "select * from exist_tb_1", newTestResult()) @@ -6398,8 +6402,10 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect3 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_1 where id=1")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeAll)) - handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_1` WHERE `id`=1")). - WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_1` WHERE `id`=1")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_1", nil, "ALL", nil, nil, nil, nil, 1000, 100.00, "Using where")) runSingleRuleInspectCase(rule, t, "", inspect3, "select * from exist_tb_1 where id=1", newTestResult()) inspect4 := NewMockInspect(e) @@ -6410,6 +6416,10 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect5 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_3 where v2='b'")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex)) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_3` WHERE `v2`='b'")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_3", nil, "index", "idx_v2", "idx_v2", 5, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_3` WHERE `v2`='b'")). WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100000000")) runSingleRuleInspectCase(rule, t, "", inspect5, "select * from exist_tb_3 where v2='b'", newTestResult().addResult(rulepkg.DMLCheckSelectRows)) @@ -6417,6 +6427,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect6 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_2 where user_id in (select v3 from exist_tb_3)")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_2` WHERE `user_id` IN (SELECT `v3` FROM `exist_tb_3`)")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_user_id", "idx_user_id", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_3", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_2` WHERE `user_id` IN (SELECT `v3` FROM `exist_tb_3`)")). WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100000000")) runSingleRuleInspectCase(rule, t, "", inspect6, "select * from exist_tb_2 where user_id in (select v3 from exist_tb_3)", newTestResult().addResult(rulepkg.DMLCheckSelectRows)) @@ -6424,6 +6439,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect7 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 limit 10, 10")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("100000000")) runSingleRuleInspectCase(rule, t, "", inspect7, "select id, v1 as id from exist_tb_2 limit 10, 10", newTestResult().addResult(rulepkg.DMLCheckSelectRows)) @@ -6431,6 +6451,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect8 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 group by id, v1")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("100000000")) runSingleRuleInspectCase(rule, t, "", inspect8, "select id, v1 as id from exist_tb_2 group by id, v1", newTestResult().addResult(rulepkg.DMLCheckSelectRows)) @@ -6438,6 +6463,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect9 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 limit 10, 10")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10")) runSingleRuleInspectCase(rule, t, "", inspect9, "select id, v1 as id from exist_tb_2 limit 10, 10", newTestResult()) @@ -6445,6 +6475,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect10 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 group by id, v1")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10")) runSingleRuleInspectCase(rule, t, "", inspect10, "select id, v1 as id from exist_tb_2 group by id, v1", newTestResult()) @@ -6452,6 +6487,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect11 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select max(v1) from exist_tb_2 group by id")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10")) runSingleRuleInspectCase(rule, t, "", inspect11, "select max(v1) from exist_tb_2 group by id", newTestResult()) @@ -6459,6 +6499,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect12 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select max(v1) from exist_tb_2 group by id")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10000000")) runSingleRuleInspectCase(rule, t, "", inspect12, "select max(v1) from exist_tb_2 group by id", newTestResult().addResult(rulepkg.DMLCheckSelectRows)) @@ -6466,6 +6511,11 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect13 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select max(v1) as id, id from exist_tb_2 group by id")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10")) runSingleRuleInspectCase(rule, t, "", inspect13, "select max(v1) as id, id from exist_tb_2 group by id", newTestResult()) @@ -6473,10 +6523,14 @@ func TestDMLCheckSelectRows(t *testing.T) { inspect14 := NewMockInspect(e) handler.ExpectQuery(regexp.QuoteMeta("select max(v1) as id, id from exist_tb_2 group by id")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range")) + // 添加 EXPLAIN 结果 + handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). + WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}). + AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where"). + AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where")) handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")). WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10000000")) runSingleRuleInspectCase(rule, t, "", inspect14, "select max(v1) as id, id from exist_tb_2 group by id", newTestResult().addResult(rulepkg.DMLCheckSelectRows)) - } func TestDMLCheckScanRows(t *testing.T) { diff --git a/sqle/driver/mysql/mysql.go b/sqle/driver/mysql/mysql.go index 701e23b29..195024a1d 100644 --- a/sqle/driver/mysql/mysql.go +++ b/sqle/driver/mysql/mysql.go @@ -491,7 +491,7 @@ func (i *MysqlDriverImpl) EstimateSQLAffectRows(ctx context.Context, sql string) return nil, err } - num, err := util.GetAffectedRowNum(ctx, sql, conn) + num, err := util.GetAffectedRowNum(ctx, sql, conn, i.Ctx.GetExecutionPlan) if err != nil && errors.Is(err, util.ErrUnsupportedSqlType) { return &driverV2.EstimatedAffectRows{ErrMessage: err.Error()}, nil } diff --git a/sqle/driver/mysql/rule/rule.go b/sqle/driver/mysql/rule/rule.go index 3d4b86005..56d44b52e 100644 --- a/sqle/driver/mysql/rule/rule.go +++ b/sqle/driver/mysql/rule/rule.go @@ -4621,7 +4621,7 @@ func checkAffectedRows(input *RuleHandlerInput) error { } affectCount, err := util.GetAffectedRowNum( - context.TODO(), input.Node.Text(), input.Ctx.GetExecutor()) + context.TODO(), input.Node.Text(), input.Ctx.GetExecutor(), input.Ctx.GetExecutionPlan) if err != nil { log.NewEntry().Errorf("rule: %v; SQL: %v; get affected row number failed: %v", input.Rule.Name, input.Node.Text(), err) return nil @@ -5193,7 +5193,7 @@ func checkSelectRows(input *RuleHandlerInput) error { if !notUseIndex { return nil } - affectCount, err := util.GetAffectedRowNum(context.TODO(), input.Node.Text(), input.Ctx.GetExecutor()) + affectCount, err := util.GetAffectedRowNum(context.TODO(), input.Node.Text(), input.Ctx.GetExecutor(), input.Ctx.GetExecutionPlan) if err != nil { return err } diff --git a/sqle/driver/mysql/util/util.go b/sqle/driver/mysql/util/util.go index d95a04c29..0480362f3 100644 --- a/sqle/driver/mysql/util/util.go +++ b/sqle/driver/mysql/util/util.go @@ -21,14 +21,14 @@ import ( var ErrUnsupportedSqlType = errors.New("unsupported sql type") -func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Executor) (int64, error) { +func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Executor, explainRecordFunc func(string) ([]*executor.ExplainRecord, error)) (int64, error) { node, err := ParseOneSql(originSql) if err != nil { return 0, err } var newNode ast.Node - var affectRowSql string + var affectedRowSql string var cannotConvert bool // 语法规则文档 @@ -77,7 +77,7 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe } // 移除后缀分号,避免sql语法错误 trimSuffix := strings.TrimRight(newSql, ";") - affectRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix) + affectedRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix) } else { sqlBuilder := new(strings.Builder) err = newNode.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, sqlBuilder)) @@ -85,17 +85,50 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe return 0, err } - affectRowSql = sqlBuilder.String() + affectedRowSql = sqlBuilder.String() } // 验证sql语法是否正确,select 字段是否有且仅有 count(*) // 避免在客户机器上执行不符合预期的sql语句 - err = checkSql(affectRowSql) + err = checkSql(affectedRowSql) if err != nil { - return 0, fmt.Errorf("check sql(%v) failed, origin sql(%v), err: %v", affectRowSql, originSql, err) + return 0, fmt.Errorf("check sql(%v) failed, origin sql(%v), err: %v", affectedRowSql, originSql, err) } - _, row, err := conn.Db.QueryWithContext(ctx, affectRowSql) + // explain 全表扫描 (type 为 ALL): 避免执行 SELECT COUNT(1),直接拿EXPLAIN影响行数作为结果 + // 索引访问 ( type 非ALL)如果 rows 较小(小于10W),可以执行 SELECT COUNT(1)。否则依然拿EXPLAIN影响行数作为结果 + // | id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra | + // |----|-------------|-----------|-------|---------------|---------|---------|---------------|--------|-------------| + // | 1 | SIMPLE | o | ref | idx_status | idx_status | 10 | const | 5000 | Using where | + // | 1 | SIMPLE | c | eq_ref | PRIMARY | PRIMARY | 4 | orders.customer_id | 1 | | + + epRecords, err := explainRecordFunc(affectedRowSql) + if err != nil { + log.NewEntry().Errorf("get execution plan failed, sqle: %v, error: %v", originSql, err) + return 0, err + } + + var notUseIndex bool + var affetcCount int64 + var estimatedRows int64 + + // 检查是否所有记录都使用了索引 + for _, record := range epRecords { + if record.Type == executor.ExplainRecordAccessTypeAll { + notUseIndex = true + } + // 统计查询过程中所有的影响行数 + estimatedRows += record.Rows + // 最后一行记录的row作为结果行数 + affetcCount = record.Rows + } + + // 如果有记录未使用索引,或者统计影响行数大于10W + if notUseIndex || estimatedRows > 100000 { + return affetcCount, nil + } + + _, row, err := conn.Db.QueryWithContext(ctx, affectedRowSql) if err != nil { return 0, err } @@ -103,12 +136,12 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe // 如果下发的 SELECT COUNT(1) 的SQL,返回的结果集为空, 则返回0 // 例: SELECT COUNT(1) FROM test LIMIT 10,10 结果集为空 if len(row) == 0 { - log.NewEntry().Errorf("affected row sql(%v) result row count is 0", affectRowSql) + log.NewEntry().Errorf("affected row sql(%v) result row count is 0", affectedRowSql) return 0, nil } if len(row) < 1 { - return 0, fmt.Errorf("affected row sql(%v) result row count(%v) less than 1", affectRowSql, len(row)) + return 0, fmt.Errorf("affected row sql(%v) result row count(%v) less than 1", affectedRowSql, len(row)) } affectCount, err := strconv.ParseInt(row[0][0].String, 10, 64)