1 Introduction

Compared with currently popular neural networks, decision trees have several favorable properties, such as good interpretability and high computational efficiency. In many practical ML (Machine Learning) applications (Chen & Guestrin, 2016; Ke et al., 2017; Wu et al., 2020; Li et al., 2020), the decision tree has proven its worth and achieved great success.

A decision tree is a hierarchical structure for the supervised learning task, composed of internal nodes and leaf nodes. The parameters \(\theta\) of a decision tree can be divided into three parts: (1) internal parameters \(\omega\), which decide the direction of each instance \(\varvec{x}\) by the router function s; (2) leaf parameters \(\upsilon\), which are the prediction outputs of all leaves \(\mathcal {T}\); (3) architecture parameters \(\gamma\), which define the architecture of the tree. Then, an instance \(\varvec{x}\) is recursively directed to the left or right child of the internal node i by the router function \(s(\varvec{x}; \omega _i)\). When a leaf t is reached (i.e., \(\gamma _t = 0\)), the leaf parameter \(\upsilon _t\) is used as the prediction. Traditionally, splitting rules are used to learn the node (i.e., internal and leaf) parameters \((\omega, \upsilon )\), and stopping rules are used to learn the architecture parameters \(\gamma\).

Fig. 1
figure 1

Overview of a hard tree (a) and a soft tree (b). a The dashed lines indicate the pruned nodes (i.e., \(\gamma _i = 0\)). The path where the instance \(\varvec{x}\) is routed is shown in purple. As two parts of the internal parameters \(\omega _i\), \(w_i\) and \(b_i\) indicate the feature weight and the feature threshold respectively. b For the soft tree, the color transparency indicates the path probability

As shown in Fig. 1, decision trees can be divided into hard trees and soft trees. For the hard tree in Fig. 1a, the response at node i has the following recursive definition:

