1 Introduction

Exposure can directly or indirectly influence the outcome of interest. The indirect effects can be mediated by intermediate variables commonly referred to as mediators. Mediation analysis has garnered significant attention in various fields, including biomedical research (Sun et al. 2021; Zhou and Song 2021), epidemiology (VanderWeele and Vansteelandt 2014), and social-psychological studies (Rucker et al. 2011). Typically, mediation analysis employs the counterfactual framework, also known as the potential outcome framework or Rubin’s model (Imai et al. 2010; Rubin, 2005), to identify direct and indirect effects. This study focuses on the investigation of causal mediation analysis within this framework.

There has been a growing interest in causal mediation analysis focusing on average causal effects in recent years. However, with the increasing importance of personalized medicine, there is a need to move beyond average causal effects and explore the estimation of conditional average causal effects or individualized causal effects (ICEs). The identification of potential heterogeneity in mediation effects has been a longstanding focus in psychology, often referred to as moderated mediation analysis (Hayes 2015). Within this framework, a covariate capable of influencing an outcome directly or indirectly is designated as a moderator. These methodologies entail integrating first-order interaction terms between an anticipated moderator and the exposure and mediator in the linear regression equations for the mediator and outcome, respectively. This integration facilitates the articulation of conditional direct and indirect effects as closed-form functions of regression coefficients and specific covariate levels. However, the traditional approaches outlined above encounter two primary constraints. First, as detailed in Sect. 7, these techniques hinge on distinct regressions for the mediator and outcome variables. They integrate the anticipated expected value of the mediator into the outcome regression model via predicted regression coefficients, a method that falls short in meeting the prerequisites for identification. Essentially, these methodologies may not furnish dependable estimates for direct and indirect effects. Second, they preassume specific model structures, which might not effectively encapsulate the intricacies of real-world systems. Interconnections among variables in real-world scenarios can be convoluted and nonlinear, and assuming a preconceived model structure could impede our capacity to comprehensively grasp and interpret the data. While recent research endeavors have put forth some remedies for analyzing heterogeneous mediations (Dyachenko and Allenby 2018; Hong et al. 2015; Park and Kaplan, 2015; Qin and Hong 2017; Rosenbaum 1987; Xue et al. 2022), they are not without their limitations. For instance, Dyachenko and Allenby (2018) presented a Bayesian mixture model that amalgamates likelihood functions derived from two distinct outcome models for ICE estimation. Nonetheless, their methodology necessitated a predetermined number of subgroups. It is imperative to emphasize that the methodologies mentioned above have predominantly centered on binary treatment, in many practical applications, however, the treatment variable is continuous. For example, when evaluating the impact of non-labor income on labor supply, the causal effect may depend on the introduction of non-labor income and the total amount of non-labor income. Hence, extending the existing methods in causal mediation analysis to handle continuous treatment variables is essential. Such an extension enables us to gain a more comprehensive understanding of the nuanced causal relationships and their variations based on different levels of the continuous treatment variable.

Estimating causal effects involving continuous treatment variables has received recent attention. In their pioneering work, Huber et al. (2020) identified and estimated the natural direct and indirect effects in the presence of a continuous treatment variable. They developed estimators using the estimated marginal density of the treatment and the inverses of two estimated conditional densities: the conditional density of the treatment given the confounder and the conditional density of the treatment given both the confounder and the mediator. Building upon this research, Huang et al. (2024) proposed a comprehensive framework for estimating direct and indirect effects, which unified the treatment variable types, including binary, multi-valued, continuous treatments, and mixed discrete and continuous treatments. They employed the generalized empirical likelihood (GEL) method to estimate the weights in their approach. However, the research outlined above operated under the assumption of uniformity among individuals and predominantly emphasized average causal mediation effects. It is a common observation that individuals with diverse characteristics can exhibit differing effects, even under identical treatment conditions. The necessity for customized approaches in analyzing heterogeneous mediations involving continuous treatment is increasingly urgent. This underscores the importance and relevance of developing advanced methods to fill the above research gap.

Recently, machine learning-based techniques have been introduced to estimate individual treatment effects (ITEs) due to their high ability to capture complex nonlinear relationships without assuming specific model forms. In particular, Yoon et al. (2018) proposed GANITE, a conditional generative adversarial network (CGAN)-based deep learning framework, to estimate ITEs for discrete interventions. Bica et al. (2020) extended GANITE and introduced SCIGAN to evaluate the effects of continuous-value interventions using GANs. While SCIGAN has demonstrated convincing efficacy in identifying ITEs for continuous interventions, it does not account for mediators nor investigate convergence properties for their method, thereby failing to provide available tools and theoretical guarantees for mediation analysis in specific research contexts. To identify the causal effects, as demonstrated in Eq. (4), it is essential to employ a sampling-based technique. This involves drawing multiple values from the estimated probability distribution of the potential mediator to gauge potential outcomes. This approach enables the consideration of a broader spectrum of possibilities and uncertainties during the estimation process, facilitating a more exhaustive analysis of the causal mediation effects. Given that GANs are adept at capturing intricate nonlinear relationships without assuming specific model structures and can facilitate sampling from probability distributions, our objective is to develop a novel approach, termed CGAN-based individualized causal mediation analysis under continuous-value treatment (CGAN-ICMA-CT), to fill this research gap.

The proposed CGAN-ICMA-CT framework consists of two fundamental layers: mediator and outcome layers. Each layer comprises two subblocks, counterfactual and inferential blocks. By formulating a CGAN in each block, the proposed method enables precise estimation of causal effects at the individual level while accounting for mediators. Moreover, we adapt the theoretical framework of Zhou et al. (2022) to mediation analysis. We establish distribution matching estimation through the Kullback–Liebler divergence and prove the convergence of CGAN-ICMA-CT. Our proposed methodology offers several key advantages. It adeptly captures intricate nonlinear relationships without the need for a parametric framework, a crucial aspect for accurately modeling real-world systems. By adopting a sampling-based approach, wherein multiple values are drawn from the estimated probability distribution of the potential mediator to estimate potential outcomes, our method can encompass a broader array of possibilities and uncertainties in the estimation process. This inclusive approach results in a more thorough analysis of the causal mediation effects. Furthermore, the integration of distribution matching estimation and the convergence theory establishes a robust theoretical underpinning for our method. The convergence result is vital as it ensures that our approach effectively captures the underlying conditional distribution and facilitates dependable conditional sampling. This foundation not only validates the application of our method but also sets the stage for its advancement in individualized causal mediation analysis. As far as we know, we are the first to deal with individualized causal mediation analysis with continuous treatment, potentially contributing to this study area.

The article is structured as follows. We introduce the motivating Job Corps study and contribution in Sect. 2. Section 3 briefly reviews ICEs with continuous treatment and describes common assumptions in mediation analysis and the problem formulation. Section 4 presents the proposed CGAN-ICMA-CT. Section 5 investigates the convergence of CGAN-ICMA-CT, and Sect. 6 outlines its architecture and implementation. Section 7 compares it with several other approaches through simulation studies. In Sect. 8, we apply the proposed method to the Job Corps dataset to examine the ICEs of spending time in the Job Corps program on the number of arrests. Section 9 concludes. All technical details and additional numerical results are provided in the Supplementary Material.

2 Motivation and contribution

Our study is motivated by analyzing the publicly funded Job Corps dataset, which targets economically disadvantaged young individuals aged 16–24. Previous research (Frölich and Huber 2017; Huber 2014; Schochet et al. 2008) has found that participating in Job Corps programs reduces criminal activity. Building upon these findings, Huber et al. (2020) and Huang et al. (2024) investigated the impact of spending time in Job Corps programs on the number of arrests. However, their studies only estimated average causal effects and were not equipped to handle situations where the causal effects vary across observable characteristics, such as gender. Ignoring heterogeneity and relying on a homogeneous/population-level model can lead to biased results. We aim to propose a novel method to address this issue, enabling the estimation of ICEs and uncovering the underlying causal mechanism behind the impact of Job Corps program participation on the number of arrests.

We propose CGAN-ICMA-CT, prove its convergence, and apply it to estimate the ICEs of the Job Corps program on the number of arrests. We conduct subgroup analysis by considering various covariate-specific groups to examine group causal effects. For example, when performing subgroup analysis for numeric covariates using times in prison as an example, we observe that within the examined range of treatment intensities, the direct effects of the Job Corps program on the number of arrests become more pronounced with increasing times in prison up to a certain threshold (around three times) and then become less pronounced, and the indirect effect of the Job Corps program through employment is minimal and close to zero initially but appears apparent when times in prison reach approximately 4–5. These findings are presented in Fig. 4 of Sect. 8 and highlight the significance of considering the individual’s criminal history and specific challenges they face when evaluating the program’s effectiveness. Further discussion on this topic can be found in Sect. 8.

Moreover, we estimate the average causal effects of the Job Corps program on the number of arrests, and our results align with those obtained using the series regression estimator in Huang et al. (2024). However, as previously noted, Huang et al. (2024) did not explore potential heterogeneity in their analysis, our approach allows for a highly nuanced analysis by exploring the heterogeneity of causal effects across different subgroups, thereby providing valuable insights into how these effects vary across different observable characteristics.

3 Problem formulation

This section presents a comprehensive overview of the standard causal mediation analysis with continuous treatment (Hirano and Imbens 2004; Huang et al. 2024; Huber et al. 2020). We also propose a set of sequential ignorability assumptions and provide a precise problem formulation to enhance understanding and context.

3.1 Preliminary

Let \({\textbf {X}}=(X_1,\ldots ,X_{d_x})\in \mathcal {X}\subset \mathbb {R}^{d_x}\) be a random vector of pre-treatment covariates, \(T\in \mathcal {T}\subset \mathbb {R}\) be a continuous treatment indicator in a subset \(\mathcal {T}\), \(M\in \mathcal {M}\subset \mathbb {R}\) be another random variable, representing the mediator, and \(Y\in \mathcal {Y}\subset \mathbb {R}\) be the outcome. Without loss of generality, we assume that \(\mathcal {T}=[0,1]\).

Assume that \(({\textbf {X}},T,M,Y)\sim P_{{\textbf {X}},T,M,Y}\) with marginal distributions such as \(({\textbf {X}},T,M)\sim P_{{\textbf {X}},T,M}\), \(({\textbf {X}},T)\sim P_{{\textbf {X}},T}\) and so forth. Furthermore, we have individual distribution functions for each variable, such as \({\textbf {X}}\sim P_{{\textbf {X}}}\), \(T\sim P_{T}\), \(M\sim P_{M}\), \(Y\sim P_{Y}\), etc. We denote the conditional distribution of Y given \(({\textbf {X}},T,M)\) as \(P_{Y|{\textbf {X}},T,M}\), and similar notations are used for other conditional distributions. Let \(p_{{\textbf {X}},T,M,Y}\) be the density function of the distribution \(P_{{\textbf {X}},T,M,Y}\), and similar notations are used for other distributions.

