1 Introduction

Learning the dynamics of a system is crucial in scientific and engineering domains. Traditionally, dynamics have been expressed using symbolic forms like equations, programs, and logic, which provide explicit interpretability and generality [1,2,3]. However, the increasing volume of data, noise, ambiguity, and dynamics complexity pose challenges in learning and extracting rules from large-scale data [4,5,6,7,8].

To tackle this problem, there have been attempts to combine symbolic methods with deep learning to learn dynamics from large data [9, 10]. While some dynamics can be effectively expressed using quantitative relationships, such as classical mechanics and electromagnetism, others are better represented by state transition rules, such as Boolean networks (BNs) [11] and Cellular Automata (CA) [12]. Symbolic methods have shown success in learning these types of dynamics [13,14,15].

Among the proposed methods, the restricted Boltzmann Machine (RBM) has been utilized to learn latent factors that capture the essence of dynamic [16, 17]. RBM, a probabilistic neural network, treats observable data as visible variables and latent features as hidden variables. It allows the expression of symbolic representations, such as propositional formulas, learned through maximum likelihood estimation of the RBM energy function [18,19,20]. However, existing RBM-based approaches have not extensively explored the learning of hidden representations from raw data, such as images. Furthermore, the prediction of the dynamics has been made as a black box process within the network, which limits interpretability.

To address these limitations and challenges, this work proposes a method that extracts a small number of essential factors as hidden variables and expresses their relationships as interpretable rules. Specifically, we introduce the recurrent temporal Gaussian-Bernoulli restricted Boltzmann Machine (RTGB-RBM), which combines the which combines Gaussian-Bernoulli restricted Boltzmann Machine (GB-RBM) [21] for handling continuous visible variables and the recurrent temporal restricted Boltzmann Machine (RT-RBM) [22, 23] for capturing time dependence between discrete latent variables. RTGB-RBM can effectively learn the dynamics by representing them as state transitions between hidden layers and reconstructing the original dynamics from the transitions in the hidden layers. Additionally, we extract interpretable state transition rules from the trained RTGB-RBM, providing insight into the learned dynamics.

The contributions of this work include the proposed RTGB-RBM, which captures the time dependence of continuous visible and discrete hidden variables, and the method for extracting interpretable state transition rules between hidden variables. Furthermore, our method achieves predictive performance comparable to existing methods for various dynamics and enables interpretable rule learning.

This paper is an extended version of our conference paper [24], where we have expanded the rule extraction and pruning methods, conducted additional experimental analyses using new datasets, and provided an updated discussion of related work. The remainder of this paper is organized as follows: Sect. 2 presents related works on Boltzmann Machines, dynamics learning, and rule learning and compares them with our approach. Section 3 provides an overview of GB-RBM and RT-RBM, including their definitions and algorithms. In Sect. 4, we describe our proposed method, RTGB-RBM, along with its learning technique, and explain the process of extracting rules between hidden layers from the trained model. Section 5 evaluates the proposed method using the Bouncing Ball, Moving MNIST, and dSprites datasets. We compare the predictions of RTGB-RBM with those obtained from existing methods and demonstrate the superior performance of our proposed method. We also introduce a method for reducing the model size without compromising prediction performance by pruning unimportant nodes based on rule importance. Finally, Sect. 6 concludes the paper by summarizing the main contributions, discussing the limitations of our approach, and outlining potential avenues for future work.

2 Related Work

This Section discusses details of related work on Boltzmann Machines, dynamics learning, rule learning.

2.1 Boltzmann Machines

Boltzmann Machine (BM) is a probabilistic neural network and can be considered a type of Hopfield network using a statistical variation [25]. It possesses fascinating theoretical aspects due to its Hebbian rule of the learning algorithm and its connections to physics. However, the original BM suffers from exponential training time, rendering it impractical for large-scale real-world datasets. To address this, the restricted Boltzmann Machine (RBM) [16] was proposed, which prohibits connections between same layers. Additionally, an efficient training method called Contrastive Divergence (CD) [17] has been developed specifically for RBM. RBM consists of a visible layer that handles observed data and a hidden layer that handles latent features.

Several extended versions of RBM have been proposed, including the Gaussian-Bernoulli restricted Boltzmann Machine (GB-RBM) [21], capable of handling both real-number and binary values, and the Recurrent temporal restricted Boltzmann Machine (RT-RBM) [22, 23], designed for time-series data. By stacking multiple layers of RBM, the learning of hidden units becomes more efficient, leading to the emergence of Deep Boltzmann Machine (DBM) [26], which serve as the foundation for modern deep learning approaches. Furthermore, through persistent efforts, RBM has achieved expressiveness comparable to modern generative models such as Variational Autoencoder (VAE) and Generative Adversarial Network (GAN) [27,28,29]. In the realm of symbolic computation, such as Knowledge Representation and Reasoning (KRR), RBM has been utilized to learn symbolic representations. For instance, [18, 19] express propositional formulas using visible and hidden variables, inferring symbolicnowledge through maximum likelihood estimation of the RBM’s energy function. The Logical Boltzmann Machine (LBM) [20] converts propositional formulas in DNF into RBMs to achieve efficient reasoning and establishes a connection between minimizing the energy of RBM and Boolean formula satisfiability. Notably, [18] applies these RBM-based symbolic knowledge extraction methods to images, bridging the gap between symbolic representations and real-world data. However, these approaches have not yet been applied to time-series data, such as videos, and there are still many aspects of learning and prediction that remain black box in nature.

2.2 Dynamics Learning

