Skip to content

Commit

Permalink
Merge pull request #2889 from actiontech/imporve_get_sql_affect_row_num
Browse files Browse the repository at this point in the history
Imporve get sql affect row num
  • Loading branch information
LordofAvernus authored Jan 24, 2025
2 parents 0286aaf + 6b07a7d commit 02809c4
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 15 deletions.
60 changes: 57 additions & 3 deletions sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6391,15 +6391,21 @@ func TestDMLCheckSelectRows(t *testing.T) {
inspect2 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_1")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_1`")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_1", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using index"))
handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_1`")).
WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100"))
runSingleRuleInspectCase(rule, t, "", inspect2, "select * from exist_tb_1", newTestResult())

inspect3 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_1 where id=1")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeAll))
handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_1` WHERE `id`=1")).
WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_1` WHERE `id`=1")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_1", nil, "ALL", nil, nil, nil, nil, 1000, 100.00, "Using where"))
runSingleRuleInspectCase(rule, t, "", inspect3, "select * from exist_tb_1 where id=1", newTestResult())

inspect4 := NewMockInspect(e)
Expand All @@ -6410,73 +6416,121 @@ func TestDMLCheckSelectRows(t *testing.T) {
inspect5 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_3 where v2='b'")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_3` WHERE `v2`='b'")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_3", nil, "index", "idx_v2", "idx_v2", 5, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_3` WHERE `v2`='b'")).
WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100000000"))
runSingleRuleInspectCase(rule, t, "", inspect5, "select * from exist_tb_3 where v2='b'", newTestResult().addResult(rulepkg.DMLCheckSelectRows))

inspect6 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select * from exist_tb_2 where user_id in (select v3 from exist_tb_3)")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN SELECT COUNT(1) FROM `exist_tb_2` WHERE `user_id` IN (SELECT `v3` FROM `exist_tb_3`)")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_user_id", "idx_user_id", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_3", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("SELECT COUNT(1) FROM `exist_tb_2` WHERE `user_id` IN (SELECT `v3` FROM `exist_tb_3`)")).
WillReturnRows(sqlmock.NewRows([]string{"COUNT(1)"}).AddRow("100000000"))
runSingleRuleInspectCase(rule, t, "", inspect6, "select * from exist_tb_2 where user_id in (select v3 from exist_tb_3)", newTestResult().addResult(rulepkg.DMLCheckSelectRows))

inspect7 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 limit 10, 10")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("100000000"))
runSingleRuleInspectCase(rule, t, "", inspect7, "select id, v1 as id from exist_tb_2 limit 10, 10", newTestResult().addResult(rulepkg.DMLCheckSelectRows))

inspect8 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 group by id, v1")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("100000000"))
runSingleRuleInspectCase(rule, t, "", inspect8, "select id, v1 as id from exist_tb_2 group by id, v1", newTestResult().addResult(rulepkg.DMLCheckSelectRows))

inspect9 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 limit 10, 10")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` LIMIT 10,10) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
runSingleRuleInspectCase(rule, t, "", inspect9, "select id, v1 as id from exist_tb_2 limit 10, 10", newTestResult())