Let M(t) be a potential mediating variable that represents the value of the mediator if the treatment variable is equals to \(t\in \mathcal {T}\) and let Y(tm) be the potential outcome if one receives treatment \(t\in \mathcal {T}\) and mediator \(m\in \mathcal {M}\). The factual (observed) mediator and the factual (observed) outcome are denoted by \(M=M(T)\) and \(Y=Y(T,M(T))\) respectively, where T is the factual treatment. In addition, we make the consistency assumption throughout: for any individual \({\textbf {X}}={\textbf {x}}\), the potential mediator M(t) is equal to the observed mediator \(M=M(T)\) if the individual \({\textbf {X}}={\textbf {x}}\) happened to receive treatment level \(T=t\); and so do the potential outcome. Denote \(M({\textbf {x}},t):=M(t)|{\textbf {X}}={\textbf {x}}\) and \(Y({\textbf {x}},t,m):=Y(t,m)|{\textbf {X}}={\textbf {x}}\). To emphasize the individualized effects, denote \(M_t({\textbf {x}}):=M({\textbf {x}},t)\) for \(t\in \mathcal {T}\), \(Y_{t}({\textbf {x}},m):=Y({\textbf {x}},t,m)\), and \(Y_{tt'}({\textbf {x}}):=Y({\textbf {x}},t,M_{t'}({\textbf {x}}))\), for \(t,t'\in \mathcal {T}\). For any given sample \({\textbf {X}}={\textbf {x}}\), we aim to obtain the distribution of \(M_t({\textbf {x}})\), then the value of \(\mathbb {E}[Y_{tt'}({\textbf {x}})]\) for \(t,t'\in \mathcal {T}\), and finally, the ICEs defined in Sect. 3.3. Let \({\textbf {Z}}\), \({\widehat{{\textbf {Z}}}}\), \(\widetilde{{\textbf {Z}}}\) and \(\overline{{\textbf {Z}}}\) be random vectors independent of \({\textbf {X}}\), T, M, Y and each other, with a known distribution \(P_{\textbf {Z}}\), \(P_{\widehat{{\textbf {Z}}}}\), \(P_{\widetilde{{\textbf {Z}}}}\) and \(P_{\overline{{\textbf {Z}}}}\), respectively. For example, we can take \(P_{\textbf {Z}}\) as the standard multivariate normal distribution \(N(0,\mathbb {I}_{d_z})\) for a given \(d_z\ge 1\).

3.2 General assumptions

We now introduce several standard assumptions (Hirano and Imbens 2004; Huang et al. 2024; Huber et al. 2020; Imai and Van Dyk 2004) to identify causal effects.

Assumption

(Sequential Ignorability): (I) ;

(II) , where the conditional probability functions (in the sense as \(p_{T}\)) satisfy, \(p_{T|{\textbf {X}}}(t|{\textbf {x}})>0\) and \(p_{T|M,{\textbf {X}}}(t|m,{\textbf {x}})>0\) for all \((t, m, {\textbf {x}})\in \mathcal {T}\times \mathcal {M}\times \mathcal {X}\).

3.3 Problems

For dataset \(S_n:=\{{\textbf {X}}={\textbf {x}}_i,T=t_i,M=m_i, Y= y_i\}_{i=1}^{n}\), we aim to approximate the expectation of the potential outcomes for a given covariates \({\textbf {X}}={\textbf {x}}\) with different treatments, that is, \({\textbf {x}}\mapsto \mathbb {E}[Y_{tt'}({\textbf {x}})]\), for \(t,t'\in \mathcal {T}\), i.e., \(\mathbb {E}[Y({\textbf {x}}, t, M_{t'}({\textbf {x}}))]\). Then, ICEs with the continuous treatment can be approximated by comparing the expected potential outcomes under different treatment and mediator combinations.

Now, we introduce ICEs (Huang et al. 2024; Huber et al. 2020) with a continuous treatment. The individualized natural direct effect (NDE) and natural indirect effect (NIE) under t versus \(t'\) are

$$\begin{aligned} \text{ NDE }=\theta _{t,t'}({\widetilde{t}}; {\textbf {x}})&=\mathbb {E}[Y({\textbf {x}},t, M_{{\widetilde{t}}}({\textbf {x}}))]\nonumber \\&\quad -\mathbb {E}[Y({\textbf {x}}, t', M_{{\widetilde{t}}}({\textbf {x}}))], \ \ {\widetilde{t}} =t' \ \textrm{or} \ t, \end{aligned}$$
(1)
$$\begin{aligned} \text{ NIE }=\delta _{t,t'}({\widetilde{t}}; {\textbf {x}})&=\mathbb {E}[Y({\textbf {x}},{\widetilde{t}}, M_{t}({\textbf {x}}))]\nonumber \\&\quad -\mathbb {E}[Y({\textbf {x}}, {\widetilde{t}}, M_{t'}({\textbf {x}}))], \ \ {\widetilde{t}} =t' \ \textrm{or} \ t, \end{aligned}$$
(2)

for \(t\ne t'\). Then, the individualized total effect (TE) under t versus \(t'\) can be decomposed as

$$\begin{aligned} \tau _{t,t'}({\textbf {x}})&= \mathbb {E}[Y({\textbf {x}},t, M_{t}({\textbf {x}}))]-\mathbb {E}[Y({\textbf {x}}, t', M_{t'}({\textbf {x}}))] \nonumber \\&= \theta _{t,t'}(t'; {\textbf {x}})+\delta _{t,t'}(t; {\textbf {x}}) = \theta _{t,t'}(t; {\textbf {x}})+\delta _{t,t'}(t'; {\textbf {x}}), \end{aligned}$$
(3)

for \(t\ne t'\). For example, in the Job Corps data analysis in Sect. 8, the individualized NDE or NIE denotes the effect of the time spent in Job Corps on criminal activity without through or through employment for a unit, respectively. The individualized TE represents the total effects of the time spent in the Job Corps on criminal activity.

To demonstrate the identifiability of the individualized total and decomposed mediation effects, we examine the identification of the relevant potential outcome, \(\mathbb {E}[Y({\textbf {x}}, t', M_t({\textbf {x}}))]\), as follows. By the sequential ignorability,

$$\begin{aligned} \begin{aligned}&\mathbb {E}[Y({\textbf {x}}, t', M_t({\textbf {x}}))]\\&\quad =\int \mathbb {E}[Y | {\textbf {X}}= {\textbf {x}}, T = t', M = m]dP_{M({\textbf {x}}, t)}(m).\\ \end{aligned} \end{aligned}$$
(4)

Thus, ICEs can be identified through the expected potential outcomes as long as the distribution of the potential mediator, \(P_{M({\textbf {x}}, t)}\), can be estimated from the observed data. Next, we present the problem formulation.

Problem in M Layer: Find a deterministic function, called inference function for mediator M, \(I_{{\textbf {M}}}:({\widehat{{\textbf {z}}}},{\textbf {x}},t)\in \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {T}\mapsto I_{{\textbf {M}}}({\widehat{{\textbf {z}}}},{\textbf {x}},t)\in \mathcal {M}\), such that for a target t,

$$\begin{aligned}&I_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {x}},t) \sim P_{M|{\textbf {X}}={\textbf {x}},T=t},\ {\textbf {x}}\in \mathcal {X}. \end{aligned}$$
(5)

Thus, for \({\textbf {x}}\in \mathcal {X}\), to sample from the distribution \(P_{M|{\textbf {X}}={\textbf {x}},T=t}\), we can first sample a \({\widehat{{\textbf {z}}}}\sim P_{{\widehat{{\textbf {Z}}}}}\) and then calculate \(I_{{\textbf {M}}}({\widehat{{\textbf {z}}}},{\textbf {x}},t)\). The resulting value \(I_{{\textbf {M}}}({\widehat{{\textbf {z}}}},{\textbf {x}},t)\) is a sample from \(P_{M|{\textbf {X}}={\textbf {x}},T=t}\).

Problem in Y Layer: Find a deterministic function, called inference function for the outcome Y, \(I_{\textbf {Y}}:(\overline{{\textbf {z}}},{\textbf {x}},t,m)\in \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {T}\times \mathcal {M}\mapsto I_{\textbf {Y}}(\overline{{\textbf {z}}},{\textbf {x}},t,m)\in \mathcal {Y}\), such that for a target t,

$$\begin{aligned}&I_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {x}},t,m) \sim P_{Y|{\textbf {X}}={\textbf {x}},T=t,M=m},\ {\textbf {x}}\in \mathcal {X},\,m\in \mathcal {M}. \end{aligned}$$
(6)

Once the aforementioned problems are resolved, we can calculate \(\mathbb {E}[Y_{tt'}({\textbf {x}})]\) as follows:

$$\begin{aligned} \mathbb {E}[Y_{tt'}({\textbf {x}})] = \mathbb {E}_{\widehat{{\textbf {Z}}} \sim P_{\widehat{{\textbf {Z}}}}, \overline{{\textbf {Z}}} \sim P_{\overline{{\textbf {Z}}}}}[I_{{\textbf {Y}}}(\overline{{\textbf {Z}}}, {\textbf {x}},t, I_{{\textbf {M}}}(\widehat{{\textbf {Z}}}, {\textbf {x}},t'))], \end{aligned}$$

for all \({\textbf {x}}\in \mathcal {X}\). Subsequently, we can utilize these estimates to identify the individualized causal effects (ICEs). We can use Monte Carlo approximations to evaluate the integral (4). Specifically, we sample \({\widehat{n}}\) samples from \({\widehat{{\textbf {Z}}}}\sim P_{{\widehat{{\textbf {Z}}}}}\) and \(\overline{n}\) samples from \(\overline{{\textbf {Z}}}\sim P_{\overline{{\textbf {Z}}}}\), denoted as \({\widehat{{\textbf {z}}}}_1, {\widehat{{\textbf {z}}}}_2, \ldots , {\widehat{{\textbf {z}}}}_{{\widehat{n}}}\) and \(\overline{{\textbf {z}}}_1, \overline{{\textbf {z}}}_2, \ldots , \overline{{\textbf {z}}}_{\overline{n}}\). Then, NDE and NIE in (1) and (2) can be estimated by

$$\begin{aligned} \theta _{t,t'}({\widetilde{t}}; {\textbf {x}})&\approx \frac{1}{\overline{n}\times {\widehat{n}}}\left( \sum _{j=1}^{\overline{n}}\sum _{i=1}^{{\widehat{n}}}I_{\textbf {Y}}(\overline{{\textbf {z}}}_j,{\textbf {x}},t,I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}},{\widetilde{t}})) \right. \nonumber \\&\quad \left. - \sum _{j=1}^{\overline{n}}\sum _{i=1}^{{\widehat{n}}}I_{\textbf {Y}}(\overline{{\textbf {z}}}_j,{\textbf {x}},t',I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}},{\widetilde{t}})) \right) , \ {\widetilde{t}} =t' \ \textrm{or} \ t, \end{aligned}$$
(7)
$$\begin{aligned} \delta _{t,t'}({\widetilde{t}}; {\textbf {x}})&\approx \frac{1}{\overline{n}\times \widehat{n}}\left( \sum _{j=1}^{\overline{n}}\sum _{i=1}^{{\widehat{n}}}I_{\textbf {Y}}(\overline{{\textbf {z}}}_j,{\textbf {x}},\widetilde{t},I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}},t)) \right. \nonumber \\&\quad \left. - \sum _{j=1}^{\overline{n}}\sum _{i=1}^{\widehat{n}}I_{\textbf {Y}}(\overline{{\textbf {z}}}_j,{\textbf {x}},{\widetilde{t}},I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}},t')) \right) , \ {\widetilde{t}} =t' \ \textrm{or} \ t, \end{aligned}$$
(8)

and TE in (3) can be estimated by

$$\begin{aligned} \tau _{t,t'}({\textbf {x}})&\approx \frac{1}{\overline{n}\times \widehat{n}}\left( \sum _{j=1}^{\overline{n}}\sum _{i=1}^{{\widehat{n}}}I_{\textbf {Y}}(\overline{{\textbf {z}}}_j,{\textbf {x}},t,I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}},t)) \right. \nonumber \\&\quad \left. - \sum _{j=1}^{\overline{n}}\sum _{i=1}^{\widehat{n}}I_{\textbf {Y}}(\overline{{\textbf {z}}}_j,{\textbf {x}},t',I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}},t')) \right) . \end{aligned}$$
(9)

4 CGAN-ICMA-CT

Yoon et al. (2018) introduced a CGAN-based deep learning approach named GANITE to generate unobserved counterfactual outcomes for discrete interventions. Bica et al. (2020) further modified the GANITE framework and proposed SCIGAN to tackle the more complex problem of estimating the effects of continuous-valued interventions. Their essential idea is to use a heavily adapted GAN model to acquire the ability to generate counterfactual outcomes, which are subsequently employed to train an inference model through conventional supervised methods, enabling the estimation of such counterfactuals for new samples. To effectively tackle the complexities associated with transitioning to continuous interventions, they construct a hierarchical discriminator that capitalizes on the inherent structure of the continuous-valued intervention framework. In developing our model, we draw inspiration from SCIGAN and adapt the loss function and network structure. Figure 1 provides a concise overview of the architecture of our proposed CGAN-ICMA-CT.

Fig. 1
figure 1

Architecture of CGAN-ICMA-CT (\(\widehat{{\textbf {m}}}\) is sampled from \({\textbf {G}}_{{\textbf {M}}}^{\widehat{\varvec{\theta }}}\) after \({\textbf {G}}_{{\textbf {M}}}^{\widehat{\varvec{\theta }}}\) has been fully trained and \(\widehat{{\textbf {y}}}\) is sampled from \({\textbf {G}}^{\widehat{\varvec{\zeta }}}_{\textbf {Y}}\) after \({\textbf {G}}^{\widehat{\varvec{\zeta }}}_{\textbf {Y}}\) has been fully trained). \({\textbf {G}}_{{\textbf {M}}}^{{\varvec{\theta }}}\), \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\), \({\textbf {G}}^{{\varvec{\zeta }}}_{\textbf {Y}}\), \({\textbf {D}}^{\varvec{\xi }}_{\textbf {Y}}\), \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^{\varvec{\omega }}\), and \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^{\varvec{\lambda }}\) are only operating during training, whereas \({\textbf {I}}_{{\textbf {M}}}^{\varvec{\psi }}\) and \({\textbf {I}}_{{\textbf {Y}}}^{\varvec{\varphi }}\) and operate both during training and at run-time

Our model consists of two layers: mediator and outcome layers. The mediator layer generates potential mediators for new samples, while the outcome layer generates potential outcomes for new samples. The inferential generator in the mediator layer is designed to achieve the objective stated in Eq. (5), generating samples that conform to the target distribution described therein. Similarly, the inferential generator in the outcome layer aims to fulfill the objective stated in Eq. (6), generating samples that adhere to the specified target distribution. Each layer consists of two subblocks: counterfactual and inferential blocks, each comprising a generator and a discriminator. The counterfactual mediator block intends to generate counterfactual mediators, which are subsequently used to train an inference model for estimating potential mediators in new samples. The outcome layer operates similarly. Once the model is trained, we can predict potential outcomes using only the inferential mediator block and the inferential outcome block based on the given covariates and a target treatment level. The inferential mediator block predicts the mediator using the given covariates and the target treatment level. Then, we can obtain potential outcomes by utilizing this predicted mediator and given covariates and the target treatment level as inputs for the inferential outcome block, allowing for estimating the ICEs of interest.

4.1 Counterfactual block in mediator layer

