-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #143 from diffix/paul-expose-stitch-iface
Added stitcher and test code
- Loading branch information
Showing
3 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from .stitcher import stitch | ||
from .synthesizer import Synthesizer | ||
|
||
__all__ = [ | ||
"Synthesizer", | ||
"stitch", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from .clustering.common import MicrodataRow, StitchOwner | ||
from .clustering.measures import measure_entropy | ||
from .clustering.stitching import StitchingMetadata, _do_stitch | ||
from .clustering.strategy import NoClustering | ||
from .common import ColumnId, Combination | ||
from .synthesizer import Synthesizer | ||
|
||
|
||
def _make_synthesizers( | ||
df_left: pd.DataFrame, | ||
df_right: pd.DataFrame, | ||
col_names_stitch: list[str], | ||
) -> Tuple[Synthesizer, Synthesizer, Synthesizer]: | ||
syn_left = Synthesizer(df_left, clustering=NoClustering()) | ||
syn_right = Synthesizer(df_right, clustering=NoClustering()) | ||
df_left_short = df_left.iloc[[0]] | ||
df_right_short = df_right.iloc[[0]].copy() | ||
for col_name in col_names_stitch: | ||
df_right_short.rename(columns={col_name: f"__fake__{col_name}"}, inplace=True) | ||
syn_fake = Synthesizer(pd.concat([df_left_short, df_right_short], axis=1), clustering=NoClustering()) | ||
syn_fake.forest.snapped_intervals = syn_left.forest.snapped_intervals + syn_right.forest.snapped_intervals | ||
return syn_fake, syn_left, syn_right | ||
|
||
|
||
def _make_microdata( | ||
df: pd.DataFrame, | ||
syn: Synthesizer, | ||
columns: list[ColumnId], | ||
) -> tuple[list[MicrodataRow], Combination]: | ||
col_names = list(df.columns) | ||
microdata: list[MicrodataRow] = [] | ||
|
||
for i in range(len(df)): | ||
row = [] | ||
for col_name in col_names: | ||
row.append((df[col_name].iloc[i], float(syn.forest.orig_data[col_name].iloc[i]))) | ||
microdata.append(row) | ||
return (microdata, tuple(columns)) | ||
|
||
|
||
def stitch(df_left: pd.DataFrame, df_right: pd.DataFrame, shared: bool = True) -> pd.DataFrame: | ||
# Make the needed column names with the original names (not later rename) | ||
col_names_left = list(df_left.columns) | ||
col_names_right = list(df_right.columns) | ||
col_names_stitch = list(set(col_names_left) & set(col_names_right)) | ||
col_names_right_minus_stitch = [col for col in col_names_right if col not in col_names_stitch] | ||
|
||
syn_fake, syn_left, syn_right = _make_synthesizers(df_left, df_right, col_names_stitch) | ||
entropy_1dim_left = np.array( | ||
[measure_entropy(syn_left.forest.get_tree((ColumnId(i),))) for i in range(len(syn_left.forest.columns))], | ||
dtype=float, | ||
) | ||
entropy_1dim_right = np.array( | ||
[measure_entropy(syn_right.forest.get_tree((ColumnId(i),))) for i in range(len(syn_right.forest.columns))], | ||
dtype=float, | ||
) | ||
stitching_metadata = StitchingMetadata( | ||
syn_left.column_is_integral + syn_right.column_is_integral, | ||
np.concatenate((entropy_1dim_left, entropy_1dim_right)), | ||
) | ||
|
||
col_names_all = syn_fake.forest.columns | ||
columns_left = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_left] | ||
columns_right = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_right] | ||
columns_right_minus_stitch = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_right_minus_stitch] | ||
columns_stitch = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_stitch] | ||
owner = StitchOwner.SHARED if shared else StitchOwner.LEFT | ||
derived_cluster = (owner, columns_stitch, columns_right_minus_stitch) | ||
|
||
microdata_left = _make_microdata(df_left, syn_left, columns_left) | ||
microdata_right = _make_microdata(df_right, syn_right, columns_right) | ||
|
||
(microdata, columns) = _do_stitch( | ||
syn_fake.forest, stitching_metadata, microdata_left, microdata_right, derived_cluster | ||
) | ||
col_names = [col_names_all[col_id] for col_id in columns] | ||
data = [[tup[0] for tup in row] for row in microdata] | ||
return pd.DataFrame(data, columns=col_names) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from syndiffix import Synthesizer | ||
from syndiffix.stitcher import stitch | ||
|
||
|
||
def make_dataframe(N: int) -> pd.DataFrame: | ||
a = np.random.randint(100, 200, size=N) | ||
b = a + np.random.randint(200, 211, size=N) | ||
c = b + np.random.uniform(2.0, 4.0, size=N) | ||
d = np.random.choice(["x", "y", "z"], size=N) | ||
e = np.random.choice(["i", "j", "k"], size=N) | ||
return pd.DataFrame({"i1": a, "i2": b, "f": c, "t1": d, "t2": e}) | ||
|
||
|
||
df = make_dataframe(500) | ||
|
||
|
||
def test_left_stitch1() -> None: | ||
# Make dataframes with a single stitchable shared column (close but not identical) | ||
df_left = Synthesizer(df[["i1", "i2", "t1"]]).sample() | ||
df_right = Synthesizer(df[["i1", "f", "t2"]]).sample() | ||
# Ensure that the dataframes are of different length | ||
df_right = df_right.sample(n=len(df_left) - 10, random_state=42) | ||
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=False) | ||
assert set(df_left["i1"]) == set(df_stitched["i1"]) | ||
if set(df_left["i1"]) != set(df_right["i1"]): | ||
assert set(df_right["i1"]) != set(df_stitched["i1"]) | ||
assert len(df_left) == len(df_stitched) | ||
|
||
|
||
def test_left_stitch2() -> None: | ||
df_left = Synthesizer(df[["i1", "i2", "t1"]]).sample() | ||
df_right = Synthesizer(df[["i1", "i2", "t2"]]).sample() | ||
df_left = df_left.sample(n=len(df_right) - 10, random_state=42) | ||
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=False) | ||
for i in ["i1", "i2"]: | ||
assert set(df_left[i]) == set(df_stitched[i]) | ||
if set(df_left[i]) != set(df_right[i]): | ||
assert set(df_right[i]) != set(df_stitched[i]) | ||
assert len(df_left) == len(df_stitched) | ||
|
||
|
||
def test_left_stitch1_1() -> None: | ||
df_left = Synthesizer(df[["i1"]]).sample() | ||
df_right = Synthesizer(df[["i1", "f", "t2"]]).sample() | ||
df_right = df_right.sample(n=len(df_left) - 10, random_state=42) | ||
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=False) | ||
assert set(df_left["i1"]) == set(df_stitched["i1"]) | ||
if set(df_left["i1"]) != set(df_right["i1"]): | ||
assert set(df_right["i1"]) != set(df_stitched["i1"]) | ||
assert set(df_right["f"]) != set(df_stitched["f"]) | ||
assert len(df_left) == len(df_stitched) | ||
|
||
|
||
def test_shared_stitch1() -> None: | ||
df_left = Synthesizer(df[["i1", "i2", "t1"]]).sample() | ||
df_right = Synthesizer(df[["i1", "f", "t2"]]).sample() | ||
df_right = df_right.sample(n=len(df_left) - 10, random_state=42) | ||
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=True) | ||
assert len(df_left) > len(df_stitched) | ||
assert len(df_right) < len(df_stitched) |