$$\begin{aligned} \displaystyle f_i (\varvec{x};\theta ) = {\left\{ \begin{array}{ll} \upsilon _i, &{} \text {if }i\hbox { is a leaf (i.e. }\gamma _i = 0) \\ f_{L(i)}(\varvec{x}; \theta ), &{} \text {if }s(\varvec{x}; \omega _i) = [1, 0]^T \\ f_{R(i)}(\varvec{x}; \theta ), &{} \text {if }s(\varvec{x}; \omega _i) = [0, 1]^T \end{array}\right. } \end{aligned}$$
(1)

where \(\gamma _i \in \{0, 1\}\), \(s: \varvec{x} \rightarrow \{[1, 0]^T, [0, 1]^T\}\) is an axis-parallel split with a one-hot vector \(\omega _i\), and L(i)/R(i) is the left/right child of node i. For example, CART (Breiman et al., 1984) greedily chooses the split feature and threshold by minimizing the Gini Index in the current node. Furthermore, top-down pre-pruning and minimal cost-complexity post-pruning (Breiman et al., 1984) are used as the stopping rules to improve the tree’s generalization.

The greedy-based splitting and stopping rules in hard trees inevitably have the disadvantage of building sub-optimal decision trees. Moreover, despite the low empirical error, decision trees are easily overfitted (Kotsiantis, 2013). To improve learning through end-to-end training with back-propagation, soft trees (Norouzi et al., 2015; Irsoy et al., 2014; Hehn et al., 2019) are proposed.

As shown in Fig. 1b, the soft tree relaxes the parameters \(\theta\) to be continuous, e.g., the discrete choices of \(\gamma\) (i.e., whether to prune) and s (i.e., which path to route). Thus, the response at node i can be expressed as follows:

$$\begin{aligned} f_i(\varvec{x}; \theta ) = \left( 1 - \gamma _i \right) \cdot \upsilon _i + \gamma _i \cdot s\left( \varvec{x}; \omega _i \right) ^T \begin{bmatrix} f_{L(i)} \left( \varvec{x}; \theta \right) \\ f_{R(i)} \left( \varvec{x}; \theta \right) \end{bmatrix} \end{aligned}$$
(2)

where \(\gamma _i \in [0, 1]\), and \(s: \varvec{x} \rightarrow [0, 1]^2\) is an oblique split with \(\omega _i \in \mathbb {R}^{d}\). Due to the continuous characteristics, various techniques (e.g., Gradient Descent (Bottou, 2012; Mukkamala & Hein, 2017) and Regularization (Prechelt, 1998; Bousquet et al., 2004)) in deep learning can be used for end-to-end tree training.

However, designing effective soft trees is a challenging task. Soft Decision Tree (Irsoy et al., 2012; Norouzi et al., 2015) only considers the global optimization of node parameters \((\omega, \upsilon )\), omitting the architecture parameters \(\gamma\). Breiman et al. (1984) pointed out that tree quality depends more on good stopping rules than on splitting rules. That is, \(\gamma\) is more crucial in some way. Budding Tree (Irsoy et al., 2014) considers the architecture parameters, but its randomly pruned nodes fail to bud. All of the above methods directly utilize probabilistic trees when testing, which looses the interpretability of decision trees. End2End Tree (Hehn et al., 2019) maintains a probabilistic tree for training and discretizes it to a deterministic one for testing. However, there exists a performance gap after the discretization. Moreover, End2End Tree is a two-stage method that first learns the node parameters end-to-end and then searches \(\gamma\) by the greedy algorithm.

In this work, we propose One-Stage Tree, which is, to our best knowledge, the first soft tree that maintains discretization during training. In contrast to the two-stage methods (e.g., CART and End2End Tree) of building and then pruning, we first formalize the joint search for node and architecture parameters as a bilevel optimization problem. Then, we keep the discretization of the path and architecture during training. Specifically, we directly sample leaves by the Gumbel Softmax to predict instances according to the path probability and propose an optimization strategy for discrete \(\gamma\) via proximal iterations. Benefiting from the discretization, we directly find the closed-form optimal solution of \(\upsilon\). Moreover, we reduce the performance gap and maintain interpretability.

Extensive experimental results on both classification and regression tasks reveal the effectiveness of One-Stage Tree. One-Stage Tree has a significant improvement over the most typical CART. Compared with the existing soft trees, One-Stage Tree achieves better performance on most datasets. Moreover, One-Stage Tree is competitive with other standard ML methods. The implementation of One-Stage Tree is publicly available on GitHub.Footnote 1

To summarize, our main contributions can be highlighted as follows:

  • We introduce One-Stage Tree to search the node and architecture parameters jointly through a bilevel optimization problem.

  • The reparameterization trick and proximal iterations are leveraged to keep the tree discrete during training. In this way, we can reduce the performance gap between training and testing and maintain interpretability.

  • Extensive experimental results on both classification and regression tasks demonstrate that One-Stage Tree outperforms CART and the existing soft trees.

The rest of the paper is structured as follows: after introducing related work in soft trees (Sect. 2), we describe our approach in detail (Sect. 3). We then report experimental results on classification and regression datasets (Sect. 4) before concluding the paper (Sect. 5).

2 Related work

2.1 Soft tree

The decision tree is among the most popular machine learning algorithms, given its interpretability and simplicity. First, due to the axis-parallel split of each internal node, the decision tree can learn from little training data and is easy to interpret. Then, benefiting from the hierarchical architecture, the decision tree is computationally efficient, with only \(\mathcal {O}(\log {|\mathcal {I} |})\) nodes needing to be visited out of all \(|\mathcal {I} |\) internal nodes for a binary complete tree.

The decision tree structure depends on internal parameters \(\omega\) and router function \(s(\varvec{x}; \omega _i)\). The most typical one is the univariate discrete tree (Quinlan, 1986, 1996; Breiman et al., 1984), also called hard tree, where \(\Vert \omega _i \Vert _0 = 1\) and \(\Vert s(\varvec{x}; \omega _i) \Vert _0 = 1\). Hard Tree selects a sub-path for instances according to a specific feature and threshold. In the multivariate tree (Irsoy et al., 2012; Norouzi et al., 2015; Irsoy et al., 2014; Hehn et al., 2019), which is also called soft tree, \(\omega _i\) is a continuous variable and \(s(\varvec{x}; \omega _i)\) defines an oblique split.

Soft Decision Tree (Irsoy et al., 2012) takes the sigmoid function as the router function and builds a multivariate dense tree whose prediction is contributed by leaves with different probabilities. For Soft Decision Tree, \(s(\varvec{x}; \omega _i) = [g_i (\varvec{x}; \omega _i), 1 - g_i(\varvec{x}; \omega _i)]^T\), where \(g_i(\varvec{x};\omega _i) = \frac{1}{1 + \exp {(-\omega _i^T \varvec{x})}}\), routes instances to all its children with probabilities. Although it has a smoother fit and lower bias around the split boundaries, all the leaves’ paths are traversed. The computational overhead increases from \(\mathcal {O} ( \log (|I |) )\) to \(\mathcal {O}(|I |)\), where I denotes the set of internal nodes.

Unlike Soft Decision Tree that only searches the splitting rule, Budding Tree (Irsoy et al., 2014) relaxes \(\gamma\) and fits the tree architecture. The bud node i can be an internal node and a leaf at the same time according to \(\gamma _i\). By gradient descent, Budding Tree splits and prunes the tree in the learning phase. However, \(\gamma _i\) will never be updated once it equals 0, which is called dying \(\gamma\) problem.

Being aware of the benefits of discretization in terms of interpretability, End2End Tree (Hehn et al., 2019) proposes a multivariate discrete tree. End2End Tree is fully probabilistic at train time but becomes deterministic at test time after an annealing process. The performance gap between training and testing still exists. Moreover, the tree architecture is still searched greedily, with the risk of a sub-optimal architecture.

In this paper, we propose One-Stage Tree to build and prune the tree jointly. One-Stage Tree directly samples the leaf node as prediction and keeps the discretization of the architecture during training. We leverage the reparameterization trick and proximal iterations to optimize the multivariate discrete tree.

2.2 Proximal algorithm

PA (Proximal Algorithm) (Parikh & Boyd, 2014) is a popular optimization technique for handling the following problem:

$$\begin{aligned} \mathop {\min }_{x\in \mathcal {C}} f(x) + g(x) \end{aligned}$$
(3)

where f and g are closed proper convex, f is differentiable, and \(\mathcal {C}\) is the feasible space. The crux of PA is the standard proximal step:

$$\begin{aligned} \begin{aligned}&x^{(k+1)} = \text {prox}_{\mathcal {C}, \epsilon g} \left( x^{(k)} - \epsilon \nabla f \left( x^{(k)} \right) \right) \\ \text {where} \ \epsilon > 0&\text { and } \text {prox}_{\mathcal {C}, \epsilon g}(x) = \mathop {\arg \min }_{z \in \mathcal {C}} \frac{1}{2} {\Vert z - x \Vert }_2^2 + \epsilon g(z) \end{aligned} \end{aligned}$$
(4)

\(\text {prox}_{\mathcal {C}, \epsilon g} \left( \cdot \right)\) represents the standard proximal operator of g with scale parameter \(\epsilon\) constrained by \(\mathcal {C}\) (Parikh & Boyd, 2014). In this form, we split the objective into two terms, one of which is differentiable. Since g can be extended-valued, it can be used to encode constraints on x.

In machine learning, PA is widely used to solve the continuously differentiable optimization problem with a constraint \(\mathcal {C}\) as \(\mathop {\min }_x f(x) \text { where } x \in \mathcal {C}\). Since \(g(x) = 0\), the proximal step is simplified to \(x^{(k+1)} =\text {prox}_{\mathcal {C}} (x^{(k)} - \epsilon \nabla f(x^{(k)}))\). Due to its excellent theoretical guarantee and good empirical performance, it has been applied to many deep learning problems (e.g., network binarization (Bai et al., 2018) and recommendation system (Yao et al., 2020)). Another variant of PA with lazy proximal step (Xiao, 2010) maintains two copies of x, i.e.,

$$\begin{aligned} x^{(k+1)} = x^{(k)} - \epsilon \nabla f(\bar{x}^{(k)}) \text { where } \bar{x}^{(k)} = \text {prox}_{\mathcal {C}}(x^{(k)}) \end{aligned}$$
(5)

Although it has no convergence guarantee in the non-convex case, it performs well empirically on deep learning tasks (Courbariaux et al., 2015; Hou et al., 2017).

3 Methodology

In this section, we introduce One-Stage Tree, which is trained with the reparameterization trick and proximal iterations in an end-to-end manner.

3.1 Problem formulation

Consider a supervised task with input space \(\mathcal {X} \subset \mathbb {R}^d\) and output space \(\mathcal {Y} \subset \mathbb {R}^c\). Let \(\mathcal {D}\) be the data distribution. We denote the training set sampled from \(\mathcal {D}\) as \(\mathcal {D}_{\textit{train}}\). The training set is also defined as \(\{\varvec{x}_1, \ldots, \varvec{x}_N\} \subset \mathcal {X}\) with corresponding \(\{\varvec{y}_1, \ldots, \varvec{y}_N\} \subset \mathcal {Y}\). Let l be a differentiable convex function that measures the difference between the prediction and the target. We denote the overall loss on a given dataset as \(\mathcal {L}(\theta )\).

The soft tree relaxes the parameters \(\theta\) to be continuous, recursively calculates the probability \(s(\varvec{x}; \omega _i)\) from an internal node i to its children, and finally gets the path probability \(\mu _t (\varvec{x}; \omega, \gamma )\) from the root to each leaf t. Let \(\text {Path}(t)\) be the node set from the root to leaf t. The path probability \(\mu _t (\varvec{x}; \omega, \gamma )\) is calculated as follows:

$$\begin{aligned} \mu _t (\varvec{x}; \omega, \gamma ) = \prod _{i \in \text {Path}(t)} \gamma _i \cdot s(\varvec{x}; \omega _i)_{\left( {1-\mathbb {I}_{L(i) \in \text {Path}(t)}}\right) } \end{aligned}$$
(6)

where \(\mathbb {I}\) represents the 0-1 indicator function and L(i) denotes the left child of node i. If \(L(i) \in \text {Path}(t)\) is satisfied, the indicator function equals 1. For example, \(s(\varvec{x}; \omega _i)_0\) indicates the probability of routing the left child. The response of the soft tree is the probability-weighted sum of the leaf values.

$$\begin{aligned} {\begin{matrix} &{}f(\varvec{x}; \theta ) = \sum _{t \in \mathcal {T}} \upsilon _t \cdot \mu _t (\varvec{x}; \omega, \gamma ) \\ \end{matrix}} \end{aligned}$$
(7)

After the continuous relaxation, the goal is to jointly learn the tree parameters and find the global optima, which can be optimized by gradient descent:

$$\begin{aligned} \theta ^{*} = \mathop {\arg \min }_{\theta } \mathcal {L} (\theta ) = \mathop {\arg \min }_{\theta } \sum _{(\varvec{x}, \varvec{y}) \sim \mathcal {D}} l \left( f \left( \varvec{x}; \theta \right), \varvec{y} \right) \end{aligned}$$
(8)

3.2 One-Stage Tree

Although the continuous relaxation allows the whole tree to be differentiable, soft trees have significant limitations:

  1. 1.

    Interpretability Although the oblique split \(\omega _i\) indicates feature importance, path probabilities at large depths are difficult to interpret.

  2. 2.

    Performance The continuous architecture \(\gamma\) needs to be discretized to \(\{0, 1\}\) at test time, resulting in inconsistent performance between training and testing.

Recall that in hard trees, the trees are all discrete when updating node parameters. Such discretization naturally alleviates the above limitations. Thus, we aim to search the differentiable soft tree but keep discrete architectures and paths when updating the parameters. Like decision trees that use the validation set for pruning, we divide the training set into two parts (i.e., the training set and validation set) and minimize the following objective:

$$\begin{aligned} {\begin{matrix} \displaystyle \mathop {\min }_{\gamma } \quad &{} \mathcal {L}_{\textit{val}} (\omega ^{*}, \upsilon ^{*}, \gamma ) \\ \text {s.t.} \quad &{} \omega ^{*}, \upsilon ^{*} = \mathop {\arg \min }_{\omega, \upsilon } \mathcal {L}_{\textit{train}} (\omega, \upsilon, \gamma ) \\ &{}\forall \text {node}\ i, \ \gamma _i \in \{0, 1\} \text { and } s: \varvec{x} \rightarrow \{[0, 1]^T, [1, 0]^T\} \end{matrix}} \end{aligned}$$
(9)

where \(\mathcal {L}_{\textit{train}}\) and \(\mathcal {L}_{\textit{val}}\) are the losses on the training and validation sets, respectively.

We call it One-Stage Tree because it keeps the discretization while simultaneously completing the two stages of building and pruning. The one-stage optimization is achieved by solving the bilevel optimization problem in Eq. (9). However, the problem of discretization remains. Specifically, the discretization in One-Stage Tree can be divided into two parts:

  • Discretization of Probabilistic Path A discrete path routes an instance from root to a leaf. In soft trees, the paths are probabilistic and summed as prediction \(\sum _{t \in \mathcal {T}} \mu _t (\varvec{x}; \omega, \gamma ) \cdot \upsilon _t\). To discretize the path, a straightforward idea is to sample a path as the prediction based on the probability. The Monte Carlo method (Metropolis and Ulam, 1949) can be used to estimate the expectation of the loss. We use the reparameterization trick (Blum et al., 2015) to make it differentiable w.r.t \(\omega\). To approximate the sampling estimator to the true expectation, we use the Gumbel Softmax (Maddison et al., 2014) for reparameterization.

  • Discretization of Continuous Architecture The continuous relaxation of \(\gamma\) unifies the forms of node and leaf. However, it makes the tree architecture difficult to interpret and leads to low computational efficiency caused by that all nodes need to be visited. To keep \(\gamma\) discrete but differentiable, we propose an architecture optimization strategy via proximal iterations (Parikh & Boyd, 2014; Xiao, 2010), which is inspired by NAS (Neural Architecture Search (Liu et al., 2018; Yao et al., 2020)).

Fig. 2
figure 2

Overview of One-Stage Tree. Each node i can be either an internal node (i.e., \(\omega _i\) is emphasized) or a leaf (i.e., \(\upsilon _i\) is emphasized) according to \(\gamma _i\). The path, where the instance \(\varvec{x}\) (including a bias term) is routed to, is shown in purple. Due to the discretization, \(\varvec{x}\) is routed to the left or right child at each node by \(s(\varvec{x};\omega )\) and finally reaches the leaf as the prediction

To summarize, as shown in Fig. 2, One-Stage Tree retains the advantages of hard trees as discrete inference models and improves learning through end-to-end training with back-propagation.

3.3 Gumbel-softmax path

The probabilistic path from the root to the leaf t lacks interpretability and leads to iterative optimization of the leaf parameters \(\upsilon\). To discretize the probabilistic path, we express the router at node i as a random variable \(s(\varvec{x}; \omega _i) \sim q_{\omega _i}(s|\varvec{x})\). Thus, a discrete path is sampled from a continuous distribution parameterized by \(\omega\). In this way, we can explore the diversity, where each instance can belong to different leaves, and exploit the best path with the highest probability. Traditionally, each instance belongs to a fixed leaf by the splitting rule, and each leaf calculates \(\upsilon\) by its instances (e.g., the average of labels in CART). Here, by sampling from the random variable when training, we explore the case that each instance can be held by different leaves in the optimization process. Moreover, we directly choose the path with the highest probability as exploitation when testing.

$$\begin{aligned} s(\varvec{x}; \omega _i) = g_{\omega _i} (\epsilon, \varvec{x}) \quad \text {with } \epsilon \sim p(\epsilon ) \end{aligned}$$
(10)

However, as a result of sampling, the loss cannot be propagated backward to \(\omega\). To train \(\omega\), we reparameterize the random variable s using a differentiable transformation \(g_{\omega _i}(\epsilon, \varvec{x})\), where g is parameterized by \(\omega _i\) and \(\epsilon\) is an (auxiliary) noise variable with independent marginal \(p(\epsilon )\).

Using the reparameterization trick, we can now form MC (Monte Carlo) estimates (Metropolis & Ulam, 1949) of the expectation of the loss, which is differentiable w.r.t. \(\omega\), as follows:

$$\begin{aligned} \mathbb {E}_{q_{\omega }(s | \varvec{x})} [ \mathcal {L} ( \theta ) ] &= \mathbb {E}_{q_{\omega }(s | \varvec{x})} \left[ \sum _{(\varvec{x}, \varvec{y}) \sim \mathcal {D}} l \left( f \left( \varvec{x}; \theta \right), \varvec{y} \right) \right] \\ & =\mathbb {E}_{p(\epsilon )} \left[ \sum _{(\varvec{x}, \varvec{y}) \sim \mathcal {D}} l \left( \sum _{t \in \mathcal {T}} v_t \prod _{i \in \text {Path}(t)} \gamma _i \cdot g_{\omega _i}\left( \epsilon, \varvec{x}\right) _{\left( {1-\mathbb {I}_{L(i) \in \text {Path}(t)}}\right) }, \varvec{y} \right) \right] \\ & \approx \frac{1}{M} \sum _{m=1}^M \sum _{(\varvec{x}, \varvec{y}) \sim \mathcal {D}} l \left( \sum _{t \in \mathcal {T}} v_t \prod _{i \in \text {Path}(t)} \gamma _i \cdot g_{\omega _i}(\epsilon ^{(m)}, \varvec{x})_{\left( {1-\mathbb {I}_{L(i) \in \text {Path}(t)}}\right) }, \varvec{y} \right) \end{aligned}$$
(11)

where M is the number of samples and \(\epsilon ^{(m)} \sim p(\epsilon )\). If \(L(i) \in \text {Path}(t)\) is satisfied, the indicator function \(\mathbb {I}\) equals 1 and the probability of routing the left child is \(g_{\omega _i}(\epsilon, \varvec{x})_0\).

Specifically, we choose to sample noise from Gumbel Distribution (Gumbel, 1954), which can smoothly approximate the expectation (Maddison et al., 2014; Jang et al., 2016). Correspondingly, the Gumbel Softmax of \(g_{\omega _i}(\epsilon, \varvec{x})\) is expressed as follows:

$$\begin{aligned} &{}g_{\omega _i}(\epsilon, \varvec{x})_k = \frac{e^{{\hat{s}(\varvec{x}; \omega _i)_k}/\tau }}{{\sum _{j=0}^1}{e^{\hat{s}(\varvec{x}; \omega _i)_j}/\tau }}\\ & \text {where } \epsilon \sim {}\textit{Gumbel}(0),\ k \in \{0,1\}, \text { and }\hat{s}(\varvec{x}; \omega _i)_k = s(\varvec{x}; \omega _i)_k+\epsilon \end{aligned}$$
(12)

\(\tau\) denotes the temperature of the Gumbel Softmax. In practice, the value of \(\tau\) can be empirically set to 1.

3.4 Architecture search via proximal iterations

Equation (9) implies a bilevel optimization problem with \(\gamma\) as the upper-level variable and \(\omega\) as the lower-level variable. We also relax the discrete choice of whether to prune or not (i.e. \(\gamma _i \in [0, 1]\)). As a result, \(\gamma\) can be optimized w.r.t. its validation set performance by gradient descent.

Budding Tree (Irsoy et al., 2014) propagates the error backwards from the root towards the leaves. \(f_i\) denotes the response at node i (Eq.  (2)). Define \(\text {pa}(i)\) as the parent of node i and \(\delta _i = \partial \mathcal {L} / \partial f_i\) as the responsibility of node i. Deriving \(\mathcal {L}\) w.r.t. \(\gamma _i\), we have:

$$\begin{aligned} \frac{\partial \mathcal {L}}{\partial \gamma _i} &= \delta _{i} \left( -\upsilon _i + \cdots \right) \\ \text {with} \quad {}\delta _i &= {\left\{ \begin{array}{ll} \frac{\partial \mathcal {L}}{\partial f_i}, &{} \text {if { i} is the root} \\ \delta _{\text {pa}(i)} \times \gamma _i \times \cdots, &{} \text {if } i\hbox { is a child} \end{array}\right. } \end{aligned}$$
(13)

The detailed derivation of Eq. (13) can be found in “Appendix 1”. From Eq. (13), we can see that once the internal node is pruned (i.e., \(\gamma _i = 0\)), it will never be able to bud again because its gradient is 0, which we call dying \(\gamma\). Moreover, it is prohibitive to evaluate the gradient due to the expensive inner optimization \(\mathop {\arg \min }_{\omega, \upsilon } \mathcal {L}_{\textit{train}}(\omega, \upsilon, \gamma )\). Following the commonly used method (e.g., meta learning (Finn et al., 2017) and NAS (Liu et al., 2018)), we use a one-step gradient approximation to the optimal internal parameter \(\omega ^{*}\) to improve efficiency. Thus, the gradient of the architecture parameter \(\gamma\) is as follows:

$$\begin{aligned} &{}\nabla _{\gamma } \mathcal {L}_{\textit{val}} \left( \omega ^{*}, \upsilon ^{*}, \gamma \right) \\ \approx \ &{}\nabla _{\gamma } \mathcal {L}_{\textit{val}} \left( \omega - \xi \nabla _{\omega } \mathcal {L}_{\textit{train}} \left( \omega, \upsilon ^{*}, \gamma \right), \upsilon ^{*}, \gamma \right) \\ = \ &{} \nabla _{\gamma } \mathcal {L}_{\textit{val}} \left( \omega ', \upsilon ^{*}, \gamma \right) - \xi \nabla ^{2}_{\gamma, \omega } \mathcal {L}_{\textit{train}} \left( \omega, \upsilon ^{*}, \gamma \right) \nabla _{\omega '} \mathcal {L}_{\textit{val}} \left( \omega ', \upsilon ^{*}, \gamma \right) \end{aligned}$$
(14)

More specifically, the approximate procedure alternatively optimizes the node parameters \((\omega,\upsilon )\) and the architecture parameters \(\gamma\). At step k, given the current architecture \(\gamma ^{(k)}\), we first calculate \(\upsilon ^{(k+1)}\) in closed-form (Sect. 3.5). Then, we obtain \(\omega ^{(k+1)}\) by descending \(\nabla _{\omega ^{(k)}} \mathcal {L}_{\text {train}}(\omega ^{(k)},\upsilon ^{(k+1)}, \gamma ^{(k)})\) with the step size \(\xi\) as a one-step optimization for \(w^{*}\) under \(\gamma ^{(k)}\). Then, we update the architecture parameters \(\gamma ^{(k)}\) so as to minimize the validation loss. The architecture gradient is given in Eq. (14). We omit the step-index k for brevity. \(\upsilon ^{*}\) denotes the optimal leaf values and \(\omega ^{*}\) denotes the internal parameters with a one-step gradient decent.

However, there exist two problems in solving Eq. (14). First, the evaluation of the second-order derivative \(\nabla ^2\) is expensive due to a large number of parameters. Second, the continuous trick further leads to the performance gap caused by discretizing \(\gamma _i \in [0, 1]\) at training to \(\{0, 1\}\) at testing.

To address the two issues, we employ the variant of Proximal Algorithm (Yao et al., 2020) for optimizing \(\gamma\) efficiently. Equivalently, we transform \(\gamma _i\) to a 2-d one-hot vector that indicates whether to prune or not, i.e., \(\gamma _i \in \{[0, 1], [1, 0]\}\). Let the feasible space of \(\gamma\) be \(\mathcal {C} = \{\gamma |\ \forall i, \ \Vert \gamma _i \Vert _0 = 1 \ \wedge \ 0 \le \gamma _{i, j} \le 1 \}\). We denote it as the intersection of two feasible spaces (i.e., \(\mathcal {C} = \mathcal {C}_1 \cap \mathcal {C}_2\)), where \(\mathcal {C}_1 = \{\gamma | \ \forall i, \ \Vert \gamma _i \Vert _0 = 1 \}\) and \(\mathcal {C}_2 = \{\gamma | \ \forall i, \ 0 \le \gamma _{i,j} \le 1 \}\). With such a constrained form, we can apply the composition of lazy and standard proximal steps proposed in Yao et al. (2020).

Specifically, as shown in Eq. (15), in each proximal iteration, we first get a discrete architecture \(\bar{\gamma }\) constrained by \(\mathcal {C}_1\). Then, we derive gradients w.r.t \(\bar{\gamma }\) and keep \(\gamma\) to be optimized as continuous variable but constrained by \(\mathcal {C}_2\):

$$\begin{aligned} \displaystyle \gamma ^{(k + 1)} = \text {prox}_{\mathcal {C}_2} \left( \gamma ^{(k)} - \epsilon \nabla _{\bar{\gamma }^{(k)}} \mathcal {L}_{\text {val}} \left( \bar{\gamma }^{(k)} \right) \right) \text {, where } \bar{\gamma }^{(k)} = \text {prox}_{\mathcal {C}_1} (\gamma ^{(k)}) \end{aligned}$$
(15)
figure a

Algorithm 1 shows the overall workflow of One-Stage Tree that searches the architecture parameters \(\gamma\) via proximal iterations. In the k-th iteration, the architecture and node parameters are updated alternatively. As the lazy proximal step first projects \(\gamma\) into the discrete feasible space \(\mathcal {C}_1\), we can obtain the discrete architecture \(\bar{\gamma }^{(k)} = \text {prox}_{\mathcal {C}_1}(\gamma ^{(k)})\) (Line 3). Then, we calculate the optimal leaf value \(\upsilon\) in closed-form (Sect. 3.5) and update the internal parameters \(\omega\) on the training dataset (Sect. 3.3) based on \(\bar{\gamma }^{(k)}\) (Lines 4-5). After forwarding \(\omega\) one-step as in Eq. (14), we optimize \(\gamma ^{(k)}\) with the gradient derived from \(\bar{\gamma }^{(k)}\) as continuous variable and then project it into \(\mathcal {C}_2\) (Line 6).

In each proximal iteration, we keep the architecture \(\gamma\) discrete when training, which contributes to reducing the performance gap caused by discretizing architecture from a continuous one. Moreover, we can ignore the second-order derivative of small magnitude \(\epsilon \cdot \xi\) because \(\gamma\) will be projected into the discrete feasible space \(\mathcal {C}_1\) in the next iteration, i.e., \(\text {prox}_{\mathcal {C}_1}^{(k + 1)}( \text {prox}_{\mathcal {C}_2}^{(k)} ( \gamma ^{(k)} - \epsilon ( \nabla _{\bar{\gamma }^{(k)}} -\xi \nabla ^2_{\bar{\gamma }^{(k)}, \omega ^{(k)}}\nabla _{\omega ^{(k+1)}} ) ) ) \approx \text {prox}_{\mathcal {C}_1}^{(k + 1)}( \text {prox}_{\mathcal {C}_2}^{(k)} ( \gamma ^{(k)} - \epsilon \nabla _{\bar{\gamma }^{(k)}}) )\). Thus, the computational efficiency of updating \(\gamma\) can be significantly improved.

3.5 Optimal leaves in closed-form

Unlike multivariate dense trees (Irsoy et al., 2012, 2014) where leaf values are iteratively optimized by gradient descent, we can solve for \(\upsilon\) in closed-form due to the discretization of path and architecture. The prediction \(f(\varvec{x})\) is \(\upsilon _t\) when \(\mu _t (\varvec{x}) = 1\). Define \((\varvec{x}, \varvec{y})\) as an instance and \(I_t = \{ (\varvec{x}, \varvec{y}) | \forall (\varvec{x}, \varvec{y}) \sim \mathcal {D}, \mu _t(\varvec{x}; \omega, \gamma ) = 1 \}\) as the instance set of leaf t, the derivative of the loss function can be expressed as:

$$\begin{aligned} \begin{aligned} \frac{ \partial \mathcal {L} }{\partial \upsilon _t}&= \sum _{\varvec{x}, \varvec{y}} \frac{ \partial l }{\partial f} \frac{\partial f}{\partial \upsilon _t} = \sum _{\varvec{x}, \varvec{y}} \frac{ \partial l }{\partial f} \frac{ \partial \sum _{t \in \mathcal {T} } \upsilon _t \cdot \mu _t (\varvec{x}; \omega, \gamma ) }{\partial \upsilon _t} \\&= \sum _{\varvec{x}, \varvec{y}} \frac{ \partial l }{\partial f} \mu _t (\varvec{x}; \omega, \gamma ) = \sum _{(\varvec{x}, \varvec{y}) \in I_t} \frac{ \partial l }{\partial f} \end{aligned} \end{aligned}$$
(16)

From Eq. (16), the optimal leaves are solved in closed-form by simply deriving l w.r.t. the tree prediction. Let \(\partial \mathcal {L} / \partial \upsilon _t = 0\), we show the optimal solution for the leaf values under the common MSE and CrossEntropy losses:

  • MSE

    $$\begin{aligned} \text {Let } l(\varvec{y}, f) &= \frac{1}{2} \left( \varvec{y} - f \right) ^2 \\ \therefore {} \frac{\partial \mathcal {L}}{\partial \upsilon _t} &= \sum _{(\varvec{x}, \varvec{y}) \in I_t} \frac{ \partial l }{\partial f} = \sum _{(\varvec{x}, \varvec{y}) \in I_t} \left( f - \varvec{y} \right) = \sum _{(\varvec{x}, \varvec{y}) \in I_t} \left( \upsilon _t - \varvec{y} \right) \\ \text {Let } \frac{\partial \mathcal {L}}{\partial \upsilon _t} &= 0 \quad \therefore {} \upsilon _t^{*} \ = \ \displaystyle \frac{ \sum _{(\varvec{x}, \varvec{y}) \in I_t} \varvec{y} }{ |I_t |} \end{aligned}$$
  • CrossEntropy with constraint \(\sum _{i=0}^c f_i = 1\) (i.e., c is the number of classes, and the probability sum is 1):

    $$\begin{aligned} \text {Let } l(\varvec{y}, f, \lambda ) &= -\sum _{i=0}^c \varvec{y}_i \log {f_i} + \lambda \left( 1 - \sum _{i=0}^c f_i\right) \ \therefore \frac{\partial \mathcal {L}}{\partial \upsilon _{t, i}} &= \sum _{(\varvec{x}, \varvec{y}) \in I_t} \frac{ \partial l }{\partial f_i} = \sum _{(\varvec{x}, \varvec{y}) \in I_t} \left( - \frac{\varvec{y}_i}{f_i} - \lambda \right) = \sum _{(\varvec{x}, \varvec{y}) \in I_t} \left( - \frac{\varvec{y}_i}{\upsilon _{t, i}} - \lambda \right) \\ \text {Let } \frac{\partial \mathcal {L}}{\partial \upsilon _{t, i}} &= 0 \quad \therefore {} \upsilon ^{*}_{t, i} = \frac{ -\sum _{(\varvec{x}, \varvec{y}) \in I_t} \varvec{y}_i }{\lambda |I_t |}\\ \text {Let } \frac{\partial l}{\partial \lambda } &= 1 - \sum _{i=0}^c f_i = 1 - \sum _{i=0}^c v_{t, i} = \sum _{i=0}^c \frac{ -\sum _{(\varvec{x}, \varvec{y}) \in I_t} \varvec{y}_i }{\lambda |I_t |}= 0 \\ \therefore {} \lambda ^{*} &= \frac{ -\sum _{(\varvec{x}, \varvec{y}) \in I_t} \sum _{i = 0}^c \varvec{y}_i }{|I_t |} = = -1, \ \upsilon _t^{*} \ = \ \displaystyle \frac{ \sum _{(\varvec{x}, \varvec{y}) \in I_t} \varvec{y} }{ |I_t |} \end{aligned}$$

In summary, the optimal solution for \(\upsilon\) under both MSE and CrossEntropy is:

$$\begin{aligned} \upsilon _t^{*} \ = \ \displaystyle \frac{ \sum _{(\varvec{x}, \varvec{y}) \in I_t} \varvec{y} }{ |I_t |} \end{aligned}$$
(17)

3.6 In-depth discussion

Table 1 Characteristic comparison between One-Stage Tree and other tree models including hard tree and existing soft trees

Table 1 shows the characteristic comparison between One-Stage Tree and other tree models from both training and inference perspectives.

In the training phase, One-Stage Tree improves learning in an end-to-end manner. Unlike the existing soft trees, One-Stage Tree can achieve joint optimization for node and architecture parameters. Soft Decision Tree does not support the optimization of the architecture parameters. Due to the lack of any pruning strategy, Soft Decision Tree is easy to fall into overfitting. Budding Tree considers the search of the architecture parameters, but the dying \(\gamma\) problem may occur. Compared to One-Stage Tree, End2End Tree is a two-stage method that first learns the node parameters end-to-end and then searches the architecture parameters greedily. Moreover, due to the discretization of path and architecture, One-Stage Tree can efficiently solve \(\upsilon\) in closed form.

In the inference phase, One-Stage Tree can keep the same advantage of interpretability as hard trees due to maintaining discretization. Unlike End2End Tree, which transforms from the probabilistic tree to the deterministic one during inference, One-Stage Tree does not require the additional transformation and thus can reduce the performance gap between training and testing.

4 Experiments

In this section, we conduct extensive experiments on public datasets to answer the following research questions:

  • RQ1 How effective is the proposed One-Stage Tree?

  • RQ2 Is One-Stage Tree robust to hyperparameters?

  • RQ3 How do different components of One-Stage Tree (e.g., Proximal Algorithm) contribute to the performance?

  • RQ4 How to reflect the interpretability of One-Stage Tree?

4.1 Experimental setting

We use a total of 22 public datasets from OpenML,Footnote 2 UCI repository,Footnote 3 and Kaggle.Footnote 4 There are 17 classification (C) datasets and 5 regression (R) datasets that have various numbers of features (5 to 57) and instances (100 to 30000).

Benefiting from soft trees, One-Stage Tree can be trained using tools from deep learning. We choose the Adam optimizer (Kingma & Ba, 2014) to train One-Stage Tree. The number of epochs is up to 200, the batch size is 32, and the learning rate is 0.01. The other hyperparameters of the Adam optimizer are all the same as default settings. EarlyStopping (Prechelt, 1998), which monitors the validation loss, is used to prevent overfitting with patience of 15. Except for Sect. 4.3.1, the depths of all tree models are set to 6 for comparison.

We use MSE loss for regression tasks and CrossEntropy loss for classification tasks in all experiments. Moreover, to evaluate the trees, we use r2-score and accuracy for regression tasks and classification tasks respectively. For clarity, we multiply all metrics by 100 in all tables.

Due to the bilevel optimization, One-Stage Tree splits the raw data as 6:2:2 (train:validation:test) and uses the validation set for optimizing the architecture parameters (i.e., tree pruning). For the other methods (i.e., CART, Soft Decision Tree, and End2End Tree), the raw data is divided using a ratio of 8:2 (train:test). These methods do not require a validation set during training.

4.2 Effectiveness of One-Stage Tree (RQ1)

In this subsection, we demonstrate the effectiveness of One-Stage Tree.

4.2.1 Comparison with trees

We compare One-Stage Tree on 22 datasets with the state-of-the-art and baseline tree methods, including:

  1. 1.

    CART (Breiman et al., 1984): the most typical univariate discrete tree, which uses MSE for regression and Gini Index for classification as the splitting rules. We choose the widely-used sklearn.tree packageFootnote 5 to run CART.

  2. 2.

    Soft Decision Tree (Irsoy et al., 2012): a multivariate dense tree, of which all the paths to all the leaves contribute to the final prediction with different probabilities. It only supports classification tasks. We use the open-source codeFootnote 6 with most stars on GitHub to obtain the experimental results.

  3. 3.

    Budding Tree (Irsoy et al., 2014): a multivariate dense tree, which searches the tree architecture in the learning phase. It supports both classification and regression tasks.

  4. 4.

    End2End Tree (Hehn et al., 2019): the state-of-the-art multivariate discrete tree, which is fully probabilistic at the training phase but becomes deterministic after an annealing process at the testing phase. It is open-sourceFootnote 7 and only supports classification tasks.

Table 2 Comparison between One-Stage Tree, CART, and the existing soft trees on classification datasets from UCIrvine
Table 3 Comparison between One-Stage Tree and CART on regression datasets
Table 4 Comparison between One-Stage Tree, CART, and Budding Tree with the same experimental setting in Irsoy et al. (2014)

For the open-source methods including CART, Soft Decision Tree, and End2End Tree, we directly use the default hyperparameters in the open-source codes. To investigate the stability of the training process, we randomly select 5 random seeds and obtain the mean and standard deviation of the trees’ performance. Table 2 shows the comparison results between One-Stage Tree and the open-source methods including CART, Soft Decision Tree, and End2End Tree on classification datasets. Since Soft Decision Tree and End2End Tree do not support regression tasks, we only show the comparison results with CART in Table 3.

Moreover, since Budding Tree is not completely open-source, we directly use the available datasets and the experimental results reported in the original paper (Irsoy et al., 2014). To set up the same experimental setting, we separate 1/3 of the dataset as a test set to evaluate the final performance. The comparison results between One-Stage Tree and Budding Tree are shown in Table 4. According to the comparison results, we can observe that:

  • The comparison results shown in Table 2 indicate that One-Stage Tree outperforms the existing tree methods and achieves the best performance in all but 1 case on classification datasets.

  • Besides the classification tasks, One-Stage Tree can also support regression tasks. As shown in Table 3, for regression datasets from different sources, One-Stage Tree outperforms CART on 3/5 datasets.

  • The comparison results shown in Table 4 demonstrate that One-Stage Tree outperforms Budding Tree and CART on 5/8 and 6/8 datasets respectively. Although CART and Budding Tree also search the node and architecture parameters in the training phase, One-Stage Tree shows better performance in learning the tree parameters \(\theta\).

  • For soft trees, end-to-end training based on gradient descent inevitably has a degree of randomness. Table 2 shows that One-Stage Tree has a smaller standard deviation on most datasets, achieving more stable performance compared to the existing soft trees.

4.2.2 Statistical comparison

Table 5 p-values for each pairwise comparison using the Nemenyi post-hoc test for the soft tree models (Confidence level \(\alpha =0.05\))

To further statistically evaluate the difference between the soft trees, we perform the Friedman test (Demšar, 2006), which is a non-parametric equivalent of the repeated-measures ANOVA. It is used to determine whether or not there is a statistically significant difference between the soft tree models.

For the comparison results in Table 2, we first calculate the Friedman statistic. Let \(r_i^j\) be the rank of the j-th of k soft tree models (k = 4, i.e., CART, Soft Decision Tree, End2End Tree, and One-Stage Tree) on the i-th of N classification datasets. The Friedman test compares the average ranks of models, \(R_j = \frac{1}{N}\sum _i{r_i^j}\). The null-hypothesis states that all the tree models are equivalent and so their ranks \(R_j\) should be equal. We employ the scipy toolFootnote 8 to calculate the Friedman statistic. The Friedman statistic is 19.837209 and the corresponding p-value is 0.00018. Since the p-value is less than 0.05, we can reject the null hypothesis that the performance is the same for all four types of soft trees. In other words, we have sufficient evidence to conclude that the trees lead to statistically significant differences in terms of performance. Since the p-value of the Friedman test is statistically significant, we perform the Nemenyi post-hoc test (Nemenyi, 1963) to further determine exactly which trees have different means. Table 5 shows the p-values for each pairwise comparison. We can conclude that One-Stage Tree is significantly different from other trees for a confidence level of \(\alpha =0.05\). Additionally, according to the Friedman test, there is no significant difference for the three trees compared in Table 4.

For the regression tasks shown in Table 3, we perform Wilcoxon signed-rank test (Demšar, 2006) to compare the two tree models statistically. The statistic is 5.0 and the corresponding p-value is 0.24886. Thus, there is no statistically significant difference between CART and One-Stage Tree in the regression tasks.

4.2.3 Comparison with other standard ML methods

Fig. 3
figure 3

Comparison between One-Stage Tree and other ML Methods on classification datasets

Fig. 4
figure 4

Comparison between One-Stage Tree and other ML Methods on regression datasets

Table 6 Average Rank of One-Stage Tree and other ML Methods

Moreover, we compare One-Stage Tree with other standard ML methods (i.e., XGBoost (Chen & Guestrin, 2016),Footnote 9 Support Vector Machine,Footnote 10 and Multi-Layer PerceptionFootnote 11) for reference. In particular, a linear kernel for SVM is used. We use the open-source implementation of the above methods and perform grid search to select the best hyperparameters for the learners in each dataset. The hyperparameter search space can be seen in “Appendix 2”. Similarly, 5 randomly selected seeds are used to obtain the performance mean and standard deviation. Figures 3 and 4 show the performance comparison on the classification and regression datasets respectively. The mean performance is present in a bar chart with an error line (i.e., the standard deviation). MLP and SVM perform very poor on the dataset BikeShare DC due to the large range of regression values. Thus, we truncate their performance to 0 in Fig. 4. This also reflects the advantage of tree models that each leaf takes the mean value of its samples as the prediction output.

To present the comparison results clearly, we further calculate the average rank of each method in Table 6. Compared to other methods, One-Stage Tree outperforms traditional machine learning methods MLP and SVM. Although XGBoost ensembles 100 trees with a max depth of 6 via GBDT (Friedman, 2001), One-Stage Tree is still competitive on several datasets. We also consider combining One-Stage Tree with ensemble learning methods such as bagging and boosting to further improve the performance in future work.

In summary, our proposed One-Stage Tree is effective for both classification and regression tasks and outperforms CART and the existing soft trees on most datasets. One-Stage Tree also shows good performance in comparison with other standard machine learning methods.

4.3 Robustness of One-Stage Tree (RQ2)

In this subsection, we evaluate whether One-Stage Tree is sensitive to the key hyperparameters, i.e., the tree depth and the validation size \(\rho\). We perform experiments on all classification datasets with the same experimental setting as in RQ1.

4.3.1 Tree depth

Table 7 Mean number of internal nodes per tree depth
Fig. 5
figure 5

Effect of tree depth

The size of tree depth ranges from 1 to 9. We run One-Stage Tree 5 times with different random seeds and report the mean number of internal nodes per tree depth. As shown in Table 7, the optimization of the architecture parameters in One-Stage Tree is effective for tree pruning. At a depth of 9, One-Stage Tree can even prune \(60\%\) of the internal nodes. Despite the larger depth, trees may be pruned as shallow ones to avoid the risk of overfitting.

With the increase of the tree depth, the search space of the tree structure is growing. Meanwhile, the ability to find effective trees becomes crucial. Figure 5 shows the performance curves of all tree models with respect to the tree depth. From Fig. 5, we can observe that:

  • Due to the joint optimization of the node and architecture parameters, One-Stage Tree can achieve better performance than other tree models at different depths. Moreover, from a global perspective, the performance of One-Stage Tree increases with the tree depth. When the tree depth is 9, One-Stage Tree can still achieve performance improvement on several datasets (e.g., Ionosphere and Wine Equality White).

  • As the tree pruning is not supported, Soft Decision Tree can easily fall into overfitting. Especially at the larger depth, the performance of Soft Decision Tree may decrease dramatically.

  • End2End Tree can achieve stable performance at different depths. However, as a two-stage method of building and then pruning, it does not perform as well as One-Stage Tree.

  • For the hard tree CART that greedily achieves building and pruning, the performance at a large tree depth may fall into local optimal. For example, at a depth of 9, the performance of CART is even much worse than that at a depth of 1 on datasets such as SpectF and Hepatitis.

In summary, One-Stage Tree not only achieves stable performance at different tree depths but also achieves performance improvement as the tree depth increases. Moreover, considering that the best performance may be achieved at depth \(\{5, 6, 7\}\), the uncertainty in deep learning is worth noting. More regularization techniques need to be added to alleviate such problems in future work.

4.3.2 Validation size

Table 8 Comparison results of One-Stage Tree with different validation sizes \(\rho\)

To evaluate the impact of the validation size, we use the same experimental setting as in RQ1 (e.g., depth of 6 and patience of 15). As shown in Table 8, the performance of One-Stage Tree remains stable with respect to the validation size. Moreover, the comparison results in Table 8 demonstrate that the optimal \(\rho\) increases as the number of features decreases. The validation size can be seen as a trade-off for the training of the internal and architecture parameters. As \(\rho\) grows, more instances are used to train the architecture parameters \(\gamma\). At the same time, the internal parameters \(\omega\) is much less optimized. The number of the internal parameters is d (i.e., the number of features) times greater than the number of the architecture parameters. Thus, for the datasets with more features, training \(\omega\) requires more instances (i.e., the smaller \(\rho\)).

4.4 Ablation study (RQ3)

In this subsection, we conduct experiments to check whether the discretization of One-Stage Tree influences the performance gap. To validate the effectiveness of proximal iterations, we propose two variants:

  • Joint Tree which is a variant of One-Stage Tree without proximal iterations. Joint Tree optimizes the node and architecture parameters according to Eq. (8).

  • Gumbel Tree which is another variant of One-Stage Tree that discretizes the architecture by the Gumbel Softmax in the outer minimization of Eq. (9).

    The details of Gumbel Tree can be found in “Appendix 3”.

We perform the experiments on all classification and regression datasets used in RQ1 with the same experimental settings.

Fig. 6
figure 6

Performance comparison between One-Stage Tree and Joint Tree (Bar chart shows the comparison in the test set, and line chart shows the comparison in the training set)

Table 9 Statistics (i.e., the number of top-ranked datasets for each tree) on the performance of One-Stage Tree, Joint Tree and Gumbel Tree (0.5 means that two tree models tie on the dataset)

The performance comparison is presented in Fig. 6 and Table 9. In Fig. 6, the bar chart shows the performance on the test set to indicate the generalization of the trees, and the line chart shows the performance on the training set to represent the fit of the trees. The height difference between points and bars of the same color represents the performance gap of the tree between training and testing. For greater clarity, we count the number of top-ranked datasets for each tree in Table 9.

Without dividing the validation set to optimize \(\gamma\), Joint Tree obtains a better fit on the training set in all datasets. However, One-Stage Tree achieves a significant performance improvement over Joint Tree on the test set. Compared with Joint Tree, One-Stage Tree reduces the fit to the training set and greatly improves the generalization ability on the test set. The performance gap between training and testing is indeed reduced by proximal iterations.

Additionally, from Fig. 6 and Table 9, we can see that One-Stage Tree performs much better than Gumbel Tree. Figure 6 shows that Gumbel Tree has the worst performance on the training set on most datasets, which indicates that Gumbel Tree is not fully trained. The main reason why the underfitting problem occurs in Gumbel Tree is that the Gumbel Softmax prefers to sample different architectures in the early stage of training. Since the internal parameters \(\omega\) and the architecture parameters \(\gamma\) are optimized alternatively in Eq.  (9), sampling a significantly different architecture each time plays a negative impact on the one-step approximation of the optimal internal parameters \(\omega ^{*}\). In contrast, One-Stage Tree gradually optimizes the current architecture with a small difference to ensure the effectiveness.

Furthermore, as discussed in Sect. 4.2.2, we perform the Friedman test (Demšar, 2006) to compare One-Stage Tree with the two variants statistically. The Friedman statistic is 14.2121 and the corresponding p-value is 0.00082. Since the p-value is less than 0.05, we conclude that the trees lead to statistically significant differences. Next, according to the Nemenyi post-hoc test, we further conclude that One-Stage Tree is significantly different from other two variants for a confidence level of \(\alpha =0.05\).

4.5 Discussion of interpretability (RQ4)

Table 10 Factors for the interpretability of decision trees

The interpretability of hard trees helps to understand the mechanism of the tree model and explore the patterns of the dataset. In contrast to hard trees, One-Stage Tree cannot be transformed into several ’if-else’ rules based on features and thresholds. In One-Stage Tree, different instances may be routed based on different features at the same node. Thus, unlike hard trees, it is difficult to visualize One-Stage Tree. In Table 10, we summarize the factors for the interpretability of decision trees. Taking a tree of depth 6 trained on the dataset PimaIndian as an example, we discuss the interpretability of One-Stage Tree.

4.5.1 Routing rule

Fig. 7
figure 7

Feature contribution of 3 different instances in Node\(_0\)

Hard trees use a greedy algorithm to select a split feature at each node with a threshold. It can be seen that a hard tree learns rules from the dataset to route instances to different leaves for prediction. Such ’if-else’-based routing rules are dataset-wise, i.e., each node routes different instances based on the same feature in the inference phase. Thus, they are easy to be understood.

In One-Stage Tree, the routing rules are instance-wise. At the same node, the features contribute differently to the routing of different instances. In the inference phase, the path with the highest probability is chosen, which is equivalent to:

$$\begin{aligned} s(\varvec{x}; \omega _i) = {\left\{ \begin{array}{ll} {[}1, 0]^T, &{} \text {if } (\omega _{i, 0}- \omega _{i, 1})^T \varvec{x} > 0\\ {[}0, 1]^T, &{} \text {o.w.} \end{array}\right. } \end{aligned}$$
(18)

where \(\omega _i \in \mathbb {R}^{(d + 1) \times 2}\) and \(\omega _{i, 0}^T \varvec{x}\) means the logit that \(\varvec{x}\) is routed to the left node. Thus, \(|(\omega _{i, 0, j} - \omega _{i, 1, j}) \cdot x_j |\) can be viewed as the contribution of feature j at node i for the instance \(\varvec{x}\). We visualize the contribution of each feature to the node router through a pie chart to obtain the instance-wise routing rules. As shown in Fig. 7a, b, different features (i.e., BMI and Age) make the major contribution to the two instances that are both routed to the right node. Meanwhile, in Fig. 7a, c, two instances are routed to different children mainly by the same feature BMI.

Since the traditional soft trees (e.g., Soft Decision Tree and Budding Tree) are probabilistic, the instances are not directly routed to child nodes, which is difficult to interpret. Compared to probabilistic trees, One-Stage Tree directly routes an instance to a child node rather than weighting it by probabilities. Thus, One-Stage Tree is more interpretable.

4.5.2 Feature importance

Fig. 8
figure 8

Feature importance of CART and One-Stage Tree in the dataset PimaIndian

Feature importance is an important way to explore data patterns. For hard trees, feature importance is calculated as the decrease in node impurity weighted by the probability of reaching that node. One-Stage Tree also provides the feature importance to explore the dataset. Due to the same discrete architecture of hard trees, One-Stage Tree calculates feature importance in the same way but weights the decreased impurity by \(|\omega _{i, 0, j} - \omega _{i, 1, j}|\) for each feature j at node i.

In Fig. 8, we show the feature importance of One-Stage Tree and CART in the dataset PimaIndian. We observe that two features, i.e., Glucose and BMI, play an important role in both trees.

4.5.3 Node instance distribution, node impurity, and predicted value

Fig. 9
figure 9

Instance distribution in leaves of One-Stage Tree trained on the dataset PimaIndian

Benefiting from the discrete architecture, One-Stage Tree deterministically routes each instance to a leaf for prediction. As a result, the instance distribution and node impurity can be computed in the same way as for the hard tree. In Fig. 9, we show the instance distribution in two leaves of One-Stage Tree trained on the dataset PimaIndian. Furthermore, we can calculate the node impurity and other criteria according to the instance distribution.

4.6 Discussion

Tabular data is generally dominated by tree models. One-Stage Tree can be seen as an attempt at deep learning on tabular data. Although One-Stage Tree inherits the advantages of the decision tree, it has the disadvantages of deep learning. For example, due to the use of gradient descent for optimization, the efficiency of constructing One-Stage Tree needs to be improved. Moreover, One-Stage Tree is a single decision tree. To achieve better performance, we need to further construct a tree ensemble model based on One-Stage Tree. Due to the continuous nature, we can perform joint tuning on all trees.

5 Conclusion and future work

In this work, we proposed One-Stage Tree, which retains the advantages of traditional decision trees as the inference model and improving learning through end-to-end training with back-propagation. Based on the continuous relaxation of soft trees, One-Stage Tree optimizes the node and architecture parameters jointly through a bilevel optimization problem. Moreover, One-Stage Tree leverages the reparameterization trick and proximal iterations to keep the tree discrete when training the continuous parameters. As a benefit, the performance gap between training and testing is reduced and the interpretability is maintained. The experimental results show that One-Stage Tree is effective on both classification and regression tasks and can outperform CART and the existing soft trees.

In the future, we plan to improve the efficiency of soft trees on GPU by parallelizing sequential decisions. Additionally, using ensemble methods such as bagging and boosting to build a forest of One-Stage Tree is also an important future work.