In this block, we first introduce the generator, \({\textbf {G}}_{{\textbf {M}}}: \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {M}\times \mathcal {T}\mapsto \mathcal {M}^\mathcal {T}\), which takes the covariates \({\textbf {x}}\) (where \({\textbf {X}}={\textbf {x}}\)), the received treatment variable t (where \(T=t\)), the factual mediator m (where \(M=m\)), and some noise \({\textbf {Z}}\) as inputs. The output is a function from \(\mathcal {T}\) to \(\mathcal {M}\), i.e., \({\textbf {G}}_{{\textbf {M}}}({\textbf {Z}},{\textbf {x}},t,m)(\cdot ):\mathcal {T}\mapsto \mathcal {M}\), which is referred to as the treatment-mediator curve. We can write \({\textbf {G}}_{{\textbf {M}}}({\textbf {Z}},{\textbf {x}},t,m)(t')\) as our generated counterfactual mediator for the target treatment levels \(t'\), and denote the random variable induced by \({\textbf {G}}_{{\textbf {M}}}\) as \({\textbf {G}}_{{\textbf {M}}}({\textbf {Z}},{\textbf {X}},T,M)(t')\). The job of counterfactual generator \({\textbf {G}}_{{\textbf {M}}}\) is to generate mediator for the unobserved treatment, given that \({\textbf {X}}={\textbf {x}}\) (feature vector), \(T=t\) (observed treatment) and \(M=m\) (factual mediator).

Next, we define a discriminator, \({\textbf {D}}_{{\textbf {M}}}\), which will act on a random set of points from the generated treatment-mediator curve, to pick out the factual treatment from among the (random set of) generated ones. Let d be the number of treatment levels we aim to compare, and let \(\widetilde{T}=\{T_1,\ldots ,T_d\}\) be a random subset of \(\mathcal {T}\) of size d, which contains \(d-1\) random elements along with the factual treatment level T.Footnote 1 Define \(\widetilde{{\textbf {M}}}=( \widetilde{M}_1, \widetilde{M}_2, \ldots ,\widetilde{M}_d)\in \mathcal {M}^d\) to be the vector of mediators corresponding to \({\widetilde{T}}\), where

$$\begin{aligned} \widetilde{M}_i=\left\{ \begin{array}{ll} M, & \quad \textrm{if} \,\, T_i=T\ \text {(factual mediator)}\\ {\textbf {G}}_{{\textbf {M}}}({\textbf {Z}},{\textbf {X}},T,M)(T_i), & \quad \textrm{if} \,\, T_i\ne T\ \text {(counterfactual mediator)}. \end{array} \right. \nonumber \\ \end{aligned}$$
(10)

Then, the discriminator, \({\textbf {D}}_{{\textbf {M}}}: \mathcal {X}\times \mathcal {T}^d\times \mathcal {M}^d\mapsto [0,1]^d\), takes the covariates \({\textbf {x}}\) (where \({\textbf {X}}={\textbf {x}}\)), the generated mediators \(\widetilde{{\textbf {m}}}\) (where \(\widetilde{{\textbf {M}}}=\widetilde{{\textbf {m}}}\)), and the subset \({\widetilde{t}}=(t'_1,\ldots ,t'_d)\) (where \(\widetilde{T}=\widetilde{t}\)) as inputs. The ith component of the output represents the probability that the treatment level \(t'_i\) is the factual treatment, with \(t'_i\) being a member of the given realization of \(\widetilde{T}\). Let \({\textbf {D}}_{{\textbf {M}}}^j\) denote the output of \({\textbf {D}}_{{\textbf {M}}}\) corresponding to treatment level \(T_j\), then the loss function is

$$\begin{aligned} \mathcal {L}_{{\textbf {M}}}({{\textbf {G}}_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {M}}}})&:=\mathbb {E}\left\{ \sum _{j=1}^{d}\Big (\mathbb {I}_{\{T=T_j\}}\log {\textbf {D}}_{{\textbf {M}}}^j({\textbf {X}},{\widetilde{T}}, \widetilde{{\textbf {M}}})\right. \nonumber \\&\left. \quad + \mathbb {I}_{\{T\ne T_j\}}\log [1-{\textbf {D}}_{{\textbf {M}}}^j({\textbf {X}},{\widetilde{T}},\widetilde{{\textbf {M}}})]\Big )\right\} , \end{aligned}$$
(11)

where the expectation is taken with respect to \({\textbf {X}},\widetilde{T}, \widetilde{{\textbf {M}}}\) and T. Then, we solve the following minimax optimization problem: \( \min _{{\textbf {G}}_{{\textbf {M}}}}\max _{{\textbf {D}}_{{\textbf {M}}}} \mathcal {L}_{{\textbf {M}}}({{\textbf {G}}_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {M}}}}). \) At the population level, the target conditional generator and discriminator, \({\textbf {G}}_{{\textbf {M}}}^*\) and \({\textbf {D}}_{{\textbf {M}}}^*\), are obtained as solutions to the optimization problem:

$$\begin{aligned} ({\textbf {G}}_{{\textbf {M}}}^*, {\textbf {D}}_{{\textbf {M}}}^*)=\text {argmin}_{{\textbf {G}}_{{\textbf {M}}}}\text {argmax}_{{\textbf {D}}_{{\textbf {M}}}} \mathcal {L}_{{\textbf {M}}}({{\textbf {G}}_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {M}}}}). \end{aligned}$$
(12)

Regarding distribution matching, we show in Supplementary Material S1.1 that a function \({\textbf {G}}_{{\textbf {M}}}^*: \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {M}\times \mathcal {T}\mapsto \mathcal {M}^ \mathcal {T}\) is a minimizer of \(\mathbb {L}_{{\textbf {M}}}({\textbf {G}}_{{\textbf {M}}}):=\sup _{{\textbf {D}}_{{\textbf {M}}}}\mathcal {L}_{{\textbf {M}}}({{\textbf {G}}_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {M}}}})\), if and only if for any treatment level \(T_j\in \mathcal {T}\), the generated counterfactual mediator \(\widetilde{M}_j\) has the same (marginal) distribution (conditional on the features) as the true marginal distribution of that mediator across all samples.

For the dataset \(\{{\textbf {X}}={\textbf {x}}_i,T=t_i,M=m_i\}_{i=1}^{n}\), independently and identically distributed according to \(P_{{\textbf {X}},T,M}\), \(\{{\textbf {Z}}={\textbf {z}}_i\}_{i=1}^{n}\) generated from \(P_{\textbf {Z}}\), and the given realization of \(\widetilde{T}: \widetilde{t}_i=\{t'_{i1},\ldots ,t'_{id}\}\) for \(i=1,\ldots ,n\), we define the sample set \(S^{M}_n:=\{{\textbf {X}}={\textbf {x}}_i,T=t_i, \widetilde{T}=\widetilde{t}_i, M=m_i,{\textbf {Z}}={\textbf {z}}_i\}_{i=1}^{n}\), which is used to train the estimated conditional generator \({\widehat{{\textbf {G}}}}_{{\textbf {M}}}\) in the counterfactual block in the mediator layer. Additionally, we have a given realization of \(\widetilde{{\textbf {M}}}\) denoted as \(\widetilde{{\textbf {m}}}_i=(\widetilde{m}_{i1},\ldots , \widetilde{m}_{id})\) for \(i=1,\ldots ,n\), where

$$\begin{aligned} \widetilde{m}_{ij}=\left\{ \begin{array}{ll} m_i, & \quad \textrm{if} \ t'_{ij}=t_i\\ {\textbf {G}}_{{\textbf {M}}}({\textbf {z}}_i,{\textbf {x}}_i,t_i,m_i)(t'_{ij}), & \quad \textrm{if} \ t'_{ij}\ne t_i. \end{array} \right. \end{aligned}$$
(13)

Now, we consider the empirical version of \(\mathcal {L}_{{\textbf {M}}}({\textbf {G}}_{\textbf {M}},{\textbf {D}}_{{\textbf {M}}})\), which is defined as follows:

$$\begin{aligned} \widetilde{\mathcal {L}}_{{\textbf {M}}}({{\textbf {G}}_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {M}}}})&=\frac{1}{n}\sum _{i=1}^{n}\left\{ \sum _{j=1}^{d}\Big (\mathbb {I}_{\{t_i=t'_{ij}\}}\log {\textbf {D}}_{{\textbf {M}}}^j({\textbf {x}}_i,\widetilde{t}_i,\widetilde{{\textbf {m}}}_i)\right. \\&\quad \left. +\mathbb {I}_{\{t_i\ne t'_{ij}\}}\log [1-{\textbf {D}}_{{\textbf {M}}}^j({\textbf {x}}_i,\widetilde{t}_i,\widetilde{{\textbf {m}}}_i)]\Big )\right\} . \end{aligned}$$

By the consistency assumption, we also introduce a supervised loss to enforce the restriction that \({\textbf {G}}_{{\textbf {M}}}({\textbf {z}},{\textbf {x}},t,m)(t)=m\): \( \widetilde{\mathcal {L}}_1({\textbf {G}}_{\textbf {M}})=\frac{1}{n}\sum _{i=1}^{n}\big | {\textbf {G}}_{{\textbf {M}}}({\textbf {z}}_i,{\textbf {x}}_i,t_i,m_i)(t_i)-m_i\big |^2. \) Then, define the following empirical objective function, for a supervised parameter \(\alpha _1\ge 0\):

$$\begin{aligned} \widehat{\mathcal {L}}_{{\textbf {M}}}({\textbf {G}}_{\textbf {M}},{\textbf {D}}_{{\textbf {M}}}):=\widetilde{\mathcal {L}}_{{\textbf {M}}}({\textbf {G}}_{\textbf {M}},{\textbf {D}}_{{\textbf {M}}})+\alpha _1\widetilde{\mathcal {L}}_1({\textbf {G}}_{\textbf {M}}). \end{aligned}$$
(14)

We utilize two feedforward neural networks (FNN) (Goodfellow et al. 2016) to estimate \({\textbf {G}}_{{\textbf {M}}}\) based on the empirical objective function \(\widehat{\mathcal {L}}_{{\textbf {M}}}({{\textbf {G}}_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {M}}}})\). The conditional generator network is denoted as \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\), parameterized by \(\varvec{\theta }\), and the conditional discriminator network is denoted as \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\), parameterized by \(\varvec{\phi }\). For any function \(f({\textbf {x}}): \mathcal {X}\rightarrow \mathbb {R}^d\), denote \(\left\| f\right\| _{L^\infty }=\sup _{{\textbf {x}}\in \mathcal {X}}\left\| f({\textbf {x}})\right\| \), where \(\left\| \cdot \right\| \) is the Euclidean norm.

Let \(\mathcal {G} \equiv \mathcal {G}_{\mathcal {H}, \mathcal {W}, \mathcal {S}, \mathcal {B}}\) be the set of the Exponential Linear Unit (ELU) neural networks \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}: \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {T}\times \mathcal {M}\mapsto \mathcal {M}^\mathcal {T}\) parameterized by \(\varvec{\theta }\), with a depth \(\mathcal {H}\), width \(\mathcal {W}\), size \(\mathcal {S}\), and \(\left\| {\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\right\| _{L^\infty } \le \mathcal {B}\). Here, the depth \(\mathcal {H}\) refers to the number of hidden layers, resulting in a total of \(\mathcal {H}+1\) layers. The width \(\mathcal {W}\) represents the maximum width of the hidden layers, and \(\mathcal {S}\) corresponds to the total number of parameters in the network.

Similarly, let \(\mathcal {D} \equiv \mathcal {D}_{\widetilde{\mathcal {H}}, \widetilde{\mathcal {W}}, \widetilde{\mathcal {S}}, \widetilde{\mathcal {B}}}\) be the set of ELU neural networks \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}: \mathcal {X}\times \mathcal {T}^d \times \mathcal {M}^d\mapsto [0,1]^d\), parameterized by \(\varvec{\phi }\), with a depth \(\widetilde{\mathcal {H}}\), width \(\widetilde{\mathcal {W}}\), size \(\widetilde{\mathcal {S}}\), and \(\left\| {\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\right\| _{L^\infty } \le \widetilde{\mathcal {B}}\).

To estimate the parameters \(\varvec{\theta }\) and \(\varvec{\phi }\), we solve the following optimization problem:

$$\begin{aligned} (\widehat{\varvec{\theta }}, \widehat{\varvec{\phi }})=\text {argmin}_{\varvec{\theta }}\text {argmax}_{\varvec{\phi }}{\widehat{\mathcal {L}}_{{\textbf {M}}}({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }},{\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }})}, \end{aligned}$$
(15)

The resulting estimated conditional generator is denoted as \({\widehat{\textbf {G}}_{{\textbf {M}}}}={\textbf {G}}_{{\textbf {M}}}^{\widehat{\varvec{\theta }}}\), and the estimated discriminator is denoted as \({\widehat{\textbf {D}}_{{\textbf {M}}}}={\textbf {D}}_{{\textbf {M}}}^{\widehat{\varvec{\phi }}}\).

4.2 Inferential block in mediator layer

Once we have learned the counterfactual generator in mediator layer, we can only use it to access (generated) mediators for all samples in the dataset. To generate potential mediators for a new sample, we use the counterfactual generator along with the original data to train an inferential network, where the generator is \(I_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {X}},T'): \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {T}\mapsto \mathcal {M}\). In this case, the generator \(I_{{\textbf {M}}}\) generates counterfactual mediators solely based on the given covariates \({\textbf {X}}\) and the target (unobserved) treatment levels \(T'\), without incorporating factual mediator M or the received treatment T.

For the discriminator, similarly, d is the number of treatment levels we aim to compare, \(\widetilde{T}=\{T_1,\ldots ,T_d\}\) is a random subset of \(\mathcal {T}\) of size d, where it contains \(d-1\) random elements along with the factual treatment level T. Define \(\widetilde{{\textbf {I}}}_{{\textbf {M}}}=(I_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {X}},T_1),\ldots ,I_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {X}},T_d))\in \mathcal {M}^d\) to be the vector of mediators corresponding to the treatments \({\widetilde{T}}\). Then, the discriminator is \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}: \mathcal {X}\times \mathcal {T}^d\times \mathcal {M}^d \mapsto [0,1]^d\). Denote \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^j\) as the output of \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}\) corresponding to treatment level \(T_j\). We adapt the classical CGAN loss as follows:

