Skip to content

Commit

Permalink
add unit test and initial logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nimakarimipour committed Nov 1, 2023
1 parent 71cbb5f commit 9812b94
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ public static Symbol.MethodSymbol getClosestOverriddenMethod(
return null;
}

/**
* Checks if the given element is a constructor.
*
* @param element The element to check.
* @return true if the given element is a constructor, false otherwise.
*/
public static boolean isConstructor(Element element) {
if (!(element instanceof Symbol.MethodSymbol)) {
return false;
}
return ((Symbol.MethodSymbol) element).name.toString().equals("<init>");
}

/**
* Gets the type of the given element. If the given element is a method, then the return type of
* the method is returned.
Expand All @@ -90,6 +103,9 @@ public static Symbol.MethodSymbol getClosestOverriddenMethod(
* @return The type of the given element.
*/
public static Type getType(Element element) {
if (isConstructor(element)) {
return ((Symbol.MethodSymbol) element).enclClass().type;
}
return element instanceof Symbol.MethodSymbol
? ((Symbol.MethodSymbol) element).getReturnType()
: ((Symbol) element).type;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package edu.ucr.cs.riple.taint.ucrtainting.serialization.location;

import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.tree.JCTree;
import edu.ucr.cs.riple.taint.ucrtainting.serialization.visitors.LocationVisitor;

public class ClassDeclarationLocation extends AbstractSymbolLocation{


public ClassDeclarationLocation(LocationKind kind, Symbol target, JCTree tree) {
super(kind, target, tree);
}

@Override
public <R, P> R accept(LocationVisitor<R, P> v, P p) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ public enum LocationKind {
PARAMETER,
LOCAL_VARIABLE,
METHOD,
CLASS_DECLARATION,
POLY_METHOD;

public boolean isField() {
Expand All @@ -26,4 +27,8 @@ public boolean isPoly() {
public boolean isLocalVariable() {
return this == LOCAL_VARIABLE;
}

public boolean isClassDeclaration() {
return this == CLASS_DECLARATION;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package edu.ucr.cs.riple.taint.ucrtainting.serialization.visitors;

import com.sun.source.tree.Tree;
import com.sun.tools.javac.util.Context;
import edu.ucr.cs.riple.taint.ucrtainting.FoundRequired;
import edu.ucr.cs.riple.taint.ucrtainting.UCRTaintingAnnotatedTypeFactory;
import edu.ucr.cs.riple.taint.ucrtainting.serialization.Fix;
import java.util.Set;

public class ClassDeclarationVisitor extends SpecializedFixComputer {

public ClassDeclarationVisitor(
Context context, UCRTaintingAnnotatedTypeFactory typeFactory, FixComputer fixComputer) {
super(context, typeFactory, fixComputer);
}

public Set<Fix> compute(Tree node, FoundRequired pair) {
throw new RuntimeException("Not implemented");
// return Set.of();
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package edu.ucr.cs.riple.taint.ucrtainting.serialization.visitors;

import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.Tree;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.util.Context;
import edu.ucr.cs.riple.taint.ucrtainting.FoundRequired;
import edu.ucr.cs.riple.taint.ucrtainting.UCRTaintingAnnotatedTypeFactory;
import edu.ucr.cs.riple.taint.ucrtainting.serialization.Fix;
import edu.ucr.cs.riple.taint.ucrtainting.serialization.Utility;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import javax.lang.model.element.Element;
import javax.lang.model.type.TypeMirror;
import org.checkerframework.framework.type.AnnotatedTypeFactory;
import org.checkerframework.framework.type.AnnotatedTypeMirror;
import org.checkerframework.framework.util.AnnotatedTypes;
Expand All @@ -22,9 +29,12 @@
*/
public class MethodTypeArgumentFixVisitor extends SpecializedFixComputer {

private final Types types;

public MethodTypeArgumentFixVisitor(
Context context, UCRTaintingAnnotatedTypeFactory factory, FixComputer fixComputer) {
super(context, factory, fixComputer);
this.types = Types.instance(context);
}

@Override
Expand All @@ -47,13 +57,27 @@ public Set<Fix> visitMethodInvocation(MethodInvocationTree node, FoundRequired p
Type paramType = calledMethod.getParameters().get(i).type;
boolean changed = updateAnnotatedTypeMirror(requiredParam, paramType, typeVar);
if (changed) {
ExpressionTree arg = node.getArguments().get(i);
AnnotatedTypeMirror paramAnnotatedType = typeFactory.getAnnotatedType(arg);
FoundRequired newPair =
new FoundRequired(paramAnnotatedType, requiredParam, pair.depth);
TypeMirror paramTypeMirror = paramAnnotatedType.getUnderlyingType();
if ((paramTypeMirror instanceof Type.ClassType)
&& (requiredParam.getUnderlyingType() instanceof Type.ClassType)) {
Type.ClassType nodeClassType = (Type.ClassType) paramTypeMirror;
Type.ClassType requiredClassType = (Type.ClassType) requiredParam.getUnderlyingType();
// check we are type matching a raw type on a class with type.
if (nodeClassType.tsym.type.getTypeArguments().isEmpty()
&& !requiredClassType.tsym.type.getTypeArguments().isEmpty()) {
Set<Fix> onDeclaration =
computeFixesOnClassDeclarationForRawType(arg, newPair, typeVar);
if (!onDeclaration.isEmpty()) {
return onDeclaration;
}
}
}
fixes.addAll(
node.getArguments()
.get(i)
.accept(
new FixComputer(context, typeFactory),
new FoundRequired(
paramsAnnotatedTypeMirrors.get(i), requiredParam, pair.depth)));
node.getArguments().get(i).accept(new FixComputer(context, typeFactory), newPair));
}
}
}
Expand Down Expand Up @@ -181,4 +205,40 @@ private boolean updateAnnotatedTypeMirror(
}
return updated;
}

public Set<Fix> computeFixesOnClassDeclarationForRawType(
Tree node, FoundRequired pair, Type.TypeVar typeVar) {
Type type = Utility.getType(TreeUtils.elementFromTree(node));
if (!(type.tsym instanceof Symbol.ClassSymbol)) {
return Set.of();
}
Symbol.ClassSymbol classType = (Symbol.ClassSymbol) type.tsym;
if (!typeFactory.isAnnotatedPackage(type.tsym.packge().fullname.toString())
|| !(pair.required instanceof AnnotatedTypeMirror.AnnotatedDeclaredType)) {
return Set.of();
}
AnnotatedTypeMirror.AnnotatedDeclaredType required =
(AnnotatedTypeMirror.AnnotatedDeclaredType) pair.required;
Type.ClassType requiredType =
(Type.ClassType) ((Type.ClassType) required.getUnderlyingType()).tsym.type;
Type.ClassType inheritedType = locateInheritedTypeOnExtendOrImplement(classType, requiredType);
if(inheritedType == null){
return Set.of();
}

// We intentionally limit the search to only the first level of inheritance. The type must
// either extend or implement the required type explicitly at the declaration.
throw new RuntimeException("Not implemented");
// return Set.of();
}

private Type.ClassType locateInheritedTypeOnExtendOrImplement(Symbol.ClassSymbol classType, Type.ClassType requiredType) {
// Look for interfaces
for (Type type : ((Type.ClassType) classType.type).interfaces_field) {
if(type.tsym.equals(requiredType.tsym)){
return (Type.ClassType) type;
}
}
return null;
}
}
178 changes: 89 additions & 89 deletions tests/polytaintserialization/foo/bar/DepthTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,93 +6,93 @@

public class DepthTest {

String foo1(String param) {
String s = param;
String ans = foo2(s);
return ans;
}

String foo2(String param) {
String s = param;
String ans = foo3(s);
return ans;
}

String foo3(String param) {
String s = param;
String ans = foo4(s);
return ans;
}

String foo4(String param) {
String s = param;
String ans = foo5(s);
return ans;
}

String foo5(String param) {
String s = param;
String ans = foo6(s);
return ans;
}

String foo6(String param) {
String s = param;
String ans = foo7(s);
return ans;
}

String foo7(String param) {
String s = param;
String ans = foo8(s);
return ans;
}

String foo8(String param) {
String s = param;
String ans = foo9(s);
return ans;
}

String foo9(String param) {
String s = param;
String ans = foo10(s);
return ans;
}

String foo10(String param) {
String s = param;
String ans = s;
return ans;
}

String bar1(String param) {
String s = param;
String ans = bar2(s);
return ans;
}

String bar2(String param) {
String s = param;
String ans = bar3(s);
return ans;
}

String bar3(String param) {
String s = param;
String ans = bar4(s);
return ans;
}

String bar4(String param) {
String s = param;
return s;
}

public void test(String param) {
// :: error: assignment
@RUntainted String outOfBound = foo1(param);
// :: error: assignment
@RUntainted String inBound = bar1(param);
}
// String foo1(String param) {
// String s = param;
// String ans = foo2(s);
// return ans;
// }
//
// String foo2(String param) {
// String s = param;
// String ans = foo3(s);
// return ans;
// }
//
// String foo3(String param) {
// String s = param;
// String ans = foo4(s);
// return ans;
// }
//
// String foo4(String param) {
// String s = param;
// String ans = foo5(s);
// return ans;
// }
//
// String foo5(String param) {
// String s = param;
// String ans = foo6(s);
// return ans;
// }
//
// String foo6(String param) {
// String s = param;
// String ans = foo7(s);
// return ans;
// }
//
// String foo7(String param) {
// String s = param;
// String ans = foo8(s);
// return ans;
// }
//
// String foo8(String param) {
// String s = param;
// String ans = foo9(s);
// return ans;
// }
//
// String foo9(String param) {
// String s = param;
// String ans = foo10(s);
// return ans;
// }
//
// String foo10(String param) {
// String s = param;
// String ans = s;
// return ans;
// }
//
// String bar1(String param) {
// String s = param;
// String ans = bar2(s);
// return ans;
// }
//
// String bar2(String param) {
// String s = param;
// String ans = bar3(s);
// return ans;
// }
//
// String bar3(String param) {
// String s = param;
// String ans = bar4(s);
// return ans;
// }
//
// String bar4(String param) {
// String s = param;
// return s;
// }
//
// public void test(String param) {
// // :: error: assignment
// @RUntainted String outOfBound = foo1(param);
// // :: error: assignment
// @RUntainted String inBound = bar1(param);
// }
}
Loading

0 comments on commit 9812b94

Please sign in to comment.