inspect10 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select id, v1 as id from exist_tb_2 group by id, v1")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_v1", "idx_v1", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`,`v1`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
runSingleRuleInspectCase(rule, t, "", inspect10, "select id, v1 as id from exist_tb_2 group by id, v1", newTestResult())

inspect11 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) from exist_tb_2 group by id")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
runSingleRuleInspectCase(rule, t, "", inspect11, "select max(v1) from exist_tb_2 group by id", newTestResult())

inspect12 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) from exist_tb_2 group by id")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10000000"))
runSingleRuleInspectCase(rule, t, "", inspect12, "select max(v1) from exist_tb_2 group by id", newTestResult().addResult(rulepkg.DMLCheckSelectRows))

inspect13 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) as id, id from exist_tb_2 group by id")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10"))
runSingleRuleInspectCase(rule, t, "", inspect13, "select max(v1) as id, id from exist_tb_2 group by id", newTestResult())

inspect14 := NewMockInspect(e)
handler.ExpectQuery(regexp.QuoteMeta("select max(v1) as id, id from exist_tb_2 group by id")).
WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(executor.ExplainRecordAccessTypeIndex).AddRow("range"))
// 添加 EXPLAIN 结果
handler.ExpectQuery(regexp.QuoteMeta("EXPLAIN select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"id", "select_type", "table", "partitions", "type", "possible_keys", "key", "key_len", "ref", "rows", "filtered", "Extra"}).
AddRow(1, "SIMPLE", "exist_tb_2", nil, "index", "idx_id", "idx_id", 5, nil, 1000, 100.00, "Using where").
AddRow(2, "SIMPLE", "exist_tb_2", nil, "range", nil, nil, nil, nil, 1000, 100.00, "Using where"))
handler.ExpectQuery(regexp.QuoteMeta("select count(*) from (SELECT 1 FROM `exist_tb_2` GROUP BY `id`) as t")).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow("10000000"))
runSingleRuleInspectCase(rule, t, "", inspect14, "select max(v1) as id, id from exist_tb_2 group by id", newTestResult().addResult(rulepkg.DMLCheckSelectRows))

}

func TestDMLCheckScanRows(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion sqle/driver/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ func (i *MysqlDriverImpl) EstimateSQLAffectRows(ctx context.Context, sql string)
return nil, err
}

num, err := util.GetAffectedRowNum(ctx, sql, conn)
num, err := util.GetAffectedRowNum(ctx, sql, conn, i.Ctx.GetExecutionPlan)
if err != nil && errors.Is(err, util.ErrUnsupportedSqlType) {
return &driverV2.EstimatedAffectRows{ErrMessage: err.Error()}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -4621,7 +4621,7 @@ func checkAffectedRows(input *RuleHandlerInput) error {
}

affectCount, err := util.GetAffectedRowNum(
context.TODO(), input.Node.Text(), input.Ctx.GetExecutor())
context.TODO(), input.Node.Text(), input.Ctx.GetExecutor(), input.Ctx.GetExecutionPlan)
if err != nil {
log.NewEntry().Errorf("rule: %v; SQL: %v; get affected row number failed: %v", input.Rule.Name, input.Node.Text(), err)
return nil
Expand Down Expand Up @@ -5193,7 +5193,7 @@ func checkSelectRows(input *RuleHandlerInput) error {
if !notUseIndex {
return nil
}
affectCount, err := util.GetAffectedRowNum(context.TODO(), input.Node.Text(), input.Ctx.GetExecutor())
affectCount, err := util.GetAffectedRowNum(context.TODO(), input.Node.Text(), input.Ctx.GetExecutor(), input.Ctx.GetExecutionPlan)
if err != nil {
return err
}
Expand Down
51 changes: 42 additions & 9 deletions sqle/driver/mysql/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ import (

var ErrUnsupportedSqlType = errors.New("unsupported sql type")

func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Executor) (int64, error) {
func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Executor, explainRecordFunc func(string) ([]*executor.ExplainRecord, error)) (int64, error) {
node, err := ParseOneSql(originSql)
if err != nil {
return 0, err
}

var newNode ast.Node
var affectRowSql string
var affectedRowSql string
var cannotConvert bool

// 语法规则文档
Expand Down Expand Up @@ -77,38 +77,71 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
}
// 移除后缀分号,避免sql语法错误
trimSuffix := strings.TrimRight(newSql, ";")
affectRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix)
affectedRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix)
} else {
sqlBuilder := new(strings.Builder)
err = newNode.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, sqlBuilder))
if err != nil {
return 0, err
}

affectRowSql = sqlBuilder.String()
affectedRowSql = sqlBuilder.String()
}

// 验证sql语法是否正确,select 字段是否有且仅有 count(*)
// 避免在客户机器上执行不符合预期的sql语句
err = checkSql(affectRowSql)
err = checkSql(affectedRowSql)
if err != nil {
return 0, fmt.Errorf("check sql(%v) failed, origin sql(%v), err: %v", affectRowSql, originSql, err)
return 0, fmt.Errorf("check sql(%v) failed, origin sql(%v), err: %v", affectedRowSql, originSql, err)
}

_, row, err := conn.Db.QueryWithContext(ctx, affectRowSql)
// explain 全表扫描 (type 为 ALL): 避免执行 SELECT COUNT(1),直接拿EXPLAIN影响行数作为结果
// 索引访问 ( type 非ALL)如果 rows 较小(小于10W),可以执行 SELECT COUNT(1)。否则依然拿EXPLAIN影响行数作为结果
// | id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra |
// |----|-------------|-----------|-------|---------------|---------|---------|---------------|--------|-------------|
// | 1 | SIMPLE | o | ref | idx_status | idx_status | 10 | const | 5000 | Using where |
// | 1 | SIMPLE | c | eq_ref | PRIMARY | PRIMARY | 4 | orders.customer_id | 1 | |

epRecords, err := explainRecordFunc(affectedRowSql)
if err != nil {
log.NewEntry().Errorf("get execution plan failed, sqle: %v, error: %v", originSql, err)
return 0, err
}

var notUseIndex bool
var affetcCount int64
var estimatedRows int64

// 检查是否所有记录都使用了索引
for _, record := range epRecords {
if record.Type == executor.ExplainRecordAccessTypeAll {
notUseIndex = true
}
// 统计查询过程中所有的影响行数
estimatedRows += record.Rows
// 最后一行记录的row作为结果行数
affetcCount = record.Rows
}

// 如果有记录未使用索引,或者统计影响行数大于10W
if notUseIndex || estimatedRows > 100000 {
return affetcCount, nil
}

_, row, err := conn.Db.QueryWithContext(ctx, affectedRowSql)
if err != nil {
return 0, err
}

// 如果下发的 SELECT COUNT(1) 的SQL,返回的结果集为空, 则返回0
// 例: SELECT COUNT(1) FROM test LIMIT 10,10 结果集为空
if len(row) == 0 {
log.NewEntry().Errorf("affected row sql(%v) result row count is 0", affectRowSql)
log.NewEntry().Errorf("affected row sql(%v) result row count is 0", affectedRowSql)
return 0, nil
}

if len(row) < 1 {
return 0, fmt.Errorf("affected row sql(%v) result row count(%v) less than 1", affectRowSql, len(row))
return 0, fmt.Errorf("affected row sql(%v) result row count(%v) less than 1", affectedRowSql, len(row))
}

affectCount, err := strconv.ParseInt(row[0][0].String, 10, 64)
Expand Down

0 comments on commit 02809c4

Please sign in to comment.