Many researchers have proposed various methods for learning dynamics, with recent years witnessing a surge in techniques based on deep learning. For instance, [30, 31] employ neural networks to learn visual information from dynamics and make video predictions. Likewise, [32] utilizes audio data to reconstruct original audio and perform speaker recognition. Many of these methods are based on Variational Auto Encoder (VAE) [27], a generative model similar to Boltzmann Machines. These approaches extract hidden representations from inputs and leverage them to reconstruct dynamics, proving highly effective in learning dynamics. However, these methods often operate as black boxes, learning input–output relationships without providing insights into the internal workings of the networks. As a result, understanding important factors for predicting and reconstructing dynamics can be challenging. Several VAE-based methods have been proposed to address this, aiming to disentangle hidden variable dimensions and maximize the separation of independent information within the original data [33]. Additionally, [34] disentangles a latent representation into static and dynamic parts, while [35] treats the latent representation as a categorical condition to control the output. Disentanglement allows for a better grasp of the meaning behind each latent variable. Nevertheless, explicitly expressing these relationships as equations or rules remains challenging. Our method offers a unique contribution by acquiring interpretable rules that capture the dynamic nature and state transitions of latent information. This gives us deeper insights into the internal mechanisms of the network.

2.3 Rule Learning

While our main objective is to propose methods for learning dynamics, it is equally crucial to represent dynamics in an interpretable form. Symbolic methods, such as Learning from interpretation transition (LFIT) [15], have successfully learned dynamics in Boolean networks (BNs) [11] and Cellular Automata (CA) [12]. LFIT is an unsupervised learning algorithm that derives rules expressing dynamic relationships between observable variables from state transitions. However, LFIT is unable to handle dynamical systems that involve unobservable hidden or latent variables. Previous studies, including LFIT, have focused on describing dynamics based on observed variables. To address this limitation, some methods have combined LFIT with neural networks (NNs) to enhance robustness to noisy data and continuous variables. For example, NN-LFIT [36] extracts propositional logic rules from trained NNs, while D-LFIT [37] translates logic programs into embeddings, infers logical values through differentiable semantics, and utilizes optimization methods and NNs to search for embeddings. These frameworks learn state transition rules as normal logic programs (NLPs). It is crucial to consider the relationship between symbolic representation and computational complexity in terms of scalability [38].

Dealing with a large number of discrete variables, such as symbols, leads to an exponential increase in the number of possible combinations, posing computational complexity challenges. Therefore, it is essential to use computationally efficient algorithms, omit unnecessary inputs and redundant information, retain necessary information, and describe their relationships. Our proposed method addresses these challenges by simultaneously reducing input dimensionality, extracting essential information, and expressing relationships among the extracted information as rules. Despite having a large input dimension, our method mitigates combinatorial problems due to the reduced number of dimensions in the hidden layer.

3 Preliminaries

3.1 Gaussian-Bernoulli Restricted Boltzmann Machine

Gaussian-Bernoulli restricted Boltzmann Machine (GB-RBM) [21] is defined on a complete bipartite graph as shown in Fig. 1. The upper layer is the visible layer V consisting of only visible variables, and the lower layer is the hidden layer H consisting of only hidden variables, where V and H are the index of visible and hidden variables, respectively.

Fig. 1
figure 1

Graphical representation of GB-RBM

\({\textbf{v}} = \{v_i \in {\mathbb {R}} \ \mid \ i \in V \}\) represents real variables directly associated with the input–output data, and \({\textbf{h}} = \{h_j \in \{+1, -1\} \ \mid \ j \in H \}\) represents hidden variables of the system that is not directly associated with the input–output data and is a discrete variable taking binary values. \({\textbf{s}} = \{s_i \ \mid \ i \in V\}\) is the parameter associated with the variance of the visible variables. The energy function of the GB-RBM is defined as

$$\begin{aligned} E_{\theta }({\textbf{v}},{\textbf{h}})= & {} \sum _{\ i \in V} \frac{(v_i - b_i)^2}{2 s_i^2} + \sum _{\ i \in V} \sum _{\ j \in H} \frac{w_{ij}}{s_i^2}v_i h_j \nonumber \\{} & {} + \sum _{\ j \in H} c_j h_j \end{aligned}$$
(1)

Here, \({\textbf{b}} = \{b_i \ \mid \ i \in V \}\) and \({\textbf{c}} = \{c_j \ \mid \ j \in H\}\) are the bias parameters for the visible and hidden variables respectively. \({\textbf{w}} = \{w_{ij} \ \mid \ i \in V, j \in H\}\) is the set of parameters that relate between the visible and hidden variables. \({\textbf{s}} = \{s_i \ \mid \ i \in V \}\) is a variance of the visible variables. These model parameters are collectively denoted by \(\theta = \{{\textbf{W}},{\textbf{b}}, {\textbf{c}} \}\). Using the energy function in (1), the joint probability distribution of \({\textbf{v}},{\textbf{h}}\) is defined as

$$\begin{aligned} \begin{aligned} P_{\theta }({\textbf{v}},{\textbf{h}})&= \frac{1}{Z(\theta )} \text{ exp } \big (-E_\theta {\textbf{v}},{\textbf{h}}) \big ) \\ Z(\theta )&= \int _{-\infty }^{+\infty } \sum _{{\textbf{h}}} \text{ exp }\big (-E_\theta ({\textbf{v}},{\textbf{h}}) \big ) d{\textbf{v}} \end{aligned} \end{aligned}$$
(2)

\(Z_\theta\) is a partition function, \(\int _{-\infty }^{+\infty }...d{\textbf{v}}\) represents the multiple integral with respect to \({\textbf{v}}\), \(\sum _{{\textbf{h}}}\) represents the multiple sum over all possible combinations of \({\textbf{h}}\). The conditional probability distributions of \({\textbf{v}}\) and \({\textbf{h}}\), given \({\textbf{h}}\) and \({\textbf{v}}\), are respectively

$$\begin{aligned} P_{\theta }(v_i = v \ \mid \ {\textbf{h}})= & {} {\mathcal {N}} \left( v_{i} \ \Bigm \vert \ b_i + \sum _j w_{ji} h_{j}, s_i^2 \right) \end{aligned}$$
(3)
$$\begin{aligned} P_{\theta }(h_j = 1 \ \mid \ {\textbf{v}})= & {} \frac{\exp \left( c_j +\sum _i w_{ij}v_i \right) }{2 \text{ cosh } \left( c_j + \sum _i w_{ij}v_j\right) } \end{aligned}$$
(4)

