Skip to content

Commit

Permalink
style: Refine docstrings and comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
T0217 committed Oct 1, 2024
1 parent e1f4a89 commit 80e4d9a
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/sdqc_check/causality/causal_analysis.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import warnings
from typing import Tuple

from abc import ABC
import numpy as np
import pandas as pd

import castle
from castle.algorithms import (
DirectLiNGAM,
GAE,
GOLEM,
GraNDAG,
Notears
DirectLiNGAM, GAE, GOLEM, GraNDAG, Notears
)
from castle.metrics import MetricsDAG

# Ignore warnings
warnings.filterwarnings('ignore')


class CausalAnalysis:
class CausalAnalysis(ABC):
"""
Causal analysis using various causal discovery algorithms.
Expand All @@ -37,6 +33,23 @@ class CausalAnalysis:
The type of device to use (default is 'cpu').
device_id : int, optional
The ID of the device to use (default is 0).
Attributes
----------
raw_data : pd.DataFrame
The input raw data for causal analysis.
synthetic_data : pd.DataFrame
The input synthetic data for causal analysis.
model_name : str
The name of the causal discovery model to use.
random_seed : int
The random seed for reproducibility.
device_type : str
The type of device to use.
device_id : int
The ID of the device to use.
modelList : list
The list of available causal discovery models.
"""

def __init__(
Expand All @@ -63,6 +76,9 @@ def __init__(
)

def compare_adjacency_matrices(self) -> None:
"""
Compare the adjacency matrices of raw and synthetic data using the specified causal discovery method.
"""
raw_matrix, synthetic_matrix = self.compute_causal_matrices()
mt = MetricsDAG(raw_matrix, synthetic_matrix)
return mt
Expand Down

0 comments on commit 80e4d9a

Please sign in to comment.