Skip to content

Commit

Permalink
feat: improve ruby parsing (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronlifton authored Jan 15, 2025
1 parent 0df03dd commit 0a837a4
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 19 deletions.
28 changes: 19 additions & 9 deletions crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
;; Capture top-level methods, class definitions, and methods within classes
(program
(class
(body_statement
(call) @class_call
(assignment) @class_assignment
(method) @method
)
) @class
)

(class
(body_statement
(call)? @class_call
(assignment)? @class_assignment
(method)? @method
)
) @class

(program
(method) @function
)
(program
(assignment) @assignment
)

(module) @module

(module
(body_statement
(call)? @class_call
(assignment)? @class_assignment
(method)? @method
)
)
165 changes: 155 additions & 10 deletions crates/avante-repo-map/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub struct Variable {
pub enum Definition {
Func(Func),
Class(Class),
Module(Class),
Enum(Enum),
Variable(Variable),
Union(Union),
Expand Down Expand Up @@ -158,6 +159,24 @@ fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<
None
}

fn ruby_method_is_private<'a>(node: &'a Node, source: &'a [u8]) -> bool {
let mut prev_sibling = node.prev_sibling();
while let Some(prev_sibling_node) = prev_sibling {
if prev_sibling_node.kind() == "identifier" {
let text = prev_sibling_node.utf8_text(source).unwrap_or_default();
if text == "private" {
return true;
} else if text == "public" || text == "protected" {
return false;
}
} else if prev_sibling_node.kind() == "class" || prev_sibling_node.kind() == "module" {
return false;
}
prev_sibling = prev_sibling_node.prev_sibling();
}
false
}

fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>> {
node.children(&mut node.walk())
.find(|child| child.kind() == child_type)
Expand Down Expand Up @@ -234,6 +253,30 @@ fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8])
None
}

fn ruby_find_parent_module_declaration_name<'a>(
node: &'a Node,
source: &'a [u8],
) -> Option<String> {
let mut path_parts = Vec::new();
let mut current = Some(*node);

while let Some(current_node) = current {
if current_node.kind() == "module" || current_node.kind() == "class" {
if let Some(name_node) = current_node.child_by_field_name("name") {
path_parts.push(get_node_text(&name_node, source));
}
}
current = current_node.parent();
}

if path_parts.is_empty() {
None
} else {
path_parts.reverse();
Some(path_parts.join("::"))
}
}

fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
node.utf8_text(source).unwrap_or_default().to_string()
}
Expand Down Expand Up @@ -301,6 +344,18 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
});
};

let ensure_module_def = |name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
class_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Class {
name: name.to_string(),
type_name: "module".to_string(),
methods: vec![],
properties: vec![],
visibility_modifier: None,
})
});
};

let ensure_enum_def = |name: &str, enum_def_map: &mut BTreeMap<String, RefCell<Enum>>| {
enum_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Enum {
Expand Down Expand Up @@ -395,6 +450,19 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.unwrap_or(node_text)
.to_string()
}
"ruby" => {
let name = node
.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(node_text)
.to_string();
if *capture_name == "class" || *capture_name == "module" {
ruby_find_parent_module_declaration_name(&node, source.as_bytes())
.unwrap_or(name)
} else {
name
}
}
_ => node
.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
Expand Down Expand Up @@ -423,6 +491,11 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
};
}
}
"module" => {
if !name.is_empty() {
ensure_module_def(&name, &mut class_def_map);
}
}
"enum_item" => {
let visibility_modifier_node =
find_descendant_by_type(&node, "visibility_modifier");
Expand Down Expand Up @@ -623,6 +696,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
.unwrap_or("")
.to_string()
} else if language == "ruby" {
ruby_find_parent_module_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if let Some(impl_item) = impl_item_node {
let impl_type_node = impl_item.child_by_field_name("type");
impl_type_node
Expand All @@ -649,9 +725,17 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,

let accessibility_modifier_node =
find_descendant_by_type(&node, "accessibility_modifier");
let accessibility_modifier = accessibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let accessibility_modifier = if language == "ruby" {
if ruby_method_is_private(&node, source.as_bytes()) {
"private"
} else {
""
}
} else {
accessibility_modifier_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("")
};

let func = Func {
name: name.to_string(),
Expand Down Expand Up @@ -679,12 +763,17 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or("");
let value_type = get_node_type(&node, source.as_bytes());
let class_name = get_closest_ancestor_name(&node, source);
if !class_name.is_empty()
&& language == "go"
&& !is_first_letter_uppercase(&class_name)
{
continue;
let mut class_name = get_closest_ancestor_name(&node, source);
if !class_name.is_empty() {
if language == "ruby" {
if let Some(namespaced_name) =
ruby_find_parent_module_declaration_name(&node, source.as_bytes())
{
class_name = namespaced_name;
}
} else if language == "go" && !is_first_letter_uppercase(&class_name) {
continue;
}
}
if class_name.is_empty() {
continue;
Expand Down Expand Up @@ -1057,6 +1146,7 @@ fn stringify_definitions(definitions: &Vec<Definition>) -> String {
for definition in definitions {
match definition {
Definition::Class(class) => res = format!("{res}{}", stringify_class(class)),
Definition::Module(module) => res = format!("{res}{}", stringify_class(module)),
Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)),
Definition::Union(union_def) => res = format!("{res}{}", stringify_union(union_def)),
Definition::Func(func) => res = format!("{res}{}", stringify_function(func)),
Expand Down Expand Up @@ -1434,7 +1524,62 @@ mod tests {
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
// FIXME:
let expected = "var test_var;func test_func(a, b) -> void;";
let expected = "var test_var;func test_func(a, b) -> void;class InnerClassInFunc{func initialize(a, b) -> void;func test_method(a, b) -> void;};class TestClass{func initialize(a, b) -> void;func test_method(a, b) -> void;};";
assert_eq!(stringified, expected);
}

#[test]
fn test_ruby2() {
let source = r#"
# frozen_string_literal: true
require('jwt')
top_level_var = 1
def top_level_func
inner_var_in_func = 2
end
module A
module B
@module_var = :foo
def module_method
@module_var
end
class C < Base
TEST_CONST = 1
@class_var = :bar
attr_accessor :a, :b
def initialize(a, b)
@a = a
@b = b
super
end
def bar
inner_var_in_method = 1
true
end
private
def baz(request, params)
auth_header = request.headers['Authorization']
parts = auth_header.try(:split, /\s+/)
JWT.decode(parts.last)
end
end
end
end
"#;
let definitions = extract_definitions("ruby", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "var top_level_var;func top_level_func() -> void;module A{};module A::B{func module_method() -> void;var @module_var;};class A::B::C{func initialize(a, b) -> void;func bar() -> void;private func baz(request, params) -> void;var TEST_CONST;var @class_var;};";
assert_eq!(stringified, expected);
}

Expand Down

0 comments on commit 0a837a4

Please sign in to comment.