Skip to content

Commit

Permalink
use cluster number instead of new iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
singlesp committed Feb 5, 2025
1 parent c4dc38c commit febfe4b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 40 deletions.
87 changes: 62 additions & 25 deletions cubids/tests/test_variant_numbering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,73 @@
@pytest.fixture
def sample_summary_df():
"""Create a sample summary DataFrame with multiple variant groups."""
return pd.DataFrame({
df = pd.DataFrame({
"EntitySet": [
"datatype-dwi_suffix-dwi", # Dominant group
"datatype-dwi_suffix-dwi", # EchoTime variant 1
"datatype-dwi_suffix-dwi", # EchoTime variant 2
"datatype-dwi_suffix-dwi", # RepetitionTime variant
"datatype-dwi_suffix-dwi", # EchoTime + RepetitionTime variant
"datatype-dwi_suffix-dwi", # Other variant
"datatype-dwi_suffix-dwi", # EchoTime variant cluster 1
"datatype-dwi_suffix-dwi", # EchoTime variant cluster 2
"datatype-dwi_suffix-dwi", # RepetitionTime variant cluster 1
"datatype-dwi_suffix-dwi", # Combined variant
],
"ParamGroup": [1, 2, 3, 4, 5, 6],
"EchoTime": ["0.05", "0.03", "0.07", "0.05", "0.03", "0.05"],
"RepetitionTime": ["2.5", "2.5", "2.5", "3.0", "3.0", "2.5"],
"FlipAngle": ["90", "90", "90", "90", "90", np.nan],
"ParamGroup": [1, 2, 3, 4, 5],
"EchoTime": ["0.05", "0.03", "0.07", "0.05", "0.03"],
"RepetitionTime": ["2.5", "2.5", "2.5", "3.0", "3.0"],
})
# Add cluster columns
df["Cluster_EchoTime"] = [1, 2, 3, 1, 2]
df["Cluster_RepetitionTime"] = [1, 1, 1, 2, 2]
return df

def test_variant_numbering_with_clusters(sample_summary_df):
"""Test variant numbering using cluster values."""
rename_cols = ["EchoTime", "RepetitionTime"]
result = assign_variants(sample_summary_df, rename_cols)

# Check that dominant group has no rename
assert pd.isna(result.loc[0, "RenameEntitySet"])

# Check EchoTime variants use cluster numbers
assert "VARIANTEchoTime2" in result.loc[1, "RenameEntitySet"]
assert "VARIANTEchoTime3" in result.loc[2, "RenameEntitySet"]

# Check RepetitionTime variant uses cluster number
assert "VARIANTRepetitionTime2" in result.loc[3, "RenameEntitySet"]

# Check combined variant uses both cluster numbers
assert "VARIANTEchoTime2RepetitionTime2" in result.loc[4, "RenameEntitySet"]

def test_variant_numbering_mixed_clustering():
"""Test variant numbering with mix of clustered and non-clustered parameters."""
df = pd.DataFrame({
"EntitySet": ["datatype-dwi_suffix-dwi"] * 3,
"ParamGroup": [1, 2, 3],
"EchoTime": ["0.05", "0.03", "0.07"],
"FlipAngle": ["90", "45", "90"],
"Cluster_EchoTime": [1, 2, 3],
})

result = assign_variants(df, ["EchoTime", "FlipAngle"])

# Check that clustered parameter uses cluster number
assert "VARIANTEchoTime2" in result.loc[1, "RenameEntitySet"]
# Check that non-clustered parameter appears without number
assert "FlipAngle" in result.loc[1, "RenameEntitySet"]

def test_variant_numbering_fieldmap():
"""Test variant numbering with fieldmap-related variants."""
df = pd.DataFrame({
"EntitySet": ["datatype-dwi_suffix-dwi"] * 3,
"ParamGroup": [1, 2, 3],
"HasFieldmap": ["True", "False", "False"],
"UsedAsFieldmap": ["False", "True", "False"],
})

result = assign_variants(df, ["HasFieldmap", "UsedAsFieldmap"])

# Check fieldmap variant naming (these don't use clusters)
assert "VARIANTNoFmap" in result.loc[1, "RenameEntitySet"]
assert "VARIANTIsUsed" in result.loc[1, "RenameEntitySet"]

def test_variant_numbering_basic(sample_summary_df):
"""Test basic variant numbering functionality."""
rename_cols = ["EchoTime", "RepetitionTime", "FlipAngle"]
Expand Down Expand Up @@ -81,21 +133,6 @@ def test_variant_numbering_acquisition_handling():
# Check that acquisition entity is properly handled
assert "acquisition-baseVARIANTEchoTime1" in result.loc[1, "RenameEntitySet"]

def test_variant_numbering_fieldmap():
"""Test variant numbering with fieldmap-related variants."""
df = pd.DataFrame({
"EntitySet": ["datatype-dwi_suffix-dwi"] * 3,
"ParamGroup": [1, 2, 3],
"HasFieldmap": ["True", "False", "False"],
"UsedAsFieldmap": ["False", "True", "False"],
})

result = assign_variants(df, ["HasFieldmap", "UsedAsFieldmap"])

# Check fieldmap variant naming
assert "VARIANTNoFmap1" in result.loc[1, "RenameEntitySet"]
assert "VARIANTIsUsed1" in result.loc[1, "RenameEntitySet"]

def test_variant_numbering_consistency():
"""Test that variant numbering is consistent across multiple runs."""
df = pd.DataFrame({
Expand Down
28 changes: 13 additions & 15 deletions cubids/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,6 @@ def assign_variants(summary, rename_cols):
The updated summary DataFrame with a new column "RenameEntitySet"
containing the new entity set names for each file.
"""
# Track variant counts for each parameter
variant_counts = {}

# loop through summary tsv and create dom_dict
dom_dict = {}
Expand Down Expand Up @@ -799,6 +797,7 @@ def assign_variants(summary, rename_cols):

if summary.loc[row, "ParamGroup"] != 1 and not renamed:
variant_params = []
variant_clusters = []
# now we know we have a deviant param group
entity_set = summary.loc[row, "EntitySet"]
for col in rename_cols:
Expand All @@ -808,32 +807,31 @@ def assign_variants(summary, rename_cols):
if f"Cluster_{col}" in dom_entity_set.keys():
if summary.loc[row, f"Cluster_{col}"] != dom_entity_set[f"Cluster_{col}"]:
variant_params.append(col)
variant_clusters.append(str(summary.loc[row, f"Cluster_{col}"]))
elif summary.loc[row, col] != dom_entity_set[col]:
if col == "HasFieldmap":
variant_params.append("NoFmap" if dom_entity_set[col] == "True" else "HasFmap")
elif col == "UsedAsFieldmap":
variant_params.append("Unused" if dom_entity_set[col] == "True" else "IsUsed")
else:
variant_params.append(col)
variant_clusters.append("")

# Sort params to ensure consistent ordering
variant_params.sort()
sorted_pairs = sorted(zip(variant_params, variant_clusters))
variant_params = [p for p, _ in sorted_pairs]
variant_clusters = [c for _, c in sorted_pairs]

# Create variant string
if variant_params:
variant_key = "".join(variant_params)
if variant_key not in variant_counts:
variant_counts[variant_key] = 1
else:
variant_counts[variant_key] += 1

acq_str = f"VARIANT{''.join(variant_params)}{variant_counts[variant_key]}"
variant_str = "VARIANT"
for param, cluster in zip(variant_params, variant_clusters):
variant_str += param
if cluster:
variant_str += cluster
acq_str = variant_str
else:
if "Other" not in variant_counts:
variant_counts["Other"] = 1
else:
variant_counts["Other"] += 1
acq_str = f"VARIANTOther{variant_counts['Other']}"
acq_str = "VARIANTOther"

if "acquisition" in entities.keys():
acq = f"acquisition-{entities['acquisition'] + acq_str}"
Expand Down

0 comments on commit febfe4b

Please sign in to comment.