$$\begin{aligned}&\mathcal {L}_{{\textbf {I}}{\textbf {M}}}({I_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}})=\mathbb {E}\Big \{\sum _{j=1}^{d}\log {\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^j({\textbf {X}},\widetilde{T},\widetilde{{\textbf {M}}})\Big \}\nonumber \\&\quad +\mathbb {E}_{{\textbf {I}}_{{\textbf {M}}}}\Big \{\sum _{j=1}^{d} \log [1-{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^j({\textbf {X}},{\widetilde{T}},\widetilde{{\textbf {I}}}_{{\textbf {M}}})]\Big \}, \end{aligned}$$
(16)

where the expectation \(\mathbb {E}\) is taken with respect to \({\textbf {X}},\widetilde{T}, \widetilde{{\textbf {M}}}\) and T, while \(\mathbb {E}_{{\textbf {I}}_{{\textbf {M}}}}\) is taken over \({\textbf {X}},\widetilde{T},\) and \(\widetilde{{\textbf {I}}}_{{\textbf {M}}}\). Here, \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^j({\textbf {X}}={\textbf {x}},{\widetilde{T}}={\widetilde{t}},\cdot ): \mathcal {M}^d \mapsto [0,1]\) represents the conditional probability that the jth component of input came from \({\widetilde{M}}_j\) rather than \(I_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {x}},t'_j)\), given \({\textbf {X}}={\textbf {x}}\) and \(\widetilde{T}=\widetilde{t}=(t'_1,\ldots ,t'_d)\).

At the population level, Define \((I_{{\textbf {M}}}^*, {\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^*)=\text {argmin}_{I_{{\textbf {M}}}} \text {argmax}_{{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}} \mathcal {L}_{{\textbf {I}}{\textbf {M}}}({I_{{\textbf {M}}}},{{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}})\) and \(\mathbb {L}_{{\textbf {I}}{\textbf {M}}}(I_{{\textbf {M}}})=\sup _{{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}}\mathcal {L}_{{\textbf {I}}{\textbf {M}}}({I_{{\textbf {M}}}},\) \({{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}})\), and let \(p_{{\widetilde{{\textbf {I}}}}_{{\textbf {M}}}}(\widetilde{{\textbf {I}}}_{{\textbf {M}}}|{\textbf {X}}={\textbf {x}})\) denote the joint density of mediators induced by \(I_{{\textbf {M}}}\) over the treatment level in \(\widetilde{T}\). Based on the standard theory of CGAN (Goodfellow et al. 2014; Mirza and Osindero 2014), a function \(I_{{\textbf {M}}}^*: \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {T}\mapsto \mathcal {M}\) is a minimizer of \(\mathbb {L}_{{\textbf {I}}{\textbf {M}}}(I_{{\textbf {M}}})\),

if and only if the mediator generated by \(I_{{\textbf {M}}}\) for any sample has the same (marginal) distribution (conditional on the features) as the true marginal distribution of that mediator; refer to Supplementary Material S1.2 for more details.

Suppose we have a sample set \(S^{M}_n\), a realization of \(\widetilde{T}\) denoted as \(\widetilde{t}_i=\{t'_{i1},\ldots ,t'_{id}\}\), and another set of independently generated samples \(\{{\widehat{{\textbf {Z}}}}={\widehat{{\textbf {z}}}}_i\}_{i=1}^{n}\) from the distribution \(P_{{\widehat{{\textbf {Z}}}}}\). Using the estimated conditional generator \({{\widehat{{\textbf {G}}}}_{{\textbf {M}}}}\) obtained in the previous section, we define a new sample set \(S^{IM}_n:=\{({\textbf {x}}_i,t_i, \widetilde{t}_i,\widehat{{\textbf {m}}}_i, {\widehat{{\textbf {z}}}}_i)\}_{i=1}^n\) to train the estimated conditional generator \({{\widehat{I}}_{{\textbf {M}}}}\) in the inferential block, where \(\widehat{{\textbf {m}}}_i=(\widehat{m}_{i1},\ldots ,\widehat{m}_{id})\) for \(i=1,\ldots ,n\) and

$$\begin{aligned} \widehat{m}_{ij}=\left\{ \begin{aligned}&m_i, \ \ \textrm{if}\ t'_{ij}=t_i \\&{\widehat{{\textbf {G}}}}_{{\textbf {M}}}({\textbf {z}}_i,{\textbf {x}}_i,t_i,m_i)(t'_{ij}), \ \ \textrm{if} \ t'_{ij}\ne t_i. \end{aligned} \right. \end{aligned}$$
(17)

Also, we have the realization of \(\widetilde{{\textbf {I}}}_{{\textbf {M}}}\), denoted as \(\widetilde{{\textbf {I}}}_{{\textbf {M}}}: \widetilde{{\textbf {I}}}_{{\textbf {M}}i}=(I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_i,t'_{i1}),\ldots ,I_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_i,t'_{id}))\) for \(i=1,\ldots ,n\). We consider the following empirical version of \(\mathcal {L}_{{\textbf {I}}{\textbf {M}}}(I_{\textbf {M}},{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}})\):

$$\begin{aligned}&\widetilde{\mathcal {L}}_{{\textbf {I}}{\textbf {M}}}(I_{\textbf {M}},{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}};\widehat{\textbf {G}}_{{\textbf {M}}})=\frac{1}{n}\sum _{i=1}^{n}\Big \{\sum _{j=1}^{d}\Big (\log {\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^j({\textbf {x}}_i,\widetilde{t}_i,\widehat{{\textbf {m}}}_i)\\&\quad + \log [1-{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^j({\textbf {x}}_i,\widetilde{t}_i,\widetilde{{\textbf {I}}}_{{\textbf {M}}i})]\Big )\Big \}. \end{aligned}$$

Then, we use a supervised loss to ensure that \(I_{{\textbf {M}}}({\widehat{{\textbf {z}}}},{\textbf {x}},t)=m\): \(\widetilde{\mathcal {L}}_2(I_{\textbf {M}})=\frac{1}{n}\sum _{i=1}^{n}\big |I_{{\textbf {M}}}(\widehat{\textbf {z}}_i,{\textbf {x}}_i,t_i)-m_i\big |^2\). Define the following empirical objective function, for a supervised parameter \(\alpha _2\ge 0\):

$$\begin{aligned} \widehat{\mathcal {L}}_{{\textbf {I}}{\textbf {M}}}(I_{\textbf {M}},{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}};\widehat{\textbf {G}}_{{\textbf {M}}}):=\widetilde{\mathcal {L}}_{{\textbf {I}}{\textbf {M}}}(I_{\textbf {M}},{\textbf {D}}_{{\textbf {I}}_{\textbf {M}}};\widehat{\textbf {G}}_{{\textbf {M}}})+\alpha _2\widetilde{\mathcal {L}}_2(I_{\textbf {M}}). \end{aligned}$$
(18)

Again, we use FNN to estimate \(I_{{\textbf {M}}}\) based on (18), which is similar to the the estimation of \(G_M\); see details in Supplementary Material S1.5. The conditional generator network \(I_{{\textbf {M}}}^{\varvec{\psi }}\) is parameterized by \(\varvec{\psi }\), and the conditional discriminator network \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^{\varvec{\omega }}\) is parameterized by \(\varvec{\omega }\), The estimated conditional generator is \({\widehat{I}_{{\textbf {M}}}}\), and the estimated discriminator is \({{\widehat{{\textbf {D}}}}_{{\textbf {I}}_{\textbf {M}}}}\).

4.3 Counterfactual block in outcome layer

Similar to Sect. 4.1, we denote the generator as \({\textbf {G}}_{{\textbf {Y}}}: \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {M}\times \mathcal {T}\times \mathcal {Y}\mapsto \mathcal {Y}^ \mathcal {T}\), taking the covariates \({\textbf {x}}\) (where \({\textbf {X}}={\textbf {x}}\)), the factual mediator m (where \(M=m\)), the received treatment variable t (where \(T=t\)), the factual outcome y (where \(Y=y\)), and some noise \({\widetilde{{\textbf {Z}}}}\) as inputs. The output is a function from \(\mathcal {T}\) to \(\mathcal {Y}\), i.e., \({\textbf {G}}_{{\textbf {Y}}}({\widetilde{{\textbf {Z}}}},{\textbf {x}},m,t,y)(\cdot ):\mathcal {T}\mapsto \mathcal {Y}\). We write \({\textbf {G}}_{{\textbf {Y}}}({\widetilde{{\textbf {Z}}}},{\textbf {x}},m,t,y)(t')\) as our generated counterfactual outcome for different treatment levels \(t'\) and denote \({\textbf {G}}_{{\textbf {Y}}}({\widetilde{{\textbf {Z}}}},{\textbf {X}},M,T,Y)(t')\) (the random variable induced by \({\textbf {G}}_{{\textbf {Y}}}\)).

The discriminator, denoted as \({\textbf {D}}_{{\textbf {Y}}}\), aims to identify the factual treatment among the set of generated treatments. Let \(\widetilde{T}=\{T_1,\ldots ,T_d\}\) be a random subset of \(\mathcal {T}\) with size d, where it consists of \(d-1\) randomly selected elements along with the factual treatment level T. Additionally, let \(\widetilde{{\textbf {Y}}}=( \widetilde{Y}_1,\ldots ,\widetilde{Y}_d)\in \mathcal {Y}^d\) denote the vector of outcomes corresponding to the treatments \({\widetilde{T}}\), where

$$\begin{aligned} \widetilde{Y}_i=\left\{ \begin{array}{ll} Y, & \quad \textrm{if} \ T_i=T\ \text {(factual outcome)}\\ {\textbf {G}}_{{\textbf {Y}}}({\widetilde{{\textbf {Z}}}},{\textbf {X}},M,T,Y)(T_i), & \quad \textrm{if}\ T_i\ne T\ \text {(counterfactual outcome)}. \end{array} \right. \nonumber \\ \end{aligned}$$
(19)

Then, the discriminator, \({\textbf {D}}_{{\textbf {Y}}}: \mathcal {X}\times \mathcal {M}\times \mathcal {T}^d\times \mathcal {Y}^d\mapsto [0,1]^d\), takes the covariates \({\textbf {x}}\) (where \({\textbf {X}}={\textbf {x}}\)), the factual mediator m (where \(M=m\)), the generated outcomes \(\widetilde{{\textbf {y}}}\) (where \(\widetilde{{\textbf {Y}}}=\widetilde{{\textbf {y}}}\)), and the subset \({\widetilde{t}}=(t'_1,\ldots ,t'_d)\) (where \(\widetilde{T}=\widetilde{t}\)) as inputs. The ith component of the output represents the probability that the treatment level \(t'_i\) is the factual treatment, where \(t'_i\) belongs to the given realization of \(\widetilde{T}\). Let \({\textbf {D}}_{{\textbf {Y}}}^j\) denote the output of \({\textbf {D}}_{{\textbf {Y}}}\) corresponding to treatment level \(T_j\), then the loss function is

$$\begin{aligned} \mathcal {L}_{{\textbf {Y}}}({{\textbf {G}}_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {Y}}}})&:=\mathbb {E}\left\{ \sum _{j=1}^{d}\Big (\mathbb {I}_{\{T=T_j\}}\log {\textbf {D}}_{{\textbf {Y}}}^j({\textbf {X}},M,\widetilde{T},\widetilde{{\textbf {Y}}})\right. \\&\quad \left. + \mathbb {I}_{\{T\ne T_j\}}\log [1-{\textbf {D}}_{{\textbf {Y}}}^j({\textbf {X}},M,\widetilde{T},\widetilde{{\textbf {Y}}})]\Big )\right\} . \end{aligned}$$

At the population level, define \(({\textbf {G}}_{{\textbf {Y}}}^*, {\textbf {D}}_{{\textbf {Y}}}^*):=\text {argmin}_{{\textbf {G}}_{{\textbf {Y}}}}\) \(\text {argmax}_{{\textbf {D}}_{{\textbf {Y}}}} \mathcal {L}_{{\textbf {Y}}}({{\textbf {G}}_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {Y}}}}).\) Regarding distribution matching, we can show that, for any treatment level \(T_j\in \mathcal {T}\), the generated counterfactual outcome \(Y_j\) for any sample has the same (marginal) distribution (conditional on the features and the mediator) as the true marginal distribution of that outcome; refer to Supplementary Material S1.3 for more details.

For the dataset \(S_n=\{{\textbf {X}}={\textbf {x}}_i,T=t_i,M=m_i, Y= y_i\}_{i=1}^{n}\), \(\{{\widetilde{{\textbf {Z}}}}={\widetilde{{\textbf {z}}}}_i\}_{i=1}^{n}\) independently generated from \(P_{\widetilde{\textbf {Z}}}\), and the given realization of \(\widetilde{T}: \widetilde{t}_i=\{t'_{i1},\ldots ,t'_{id}\}\) for \(i=1,\ldots ,n\), we define the sample set \(S^{Y}_n:=\{{\textbf {X}}={\textbf {x}}_i,T=t_i,\widetilde{T}=\widetilde{t}_i,M=m_i, Y= y_i,{\widetilde{{\textbf {Z}}}}={\widetilde{{\textbf {z}}}}_i\}_{i=1}^{n}\), which is used to train the estimated conditional generator \({{\widehat{{\textbf {G}}}}_{{\textbf {Y}}}}\) in the counterfactual block in outcome layer. Also, we have a given realization of \(\widetilde{{\textbf {Y}}}\), denoted as \(\widetilde{{\textbf {y}}}_i=(\widetilde{y}_{i1},\ldots ,\widetilde{y}_{id})\) for \(i=1,\ldots ,n\), where

$$\begin{aligned} \widetilde{y}_{ij}=\left\{ \begin{array}{ll} y_i, & \quad \textrm{if} \ t'_{ij} =t_i \\ {\textbf {G}}_{{\textbf {Y}}}({\widetilde{{\textbf {z}}}}_i,{\textbf {x}}_i,m_i,t_i,y_i)(t'_{ij}), & \quad \textrm{if} \ t'_{ij} \ne t_i. \end{array} \right. \end{aligned}$$
(20)

Now, we consider the empirical version of \(\mathcal {L}_{{\textbf {Y}}}({{\textbf {G}}_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {Y}}}})\), which is defined as follows:

$$\begin{aligned}&\widetilde{\mathcal {L}}_{{\textbf {Y}}}({{\textbf {G}}_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {Y}}}})\\&\quad =\frac{1}{n}\sum _{i=1}^{n}\Big \{\sum _{j=1}^{d}\Big (\mathbb {I}_{\{t_i=t'_{ij}\}}\log {\textbf {D}}_{{\textbf {Y}}}^j({\textbf {x}}_i,m_i,\widetilde{t}_i,\widetilde{{\textbf {y}}}_i)\\&\qquad + \mathbb {I}_{\{t_i\ne t'_{ij}\}}\log [1-{\textbf {D}}_{{\textbf {Y}}}^j({\textbf {x}}_i,m_i,\widetilde{t}_i,\widetilde{{\textbf {y}}}_i)]\Big )\Big \}. \end{aligned}$$

