Skip to content

Commit

Permalink
Add Semantic Deduplication (#9)
Browse files Browse the repository at this point in the history
Add a semantic deduplication operator. The dedup is performed based on
semantic similarity via embedding model. Pairs of elements whose
similarity exceed `threshold` are considered duplicates.

Example
```
data = {
    "Text": [
        "Probability and Random Processes",
        "Optimization Methods in Engineering",
        "Digital Design and Integrated Circuits",
        "Computer Security",
        "I don't know what day it is",
        "I don't know what time it is",
        "Harry potter and the Sorcerer's Stone",
    ]
}
df = pd.DataFrame(data)
df = df.sem_index("Text", "index_dir").sem_dedup("Text", threshold=0.815)
print(df)
```

Will print
```
                                    Text
1    Optimization Methods in Engineering
4            I don't know what day it is
6  Harry potter and the Sorcerer's Stone
```
  • Loading branch information
sidjha1 authored Sep 29, 2024
1 parent 9f0db7d commit ff20f9b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
22 changes: 22 additions & 0 deletions examples/op_examples/dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pandas as pd

import lotus
from lotus.models import E5Model

rm = E5Model()

lotus.settings.configure(rm=rm)
data = {
"Text": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
"I don't know what day it is",
"I don't know what time it is",
"Harry potter and the Sorcerer's Stone",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Text", "index_dir").sem_dedup("Text", threshold=0.815)
print(df)
2 changes: 2 additions & 0 deletions lotus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
sem_search,
sem_sim_join,
sem_cluster_by,
sem_dedup,
sem_topk,
)
from lotus.settings import settings
Expand All @@ -36,6 +37,7 @@
"sem_sim_join",
"sem_cluster_by",
"sem_search",
"sem_dedup",
"settings",
"nl_expression",
"templates",
Expand Down
1 change: 1 addition & 0 deletions lotus/sem_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
"sem_sim_join",
"sem_cluster_by",
"sem_partition_by",
"sem_dedup",
]
82 changes: 82 additions & 0 deletions lotus/sem_ops/sem_dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from collections import defaultdict

import pandas as pd

import lotus


@pd.api.extensions.register_dataframe_accessor("sem_dedup")
class SemDedupByDataframe:
"""DataFrame accessor for semantic deduplication."""

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj

@staticmethod
def _validate(obj):
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

def __call__(
self,
col_name: str,
threshold: float,
) -> pd.DataFrame:
"""
Perform semantic deduplication on the DataFrame.
Args:
col_name (str): The column name to deduplicate on.
threshold (float): The threshold for similarity score.
Returns:
pd.DataFrame: The DataFrame with duplicates removed.
"""
joined_df = self._obj.sem_sim_join(self._obj, col_name, col_name, len(self._obj), lsuffix="_l", rsuffix="_r")
dedup_df = joined_df[joined_df["_scores"] > threshold]
dedup_df = dedup_df[dedup_df[f"{col_name}_l"] != dedup_df[f"{col_name}_r"]]
lotus.logger.debug(f"dedup_df: {dedup_df}")
left_col_name, right_col_name = f"{col_name}_l", f"{col_name}_r"

pairs = set()
for _, row in dedup_df.iterrows():
left_val, right_val = row[left_col_name], row[right_col_name]
if left_val == right_val:
continue
pairs.add((left_val, right_val))

def find_connected_components(pairs):
graph = defaultdict(set)
for left_val, right_val in pairs:
graph[left_val].add(right_val)
graph[right_val].add(left_val)

visited = set()
components = []

def dfs(node, component):
stack = [node]
while stack:
current = stack.pop()
if current not in visited:
visited.add(current)
component.append(current)
stack.extend(graph[current] - visited)

for node in graph:
if node not in visited:
component = []
dfs(node, component)
components.append(component)

return components

connected_components = find_connected_components(pairs)
lotus.logger.debug(f"dedup connected components: {connected_components}")

removed_vals = []
for component in connected_components:
removed_vals.extend(component[1:])

return self._obj[~self._obj[col_name].isin(removed_vals)]

0 comments on commit ff20f9b

Please sign in to comment.