Skip to content

Commit

Permalink
fix pytest for the new visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
gesuwen committed Mar 4, 2024
2 parents 05aaad2 + 1d57713 commit b5fa0ea
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
1 change: 0 additions & 1 deletion torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
update_functorch_imports_visitor = _UpdateFunctorchImports()
new_module = new_module.visit(update_functorch_imports_visitor)


if fixes_count == 0 and not update_functorch_imports_visitor.changed:
raise codemod.SkipFile("No changes")

Expand Down
4 changes: 2 additions & 2 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ def visit_Import(self, node: cst.Import) -> None:
if (
isinstance(imported_item.name.value, cst.Name)
and imported_item.name.value.value == "torchvision"
and isinstance(imported_item.name.attr, cst.Name)
and imported_item.name.attr.value == "models"
and imported_item.asname is not None
and isinstance(imported_item.asname.name, cst.Name)
and imported_item.asname.name.value == "models"
):
print(imported_item.asname.name.value)
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
# print(position)
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
Expand Down

0 comments on commit b5fa0ea

Please sign in to comment.