Source code for causalexplain.models._columnar
import pandas as pd
import torch
from torch.utils.data import Dataset
[docs]
class ColumnsDataset(Dataset):
[docs]
def __init__(self, target_name, df: pd.DataFrame):
target = df.loc[:, target_name].values.reshape(-1, 1)
features = df.drop(target_name, axis=1).values
self.features = torch.tensor(features, dtype=torch.float32)
self.target = torch.tensor(target, dtype=torch.float32)
def __len__(self):
return len(self.target)
def __getitem__(self, idx):
return [self.features[idx], self.target[idx]]