We also introduce a supervised loss to ensure that \({\textbf {G}}_{{\textbf {Y}}}(\widetilde{\textbf {z}},{\textbf {x}},t,\) \( m,y)(t)=y\): \( \widetilde{\mathcal {L}}_3({\textbf {G}}_{\textbf {Y}})=\frac{1}{n}\sum _{i=1}^{n}\big | {\textbf {G}}_{{\textbf {Y}}}({\widetilde{{\textbf {z}}}}_i,{\textbf {x}}_i,m_i,t_i,y_i)(t_i)-y_i\big |^2. \) Then, we define the empirical objective function with a supervised parameter \(\alpha _3 \ge 0\):

$$\begin{aligned} \widehat{\mathcal {L}}_{{\textbf {Y}}}({\textbf {G}}_{\textbf {Y}},{\textbf {D}}_{{\textbf {Y}}}):=\widetilde{\mathcal {L}}_{{\textbf {Y}}}({\textbf {G}}_{\textbf {Y}},{\textbf {D}}_{{\textbf {Y}}})+\alpha _3\widetilde{\mathcal {L}}_3({\textbf {G}}_{\textbf {Y}}). \end{aligned}$$
(21)

We again use FNN to estimate \({\textbf {G}}_{\textbf {Y}}\) based on the empirical objective function (21). The details are presented in Supplementary Material S1.6. The conditional generator network \({\textbf {G}}^{\varvec{\zeta }}_{\textbf {Y}}\) is parameterized by \(\varvec{\zeta }\), and the conditional discriminator network \({\textbf {D}}^{\varvec{\xi }}_{\textbf {Y}}\) is parameterized by \(\varvec{\xi }\). The estimated conditional generator is \({{\widehat{{\textbf {G}}}}}_{\textbf {Y}}\) and the estimated discriminator is \({\widehat{{\textbf {D}}}}_{\textbf {Y}}\).

4.4 Inferential block in outcome layer

Similar to the approach described in Sect. 4.2 and using similar notation, we denote the generator as \(I_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {X}},M,T'): \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {M}\times \mathcal {T}\mapsto \mathcal {Y}\). In this case, the generator generates counterfactual outcomes solely based on the provided covariates, factual mediator, and the target treatment levels, without incorporating factual outcome or the received treatment. \(\widetilde{{\textbf {I}}}_{{\textbf {Y}}}=(I_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {X}},M,T_1),\ldots ,I_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {X}}, M,T_d)) \in \mathcal {Y}^d\) represents the vector of outcomes corresponding to the treatments \({\widetilde{T}}\). Then, the discriminator is \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}: \mathcal {X}\times \mathcal {M}\times \mathcal {T}^d\times \mathcal {Y}^d \mapsto [0,1]^d\). Using \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^j\) to denote the output of \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}\) corresponding to treatment level \(T_j\), we denote the classical CGAN loss for training by

$$\begin{aligned} \mathcal {L}_{{\textbf {I}}{\textbf {Y}}}({I_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}})&=\mathbb {E}\left\{ \sum _{j=1}^{d}\log {\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^j({\textbf {X}},M,\widetilde{T},\widetilde{{\textbf {Y}}})\right\} \\&\quad +\mathbb {E}_{{\textbf {I}}_{{\textbf {Y}}}}\left\{ \sum _{j=1}^{d}\bigg \{ \log [1-{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^j({\textbf {X}},M,{\widetilde{T}},\widetilde{{\textbf {I}}}_{{\textbf {Y}}})]\right\} . \end{aligned}$$

In this expression, the expectation \(\mathbb {E}\) is taken over \({\textbf {X}},M, \widetilde{T}, \widetilde{{\textbf {Y}}},\) and T, while the expectation \(\mathbb {E}_{{\textbf {I}}_{{\textbf {M}}}}\) is taken over \({\textbf {X}}, M, \widetilde{T},\) and \(\widetilde{{\textbf {I}}}_{{\textbf {Y}}}\). Here, \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^j({\textbf {X}}={\textbf {x}},M=m,{\widetilde{T}}={\widetilde{t}},\cdot ): \mathcal {Y}^d \mapsto [0,1]\) represents the conditional probability that the jth component of input came from \({\widetilde{Y}}_j\) rather than \(I_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {x}},m, t'_j)\), given \({\textbf {X}}={\textbf {x}}\) and \(\widetilde{T}={\widetilde{t}}=(t'_1,\ldots ,t'_d)\).

Define\( (I_{{\textbf {Y}}}^*, {\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^*):=\text {argmin}_{I_{{\textbf {Y}}}}\text {argmax}_{{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}} \mathcal {L}_{{\textbf {I}}{\textbf {Y}}}({I_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}})\) and \(\mathbb {L}_{{\textbf {I}}{\textbf {Y}}}(I_{{\textbf {Y}}}):=\sup _{{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}}\mathcal {L}_{{\textbf {I}}{\textbf {Y}}}({I_{{\textbf {Y}}}},{{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}})\), and let \(p_{{\widetilde{{\textbf {I}}}}_{{\textbf {Y}}}}({\widetilde{{\textbf {I}}}}_{{\textbf {Y}}}|{\textbf {X}}={\textbf {x}},M=m)\) denote the joint density of mediators induced by \(I_{{\textbf {Y}}}\) over the treatment level in \(\widetilde{T}\). By the standard theory of CGAN (Mirza and Osindero 2014), the outcome generated by \(I_{{\textbf {Y}}}\) for any sample has the same (marginal) distribution (conditional on the features and the mediator) as the true marginal distribution for that outcome; refer to Supplementary Material S1.4 for more details.

Given the sample set \(S^{Y}_n\), a realization \(\widetilde{T}: \widetilde{t}_i=\{t'_{i1},\ldots ,t'_{id}\}\), and independently generated \(\{\overline{{\textbf {Z}}}=\overline{{\textbf {z}}}_i\}_{i=1}^{n}\) from \(P_{\overline{{\textbf {Z}}}}\), we can use \(\widehat{\textbf {G}}_{{\textbf {Y}}}\) obtained in the previous section to define another sample set \(S^{IY}_n:=\{({\textbf {x}}_i,t_i,\widetilde{t}_i,m_i,\widehat{{\textbf {y}}}_i, \overline{{\textbf {z}}}_i)\}_{i=1}^{n}\) to train the estimated conditional generator \({{\widehat{I}}_{{\textbf {Y}}}}\) in the inferential block. Here, \(\widehat{{\textbf {y}}}_i=(\widehat{y}_{i1},\ldots ,\widehat{y}_{id})\) for \(i=1,\ldots ,n\), and

$$\begin{aligned} \widehat{y}_{ij}=\left\{ \begin{array}{ll} y_i, & \quad \textrm{if} \ t'_{ij}=t_i \\ {\widehat{{\textbf {G}}}}_{{\textbf {Y}}}({\widetilde{{\textbf {z}}}}_i,{\textbf {x}}_i,m_i,t_i,y_i)(t'_{ij}), & \quad \textrm{if} \ t'_{ij}\ne t_i. \end{array} \right. \end{aligned}$$
(22)

For the given realization of \(\widetilde{{\textbf {I}}}_{{\textbf {Y}}}: \widetilde{{\textbf {I}}}_{{\textbf {Y}}i}=(I_{{\textbf {Y}}}(\overline{{\textbf {z}}}_i,{\textbf {x}}_i,m_i,t'_{i1}),\) \(\ldots ,I_{{\textbf {Y}}}(\overline{{\textbf {z}}}_i,{\textbf {x}}_i,m_i,t'_{id}))_{j=1}^{d}\) for \(i=1,\ldots ,n\), we consider the following empirical version of the objective function \(\mathcal {L}_{{\textbf {I}}{\textbf {Y}}}(I_{\textbf {Y}},{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}})\):

$$\begin{aligned} \widetilde{\mathcal {L}}_{{\textbf {I}}{\textbf {Y}}}(I_{\textbf {Y}},{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}};\widehat{\textbf {G}}_{{\textbf {Y}}})&=\frac{1}{n}\sum _{i=1}^{n}\left\{ \sum _{j=1}^{d}\Big (\log {\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^j({\textbf {x}}_i,m_i,{\widetilde{t}}_i,\widehat{{\textbf {y}}}_i)\right. \\&\left. \quad + \log [1-{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^j({\textbf {x}}_i,m_i, {\widetilde{t}}_i,\widetilde{{\textbf {I}}}_{{\textbf {Y}}i})]\Big )\right\} . \end{aligned}$$

To enforce the restriction that \(I_{{\textbf {Y}}}(\overline{{\textbf {z}}},{\textbf {x}},t,m)=y\), we introduce the supervised loss: \( \widetilde{\mathcal {L}}_4({\textbf {I}}_{\textbf {Y}})=\frac{1}{n}\sum _{i=1}^{n}\big | I_{{\textbf {Y}}}(\overline{{\textbf {z}}}_i,{\textbf {x}}_i,m_i,t_i)-y_i\big |^2. \) Finally, we define an empirical objective function with a supervised parameter \(\alpha _4\ge 0\):

$$\begin{aligned} \begin{aligned} \widehat{\mathcal {L}}_{{\textbf {I}}{\textbf {Y}}}(I_{\textbf {Y}},{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}};\widehat{\textbf {G}}_{{\textbf {Y}}}):=\widetilde{\mathcal {L}}_{{\textbf {I}}{\textbf {Y}}}(I_{\textbf {Y}},{\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}};\widehat{\textbf {G}}_{{\textbf {Y}}})+\alpha _4\widetilde{\mathcal {L}}_4({\textbf {I}}_{\textbf {Y}}). \end{aligned} \end{aligned}$$
(23)

Likewise, we use FNN to estimate \(I_{{\textbf {Y}}}\) based on (23). The details are presented in Supplementary Material S1.7. The conditional generator network \(I^{\varvec{\varphi }}_{\textbf {Y}}\) is parameterized by \(\varvec{\varphi }\), the conditional discriminator network \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^{\varvec{\lambda }}\) is parameterized by \(\varvec{\lambda }\). The estimated conditional generator and the estimated discriminator is \({{\widehat{I}}_{{\textbf {Y}}}}\) and \({{\widehat{{\textbf {D}}}}_{{\textbf {I}}_{\textbf {Y}}}}\), respectively.

5 Convergence

This section establishes the convergence of the conditional distribution \(\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {X}}={\textbf {x}},T=t)\) to the distribution of \(M({\textbf {X}}={\textbf {x}},T=t)\) in the total variation norm as n approaches infinity in the mediator layer, given a target value \(t\in \mathcal {T}\) and \({\textbf {X}}={\textbf {x}}\) with \(p_{{\textbf {X}}}({\textbf {x}})>0\). Then, we show that for a given \({\textbf {X}}={\textbf {x}}\) and \(M=m\) with \(p_{{\textbf {X}},{\textbf {M}}}({\textbf {x}},m)>0\), the distribution of \(\widehat{I}_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {X}}={\textbf {x}},T=t,M=m)\) converges, in the same norm as the previous layer, to the distribution of \(Y({\textbf {X}}={\textbf {x}},T=t,M=m)\) in the outcome layer, conditioned on \({\textbf {X}}={\textbf {x}}\) and \(M=m\).

We begin by presenting the theoretical result for the counterfactual block of the mediator layer, which allows us to derive the convergence result of the inferential block. Then, we provide the final result for the mediator and outcome layers. Supplementary Material S2 and S3 provide regularity conditions and the detailed proofs of the theoretical results.

5.1 Convergence in mediator layer

Denote \(\widehat{{\textbf {M}}}=( \widehat{M}_1, \widehat{M}_2, \ldots ,\widehat{M}_d)\in \mathcal {M}^d\) as the vector of mediators corresponding to \({\widetilde{T}}\), where

$$\begin{aligned} \widehat{M}_i=\left\{ \begin{array}{ll} M, \ \textrm{if} & \quad T_i=T\\ {\widehat{{\textbf {G}}}}_{{\textbf {M}}}({\textbf {Z}},{\textbf {X}},T,M)(T_i), & \quad \textrm{if} \ T_i\ne T. \end{array} \right. \end{aligned}$$
(24)

Refer the notations below equation (S1) in the Supplementary Material, we define \(p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}}={\textbf {m}}|{\textbf {X}}={\textbf {x}},T=T_j):=p_{M|{\textbf {X}},T}(M=m_j|{\textbf {X}}={\textbf {x}},T=T_j)p_{{\widehat{{\textbf {G}}}}_{\textbf {M}}}( \widehat{{\textbf {M}}}'_j={\textbf {m}}'_j|{\textbf {X}}={\textbf {x}},M=m_j,T=T_j)\) as the joint density of mediator given the factual treatment level \(T=T_j\), , where \(p_{M|{\textbf {X}},T}(M=m_j|{\textbf {X}}={\textbf {x}},T=T_j)\) denotes the true density that generated the observed mediator and \(p_{{\widehat{{\textbf {G}}}}_{\textbf {M}}}( \widehat{{\textbf {M}}}'_j={\textbf {m}}'_j|{\textbf {X}}={\textbf {x}},M=m_j,T=T_j)\) denotes the density induced by \({\widehat{{\textbf {G}}}}_{\textbf {M}}\) over the remaining treatment level in \(\widetilde{T}\).

For any two probability measures \(P_{{\textbf {Q}}}\) and \(P_{{\textbf {W}}}\) with probability densities \(p_{{\textbf {Q}}}\) and \(p_{{\textbf {W}}}\), \(\Vert P_{{\textbf {Q}}}-P_{{\textbf {W}}}\Vert _{TV}=\frac{1}{2} \Vert p_{{\textbf {Q}}}-p_{{\textbf {W}}}\Vert _{L^1}\), where \(\Vert p_{{\textbf {Q}}}-p_{{\textbf {W}}}\Vert _{L^1}:=\int _{\mathcal {X}\times \mathcal {M}\times \mathcal {M}}|p_{{\textbf {Q}}}({\textbf {q}})-p_{{\textbf {W}}}({\textbf {q}})|d{\textbf {q}}\). Thus, the convergence of two probability distributions in the total variation norm is equivalent to their convergence in the \(L^1\) norm. Therefore, we establish the following convergence as the sample size n tends to infinity for almost all \(x\in \mathcal {X}\), \(\widetilde{T}=(T_1,\ldots ,T_d)\in \mathcal {T}^d\), and all \(i,j\in \{1,\ldots ,d\}\):

