Source code for endogen.endogen

from .config import InputModel, ExogenModel, Differences, Lags, Rolling, Transform
from .variables import (
    Variable,
    VariableLag,
    VariableRolling,
    VariableDifference,
    VariableTransform,
)
from .tools import measure, flatten, flatten_recursive
from .data_utilities import (
    read_input_data,
    drop_superfluous,
    drop_missing_units,
    generate_comparison_report,
)

from .utilities import PanelUnits
from .adapter_mlforecast import forecast_mlforecast

import xarray
import pandas as pd
import numpy as np
from dataclasses import dataclass, field
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import collections

from sklearn.base import BaseEstimator
from sklearn.exceptions import NotFittedError

from mlforecast.forecast import MLForecast
from mlforecast.utils import PredictionIntervals

import os

from formulae import design_matrices
from typing import Optional, Sequence, Iterable, Any, Mapping, Tuple, Literal

import logging

log = logging.getLogger(__name__)


[docs] @dataclass(frozen=True) class ModelSchedule: delta_t: int schedule: Iterable[str | Iterable[str]]
[docs] class ModelController: """A controller for organizing and scheduling models.""" def __init__(self): self._models = [] self._graph = nx.DiGraph()
[docs] def add_models(self, models: InputModel | ExogenModel | Sequence[InputModel|ExogenModel]) -> None: """Adds a model to the system. Parameters ---------- model : VariableModel Any model type supported by the VariableModel class. """ if not isinstance(models, Sequence): models = [models] output_vars = [m.output_var for m in self._models] new_models = [m for m in models if m.output_var not in output_vars] models_already_exists = [ m.output_var for m in models if m.output_var in output_vars ] if len(models_already_exists) > 0: log.warning( f"Models for variables: {models_already_exists} are already loaded in the system. Please remove before adding." ) self._models = [*self._models, *new_models] self._models_to_graph()
@property def models(self) -> Sequence[InputModel|ExogenModel]: return self._models def plot(self, path: str = None) -> None: fig = plt.figure() nx.draw(self._graph, with_labels=True, pos=nx.multipartite_layout(self._graph)) if path != None: plt.savefig(path) plt.close(fig) else: plt.show() plt.close(fig) def _models_to_graph(self) -> None: def test_for_cyclic_graph(graph, edges) -> None: T = graph.copy() for edge in edges: T.add_edge(*edge) assert nx.is_directed_acyclic_graph( T ), f"Adding edge: {edge} would introduce cyclic graphs." variable_types: Sequence[str] = ["lags", "differences", "rolling", "transforms"] input_models = [m for m in self.models if isinstance(m, InputModel)] exogen_models = [m for m in self.models if isinstance(m, ExogenModel)] variables = [] for var_type in variable_types: for model in input_models: variable_recipies = getattr(model, var_type) if len(variable_recipies) > 0: vars = [recip.get_variables() for recip in variable_recipies] variables.append(vars) variables: Sequence[Variable] = flatten_recursive([variables, input_models, exogen_models]) if set([var.subset for var in variables]) != {0, 1}: raise ValueError( f'Models must contain subsets 0 ("before forecast") and 1 ("after forecast"), and not any others.' ) # There cannot be cycles before the forecast or during the forecast (each must be a DAG) for subset in [0, 1]: test_edges = flatten_recursive( [var.edges for var in variables if var.subset == subset and not isinstance(var, ExogenModel)] ) test_for_cyclic_graph(self._graph, test_edges) edges = flatten_recursive([var.edges for var in variables if not isinstance(var, ExogenModel)]) self._graph.add_edges_from(edges) nodes = [var.node for var in variables] self._graph.add_nodes_from(nodes) self.prepare_nodes = [ var.node[0] for var in variables if not isinstance(var, InputModel) ] self.derived_nodes = [var.node[0] for var in variables if var.subset == 0] self.forecast_nodes = [var.node[0] for var in variables if var.subset == 1] # There cannot be any variables without a model. empty_nodes: Sequence[str] = [ n for n, v in self._graph.nodes(data=True) if v == {} ] if len(empty_nodes) > 0: raise ValueError(f"A model is missing for these variables: {empty_nodes}") # The complete graph should be completely cyclic (all nodes should have a cycle) # Perhaps put this into a try-except to catch the error more gracefully for node in list(self._graph.nodes): nx.find_cycle(self._graph, node) @property def _model_schedule(self) -> Tuple[ModelSchedule, ModelSchedule]: """Returns a schedule generator used to process variables in the correct order.""" t0: ModelSchedule = ModelSchedule( delta_t=0, schedule=list( nx.topological_generations(self._graph.subgraph(self.derived_nodes)), ), ) t1: ModelSchedule = ModelSchedule( delta_t=1, schedule=list( nx.topological_generations(self._graph.subgraph(self.forecast_nodes)), ), ) return t0, t1
[docs] @dataclass class EndogenousSystem: """An endogenous panel-data system of models/nodes with associated methods for correct scheduling of model forecasts. Parameters ---------- input_data : str or pandas.DataFrame Panel data (or path to data) that includes all variables required by the forecasting system (and possibly fitting of models). time_var : str The variable name indicating time in input_data. unit_var : str The variable name indicating units in input_data. nsim : int The number of independent simulations of the endogenous system. start: int The number on the same scale as time_var when forecasting should start. end : int The number on the same scale as time_var when forecasting should end. vars : Optional[Sequence[str]] A subset of variables in input_data. Defaults to all variables in input_data. include_past_n : Optional[int] How much of the past to include when fitting statistical models. """ input_data: str | os.PathLike | pd.DataFrame time_var: str unit_var: str nsim: int end: int vars: Optional[Sequence[str]] = field(default_factory=list) start: Optional[int] = None include_past_n: Optional[int] = None def __repr__(self): return f"EndogenousSystem({self._xa})" def __post_init__(self): self.pnames = PanelUnits(self.time_var, self.unit_var) # Read input-data self.input_data = read_input_data(self.input_data) # Use variables in the input-data unless specified if len(self.vars) == 0: self.vars = [ var for var in self.input_data.columns if var not in [self.time_var, self.unit_var] ] else: self.vars = [ var for var in self.vars if var not in [self.time_var, self.unit_var] ] if self.start == None: self.start = self.input_data[self.pnames.time_var].max() + 1 self._last_train = self.start - 1 data_to_xarray = self.input_data[ self.input_data[self.pnames.time_var] < self.start ] report = generate_comparison_report( data_to_xarray, time_var=self.pnames.time_var, unit_var=self.pnames.unit_var, alternative_time_comparison=self._last_train, ) self.missing_units = set().union( *report.loc[slice(self._last_train - self.include_past_n, self._last_train)][ "missing" ].tolist() ) if len(self.missing_units) > 0: log.warning( f"The following units were removed to attain a balanced dataset over {self.include_past_n} years: {self.missing_units}." ) data_to_xarray = data_to_xarray[ ~data_to_xarray[self.pnames.unit_var].isin(self.missing_units) ] data_to_xarray = drop_superfluous( data_to_xarray, time_var=self.pnames.time_var, unit_var=self.pnames.unit_var, alternative_time_comparison=self._last_train, ) data_to_xarray = drop_missing_units( data_to_xarray, time_var=self.pnames.time_var, unit_var=self.pnames.unit_var, alternative_time_comparison=self._last_train, ) data_to_xarray = data_to_xarray.rename(columns=self.pnames.to_dict()).set_index( self.pnames.internal_index ) self._past = data_to_xarray[self.vars].dropna().to_xarray() self._past = self._past.sel( ds=slice(self._last_train - self.include_past_n, self._last_train) ) del data_to_xarray # Initialize the model-controller self.models = ModelController() @classmethod def _make_container( cls, vars: Sequence[str], nsim: int, unit_index: pd.Index, time_index: pd.Index, ones: bool = False, ): nvar = len(vars) nunit = len(unit_index) ntime = len(time_index) if ones: arr = np.ones(shape=(nvar, ntime, nunit, nsim), dtype=np.float32) else: arr = np.zeros(shape=(nvar, ntime, nunit, nsim), dtype=np.float32) return xarray.DataArray( data=arr, dims=["vars", "ds", "unique_id", "sim"], coords={ "vars": vars, "ds": time_index, "unique_id": unit_index, }, ).to_dataset(dim="vars") def create_forecast_container(self): # To update self._past and self.vars with any transformations in self.models.models self.prepare_data() time_index = pd.Index( range(self._last_train - self.include_past_n, self.end), name="ds" ) unit_index, _ = self._past.indexes.values() self._xa = self._make_container( vars=self.vars, nsim=self.nsim, unit_index=unit_index, time_index=time_index, ) if isinstance(self.include_past_n, int) and self.include_past_n > 0: data_to_include = [] single_slice = self._past.sel(ds=slice(time_index.start, time_index.stop)) for _ in itertools.repeat(None, self.nsim): data_to_include.append(single_slice) data_to_include = xarray.concat(data_to_include, dim="sim").transpose( "ds", "unique_id", "sim" ) self.update_sim(data_to_include) def prepare_data(self): t0, t1 = self.models._model_schedule for node_schedule in t0.schedule: for node in node_schedule: if isinstance(node, str) and node in self.models.prepare_nodes: if "model" in self.models._graph.nodes[node]: self._past = xarray.merge( [ self._past, self.models._graph.nodes[node]["model"].calc(xd=self._past), ] ) # for node_schedule in t1.schedule: # for node in node_schedule: # if isinstance(node, str) and node in self.models.prepare_nodes: # self._past = xarray.merge( # [ # self._past, # self.models._graph.nodes[node]["model"].calc(xd=self._past), # ] # ) self._past = self._past.to_dataframe().dropna().to_xarray() self.vars = list(self._past.keys()) def fit_models(self): for model in self.models.models: if isinstance(model, ExogenModel): # Fitting the model is equvivalent to writing the data into all simulations df = read_input_data(model.exogen_data) df = df.rename(columns=self.pnames.to_dict()).set_index(self.pnames.internal_index)[model.output_var].to_xarray() df = df.sel(ds=self._xa.ds.values) self._xa[model.output_var][np.searchsorted(self._xa.ds.values, df.ds.values)] = df else: if isinstance(model.model, str): pass if isinstance(model.model, BaseEstimator): df = self._past.to_dataframe() y, X = df[model.output_var], df[model.input_vars] model.model.fit(X, y) if isinstance(model.model, MLForecast): data_variables = list( itertools.chain([model.output_var], model.input_vars) ) df = self._past.to_dataframe()[data_variables] df = df.rename(columns={model.output_var: "y"}) df.reset_index(inplace=True) model.model.fit( df, static_features=[], prediction_intervals=PredictionIntervals(n_windows=4, h=1), ) def simulate(self): levels = [5, 15, 25, 35, 45, 55, 65, 75, 85, 95] t0, t1 = self.models._model_schedule for t in range(self.start, self.end): ds_index = (self._xa.ds.values == t).nonzero()[0][0] for schedules in [t0, t1]: for node_schedule in schedules.schedule: for node in node_schedule: if any([isinstance(m, ExogenModel) for m in self.models.models if m.output_var == node]): continue # If model is Exogen, data is already in simulation from fitting. model = self.models._graph.nodes[node]["model"] input_vars = [ m.input_vars for m in self.models.models if m.output_var == node ] if len(input_vars) == 1: input_vars = input_vars[0] else: del input_vars match model: case VariableTransform(): self._xa[node][ds_index] = ( model.calc(xd=self._xa.sel(ds=t)) )[node] case VariableDifference() | VariableLag() | VariableRolling(): self._xa[node] = model.calc(xd=self._xa) case BaseEstimator(): for s in range(self.nsim): self._xa[node][ds_index, :, s] = model.predict( self._xa[input_vars].to_dataframe().loc[t, :, s] ) case MLForecast(): for s in range(self.nsim): self._xa[node][ ds_index, :, s ] = forecast_mlforecast( t, s, model, self._xa, self.pnames, node, input_vars, levels, ) case str(): for s in range(self.nsim): df = self._xa[input_vars].to_dataframe().loc[t,:,s] ind = df.index res = design_matrices(f'0 + {model}', df, na_action = "pass").common.as_dataframe() varname = res.columns[0] self._xa[node][ ds_index, :, s ] = res.rename(columns={varname: node}).set_index(ind)[node].to_xarray() case _: raise NotImplementedError( f"Model of type {type(model)} is not implemented." ) def update_sim(self, value: xarray.Dataset | xarray.DataArray): if isinstance(value, xarray.Dataset): vars = list(value.keys()) t_index = [ i for i, k in enumerate(self._xa.ds.values) if k in value.ds.values ] for var in vars: self._xa[var][t_index] = value[var] elif isinstance(value, xarray.DataArray): t_index = [ i for i, k in enumerate(self._xa.ds.values) if k in value.ds.values ] self._xa[value.name][t_index] = value else: raise ValueError( "Value must be either a xarray.Dataset or an xarray.DataArray object" )
[docs] def plot( self, var: str, unit: Optional[Sequence[int]] = None, start: Optional[int] = None, *args, **kwargs, ): """Plot method for historical and forecasted data. Parameters ---------- var : str Name of the variable you want to plot unit : Optional[Sequence[int]] List of subset of units you want to plot in facets. Plot will otherwise show global statistics. start : Optional[int] Alternative start time. Plot will otherwise start as early as possible with the data given. args : Other arguments to pass to seaborn.relplot kwargs : key, value pairings Dictionary of keyword arguments to pass to seaborn.relplot Returns ------- seaborn.FacetGrid An object managing one or more subplots that correspond to conditional data subsets with convenient methods for batch-setting of axes attributes. """ if start == None: start = self._past.coords["ds"].min() if unit == None: units = self._xa.coords["unique_id"] else: units = unit forecast = self._xa.sel(unique_id=units, ds=slice(self.start + 1, self.end))[ var ] past = self._past.sel(unique_id=units, ds=slice(start, self.start + 1))[var] forecast = forecast.to_dataframe() past = past.to_dataframe() forecast["type"] = "forecast" past["type"] = "historical" past["sim"] = 0 past = past.reset_index().set_index(["ds", "unique_id", "sim"]) forecast = forecast.reset_index().set_index(["ds", "unique_id", "sim"]) if unit == None: plot_object = sns.relplot( pd.concat([past, forecast]), x="ds", y=var, hue="type", errorbar="sd", kind="line", *args, **kwargs, ) else: plot_object = sns.relplot( pd.concat([past, forecast]), x="ds", y=var, col="unique_id", hue="type", errorbar="sd", kind="line", *args, **kwargs, ) for ax in plot_object.axes.flat: ax.axvline(x=self.start + 0.5, color="black", linestyle="--") return plot_object