By (3), given \({\textbf{h}}\), the probability of \(v_{i} = v\) is calculated, and we can sample \(v_{i}\) from the probability. By (4), given \({\textbf{v}}\), the probability of \(h_{j}=1\) is calculated, and we can sample \(h_{j}\) from the probability.

3.2 Recurrent Temporal Restricted Boltzmann Machine

A recurrent temporal restricted Boltzmann Machine (RT-RBM) is an extension of RBM [22, 23] and is suitable for handling time series data. The RT-RBM has a structure with connections from the set of hidden variables from the past k frames \(\{ {\textbf{h}}_{t-k},{\textbf{h}}_{t-k+1},...,{\textbf{h}}_{t-1}\}\) to the current visible variables \({\textbf{v}}_t\) and hidden variables \({\textbf{h}}_t\). This paper assumes that the state at time t depends only on one previous time state \(t-1\) and fix \(k=1\). The RT-RBM for \(k=1\) is shown in Fig. 2. In RT-RBM, in addition to parameters \({\textbf{W}}\), \({\textbf{b}}\) and \({\textbf{c}}\), which are defined in Sect. 3.1, we newly define \({\textbf{U}} = \{u_{jj'} \ \mid \ j,j' \in H\}\), which is the set of parameters that relate between the hidden variables at time t and \(t-1\). These model parameters are collectively denoted by \(\theta = \{{\textbf{W}}, {\textbf{U}}, {\textbf{b}}, {\textbf{c}}\}\). Then, the expected value of the hidden vector \(\hat{{\textbf{h}}}_t\) at time t is defined as,

$$\begin{aligned} \hat{{\textbf{h}}}_t = {\left\{ \begin{array}{ll} \sigma \left( {\textbf{W}}{\textbf{v}}_t + {\textbf{c}} + {\textbf{U}}\hat{{\textbf{h}}}_{t-1}\right) , \text{ if } t > 1 \\ \sigma \big ({\textbf{W}}{\textbf{v}}_t + {\textbf{c}} \big ), \ \ \ \ \ \text{ if } t = 1 \end{array}\right. } \end{aligned}$$
(5)

where \(\sigma\) is a sigmoid function \(\sigma (x) = (1 + \text{ exp }(-x))^{-1}\). Given \(\hat{{\textbf{h}}}_{t-1}\), The conditional probability distributions of \(v_{t,i}\) and \(h_{t,j}\) are inferred by

$$\begin{aligned} P_\theta (v_{t,i}=1 \ \mid \ {\textbf{h}}_{t}, \hat{{\textbf{h}}}_{t-1})= & {} \sigma \left( \sum _{j \in H} w_{ji} h_{t,j} + b_i \right) \end{aligned}$$
(6)
$$\begin{aligned} P_\theta (h_{t,j}=1 \ \mid \ {\textbf{v}}_{t}, \hat{{\textbf{h}}}_{t-1})= & {} \sigma \left( \sum _{i \in V} w_{ji} v_{t,i} + c_j \right. \nonumber \\{} & {} \left. + \sum _{j' \in H} u_{jj'} {\hat{h}}_{t-1,j'} \right) \end{aligned}$$
(7)

By (6), given \({\textbf{h}}_t\) and \(\varvec{{\hat{h}}}_t\), the probability of \(v_{t,i}=1\) is calculated, and we can sample \(v_{t,i}\) from the probability. By (7), given \({\textbf{v}}_t\) and \(\varvec{{\hat{h}}}_t\), the probability of \(h_{t,j}=1\) is calculated, and we can sample \(h_{t,j}\) from the probability.

Fig. 2
figure 2

Graphical representation of RT-RBM

4 Learning State Transition Rules from RBM

This section proposes a new method called recurrent temporal Gaussian-Bernoulli restricted Boltzmann Machine (RTGB-RBM), which integrates GB-RBM (to handle continuous visible variables) with RT-RBM (to capture time dependencies between discrete hidden variables). The RTGB-RBM takes \(\{{\textbf{v}}_0,{\textbf{v}}_1,...,{\textbf{v}}_t \}\) as input and and predicts the future states \(\{ {\textbf{v}}_{t+1},{\textbf{v}}_{t+2},...,{\textbf{v}}_{T}\}\). In addition, we extract a set of state transition rules from the trained RTGB-RBM.

4.1 Recurrent Temporal Gaussian-Bernoulli RBM

This subsection describes the Recurrent Temporal Gaussian-Bernoulli Restricted Boltzmann Machine (RTGB-RBM). RTGB-RBM is defined by the set of parameters \(\theta = \{{\textbf{W}}, {\textbf{U}}, {\textbf{b}},{\textbf{c}}, {\textbf{s}}\}\). These parameters are introduced in Sect. 3.1 and 3.2. In RT-RBM, both visible and hidden variables take binary values, while in RTGB-RBM, visible variables take continuous values, and hidden variables take binary values. This difference makes it possible to handle data with a wide range of values, such as images, in the visible layer and to handle their features in the hidden layer. The difference between GB-RBM and RTGB-RBM is that GB-RBM cannot handle sequences, while RTGB-RBM can handle sequences in both visible and hidden layers by defining weights for transitions between hidden layers. By combining RT-RBM and GB-RBM, RTGB-RBM can learn time series data with a wide range of values, such as video and sound. \({\textbf{v}}_t\) and \({\textbf{h}}_t\) are inferenced by the following equations,

$$\begin{aligned} P_{\theta }(v_{t,i} \ \mid \ {\textbf{h}}_t,\ \hat{{\textbf{h}}}_{t-1})= & {} {\mathcal {N}} \left( v_{t,i} \ \Bigm \vert \ b_{i} \right. \nonumber \\{} & {} \left. + \sum _j w_{ji} h_{t,j} \,s_i^2 \right) \end{aligned}$$
(8)
$$\begin{aligned} P_{\theta }(h_{t,j}=1 \ \mid \ {\textbf{v}}_t,\ \hat{{\textbf{h}}}_{t-1})= & {} \sigma \left( \sum _i w_{ji} \frac{v_{t,i}}{s^2} + c_{j}\right. \nonumber \\{} & {} \left. + \sum _{j'}u_{jj'}{\hat{h}}_{t-1,j'} \right) \end{aligned}$$
(9)

\(\hat{{\textbf{h}}}_{t}\) is calculated by (5). By (8), given \({\textbf{h}}_t\) and \(\varvec{{\hat{h}}}_t\), the probability of \(v_{t,i}\) is calculated, and we can sample \(v_{t,i}\) from the probability. By (9), given \({\textbf{v}}_t\) and \(\varvec{{\hat{h}}}_t\), the probability of \(h_{t,j}=1\) is calculated, and we can sample \(h_{t,j}\) from the probability. The architecture of RTGB-RBM is shown in Fig. 3.

Fig. 3
figure 3

Graphical representation of RTGB-RBM. The values of visible variables are computed recursively, where the values of \({\textbf{v}}_t\) are a real numbers, and the values of \({\textbf{h}}_t\) are binary

4.2 Training

We update the parameters of RTGB-RBM so that the likelihood L is maximized.

$$\begin{aligned} L= & {} \prod _n^N \prod _t^T P_{\theta } \left( {\textbf{v}}_{t}^{(n)} \bigm \vert \ {\textbf{h}}_{t}^{(n)}, \hat{{\textbf{h}}}_{t-1}^{(n)} \right) \end{aligned}$$

The parameter \(\mathbf {\theta } = \{{\textbf{W}},{\textbf{U}}, {\textbf{b}}, {\textbf{c}}\}\) that maximizes the product of \({\textbf{v}}_t\) is estimated by the gradient method \(\mathbf {\theta } = \mathbf {\theta } + \frac{\partial log L}{\partial \theta }\). The gradients of each parameter are calculated as follows

$$\begin{aligned} \begin{aligned} \frac{\partial log L}{\partial w_{ij}}&= \langle \frac{v_{t,i} {\hat{h}}_{t,j}}{s_i^2} \rangle _{data} - \langle \frac{v_{t,i} {\hat{h}}_{t,j}}{s_i^2} \rangle _{model} \\ \frac{\partial log L}{\partial b_{i}}&= \langle v_{t,i} \rangle _{data} - \langle v_{t,i} \rangle _{model} \\ \frac{\partial log L}{\partial u_{jj'}}&= \langle {\hat{h}}_{t-1,j'}{\hat{h}}_{t,j} \rangle _{data} - \langle {\hat{h}}_{t-1,j'}{\hat{h}}_{t,j} \rangle _{model} \\ \frac{\partial log L}{\partial c_{j}}&= \langle {\hat{h}}_{t,j} \rangle _{data} - \langle {\hat{h}}_{t,j} \rangle _{model} \end{aligned} \end{aligned}$$
(10)

Learning \(w_{ij}\) and \(b_i\) increase the accuracy of reconstruction data in the visible layer from the hidden layer, and learning \(u_{jj'}\) and \(c_{j}\) increases the accuracy of predicting transitions. \(\langle x \rangle _{data}\) is the mean of x, \(\langle x \rangle _{model}\) is the expected value of x. To get the expected value \(\langle x \rangle _{model}\), all combinations of x must be considered. However, since the number of combinations grows exponentially, it is not easy to compute all of them. Therefore, we approximate \(\langle x \rangle _{model}\) using Gibbs sampling. Here we define \({\textbf{v}}_t(k)\) and \({\textbf{h}}_t(k)\) by repeating Gibbs sampling k times, and assume \({\textbf{h}}_{t-1}\) is given. Then, the expected values are approximated by the Algorithm 1.

figure a

By repeating Gibbs sampling, we get \({\textbf{v}}_{t}(k)\) and \({\textbf{h}}_{t}(k)\) from \({\textbf{v}}_{t}(0)\) and \({\textbf{h}}_{t}(0)\) as,

$$\begin{aligned} {\textbf{h}}_t(0) \rightarrow {\textbf{v}}_t(0) \rightarrow {\textbf{h}}_t(1) \rightarrow {\textbf{v}}_t(1) \rightarrow ... \rightarrow {\textbf{h}}_t(k) \rightarrow {\textbf{v}}_t(k) \end{aligned}$$

If we take k large enough, we can approximate the expected values as \(\langle {\textbf{v}}_{t} \rangle _{model} = {\textbf{v}}_t(k)\) \(\langle {\textbf{h}}_{t} \rangle _{model} = {\textbf{h}}_t(k)\). Although training using Gibbs sampling requires a large k, in this study, we use the Contrastive Divergence (CD) method [39] for efficient training. In the CD method, we set \({\textbf{v}}_{t}(0)={\textbf{v}}_t^{(n)}\) instead of randomly initializing \({\textbf{v}}_{t}(0)\), which makes it possible to train networks even when k is small.

4.3 Extracting State Transition Rules

We extract state transition rules from the trained RTGB-RBM.

Definition

(State Transition Rule) A state transition rule describes the transition from the state at time t to the state at \(t+1\) by a rule.

$$\begin{aligned} p_j {:}{:} L_{t+1,j} \leftarrow L_{t,1} \wedge L_{t,2} \wedge L_{t,3} \wedge .... \wedge \ L_{t,m} \end{aligned}$$
(11)

where, m is the number of hidden variables in each hidden layer, \(L_{t,j} (1 \le j \le m)\) is a literal that represents a hidden variable \(h_{t,j}\) or its negation \(\lnot h_{t,j}\). \(L_{t+1}\) is the head of the rule and \(L_{t,1} \wedge L_{t,2} \wedge .... L_{t,m}\) are the body of the rule. \(p_j\) is the probability of occurring the j-th rule.

The rule (11) means “\(L_{t+1}\) become true with the probability of \(p_j\), if all of \(L_{t,1}, L_{t,2},.... L_{t,m}\) are true.” To extract rules of the form (11) from the trained RTGB-RBM, we convert the network parameters to rules by

$$\begin{aligned}&p_j{:}{:} h_{t+1,j} \leftarrow {\displaystyle \bigwedge _{j', u_{jj'} \ge 0}} h_{t,j'} \ \wedge \ {\displaystyle \bigwedge _{j',u_{jj'}<0}} \lnot h_{t,j'} \end{aligned}$$
(12)
$$\begin{aligned}&p_j = \left( 1+exp \Big (- {\displaystyle \sum _{j'} u_{jj'}} \Big ) \right) ^{-1} \end{aligned}$$
(13)

Suppose that the hidden unit \(h_{t+1,j}\) is connected to the hidden unit \(h_{t,j'}\) by the weight \(u_{jj'}\). Then, if \(u_{jj'} \ge 0\), the positive literal \(h_{t,j'}\) is added to the j-th rule. Conversely, if \(u_{jj'}<0\), the negative literal \(\lnot h_{t,j'}\) is added. The equation (13) shows that the larger \((-\sum _j u_{jj'})\) is, the higher probability of \(h_{t+1,j}\) being activated. Conversely, the less \((-\sum _j u_{jj'})\) is, the lower probability of \(h_{t+1,j}\) being activated. The extracted rules represent the temporal relationship between the hidden layer at time t and \(t+1\). State transitions in the hidden layer can be calculated using the extracted rules, and once the state of the hidden layer is determined, the state of the visible layer can be decoded using equation (8).

Example

Suppose we get a rule \(p_1{:}{:} h_{t+1,1}\leftarrow h_{t,1} \wedge h_{t,2} \wedge \lnot h_{t,3}\) as shown in Fig. 4. This rule represents that, if \(h_{t,1}=1, h_{t,2}=1, h_{t,3}=0\), then \(h_{t+1,1}\) will be 1 with the probability \(p_1\). The value of \(h_{t+1,1}\) is determined based on this rule. In the same way, \({\textbf{h}}_{t+1}\) can be determined by applying rules to \(h_{t+1,2}\) and \(h_{t+1,3}\) respectively. Then, by (8), \({\textbf{v}}_{t+1}\) is decoded from \({\textbf{h}}_{t+1}\).

Fig. 4
figure 4

An example of rule extraction by (12). In this case, \(u_{11}\) and \(u_{21}\) are greater than 0, so the positive literals \(h_1\) and \(h_2\) are added to the body of the rule. Conversely, \(u_{31}\) is less than 0, so the negative literal \(h_3\) is added

The overview of our method is illustrated in Fig. 5. We have two types of predictions: model-based predictions and rule-based predictions. Given the observed state sequence \(\{ {\textbf{v}}_{0},{\textbf{v}}_{1},...,{\textbf{v}}_{t}\}\) as input, the former predicts future states using RTGB-RBM and the latter predicts future states by interpretable rules (12) extracted from the trained RTGB-RBM.

Fig. 5
figure 5

Overview of our method. We extract state transition rules by converting the network parameters to rules format. Future states can be predicted directly using RTGB-RBM or by determining state transitions by rule

5 Experiments

We conduct experiments on three datasets, Bouncing Ball [40], Moving MNIST [41] and dSprite [33], to evaluate our proposed method. The number of sequences in the dataset (N), the number of frames in each sequence (T), and the number of pixels in each frame (n) are described in Table 1 for these three datasets. The nature of each dataset are different. For example, for Bouncing Ball, we only need to predict the shape, position, and velocity of a single-color ball. On the other hand, in Moving MNIST, we need to learn ten different shapes (0-9), which increases the difficulty of shape reconstruction. In addition, dSprite requires learning not only position and shape but also color, size, and rotation. Because of the different attributes and learning difficulty of each dataset, we changed the number of hidden variables and the number of training epochs for each experiment.

In each experiment, we first train the RTGB-RBM and then extract rules from the trained RTGB-RBM. Then, the hidden units with low contribution are pruned from the network, using the probability of the rule as a contribution to the prediction. Finally, we confirm that removing hidden units with low contribution did not reduce prediction accuracy much. The hyperparameters necessary for each experiment, such as the number of hidden variables and the number of sampling iterations, are described in the sub-sections corresponding to each experiment.

For evaluation of the predictions of our model, we use (14), which indicates an error between the prediction \(\{ \hat{{\textbf{v}}}_{T+1},...,\hat{{\textbf{v}}}_{T'} \}\) and ground-truth \(\{ {\textbf{v}}_{T+1},..., {\textbf{v}}_{T'} \}\), given \(\{{\textbf{v}}_{0},...,{\textbf{v}}_{T}\}\). N is the number of sequences.

$$\begin{aligned} \text{ Loss } = \frac{1}{N}\sum _{n=0}^N \left( \frac{1}{T'-T}\sum _{t=T+1}^{t=T'} \sum _{i \in V} \big (v_{t,i}^{(n)} - {\hat{v}}_{t,i}^{(n)} \big )^2 \right) \end{aligned}$$
(14)
Table 1 Description of Datasets. N is the number of sequences in the data set, T is the number of frames in each sequence, and n is the number of pixels in each frame

Here is the summarized configuration: Operating System: Ubuntu 18.04.6 LTS, Python Version: 3.9.7 (also tested with versions 3.6.9, 3.7.6, and 3.8.0), NVIDIA Driver Version: 470.57.02, CUDA Version: 11.4, GPUs: NVIDIA® RTX\(^{\textrm{TM}}\) A6000 and NVIDIA® RTX\(^{\textrm{TM}}\) A100.

5.1 Bouncing Ball

5.1.1 Setting

The Bouncing Ball dataset is generated by the neural physics engine (NPE) [40]. The dataset is a simulation of multiple balls moving around in a two-dimensional space surrounded by walls on all four sides. The number of balls, radius, color, speed, etc., can be changed. We experimented with two types of videos, with only one ball and with three balls. In both cases, we generated 10000 sequences, each containing 100 frames of 100x100 pixels. For evaluation, we set \(N = 10000\), \(T = 90\), and \(T' = 100\).

5.1.2 Training

We experimented with the one-ball and thre-ball cases. In both cases, the number of visible units is set to 10000, the number of hidden units to \(\{10,30,100\}\), and the number of CD iterations K to \(\{3,10,20\}\). Learning curves are shown in Fig. 6.

Fig. 6
figure 6

Learning curve of one-ball (top) and three-balls (bottom)

Fig. 7
figure 7

Ball state prediction (\({\textbf{h}} = 100\)). Three steps are given as input, and five steps are predicted. The top is the ground-truth, and the bottom is our prediction

Fig. 6 shows that the larger the number of hidden units, the better the prediction performance. It also shows that the three-ball case is harder to learn the dynamics than the one-ball case, and the number of CD iterations does not make a big difference, but less is better. RTGB-RBM is trained rapidly in the first epoch, and learning progresses slowly from the second epoch. The results show that the proposed method can learn the dynamics of bouncing balls in the early phase.

An example of the prediction by the trained RTGB-RBM is shown in Fig. 7. In the case of one-ball, our model can predict ball-to-wall collisions. In the case of three-ball, our model can also predict ball-to-ball collisions. This result shows that our model can learn the ball’s trajectory but also the concept of collision.

5.1.3 Relationship between rules and features

In this subsection, we show the relationship between the extracted rules and features which RTGB-RBM learned. To get a visual understanding of what the extracted hidden variables represent, we compute the feature map by applying the weight \({\textbf{W}}\) to each hidden unit by (15).

$$\begin{aligned} v_{t,i} = \sigma \left( \sum _{j \in H} w_{ij} h_{t,j} + b_i \right) \ \ ( i \in V) \end{aligned}$$
(15)

The feature maps of \({\textbf{v}}\) for the one-ball case calculated by (15) are shown in Fig. 8. These feature maps imply the ball’s position and direction. For example, the map in the top left corner represents that the ball is located near the center of the lower side. The middle map above represents the ball moving from left to right. By combining these features, our model predicts the trajectory of the ball.

Fig. 8
figure 8

Feature map for x1 ball (\({\textbf{h}}_t\) = 10)

We evaluate the rule (16) as an example. For simplicity, negative literals are removed here. Corresponding the extracted rules with the feature maps, we get Fig. 9. The rule (16) represents that if \(h_{t,0}, h_{t,1}, h_{t,2},h_{t,6} = 1\) and \(h_{t,3}, h_{t,4}, h_{t,5}, h_{t,8}, h_{t,9}=0\), then \(h_{t+1,3}\) become 1 with probability 0.8732. Figure 9 implies that \(h_{t,0},h_{t,2}, h_{t,6}\) represent features that are trying to move in the lower right direction, and \(h_{t,3}\) has a large value in the lower right corner.

$$\begin{aligned} 0.8732 {:}{:} h_{t+1,3} \leftarrow h_{t,0} \wedge h_{t,1} \wedge h_{t,2}\wedge h_{t,6} \end{aligned}$$
(16)
Fig. 9
figure 9

Transition rule on feature maps corresponding to (16) (\({\textbf{h}} = 10\))

In Fig. 9, the rule represents that the next feature map in the head is generated by combining the ball’s current position and direction, represented by the four features in the body. We extract such rules for all hidden variable state transitions. In Fig. 10, we show an example of predicting \({\textbf{v}}_{t+1}\) from \({\textbf{v}}_t\) by using extracted rules. If we apply the learned rules to \({\textbf{h}}_t=[1,1,1,0,0,0,1,0,0,0]\), we get \({\textbf{h}}_{t+1}=[0,1,0,1,0,0,1,0,0,0]\), then \({\textbf{v}}_{t+1}\) is decoded from \({\textbf{h}}_{t+1}\).

Fig. 10
figure 10

Decode \({\textbf{v}}_{t+1}\) from \({\textbf{h}}_{t+1}\) by applying learned transition rules to \({\textbf{h}}_t\)

Although our rules can represent temporal relationships between hidden layers, the rules can not include the relationship between the hidden and visible layers nor the relationship between visible units in the same layer. Therefore, when obtaining the state of the visible layer, the state must be decoded by Gibbs sampling based on the parameters of the network. Through this mechanism, the essential information of the dynamics is determined by interpretable rules, while the visual information is decoded by probabilistic inference.

5.1.4 Comparative Experiment

We compare RT-RBM, RTGB-RBM, and rule-based predictions. The learning curves of those three models are shown in Fig. 11. The results show that in the one-ball case, both RTGB-RBM and rule-based predictions have higher accuracy than RT-RBM. On the other hand, in the three-ball case, RT-RBM performs better than the others, but RTGB-RBM performs as well as RT-RBM. This result indicated that some information loss occurs when rules are extracted from trained RTGB-RBM.

The deterioration in rule-based prediction is more distinct in the three-ball case than in the one-ball case, probably because the dynamics in the three-ball case had more intrinsic information that could not be expressed as rules.

Fig. 11
figure 11

Learning curve of one-ball case (left) and three-ball case (right) (\({\textbf{h}} = 100\), \(K=3\))

Fig. 12 (left) shows the prediction of RTGB-RBM and prediction using rules when \({\textbf{h}} = 10\). In one-ball case, the trajectory of the ball is predicted by the rules and the RTGB-RBM. On the other hand, in the three-ball case, predictions by RTGB-RBM are increasingly far from ground truth, and predictions by rules are no longer on the same trajectory as grand truth. This result shows that rules \({\textbf{h}}=10\) are sufficient to predict the trajectory of the one-ball, but they are not expressive enough to predict the trajectory of the three-ball.

On the other hand, when \({\textbf{h}} = 10\), RTGB-RBM and rule-based predictions are improved, and they predict the ball’s trajectory and bounce in Fig. 12 (right). This result indicates that the prediction accuracy increases as the number of hidden units increases, even in rule-based prediction. As the number of hidden variables increases, more state transitions can be expressed by rules. This result indicates that some hidden units are needed to predict the dynamics.

Fig. 12
figure 12

Ground truth × 1 and × 3 ball trajectories (top), prediction of RTGB-RBM (middle), prediction using rules (bottom). (Right: \({\textbf{h}} = 10\), Left: \({\textbf{h}} = 100\))

The rule-based method has lower prediction accuracy than the other two methods, and it predicts the ball’s trajectory for the first five or six steps, gradually deviating from the ground truth. One possible reason is that since transitions based on rules occur with probability, noise and error increase as the state transitions forward, leading to incorrect predictions. Furthermore, it is difficult to describe the dynamics of multiple objects, such as three-ball, in our rule form. We guess the rules for each ball are not expressive enough, and we need to use rules that can represent interactions and relationships between balls. For example, the rules must represent each ball’s position, direction, and collision.

Although there are still limitations described above, our method can predict the state transition of the dynamics and learn rules between hidden layers that correspond to the feature maps. Using hidden variables can reduce the size of the state transition. While the original dynamics consist of many visible variables, this method can represent dynamics with few rules. Furthermore, since the rules are expressed in the form (11), they are interpretable, e.g., Fig. 9, yet rule-based predictions are comparable to RTGB-RBM predictions, e.g., Fig. 11 (left).

5.1.5 Network pruning

We evaluate each hidden unit’s contribution based on the \(p_j\) of the j-th rule and prune some units from the network with low \(p_j\). This process allows lightweight inference while preserving the information necessary to reconstruct the original dynamics. However, since some information loss is expected due to pruning, this subsection evaluates the relationship between the number of remaining hidden units and their reconstruction error.

Fig. 13
figure 13

Reconstruction loss for one-ball case (orange) and three-ball case (blue) as hidden units is pruned (K=3)

Fig. 13 shows the relationship between loss and the number of remaining hidden units when \({\textbf{h}}=100\). For example, the loss at \({\textbf{h}}=60\) represents the prediction loss by the network with 40 hidden units with low \(p_j\) removed. This graph shows that prediction performance does not decrease significantly even if about 90 out of 100 hidden units are removed. Figure 14 shows an example of prediction for the case of 70 and 90 hidden units pruned from 100 hidden units. The ball position and trajectory can be predicted well in both cases, although the prediction performance is dropped compared to the original network. This result shows that making state transitions interpretable by rules is also useful for network pruning.

Fig. 14
figure 14

Reconstructed images: (first line) ground truth, (second line) prediction of original RTGB-RBM, (third line) Prediction of RTGB-RBM with 70 units pruned, (fourth line) Prediction of RTGB-RBM with 90 units pruned

5.2 Moving MNIST

We evaluate our method on Moving MNIST [41], which is more complex than the Bouncing Ball dataset. Moving MNIST consists of sequences of MNIST digits. It contains 10, 000 sequences; each sequence contains 20 frames, which consists of 2 digits moving in 64x64 pixels. Originally, each pixel takes a value of 0 to 1, but to outline the digit more distinct, we threshold each pixel value at 0.1. In Bouncing Ball, there were ball-to-ball collisions and wall-to-ball collisions, but in Moving MNIST, there are only wall-to-digit collisions but no digit-to-digit collisions; the digits pass through each other. Therefore, we need to learn to reconstruct different digits, predict trajectories, wall-to-digit collisions, and overlaps.

The learning curves for RT-RBM, RTGB-RBM, and rule-based on moving MNIST are illustrated in Fig. 15. We trained networks with the first 5 frames as input and the remaining 15 frames as predictions, in 100 epochs. Figure 15 shows that RTGB-RBM prediction is better than RT-RBM. In addition, the rule-based method approached the accuracy of RTRGB-RBM and finally became more accurate than RT-RBM. This result indicates that as the network parameters are optimized through training, the expressiveness of the extracted rules is increased, and the loss of information due to extraction is reduced.

Fig. 15
figure 15

Learning curve on moving MNIST (\({\textbf{h}} = 1000\), \(K=3\))

An example of RTGB-RBM and rule-based prediction is shown in Fig. 16. There are numbers 2 and 9. Our methods reconstruct the different shapes of the numbers and predict their trajectories. When the digits overlap, the prediction becomes ambiguous, and it seems impossible to distinguish the digits from the ambiguous frame. Nevertheless, it can predict the trajectory after overlap. From the result, we guess the hidden layers contain some features to distinguish the digits, and the transitions between hidden layers also contain enough information to reconstruct their trajectories after overlap. In addition, although the rule-based prediction is less accurate than the RTGB-RBM prediction, it can predict trajectories because the essential information is preserved in extracted rules.

Fig. 16
figure 16

Prediction of RTGB-RBM and Rule-based method when \(h=1000\) and \(K=3\): (A) Ground truth, (B) RTGB-RBM, (C) Rule-based. Input Reconstruction is decoded by giving the observed data to the visible layer as input, and Future Prediction is decoded only from the information in the hidden layer

The relationship between network pruning and prediction performance is shown in Fig. 17. This graph shows that the predictions do not decrease significantly until about 400 hidden units are removed. This result indicates that about 600 hidden units are needed to adequately represent the dynamics of Moving MNIST. Figure 18 shows that the shape and trajectory of digits are well predicted with 600 hidden units, but the shape is not well predicted with 100 hidden units. These results indicate that the 100 hidden units still contain information for predicting trajectories but not preserving the shape. In other words, the pruned 900 hidden units contained information to preserve shape. By pruning the network based on rules, we can evaluate which units are effective for prediction.

Fig. 17
figure 17

Reconstruction loss for Moving MNIST as hidden units are pruned (K=3)

Fig. 18
figure 18

Reconstructed images: (first line) ground truth, (second line) prediction of original RTGB-RBM, (third line) Prediction of RTGB-RBM with 400 units pruned, (fourth line) Prediction of RTGB-RBM with 900 units pruned

5.3 dSprite

The dSprites dataset [33] was created to assess the disentangled properties of unsupervised learning methods. The dataset contains image sequences of 2D shapes (square, ellipse, heart) procedurally generated from six ground-truth independent latent factors. These factors are sprite color, shape, scale, rotation, and X and Y position.

To simplify the experiment, we modified the original dSprite to consist of 3 different shapes (square, oval, heart), 6 values for the scale, 40 values for orientation, and 4 colors (white, red, green, blue) and limited the motion to simple horizontal or vertical. For each sequence, the scale gradually increases or decreases with each frame. Similarly, after the initial random selection for orientation, subsequent frames rotated the shape clockwise or counterclockwise at each frame. The final data set consists of approximately a hundred thousand data points. Each sequence includes 8 frames, each frame consisting of 32x32 pixels. In this experiment, our method needs to learn multiple attributes (color, shape, size change, rotation, position, and direction of movement).

Fig. 19
figure 19

Learning curve on dSprite (\({\textbf{h}} = 1000\), \(K = 3\))

Learning curves for RTGBR-RBM and rule-based predictions are shown in Fig. 19. In dSprite, input values are not binary but take values 0-255, so RT-RBM, which can handle only discrete values, cannot be applied to dSprite. Figure 19 shows that our methods’ performance becomes stable after 150 epochs of training. We found that our method can learn multiple attributes from Fig. 20. Even the rule-based method can reconstruct and predict these properties. This result indicates that in the dynamics of dSprite, the attributes (color, shape, size change, rotation, position, and direction of movement), which are essential information, are preserved even in rule forms.

Fig. 20
figure 20

Prediction of RTGB-RBM and Rule-based method when \(h = 1000\) and \(K = 3\): (A) Ground truth, (B) RTGB-RBM, (C) Rule-based. Three frames are given as input, and the remaining five frames are predicted

The results of the network pruning are shown in Fig. 21. The graph shows that the prediction is relatively well even when around 900 hidden units are pruned, but the prediction deteriorates significantly when pruning more than 900 hidden units. Figure 22 shows the prediction for the rotating red ellipse. The results show that even with 100 hidden units, the color, shape, and rotation are successfully predicted, even though loss has increased by 10 times. Each frame is 32x32x3 dimensional data. After pruning, the network can predict the dSprite dynamics by 100 units, which is about 1/30th of the input dimension. This is probably because the importance of rules is higher for units that are essential to the prediction, so even if the less important hidden units are pruned, the network which is essential to the prediction remains. From these results, 100 remaining units and their transitions contain enough information to predict these attributes.

Fig. 21
figure 21

Reconstruction loss for Moving MNIST as the number of hidden units is reduced (K=3)

Fig. 22
figure 22

Reconstructed images: (first line) ground truth, (second line) prediction of original RTGB-RBM, (third line) Prediction of RTGB-RBM with 400 units pruned, (fourth line) Prediction of RTGB-RBM with 900 units pruned

5.4 Discussion

Our experimental results demonstrate the effectiveness of our method in learning system dynamics, offering several notable advantages.

Firstly, our approach allows for extracting interpretable rules that capture the underlying dynamics as state transition rules between hidden variables. This enhances our understanding of the system’s behavior and provides valuable insights into the relationships between variables. Additionally, our method excels in predicting unobserved future states based on observed state transitions, showcasing its ability to capture the temporal evolution of the system.

Another significant advantage of our method is its flexibility in handling both continuous and discrete values. Unlike many existing approaches that struggle with converting continuous data into discrete values, our method seamlessly integrates both data types, enabling a more comprehensive analysis of dynamics. This capability proves particularly valuable in real-world applications where data often exhibit mixed types.

Finally, our method offers a unique advantage by learning interpretable rules that effectively capture system dynamics. By combining interpretability, predictive accuracy, and flexibility in handling different data types, our approach opens new avenues for understanding and leveraging dynamics in various scientific and engineering domains.

Nevertheless, we also acknowledge certain limitations that warrant further attention. While our learned rules are interpretable, understanding the meaning behind each rule, particularly concerning attributes such as color or shape information, can be challenging. To address this, future work will focus on refining our rule-based method, including developing a more structured representation capable of handling multiple attributes simultaneously. Additionally, we aim to extend our rule form to incorporate visible and hidden variables and continuous and discrete values. Furthermore, we recognize the importance of disentangling the interpretations of each variable within the rules. Currently, our rule extraction process does not consider the weights of individual hidden units, treating units with both small and large weights as simple literals in the rules. To address this, we will explore using more expressive rules considering the weights, contributing to a more comprehensive and accurate understanding of the learned dynamics.

6 Conclusion

In this study, we proposed RTGB-RBM, a method combining GB-RBM and RT-RBM to handle continuous visible variables and time dependence between discrete hidden variables. We also introduced a rule-based approach to extract essential information as hidden variables and represent state transition rules in an interpretable form. Furthermore, we developed a network pruning method to evaluate the contribution of hidden units based on the rules, retaining only those units containing essential information about the dynamics.

Our experimental results demonstrated the effectiveness of our methods in predicting future states of Bouncing Ball, Moving MNIST, and dSprite datasets. Furthermore, by correlating the learned rules with the features represented by the hidden units, we discovered that these rules contain crucial information for determining the future state based on the current information. Additionally, through the network pruning process, we assessed the number of hidden units required to represent the dynamics.

Overall, our proposed method offers the potential to extract latent relationships as rules without the need for prior discretization or preprocessing. We plan to extend and refine our approach based on the insights gained from this study, conducting more comprehensive experiments on various dynamic datasets. Theoretical analysis will also be pursued to provide a deeper understanding of the effectiveness of our method.