$$\begin{aligned}&\Vert p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}} |{\textbf {X}}={\textbf {x}},T=T_j)\\&\quad -p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}} |{\textbf {X}}={\textbf {x}},T=T_i)\Vert _{L^1}\rightarrow 0. \end{aligned}$$

Theorem 1

Under the assumptions (A.1), (A.2), (B.1) and (B.2), for almost all \({\textbf {x}}\in \mathcal {X}\), \({\widetilde{T}}\in \mathcal {T}^d\) and all \(i,j\in \{1,\ldots ,d\}\), the following statement is valid:

$$\begin{aligned}&\mathbb {E}_{S^M_n}\Vert p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}} |{\textbf {X}}={\textbf {x}},T=T_j)\nonumber \\&-p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}} |{\textbf {X}}={\textbf {x}},T=T_i)\Vert _{L^1}^2\rightarrow 0, \ \ \text {as}\ \ n\rightarrow \infty . \end{aligned}$$
(25)

Given \({\widehat{{\textbf {G}}}}_{{\textbf {M}}} \in \mathcal {G} \equiv \mathcal {G}_{\mathcal {H}, \mathcal {W}, \mathcal {S}, \mathcal {B}}\), let \(p_{{\widehat{{\textbf {I}}}}_{{\textbf {M}}}}(\widehat{{\textbf {I}}}_{{\textbf {M}}}={\textbf {m}}|{\textbf {X}}={\textbf {x}})\) denote the joint density of mediators induced by \({\widehat{I}}_{{\textbf {M}}}\), the estimated conditional generator, over the treatment level in \(\widetilde{T}\), and we aim to prove the following statement: for almost all \(x\in \mathcal {X}\), \({\widetilde{T}}\in \mathcal {T}^d\) and all \(j\in \{1,\ldots ,d\}\),

$$\begin{aligned}&\Vert p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}} |{\textbf {X}}={\textbf {x}},T=T_j)\\&\quad -p_{\widehat{\textbf {I}}_{{\textbf {M}}}}(\widehat{{\textbf {I}}}_{{\textbf {M}}} |{\textbf {X}}={\textbf {x}})\Vert _{L^1}\rightarrow 0\text { as }n\rightarrow \infty . \end{aligned}$$

Theorem 2

Under the assumptions (A.3), (A.4), (B.3) and (B.4), for almost all \({\textbf {x}}\in \mathcal {X}\), \({\widetilde{T}}\in \mathcal {T}^d\) and all \(i\in \{1,\ldots ,d\}\), the following statements is valid:

$$\begin{aligned}&\mathbb {E}_{S^{IM}_n}\Vert p_{\widehat{{\textbf {M}}}|{\textbf {X}},T}(\widehat{{\textbf {M}}} |{\textbf {X}}={\textbf {x}},T=T_i)\nonumber \\&\quad -p_{{\widehat{{\textbf {I}}}}_{{\textbf {M}}}}({\widehat{{\textbf {I}}}}_{{\textbf {M}}} |{\textbf {X}}={\textbf {x}})\Vert _{L^1}\rightarrow 0, \ \ \text {as}\ \ n\rightarrow \infty . \end{aligned}$$
(26)

By combining the results from Theorem 1 and 2, we can derive the following theorem:

Theorem 3

Under the assumptions (A.1)(A.4) and (B.1)(B.4), then

$$\begin{aligned}&\mathbb {E}_{S^M_n\cup \{{\widehat{{\textbf {z}}}}_i\}_{i=1}^n} \Vert p_{M|{\textbf {X}},T}(M|{\textbf {X}}={\textbf {x}},T=t)\nonumber \\&\quad -p_{{\widehat{I}}_{{\textbf {M}}}|{\textbf {X}},T}({\widehat{I}}_{{\textbf {M}}}(\widehat{\textbf {Z}},{\textbf {X}}={\textbf {x}},T=t))\Vert _{L^1}\rightarrow 0, \ \ \text {as}\ \ n\rightarrow \infty , \end{aligned}$$
(27)

for almost all \({\textbf {x}}\in \mathcal {X}\) and \(t\in \mathcal {T}\); that is, as the sample size n tends to infinity, the limit of \(\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {Z}}}},{\textbf {x}},t)\) follows the distribution \(P_{M|{\textbf {X}}={\textbf {x}},T=t}\).

5.2 Convergence in outcome layer

By employing a similar argument to that used in the mediator layer and under the assumptions (C.1)(C.4) and (D.1)(D.4), we can derive the following theorem:

Theorem 4

Under the assumptions (C.1)(C.4) and (D.1)(D.4), for almost all \({\textbf {x}}\in \mathcal {X}\), \(m\in \mathcal {M}\), \(t\in \mathcal {T}\), the following statements hold true:

$$\begin{aligned}&\mathbb {E}_{S^Y_n\cup \{\overline{{\textbf {z}}}_i\}_{i=1}^n} \Vert p_{Y|{\textbf {X}},T,M}(Y|{\textbf {X}}={\textbf {x}},T=t,M=m)\nonumber \\&\quad -p_{{\widehat{I}}_{{\textbf {Y}}}|{\textbf {X}},T,M}({\widehat{I}}_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {X}}={\textbf {x}},T=t,M=m))\Vert _{L^1}\rightarrow 0, \nonumber \\&\quad \text {as}\ \ n\rightarrow \infty . \end{aligned}$$
(28)

That is, as the sample size n tends to infinity, the limit of \(\widehat{I}_{{\textbf {Y}}}(\overline{{\textbf {Z}}},{\textbf {X}}={\textbf {x}},T=t,M=m)\) follows the distribution \(P_{Y|{\textbf {X}}={\textbf {x}},T=t,M=m}\).

6 Architecture and implementation

6.1 Generator architecture

We adopt a multi-task deep learning model (Bica et al. 2020; Zaheer et al. 2017) for the generators \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}, I_{{\textbf {M}}}^{\varvec{\psi }}, {\textbf {G}}^{\varvec{\zeta }}_{\textbf {Y}}\), and \( I_{{\textbf {Y}}}^{\varvec{\varphi }}\). To illustrate, we focus on \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\) and define a function \(g_M: \mathbb {R}^{d_z}\times \mathcal {X}\times \mathcal {M}\times \mathcal {T}\mapsto \mathcal {H}\), where \(\mathcal {H}\) is typically a latent space such as \(\mathbb {R}^{l}\) for some l. We then introduce a “head” function \(g_{M_t}: \mathcal {H}\times \mathcal {T}\mapsto \mathcal {M}\), which takes inputs from \(\mathcal {H}\) and a treatment level \(t'\) to produce a mediated value \({\widehat{m}}(t')\in \mathcal {M}\). Given observations \(({\textbf {x}},t,m)\), a noise vector \({\textbf {z}}\) and a target treatment level \(t'\), we define the generator \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}({\textbf {z}},{\textbf {x}},t,m)(t')\) as follows:

$$\begin{aligned} {\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}({\textbf {z}},{\textbf {x}},t,m)(t')=g_{M_t}(g_M({\textbf {z}},{\textbf {x}},t,m),t'), \end{aligned}$$
(29)

where \(g_M\) and \(g_{M_t}\) are fully connected networks. A visual representation of our generator architecture, with \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\) as an example, is provided in Fig. 2 (left panel). The inferential network \(I_{{\textbf {M}}}^{\varvec{\psi }}\) shares the same architecture as \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\), but does not receive t and m as inputs.

Fig. 2
figure 2

Architecture of generators (e.g.,\({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\)) and discriminators (e.g., \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\))

6.2 Discriminator architecture

We follow Bica et al. (2020) and Zaheer et al. (2017) to construct permutation equivariant deep networks, ensuring that the discriminators function as sets. Let \({\mathcal {U}}, {\mathcal {V}}, {\mathcal {C}}\) be certain spaces, and let \(b\in {\mathbb {Z}}^+\). We define permutation equivariance as follows:

Definition 1

A function \(f: {\mathcal {U}}^m\times {\mathcal {V}} \mathcal \mapsto \mathcal C^m\) is permutation equivariant with respect to space \({\mathcal {U}}^m\) if, for any \({{\textbf {u}}}\in {\mathcal {U}}^m\), \(v\in {\mathcal {V}}\), and permutation \(\sigma \) of \(1,\ldots , b\), we have \(f(u_{\sigma (1)},\ldots , u_{\sigma (b)},v)=(f_{\sigma (1)}({\textbf {u}},v),\ldots , f_{\sigma (b)}({{\textbf {u}}},v))\), where \(f_j({{\textbf {u}}},v)\) denotes the jth element of \(f({{\textbf {u}}},v)\).

The composition of two permutation equivariant functions remains permutation equivariant. For our equivariant functions, we utilize a basic building block defined in terms of equivariance input \({\textbf {u}}\) and auxiliary input \({{\textbf {v}}}\):

$$\begin{aligned} f_{equi}({{\textbf {u}}},{{\textbf {v}}})=\sigma (\lambda {{\textbf {I}}}_m{\textbf {u}}+\gamma ({{\textbf {1}}}_m{{\textbf {1}}}_m^T){\textbf {u}}+({{\textbf {1}}}_m\Theta ^T){{\textbf {v}}}), \end{aligned}$$
(30)

where \({{\textbf {I}}}_m\) is an \(m\times m\) identity matrix, \(\lambda \) and \(\gamma \) are scalar parameters, and \(\Theta \) is a vector of weights.

We aim for \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}, {\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^{\varvec{\omega }}, {\textbf {D}}^{\varvec{\xi }}_{\textbf {Y}}\), and \( {\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^{\varvec{\lambda }}\) to possess permutation equivariance. We use \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\) for illustration. To achieve equivariance with \(({\widetilde{t}},{\widetilde{{\textbf {m}}}})\), \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\) consists of two layers in the form given by Equation (30). The equivariance input \({{\textbf {u}}}\) for the first layer is \(({\widetilde{t}},{\widetilde{{\textbf {m}}}})\), while for the second layer, it is the output of the first layer. The auxiliary input \({{\textbf {v}}}\) for the first layer is \({\textbf {x}}\), and there is no auxiliary input for the second layer. Figure 2 (right panel) presents a diagram illustrating the architectures of the discriminators, including \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\).

6.3 Implementation

To train the generator \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}, I_{{\textbf {M}}}^{\varvec{\psi }}, {\textbf {G}}^{\varvec{\zeta }}_{\textbf {Y}}, I_{{\textbf {Y}}}^{\varvec{\varphi }}\) and the discriminator \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}, {\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^{\varvec{\omega }}, {\textbf {D}}^{\varvec{\xi }}_{\textbf {Y}}, {\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^{\varvec{\lambda }}\), we use the Exponential Linear Unit (ELU) as the activation function. We train the discriminator and the generator iteratively by updating \(\varvec{\theta }, \varvec{\psi }, \varvec{\zeta }, \varvec{\varphi }, \varvec{\phi }, \varvec{\omega }, \varvec{\xi }\) and \(\varvec{\lambda }\) as follows:

  • Keep \(\varvec{\theta }\) fixed, update the discriminator \({\textbf {D}}_{{\textbf {M}}}^{\varvec{\phi }}\) by ascending the stochastic gradient of the loss (14) with respect to \(\varvec{\phi }\).

  • Keep \(\varvec{\phi }\) fixed, update the generator \({\textbf {G}}_{{\textbf {M}}}^{\varvec{\theta }}\) by descending the stochastic gradient of the loss (14) with respect to \(\varvec{\theta }\).

  • Keep \(\varvec{\zeta }\) fixed, update the discriminator \({\textbf {D}}^{\varvec{\xi }}_{\textbf {Y}}\) by ascending the stochastic gradient of the loss (21) with respect to \(\varvec{\xi }\).

  • Keep \(\varvec{\xi }\) fixed, update the generator \({\textbf {G}}^{\varvec{\zeta }}_{\textbf {Y}}\) by descending the stochastic gradient of the loss (21) with respect to \(\varvec{\zeta }\).

  • Keep \(\varvec{\psi }\) fixed, update the discriminator \({\textbf {D}}_{{\textbf {I}}_{\textbf {M}}}^{\varvec{\omega }}\) by ascending the stochastic gradient of the loss (18) with respect to \(\varvec{\omega }\).

  • Keep \(\varvec{\omega }\) fixed, update the generator \(I_{{\textbf {M}}}^{\varvec{\psi }}\) by descending the stochastic gradient of the loss (18) with respect to \(\varvec{\psi }\).

  • Keep \(\varvec{\varphi }\) fixed, update the discriminator \({\textbf {D}}_{{\textbf {I}}_{\textbf {Y}}}^{\varvec{\lambda }}\) by ascending the stochastic gradient of the loss (23) with respect to \(\varvec{\lambda }\).

  • Keep \(\varvec{\lambda }\) fixed, update the generator \( I_{{\textbf {Y}}}^{\varvec{\varphi }}\) by descending the stochastic gradient of the loss (23) with respect to \(\varvec{\varphi }\).

The algorithm for training CGAN-ICMA-CT is given below.

figure a
figure b

Once the model is trained, we can estimate \(\theta _{t,t'}(t'; {\textbf {x}}), \theta _{t,t'}\) \((t'; {\textbf {x}}),\delta _{t,t'}(t'; {\textbf {x}}), \) \(\delta _{t,t'}(t; {\textbf {x}})\), and \(\tau _{t,t'}({\textbf {x}})\) using the obtained inferential generators \({\widehat{I}}_{{\textbf {M}}}\) and \({\widehat{I}}_{{\textbf {Y}}}\), based on Equations (7), (8), and (9), given the covariates \({\textbf {x}}_e\) with \({\textbf {X}}={\textbf {x}}_e\) and target treatment t or \(t'\) with \(T=t\) or \(T=t'\). This estimation process can be described in three steps.

Step 1: Estimated Inferential Mediator Generator (\(\widehat{I}_{{\textbf {M}}}\)):

  • Sample \({\widehat{{\textbf {Z}}}}\sim P_{{\widehat{{\textbf {Z}}}}}\) to obtain \({{\widehat{{\textbf {z}}}}_1, {\widehat{{\textbf {z}}}}_2, \ldots , {\widehat{{\textbf {z}}}}_{{\widehat{n}}}}\).

  • Feed \({\textbf {x}}_e\), \(\{{\widehat{{\textbf {z}}}}_i\}_{i=1}^{{\widehat{n}}}\), and t or \(t'\) into the inferential mediator generator (\({\widehat{I}}_{{\textbf {M}}}\)) to predict \(\{\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t)\}_{i=1}^{{\widehat{n}}}\) and \(\{\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t')\}_{i=1}^{{\widehat{n}}}\).

