Skip to content

Commit

Permalink
Merge pull request #143 from diffix/paul-expose-stitch-iface
Browse files Browse the repository at this point in the history
Added stitcher and test code
  • Loading branch information
yoid2000 authored Sep 17, 2024
2 parents 44ae817 + aca3ca6 commit a27d232
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
2 changes: 2 additions & 0 deletions syndiffix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .stitcher import stitch
from .synthesizer import Synthesizer

__all__ = [
"Synthesizer",
"stitch",
]
83 changes: 83 additions & 0 deletions syndiffix/stitcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Tuple

import numpy as np
import pandas as pd

from .clustering.common import MicrodataRow, StitchOwner
from .clustering.measures import measure_entropy
from .clustering.stitching import StitchingMetadata, _do_stitch
from .clustering.strategy import NoClustering
from .common import ColumnId, Combination
from .synthesizer import Synthesizer


def _make_synthesizers(
df_left: pd.DataFrame,
df_right: pd.DataFrame,
col_names_stitch: list[str],
) -> Tuple[Synthesizer, Synthesizer, Synthesizer]:
syn_left = Synthesizer(df_left, clustering=NoClustering())
syn_right = Synthesizer(df_right, clustering=NoClustering())
df_left_short = df_left.iloc[[0]]
df_right_short = df_right.iloc[[0]].copy()
for col_name in col_names_stitch:
df_right_short.rename(columns={col_name: f"__fake__{col_name}"}, inplace=True)
syn_fake = Synthesizer(pd.concat([df_left_short, df_right_short], axis=1), clustering=NoClustering())
syn_fake.forest.snapped_intervals = syn_left.forest.snapped_intervals + syn_right.forest.snapped_intervals
return syn_fake, syn_left, syn_right


def _make_microdata(
df: pd.DataFrame,
syn: Synthesizer,
columns: list[ColumnId],
) -> tuple[list[MicrodataRow], Combination]:
col_names = list(df.columns)
microdata: list[MicrodataRow] = []

for i in range(len(df)):
row = []
for col_name in col_names:
row.append((df[col_name].iloc[i], float(syn.forest.orig_data[col_name].iloc[i])))
microdata.append(row)
return (microdata, tuple(columns))


def stitch(df_left: pd.DataFrame, df_right: pd.DataFrame, shared: bool = True) -> pd.DataFrame:
# Make the needed column names with the original names (not later rename)
col_names_left = list(df_left.columns)
col_names_right = list(df_right.columns)
col_names_stitch = list(set(col_names_left) & set(col_names_right))
col_names_right_minus_stitch = [col for col in col_names_right if col not in col_names_stitch]

syn_fake, syn_left, syn_right = _make_synthesizers(df_left, df_right, col_names_stitch)
entropy_1dim_left = np.array(
[measure_entropy(syn_left.forest.get_tree((ColumnId(i),))) for i in range(len(syn_left.forest.columns))],
dtype=float,
)
entropy_1dim_right = np.array(
[measure_entropy(syn_right.forest.get_tree((ColumnId(i),))) for i in range(len(syn_right.forest.columns))],
dtype=float,
)
stitching_metadata = StitchingMetadata(
syn_left.column_is_integral + syn_right.column_is_integral,
np.concatenate((entropy_1dim_left, entropy_1dim_right)),
)

col_names_all = syn_fake.forest.columns
columns_left = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_left]
columns_right = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_right]
columns_right_minus_stitch = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_right_minus_stitch]
columns_stitch = [ColumnId(col_names_all.index(col_name)) for col_name in col_names_stitch]
owner = StitchOwner.SHARED if shared else StitchOwner.LEFT
derived_cluster = (owner, columns_stitch, columns_right_minus_stitch)

microdata_left = _make_microdata(df_left, syn_left, columns_left)
microdata_right = _make_microdata(df_right, syn_right, columns_right)

(microdata, columns) = _do_stitch(
syn_fake.forest, stitching_metadata, microdata_left, microdata_right, derived_cluster
)
col_names = [col_names_all[col_id] for col_id in columns]
data = [[tup[0] for tup in row] for row in microdata]
return pd.DataFrame(data, columns=col_names)
63 changes: 63 additions & 0 deletions tests/test_stitcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import pandas as pd

from syndiffix import Synthesizer
from syndiffix.stitcher import stitch


def make_dataframe(N: int) -> pd.DataFrame:
a = np.random.randint(100, 200, size=N)
b = a + np.random.randint(200, 211, size=N)
c = b + np.random.uniform(2.0, 4.0, size=N)
d = np.random.choice(["x", "y", "z"], size=N)
e = np.random.choice(["i", "j", "k"], size=N)
return pd.DataFrame({"i1": a, "i2": b, "f": c, "t1": d, "t2": e})


df = make_dataframe(500)


def test_left_stitch1() -> None:
# Make dataframes with a single stitchable shared column (close but not identical)
df_left = Synthesizer(df[["i1", "i2", "t1"]]).sample()
df_right = Synthesizer(df[["i1", "f", "t2"]]).sample()
# Ensure that the dataframes are of different length
df_right = df_right.sample(n=len(df_left) - 10, random_state=42)
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=False)
assert set(df_left["i1"]) == set(df_stitched["i1"])
if set(df_left["i1"]) != set(df_right["i1"]):
assert set(df_right["i1"]) != set(df_stitched["i1"])
assert len(df_left) == len(df_stitched)


def test_left_stitch2() -> None:
df_left = Synthesizer(df[["i1", "i2", "t1"]]).sample()
df_right = Synthesizer(df[["i1", "i2", "t2"]]).sample()
df_left = df_left.sample(n=len(df_right) - 10, random_state=42)
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=False)
for i in ["i1", "i2"]:
assert set(df_left[i]) == set(df_stitched[i])
if set(df_left[i]) != set(df_right[i]):
assert set(df_right[i]) != set(df_stitched[i])
assert len(df_left) == len(df_stitched)


def test_left_stitch1_1() -> None:
df_left = Synthesizer(df[["i1"]]).sample()
df_right = Synthesizer(df[["i1", "f", "t2"]]).sample()
df_right = df_right.sample(n=len(df_left) - 10, random_state=42)
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=False)
assert set(df_left["i1"]) == set(df_stitched["i1"])
if set(df_left["i1"]) != set(df_right["i1"]):
assert set(df_right["i1"]) != set(df_stitched["i1"])
assert set(df_right["f"]) != set(df_stitched["f"])
assert len(df_left) == len(df_stitched)


def test_shared_stitch1() -> None:
df_left = Synthesizer(df[["i1", "i2", "t1"]]).sample()
df_right = Synthesizer(df[["i1", "f", "t2"]]).sample()
df_right = df_right.sample(n=len(df_left) - 10, random_state=42)
df_stitched = stitch(df_left=df_left, df_right=df_right, shared=True)
assert len(df_left) > len(df_stitched)
assert len(df_right) < len(df_stitched)

0 comments on commit a27d232

Please sign in to comment.