Source code for causalexplain.gui.tabs.generate

"""Dataset generation tab for the causalexplain GUI."""

from __future__ import annotations

import asyncio
import os
import time
from typing import Any, Dict, Optional

import numpy as np

from causalexplain.common import DEFAULT_SEED
from causalexplain.common.plot import dag2dot
from causalexplain.generators.generators import AcyclicGraphGenerator
from causalexplain.gui.graph_utils import dag_is_valid
from causalexplain.gui.io_utils import sanitize_output_name
from causalexplain.gui.rendering import render_cytoscape_graph
from causalexplain.gui.ui_helpers import bind_setting


[docs] class GenerateTab: """Build and manage the Generate Dataset tab."""
[docs] def __init__( self, ui: Any, run: Any, storage: Any, settings: Dict[str, Any], notify: Any, ) -> None: """Initialize the generate tab with shared GUI dependencies.""" self.ui = ui self.run = run self.storage = storage self.settings = settings self.notify = notify self.generate_state: Dict[str, Any] = { "graph": None, "data": None, "adjacency": None, } self.generate_dag_container: Optional[Any] = None self.save_button: Optional[Any] = None self.output_dir_input: Optional[Any] = None self.output_name_input: Optional[Any] = None
[docs] def build(self) -> None: """Render the Generate Dataset tab UI.""" with self.ui.element("div").classes("section-card w-full"): self.ui.label("Generate Dataset").classes("section-title") with self.ui.element("div").classes("generate-grid w-full"): with self.ui.element("div").classes("nested-panel"): self.ui.label("Generation controls").classes("subtle") timeout_input = self.ui.number( "t (seconds)", value=self.settings.get("timeout_s", 30), ).props("dense").classes("w-full") retries_input = self.ui.number( "R (max retries)", value=self.settings.get("max_retries", 50), ).props("dense").classes("w-full") min_edges_input = self.ui.number( "Min edges", value=self.settings.get("min_edges", 0), ).props("dense").classes("w-full") max_edges_input = self.ui.number( "Max edges", value=self.settings.get("max_edges", 30), ).props("dense").classes("w-full") with self.ui.element("div").classes("nested-panel"): self.ui.label("Dataset parameters").classes("subtle") mechanism_select = self.ui.select( [ "linear", "polynomial", "sigmoid_add", "sigmoid_mix", "gp_add", "gp_mix", ], value=self.settings.get("mechanism", "linear"), label="Mechanism", ).classes("w-full") nodes_input = self.ui.number( "Variables", value=self.settings.get("nodes", 10), ).props("dense").classes("w-full") samples_input = self.ui.number( "Samples", value=self.settings.get("samples", 500), ).props("dense").classes("w-full") parents_input = self.ui.number( "Max parents", value=self.settings.get("max_parents", 3), ).props("dense").classes("w-full") gen_seed_input = self.ui.number( "Seed", value=self.settings.get("seed", DEFAULT_SEED), ).props("dense").classes("w-full") rescale_switch = self.ui.switch( "Rescale", value=self.settings.get("rescale", True), ) with self.ui.element("div").classes("generate-actions w-full"): self.ui.element("div").classes("generate-spacer") self.ui.button( "Generate", on_click=lambda: asyncio.create_task(self.run_generate()), ).classes("w-full") self.generate_dag_container = self.ui.element("div").classes( "dag-frame w-full" ) render_cytoscape_graph(self.generate_dag_container, None) with self.ui.element("div").classes("save-row w-full"): self.output_dir_input = self.ui.input( "Output directory", value=self.settings.get("output_dir", ""), ).props("dense").classes("w-full") self.output_name_input = self.ui.input( "Dataset name", value=self.settings.get("output_name", "generated_dataset"), ).props("dense").classes("w-full") self.save_button = self.ui.button( "SAVE", on_click=lambda: asyncio.create_task( self.run_save_generated() ), ).classes("w-full") self.save_button.disable() bind_setting( mechanism_select, self.storage, "generate_settings", self.settings, "mechanism", ) bind_setting( timeout_input, self.storage, "generate_settings", self.settings, "timeout_s", ) bind_setting( retries_input, self.storage, "generate_settings", self.settings, "max_retries", ) bind_setting( min_edges_input, self.storage, "generate_settings", self.settings, "min_edges", ) bind_setting( max_edges_input, self.storage, "generate_settings", self.settings, "max_edges", ) bind_setting( nodes_input, self.storage, "generate_settings", self.settings, "nodes", ) bind_setting( samples_input, self.storage, "generate_settings", self.settings, "samples", ) bind_setting( parents_input, self.storage, "generate_settings", self.settings, "max_parents", ) bind_setting( gen_seed_input, self.storage, "generate_settings", self.settings, "seed", ) bind_setting( rescale_switch, self.storage, "generate_settings", self.settings, "rescale", ) bind_setting( self.output_dir_input, self.storage, "generate_settings", self.settings, "output_dir", ) bind_setting( self.output_name_input, self.storage, "generate_settings", self.settings, "output_name", )
def _generate_job(self, settings: Dict[str, Any]) -> Dict[str, Any]: """Generate a synthetic dataset in a worker thread.""" np.random.seed(int(settings.get("seed", DEFAULT_SEED))) timeout_s = max(0.0, float(settings.get("timeout_s", 30))) max_retries = int(settings.get("max_retries", 50)) min_edges = int(settings.get("min_edges", 0)) max_edges = int(settings.get("max_edges", 30)) if min_edges > max_edges: raise ValueError("min_edges must be less than or equal to max_edges.") if max_retries < 0: max_retries = 0 max_attempts = max_retries + 1 start_time = time.monotonic() attempt = 0 while attempt < max_attempts: if (time.monotonic() - start_time) >= timeout_s: break attempt += 1 generator = AcyclicGraphGenerator( settings["mechanism"], points=int(settings["samples"]), nodes=int(settings["nodes"]), parents_max=int(settings["max_parents"]), verbose=False, ) graph, data = generator.generate(rescale=bool(settings["rescale"])) if not dag_is_valid(graph, min_edges, max_edges): continue return { "graph": graph, "data": data, "adjacency": generator.adjacency_matrix, } elapsed = time.monotonic() - start_time if elapsed >= timeout_s: raise TimeoutError("Timeout reached before a valid DAG was found.") raise ValueError("No valid DAG found within the retry limit.")
[docs] async def run_generate(self) -> None: """Generate a dataset and update UI widgets.""" try: result = await self.run.io_bound(self._generate_job, self.settings) except Exception as exc: self.notify(f"Error: {str(exc)}", "negative") if self.save_button is not None and self.generate_state.get("graph") is None: self.save_button.disable() return self.generate_state.update( { "graph": result.get("graph"), "data": result.get("data"), "adjacency": result.get("adjacency"), } ) render_cytoscape_graph( self.generate_dag_container, self.generate_state["graph"] ) if self.save_button is not None: self.save_button.enable() self.notify("Generation completed.", "positive")
def _save_job( self, payload: Dict[str, Any], output_dir: str, output_name: str ) -> None: """Persist the generated dataset to disk.""" graph = payload.get("graph") data = payload.get("data") adjacency = payload.get("adjacency") if graph is None or data is None or adjacency is None: raise ValueError("Generate a dataset first.") output_dir = output_dir.strip() output_name = sanitize_output_name(output_name) if not output_dir: raise ValueError("Output directory is required.") if not output_name: raise ValueError("Dataset name is required.") os.makedirs(output_dir, exist_ok=True) output_base = os.path.join(output_dir, output_name) data.to_csv(f"{output_base}.csv", index=False) dot_obj = dag2dot(graph) if dot_obj is None: raise ValueError("Unable to save a DAG with no edges.") graph_dot_format = dot_obj.to_string() graph_dot_format = f"strict {graph_dot_format[:-9]}\n}}" with open(f"{output_base}.dot", "w") as handle: handle.write(graph_dot_format)
[docs] async def run_save_generated(self) -> None: """Save the generated dataset to disk.""" output_dir = self.settings.get("output_dir", "") output_name = self.settings.get("output_name", "") try: await self.run.io_bound( self._save_job, self.generate_state, output_dir, output_name, ) except Exception as exc: self.notify(f"Error: {str(exc)}", "negative") return self.notify("Dataset saved.", "positive")