Step 2: Estimated Inferential Outcome Generator (\(\widehat{I}_{{\textbf {Y}}}\)):

  • Sample \(\overline{{\textbf {Z}}}\sim P_{\overline{{\textbf {Z}}}}\) to obtain \({\overline{{\textbf {z}}}_1, \overline{{\textbf {z}}}_2, \ldots , \overline{{\textbf {z}}}_{\overline{n}}}\).

  • Use the predicted values \(\{\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t)\}_{i=1}^{{\widehat{n}}}\) or \(\{\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,\) \( t')\}_{i=1}^{{\widehat{n}}}\), covariates \({\textbf {x}}_e\), noise \(\{\overline{{\textbf {z}}}_i\}_{i=1}^{\overline{n}}\), and the target treatment t or \(t'\) as inputs and feed them into the inferential outcome generator (\(\widehat{I}_{{\textbf {Y}}}\)) to generate outcome samples \({\widehat{I}}_{{\textbf {Y}}}(\overline{{\textbf {z}}}_j,{\textbf {x}}_e,t,\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t))\), \({\widehat{I}}_{{\textbf {Y}}}(\overline{{\textbf {z}}}_j,{\textbf {x}}_e,t',\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t))\), \({\widehat{I}}_{{\textbf {Y}}}(\overline{{\textbf {z}}}_j,{\textbf {x}}_e,t, \widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t'))\), and \({\widehat{I}}_{{\textbf {Y}}}(\overline{{\textbf {z}}}_j,{\textbf {x}}_e,t',\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_i,{\textbf {x}}_e,t'))\), for \(i\in \{1,\ldots , {\widehat{n}}\}\), and \(j\in \{1,\ldots , \overline{n}\}\).

Step 3: By utilizing these outcome samples, we can obtain estimates of \(\theta _{t,t'}(t'; {\textbf {x}}), \theta _{t,t'}(t'; {\textbf {x}})\), \(\delta _{t,t'}(t'; {\textbf {x}}), \delta _{t,t'}(t; {\textbf {x}})\), and \(\tau _{t,t'}({\textbf {x}})\) based on Equations (7), (8), and (9), given the covariates \({\textbf {x}}_e\) and target treatment t or \(t'\).

7 Simulation study

We conduct simulation studies to evaluate the empirical performance of our approach for estimating ICEs and compare our method with four other approaches: linear regression (LR), support vector machine (SVM), decision tree (DT), and random forest (RF); see details about the competing methods in Supplementary Material S4. We use four metrics to assess the performance: \(\sqrt{\textrm{MISE}_\mathrm{NDE_1}}\), \(\sqrt{\textrm{MISE}_\mathrm{NDE_2}}\), \(\sqrt{\textrm{MISE}_\mathrm{NIE_1}}\), and \(\sqrt{\textrm{MISE}_\mathrm{NIE_2}}\). Their definitions are given in Supplementary Material S5. Hyperparameters for the network in this simulation are reported in Supplementary Material S6.

We consider data generating processes similar to those in Huber et al. (2020) and Huang et al. (2024). The covariate X is generated from the uniform distribution U[1, 4]. The continuous treatment variable is generated as \(T = 0.2x+\epsilon _1\), where \(\epsilon _1\sim U(-0.2, 0.2)\). The mediator and outcome are given as follows:

$$\begin{aligned}&M(x, t)=0.5x^2+2t+xt+\epsilon _2,\\&Y(x, t,m_t(x))=0.3x+0.3t+0.25tx^2+0.1m^2_t(x)\\&\qquad +0.5tm^2_t(x)+\epsilon _3, \end{aligned}$$

where M(xt) is a scalar mediator, \(Y(x, t,m_t(x))\) is a scalar outcome, \(\epsilon _2\) and \(\epsilon _3\) are error terms following \(U(-0.2, 0.2)\), and \(\epsilon _1, \epsilon _2\), and \(\epsilon _s\) are independent of X and each other.

We generate samples of size \(n = 2000\) and \(n = 4000\) from the above settings and set the training rate to 0.9. Based on 100 generated datasets, we conduct analysis using the proposed and four competing methods. Table 1 reports the average value of the square root of the metrics, along with the corresponding standard deviation (std), where “linear" or “rbf" for SVM represents the linear or Gaussian kernel function used in SVM, respectively.

Table 1 Performance of five methods for estimating ICEs

As shown in Table 1, our method consistently outperforms the other methods, achieving the smallest values for the averaged \(\sqrt{\textrm{MISE}_\mathrm{NDE_1}},\sqrt{\textrm{MISE}_\mathrm{NDE_2}},\sqrt{\textrm{MISE}_\mathrm{NIE_1}}\), and \(\sqrt{\textrm{MISE}_\mathrm{NIE_2}}\), suggesting that our method estimates ICEs more accurately than the four others. Except for LR and SVM (linear kernel), all the methods improve with larger sample sizes.

Furthermore, we repeat the above analysis by varying the hyperparameters and training rates while remaining \(n = 2,000\). Supplementary Material S7 present the results, reaffirming that our proposed method consistently outperforms the others in estimating ICEs under various scenarios.

Notably, our main objective is to assess the empirical performance of CGAN-ICMA-CT by comparing it with existing methods. However, since no specific methods are available to address the problem we are investigating, we conduct the above simulation studies and introduce the four alternative methods above for comparison. It is important to note that the methods used for comparison do not successfully identify the ICEs. This limitation arises because they rely on separate regressions for the mediator and outcome layers and incorporate the predicted expected value of the mediator into the outcome regression model for prediction. In contrast, our proposed method, CGAN-ICMA-CT, as illustrated by Eq. (4), takes a different approach. It adopts a sampling-based method, drawing multiple values from the estimated probability distribution of the potential mediator, \(P_{M({\textbf {x}}, t)}\), to estimate potential outcomes. This fundamental difference allows our method to capture a wider range of possibilities and uncertainties in the estimation process, providing a more comprehensive analysis of the causal mediation effects.

8 Application: job corps

This section applies the proposed method to the analysis of the Job Corps dataset. Job Corps is a publicly funded training program in the US specifically designed to assist economically disadvantaged young individuals between the ages of 16 and 24. This program aims to provide participants who must be legal US residents with around 1200 h of vocational training and education and housing and boarding services, typically spanning an average duration of 8 months. Schochet et al. (2008) indicated that Job Corps participation increases educational attainment and reduces criminal activity. Previous research on the Job Corps program has predominantly utilized binary treatment definitions, categorizing individuals as either participants or non-participants. In contrast to this binary treatment approach, Huber et al. (2020) and Huang et al. (2024) used the total number of hours participants spent in academic or vocational classes over 12 months as a continuous treatment variable (T). However, their studies estimated average causal effects, failing to explore potential heterogeneity. By contrast, our study examines the individualized causal mechanism to understand how the causal effects vary based on observable characteristics to assist policymakers in designing more efficient intervention programs.

We focus on 4,000 individuals who received a positive treatment intensity, denoted as \(T>0\); see a detailed description in Huber et al. (2020). We adopt the same continuous treatment variable, mediator, and outcome as in Huber et al. (2020). The mediator variable (M) is the proportion of weeks employed in the second year following participation. The outcome variable (Y) is the number of times the police arrested an individual in the fourth year after completing the program. The pre-treatment covariates (\({\textbf {X}}\)) include age, gender, ethnicity, language competency, education level, marital status, household size and income, times in prison before participating in the program, previous receipt of social aid, family background (such as parents’ education), health status and health-related behaviors at baseline, participants’ expectations, and the interactions with recruiters. The variables in the analysis are normalized using the formula (variable − minimum value) / (maximum value − minimum value). This normalization method is used to ensure that the range of treatment levels is scaled to [0, 1]. By normalizing the variables in this way, the analysis benefits from a standardized scale, enhanced visibility of effects, and compatibility with network analysis techniques.

We employ a stratified random split to divide the original dataset S of 4,000 samples into ten mutually exclusive folds \(S_1, \dots , S_{10}\) with equal sizes, each containing 400 samples. In each round \(k \in \{1, \dots , 10\}\), our model is trained on the training set \(S\backslash S_k\) for 5,000 iterations, utilizing the same network hyperparameters as in the simulation. The trained model then generates predictions for the remaining samples in the testing set \(S_k\). To ensure robustness, we repeat this process ten times and calculate the average values of the predicted outcomes. The resulting predictions are presented below.

8.1 Heterogeneous causal effects

We set the benchmark treatment level at \(t'=40\), which represents a relatively low intensity of 40 h participants spent in academic or vocational classes over 12 months. Using the notations defined in Supplementary Material S6, we predict the values of \(\mathbb {E}[Y({\textbf {x}}_{ei},t, M_{t}({\textbf {x}}_{ei}))]\), \(\mathbb {E}[Y({\textbf {x}}_{ei},t, M_{t'}({\textbf {x}}_{ei}))]\), \(\mathbb {E}[Y({\textbf {x}}_{ei},t', M_{t}({\textbf {x}}_{ei}))]\), and \(\mathbb {E}[Y({\textbf {x}}_{ei},t', M_{t'}({\textbf {x}}_{ei}))]\) for \(t'=40\) and \(t \in \{100,200,\ldots ,1900,2000\}\). Denote the predicted conditional samples as \(\big \{{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},{\textbf {x}}_{ei},t,{\widehat{I}}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh}\), \({\textbf {x}}_{ei},t)),{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},{\textbf {x}}_{ei},t,\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh}\!,{\textbf {x}}_{ei},\!t')),{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},\!{\textbf {x}}_{ei},t',\) \(\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh}\!,{\textbf {x}}_{ei},\!t)),{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},\!{\textbf {x}}_{ei},\!t',\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh},\!{\textbf {x}}_{ei},\!t')),\) \(h = 1, \ldots \), \( {\widehat{n}}_e\, j = 1,\ldots , \overline{n}_e\big \}\) for \(t'=40\) and \(t \in \{100,200,\ldots ,1900,\) \(2000\}\). Here, we also take \({\widehat{n}}_e=\overline{n}_e=20\). Then, the predicted values are

$$\begin{aligned}&{\widehat{\mathbb {E}}}[Y({\textbf {x}}_{ei},t, M_t({\textbf {x}}_{ei}))]\\&\quad =\frac{1}{\overline{n}_e\times {\widehat{n}}_e}\left( \sum _{j=1}^{\overline{n}_e}\sum _{h=1}^{{\widehat{n}}_e}{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},{\textbf {x}}_{ei},t,\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh},{\textbf {x}}_{ei},t)) \right) ,\\&{\widehat{\mathbb {E}}}[Y({\textbf {x}}_{ei},t, M_{t'}({\textbf {x}}_{ei}))]\\&\quad =\frac{1}{\overline{n}_e\times {\widehat{n}}_e}\left( \sum _{j=1}^{\overline{n}_e}\sum _{h=1}^{{\widehat{n}}_e}{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},{\textbf {x}}_{ei},t,\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh},{\textbf {x}}_{ei},t')) \right) ,\\&{\widehat{\mathbb {E}}}[Y({\textbf {x}}_{ei},t', M_t({\textbf {x}}_{ei}))]\\&\quad =\frac{1}{\overline{n}_e\times {\widehat{n}}_e}\left( \sum _{j=1}^{\overline{n}_e}\sum _{h=1}^{{\widehat{n}}_e}{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},{\textbf {x}}_{ei},t',\widehat{I}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh},{\textbf {x}}_{ei},t)) \right) ,\\&{\widehat{\mathbb {E}}}[Y({\textbf {x}}_{ei},t', M_{t'}({\textbf {x}}_{ei}))] \\&\quad =\frac{1}{\overline{n}_e\times \widehat{n}_e}\left( \sum _{j=1}^{\overline{n}_e}\sum _{h=1}^{{\widehat{n}}_e}{\widehat{I}}_{\textbf {Y}}(\overline{{\textbf {z}}}_{ej},{\textbf {x}}_{ei},t',{\widehat{I}}_{{\textbf {M}}}({\widehat{{\textbf {z}}}}_{eh},{\textbf {x}}_{ei},t')) \right) , \end{aligned}$$

for \(i=1,2,\cdots ,4000\). Given the predicted values, we can estimate the individualized NDE and NIE, denoted as \({\widehat{\theta }}_{t,t'}(t; {\textbf {x}}_{ei}), {\widehat{\delta }}_{t,t'}(t; {\textbf {x}}_{ei}), {\widehat{\theta }}_{t,t'}(t'; {\textbf {x}}_{ei})\), and \({\widehat{\delta }}_{t,t'}(t'; {\textbf {x}}_{ei})\), for \(i = 1, \ldots , 4,000\). We now consider covariate-specific groups to examine group causal effects. Due to the large number (65) of covariates in our study, we randomly select one covariate from each variable type, binary, numeric, and categorical, to perform subgroup analysis.

8.1.1 Subgroup analysis for binary covariates

We select gender as an example of binary covariates. By averaging the individualized NDE and NIE estimates for males and females separately, we obtain the group average NDE and NIE for each gender at the benchmark treatment level \(t'=40\) and varying treatment levels \(t \in \{100,200,\ldots ,1900,2000\}\). For \(t \in \{100,200,\ldots ,1900,\) \(2000\}\), denote the group average estimates as \({\widehat{\theta }}_{t,40}^M(40; {\textbf {x}})\), \({\widehat{\theta }}_{t,40}^M(t; {\textbf {x}})\), \({\widehat{\delta }}_{t,40}^M(t; {\textbf {x}})\) and \({\widehat{\delta }}_{t,40}^M(40; {\textbf {x}}_{ei})\) for males, and \({\widehat{\theta }}_{t,40}^F(40; {\textbf {x}})\), \({\widehat{\theta }}_{t,40}^F(t; {\textbf {x}})\), \({\widehat{\delta }}_{t,40}^F(t; {\textbf {x}})\), and \({\widehat{\delta }}_{t,40}^F(40; {\textbf {x}})\) for females. By connecting the 20 points representing the causal effects for different treatment levels, Fig. 3 plots the trends of the causal effects for males and females over the range of treatment levels.

Fig. 3
figure 3

Effects heterogeneity regarding gender, where “Group average NDE_1” means \({\widehat{\theta }}_{t,40}^M(40; {\textbf {x}})\) and \({\widehat{\theta }}_{t,40}^F(40; {\textbf {x}})\), etc. Black and blue lines indicate estimated group average causal effects for males and females, respectively

As shown in Fig. 3, the effect heterogeneity between genders is primarily driven by the direct mechanism. At the benchmark treatment level (\(t'=40\)), the group average NIE for each gender remains close to zero across varying treatment levels. Conversely, the group average NDE for males is negative and consistently decreases as the training time increases, while the group average NDE for females remains unchanged with longer training time and close to zero. Based on these findings, we can conclude that the association between training time and criminal activity is stronger among male subjects, primarily through the direct mechanism. In addition, within the investigated range of treatment intensities, the direct effects of the Job Corps program on the number of arrests become larger among the male group as the training time increases, implying that as the training time increases among the male group, there is a more significant decrease in the number of arrests.

The observed effect heterogeneity between genders, primarily driven by the direct mechanism, can be explained by considering the differences in factors and dynamics that influence criminal activity for males and females. Firstly, societal norms and expectations regarding gender roles can play a role in shaping criminal behavior. Males may face different social pressures and expectations, contributing to a higher likelihood of engaging in criminal activities. This result could be due to peer influence, socialization patterns, or exposure to risky environments. Then, the research suggests that males tend to exhibit higher risk-taking behavior levels than females. This propensity for risk-taking can increase the likelihood of involvement in criminal activities. The Job Corps program provides vocational training and educational opportunities, offering an alternative path for males to channel their energy and ambitions and decrease their criminal behavior.

8.1.2 Subgroup analysis for numeric covariates

We choose times in prison as an example of numeric covariates; times in prison represent the number of times an individual has been incarcerated. This variable captures the frequency of past imprisonment for each individual, reflecting their criminal history and involvement with the legal system. The dataset includes individuals with times in prison ranging from 0 to 5 (a total of 6 levels). We average the individualized NDE and NIE estimates separately for each level in times in prison to obtain the group average NDE and NIE for each level of times in prison at the benchmark treatment level \(t'=40\) and varying treatment levels \(t\in \{100,200,\ldots ,1900,2000\}\). These group average estimates are denoted as \({\widehat{\theta }}_{t,40}^m(40; {\textbf {x}})\), \({\widehat{\theta }}_{t,40}^m(t; {\textbf {x}})\), \({\widehat{\delta }}_{t,40}^m(t; {\textbf {x}})\), and \({\widehat{\delta }}_{t,40}^m(40; {\textbf {x}})\), where \(m\in \{0,1,\ldots ,5\}\) represents the times in prison value and \(t\in \{100,200,\ldots ,1900,2000\}\). Then, we connect the six lines, each representing the group average effect for a specific times-in-prison value to visualize the trends of the causal effects at different levels of times in prison, and each line comprises 20 points illustrating the estimated causal effects at different treatment levels. By connecting the six lines and considering times in prison as a continuous covariate, Fig. 4 shows a three-dimensional curve surface to depict the relationship between treatment levels, times in prison, and the estimated causal effects, providing an overview of how the causal effects vary over treatment levels for different times-in-prison groups.

All group average causal effects are associated with times in prison. The three-dimensional plot illustrates separate surfaces for the group average NIE and NDE concerning times in prison at the benchmark treatment level (\(t'=40\)). The magnitude of the group average NDE increases with times in prison, attains a maximum at three, and then decreases. Meanwhile, the group average NDE exhibits a decreasing trend as the treatment level increases, as represented by the curved surface. The magnitude of the group average NIE initially remains relatively stable and close to zero but experiences a sudden increase when times in prison reach 4–5 and as the treatment level increases, reflected by the vertical surface initially, followed by the inclined surface. In summary, within the examined range of treatment intensities, the direct effects of the Job Corps program on the number of arrests become more pronounced with increasing times in prison up to a certain threshold (around three times) and then become less pronounced. The indirect effect of the Job Corps program through employment is minimal and close to zero initially but appears apparent when times in prison reach approximately 4–5. Similarly, the direct effect of the program becomes more prominent as the training time increases at a specific times-in-prison level, suggesting that a longer training duration leads to a more significant decrease in arrests. The indirect effect of the program through employment appears and becomes more prominent as the training duration increases and times in prison are approximately 4–5.

Fig. 4
figure 4

Effects heterogeneity for times in prison. “Group average NDE_1” means \({\widehat{\theta }}_{t,40}^m(40; {\textbf {x}})\), etc. From bottom to top, the six lines correspond to the times-in-prison levels of 0, 1, \(\cdots \), 5

Fig. 5
figure 5

Effect heterogeneity regarding welfare receipt during childhood, where “Group average NDE_1” means \({\widehat{\theta }}_{t,40}^w(40; {\textbf {x}})\), etc., and 1 to 4 represents the frequency of receiving welfare while growing up, with larger values indicating a higher frequency

Fig. 6
figure 6

Estimated average NDE and NIE. “Average NDE_1” means \({\widehat{\theta }}_{t,40}(40)\), etc

We can explain the above trend regarding times in prison as follows. When individuals were incarcerated multiple times (up to a certain threshold, around three times in this case), they may have a more extensive criminal history, leading to a higher risk of recidivism and a greater need for intervention. As a result, the Job Corps program, which provides vocational training and educational opportunities, may have a more pronounced effect on reducing arrests for individuals of a higher times-in-prison level. The program may offer them a path to skill development, providing alternatives to criminal activities and reducing their involvement in illegal behavior. However, as times in prison continue to increase beyond the threshold (around three times), other factors related to repeated incarceration, such as persistent social and economic challenges, may come into play. These factors may diminish the impact of the Job Corps program on reducing arrests, leading to a less pronounced direct effect. The indirect effect of the Job Corps program through employment is close to zero when times in prison are less than three. This phenomenon can be attributed to several factors. Individuals with fewer incarcerations may already possess skills and education that make them more employable than those with more incarcerations. As a result, the Job Corps program’s impact on improving their employment prospects may be small or negligible. It turns out that individuals with fewer incarcerations may face fewer barriers to employment, such as a lack of criminal records or a shorter history of involvement with the illegal system, making it easier for them to find employment opportunities without additional support from the Job Corps program. However, as times in prison reaches 4–5, the apparent indirect effect could be due to the following reasons. Individuals with a higher number of incarcerations often face more significant barriers to employment. These barriers can include stigma associated with their criminal history, limited job opportunities due to repeated convictions, and skills gaps resulting from prolonged periods of incarceration. The Job Corps program may provide tailored support and resources to help individuals overcome these barriers, leading to a significant increase in the indirect effect through improved employment outcomes. Overall, the observed trends in the direct and indirect effects of the Job Corps program on the number of arrests highlight the importance of considering the individual’s criminal history and specific challenges they face when evaluating the program’s effectiveness.

8.1.3 Subgroup analysis for categorical covariates

We select welfare receipt during childhood as an example and only consider 3,726 samples without missingness. We calculate the average individualized NDE and NIE separately for each level of welfare receipt to obtain the group average NDE and NIE estimates for each welfare receipt level at the benchmark treatment level \(t'=40\) and varying treatment levels \(t \in \{100, 200, \ldots , 1900, 2000\}\). These group average estimates are denoted as \({\widehat{\theta }}_{t,40}^w(40; {\textbf {x}})\), \({\widehat{\theta }}_{t,40}^w(t; {\textbf {x}})\), \({\widehat{\delta }}_{t,40}^w(t; {\textbf {x}})\), and \({\widehat{\delta }}_{t,40}^w(40; {\textbf {x}})\), where w denotes the frequency of receiving welfare while growing up, ranging from 1 to 4, with larger values indicating a higher frequency. The variable t ranges from 100 to 2000, representing varying treatment levels. Figure 5 depicts the patterns and trends of the causal effects as the treatment level varies, enabling us to examine how these effects differ according to the frequencies of welfare obtainment.

The direct mechanism mainly drives the effect heterogeneity regarding the frequency of welfare receipt during childhood. At the benchmark treatment level (\(t'=40\)), the group average NIE remains close to zero regardless of the treatment levels and the frequency of receiving welfare during childhood. In contrast, the group average NDE is consistently negative and decreases as the frequency of welfare receipt during childhood decreases across different treatment levels. In summary, within the range of treatment intensities examined, the direct effects of the Job Corps program on the number of arrests are more pronounced when the frequency of receiving welfare during childhood is lower, regardless of the specific training time. This result implies that for a specific duration of the Job Corps program, individuals with a lower frequency of receiving welfare during childhood experience a more substantial decrease in the number of arrests.

We can explain the above trend regarding the frequency of welfare receipt during childhood as follows. Individuals with a lower frequency of receiving welfare during childhood may have had relatively fewer barriers and disadvantages compared to those with a higher frequency. As a result, they may be better positioned to take advantage of the opportunities provided by the Job Corps program, leading to the direct effect of the program. Such a direct effect captures the immediate impact of the Job Corps program on reducing arrests. It becomes more pronounced for individuals with a lower frequency of welfare receipt during childhood, regardless of the specific training time.

8.2 Average causal effects

The estimated individualized NDE and NIE can also be used to calculate the average NDE and NIE, denoted as \({\widehat{\theta }}_{t,40}(40)\), \({\widehat{\theta }}_{t,40}(t)\), \({\widehat{\delta }}_{t,40}(t)\), and \({\widehat{\delta }}_{t,40}(40)\), where t ranges from 100 to 2000. By averaging the estimated individualized NDE and NIE and connecting the 20 points denoting causal effects for different treatment levels, Fig. 6 provides a comprehensive overview of how the average causal effects vary across the range of treatment levels.

At the benchmark treatment level (\(t'=40\)), the average NIE remains relatively close to zero across different treatment levels. Instead, the average NDE consistently exhibits a negative trend and decreases as the treatment level increases. Based on the examined range of treatment intensities, the Job Corps program exhibits direct effects on reducing the number of arrests in the fourth year, and the magnitude of these effects becomes more pronounced as the training time increases. However, within the investigated treatment range, the program-induced employment changes have minimal indirect effects on reducing arrests. These findings aligns with existing findings (Huang et al. 2024).

Notably, Huang et al. (2024) analyzed this dataset using two different estimators: the series regression estimator and the kernel regression estimator. However, they focused on estimating average causal effects and did not explore potential heterogeneity in their analysis. While our findings regarding the average causal effects align with the results obtained using the series regression estimator in their study, our proposed method can examine the presence of effect heterogeneity across different subgroups, which adds depth to the understanding of the causal relationship under investigation.

9 Discussion

This study presents CGAN-ICMA-CT, a novel approach designed for estimating ICEs and exploring personalized causal mechanisms in the presence of continuous treatment. We establish the theoretical foundation of CGAN-ICMA-CT, showing that the estimated distribution of our inferential conditional generator converges to the true conditional distribution under mild conditions. Simulation results demonstrate the superior performance of CGAN-ICMA-CT over competing methods in estimating ICEs. We apply our method to estimate the ICEs of the Job Corps program on the number of arrests and examine how causal effects vary with observable characteristics, providing insights into personalized causal mechanisms.

Several areas warrant further exploration. First, from a theoretical perspective, it would be valuable to derive the convergence rate of the sampling distribution, which would enhance our understanding of the efficiency and speed of convergence of the proposed method. On the algorithmic front, although we employ the Adam optimizer for training CGAN-ICMA-CT, there is potential for discovering and developing optimization algorithms specifically tailored to our model. These algorithms could enhance training efficiency and stability, leading to improved performance. Regarding causal analysis, our current method estimates ICEs with continuous outcomes. Future research could extend our approach to accommodate survival outcomes, allowing for a broader range of analyses. In addition, adapting our method to handle multiple-mediator scenarios would further enhance its applicability to complex causal mediation analysis. This extension enables the exploration of causal mechanisms involving multiple intermediate variables, providing a more comprehensive understanding of the underlying processes. However, in the context of multiple mediators, establishing identification necessitates obtaining the joint distribution of multiple mediators. This requirement can lead to increased complexity in the inference block network and in the convergence proof. While exploring this avenue could be intriguing, an alternative approach might involve assuming linearity between the outcome and the mediators, which would obviate the need for deriving the joint distribution of multiple mediators. This simplified assumption could potentially lead to more manageable processing. For further insights on handling multiple mediator scenarios, one may refer to Huan et al. (2024) that considered such a scenario with binary treatment. Moreover, sensitivity analysis is essential for assessing the robustness of causal conclusions, especially when the unconfoundedness assumption is violated, which is often untestable. Developing sensitivity analysis strategies tailored to our proposed method would be a valuable future direction. Finally, confidence interval is usually very important for causal analysis, the estimation of confidence intervals in our proposed model represents a weakness due to the lack of explicit probability distributions and the high computational demands required to accurately gauge these intervals. Future research should focus on developing novel methodologies to address these limitations. This includes exploring hybrid models that combine CGAN with other probabilistic frameworks to offer more accurate and interpretable estimates of confidence intervals, thereby enhancing the reliability of CGAN applications across various domains.