1 Introduction

Fig. 1
figure 1

Personalized medicine in sepsis treatment. Credit: Itenov et al. (2018)

Real-world sequential decision making problems often share three important properties — (1) the reward function is often unknown, yet (2) expert demonstrations can be acquired, and (3) the reward and/or dynamics often depend on a static parameter, also known as the context. For a concrete example, consider a dynamic treatment regime (Chakraborty & Murphy 2014), where a clinician acts to improve a patient’s medical condition. While the patient’s dynamic measurements, e.g., heart rate and blood pressure, define the state, there are static parameters, e.g., age and weight, which determine how the patient reacts to certain treatments and what form of treatment is optimal.

The contextual model is motivated by recent trends in personalized medicine, predicted to be one of the technology breakthroughs of 2020 by MIT’s Technology Review (Juskalian et al. 2020). As opposed to traditional medicine, which provide a treatment for the “average patient”, in the contextual model, patients are separated into different groups for which the medical decisions are tailored (Fig. 1). This enables the decision maker to provide tailored decisions (e.g., treatments) which are more effective, based on these static parameters.

For example, in Wesselink et al. (2018), the authors study organ injury, which may occur when a specific measurement (mean arterial pressure) decreases below a certain threshold. They found that this threshold varies across different patient groups (contextual behavior). In other examples, clinicians set treatment goals for the patients, i.e., they take actions to drive the patient measurements towards some predetermined values. For instance, in acute respiratory distress syndrome (ARDS), clinicians argue that these treatment goals should depend on the static patient information (the context) (Berngard et al. 2016).

In addition to the contextual structure, we consider the setting where the reward itself is unknown to the agent. This, also, is motivated by real-world problems, in which serious issues may arise when manually attempting to define a reward signal. For instance, when treating patients with sepsis, the only available signal is the mortality of the patient at the end of the treatment (Komorowski et al. 2018). While the goal is to improve the patients’ medical condition, minimizing mortality does not necessarily capture this objective. This model is illustrated in Fig. 2. The agent observes expert interactions with the environment, either through pre-collected data, or through interactive expert interventions. The agent then aims to find a reward which explains the behavior of the expert, meaning that the experts policy is optimal with respect to this reward.

To tackle these problems, we propose the Contextual Inverse Reinforcement Learning (COIRL) framework. Similarly to Inverse Reinforcement Learning (Ng & Russell 2000, IRL), provided expert demonstrations, the goal in COIRL is to learn a reward function which explains the expert’s behavior, i.e., a reward function for which the expert behavior is optimal. In contrast to IRL, in COIRL the reward is not only a function of the state features but also the context. Our aim is to provide theoretical analysis and insights into this framework. As such, throughout most of the paper we consider a reward which is linear in both the context and the state features. This analysis enables us to propose algorithms, analyze their behavior and provide theoretical guarantees. We further show empirically in Sect. 4 that our method can be easily extended to mappings which are non-linear in the context using deep neural nets.

Fig. 2
figure 2

The COIRL framework: a context vector parametrizes the environment. For each context, the expert uses the true mapping from contexts to rewards\(, W^*,\) and provides demonstrations. The agent learns an estimation of this mapping \({\hat{W}}\) and acts optimally with respect to it

The paper is organized as follows. In Sect. 2 we introduce the Contextual MDPs and provide relevant notation. In Sect. 3.1 we formulate COIRL, with a linear mapping, as a convex optimization problem. We show that while this loss is not differentiable, it can be minimized using subgradient descent and provide methods to compute subgradients. We propose algorithms based on Mirror Descent (MDA) and Evolution Strategies (ES) for solving this task and analyze their sample complexity. In addition, in Sect. 3.2, we adapt the cutting plane (ellipsoid) method to the COIRL domain. In Sect. 3.3 we discuss how existing IRL approaches can be applied to COIRL problems and their limitations. Finally, in Sect. 3.4 we discuss how to efficiently (without re-solving the MDP) perform zero-shot transfer to unseen contexts.

These theoretical approaches are then evaluated, empirically, in Sect. 4. We perform extensive testing of our methods and the relevant baselines both on toy problems and on a dynamic treatment regime, which is constructed from real data. We evaluate the run-time of IRL vs COIRL, showing that when the structure is indeed contextual, standard IRL schemes are computationally inefficient. We show that COIRL is capable of generalizing (zero-shot transfer) to unseen contexts, while behavioral cloning (log-likelihood action matching) is sub-optimal and struggles to find a good solution. These results show that in contextual problems, COIRL enables the agent to quickly recover a reward mapping that explains the expert’s behavior, outperforming previous methods across several metrics and can thus be seen as a promising approach for real-life decision making.

Our contribution is three fold: First, the formulation of COIRL problem as a convex optimization problem, and the novel adaptation of the descent methods to this setting. Second, we provide theoretical analysis for the linear case for all of the proposed methods. Third, we bridge between the theoretical results and real-life application through a series of experiments that aim to apply COIRL to sepsis treatment (Sect. 4).

2 Preliminaries

2.1 Contextual MDPs

A Markov Decision Process (Puterman 1994, MDP) is defined by the tuple \(({\mathcal {S}},{\mathcal {A}},P,\xi ,R,\gamma )\) where \({\mathcal {S}}\) is a finite state space, \({\mathcal {A}}\) a finite action space, \(P : S \times S \times A \rightarrow [0,1]\) the transition kernel, \(\xi \) the initial state distribution, \(R: {\mathcal {S}} \rightarrow \mathbb {R}\) the reward function and \(\gamma \in [0,1)\) is the discount factor. A Contextual MDP (Hallak et al. 2015, CMDP) is an extension of an MDP, and is defined by \(({\mathcal {C}},{\mathcal {S}},{\mathcal {A}},{\mathcal {M}},\gamma )\) where \({\mathcal {C}}\) is the context space, and \({\mathcal {M}}\) is a mapping from contexts \(c \in {\mathcal {C}}\) to MDPs: \({\mathcal {M}}(c) = ({\mathcal {S}}, {\mathcal {A}}, P_c, \xi , R_c, \gamma )\). For consistency with prior work, we consider the discounted infinite horizon scenario. We emphasize here that all the results in this paper can be easily extended to the episodic finite horizon and the average reward criteria.

We consider a setting in which each state is associated with a feature vector \(\phi : {\mathcal {S}} \rightarrow [0,1]^k\), and the reward for context c is a linear combination of the state features: \(R^*_c(s) = f^*(c)^T\phi (s)\). The goal is to approximate \(f^*(c)\) using a function \(f_W(c)\) with parameters W. This notation allows us to present our algorithms for any function approximator \(f_W(c)\), and in particular a deep neural network (DNN).

For the theoretical analysis, we will further assume a linear setting, in which the reward function and dynamics are linear in the context. Formally:

$$\begin{aligned} f^*(c)= c^TW^*,\, f_W (c) = c^T W,\, W^* \in {\mathcal {W}},\, \text {and } P_c(s'|s,a) = c^T \begin{bmatrix} P_1(s'|s,a) \\ \vdots \\ P_d(s'|s,a) \end{bmatrix} \end{aligned}$$

for some convex set \({\mathcal {W}}\). In order for the contextual dynamics to be well-defined, we assume the context space is the standard \(d-1\) dimensional simplex: \({\mathcal {C}} = \varDelta _{d-1}\). One interpretation of this model is that each row in the mapping \(W^*\) along with the corresponding transition kernels defines a base MDP, and the MDP for a specific context is a convex combination of these base environments.

We focus deterministic policies \(\pi : {\mathcal {S}} \rightarrow {\mathcal {A}}\) which dictate the agent’s behavior at each state. The value of a policy \(\pi \) for context c is:

$$\begin{aligned} V^\pi _c = E_{\xi ,P_c,\pi }\left[ \sum _{t=0}^{\infty } \gamma ^t R^*_c(s_t)\right] = f^*(c)^T\mu ^\pi _c , \end{aligned}$$

where \(\mu ^\pi _c:=E_{\xi ,P_c,\pi }[\sum _{t=0}^{\infty } \gamma ^t \phi (s_t)]\in \mathbb {R}^k\) is called the feature expectations of \(\pi \) for context c. For other RL criteria there exist equivalent definitions of feature expectations; see Zahavy et al. (2020b) for the average reward. We also denote by \(V^\pi _c(s), \mu ^\pi _c(s)\) the value and feature expectations for \(\xi = \mathbb {1}_s\). The action-value function, or the Q-function, is defined by: \(Q_c^\pi (s,a) = R^*_c(s) + \gamma E_{s' \sim P_c(\cdot |s,a)}V^\pi _c(s')\). For the optimal policy with respect to (w.r.t.) a context c, we denote the above functions by \(V^*_c, Q^*_c,\mu ^*_c\). For any context c, \(\pi ^*_c\) denotes the optimal policy w.r.t. \(R_c^*\), and \({\hat{\pi }}_c(W)\) denotes the optimal policy w.r.t. \({{\hat{R}}}_c(s) = f_W (c)^T\phi (s)\).

For simpler analysis, we define a “flattening” operator, converting a matrix to a vector: \(\mathbb {R}^{d\times k}\rightarrow \mathbb {R}^{d\cdot k}\) by \(\underline{W}=\begin{bmatrix} w_{1,1}, \ldots ,w_{1,k}, \ldots ,w_{d,1}, \ldots ,w_{d,k}\end{bmatrix}\). We also define the operator \(\odot \) to be the composition of the flattening operator and the outer product: \(u \odot v = \begin{bmatrix} u_1v_1, \ldots ,u_1v_k, \ldots ,u_dv_1, \ldots ,u_dv_k\end{bmatrix}\). Therefore, the value of policy \(\pi \) for context c is given by \(V^\pi _{c} = c^TW^*\mu ^\pi _c = \underline{W^*}^T(c\odot \mu ^\pi _c),\) where \(||c\odot \mu ^\pi _c||_1 \le \frac{k}{1-\gamma }\).

2.2 Apprenticeship learning and inverse reinforcement learning

In Apprenticeship Learning (AL), the reward function is unknown, and we denote the MDP without the reward function (also commonly called a controlled Markov chain) by MDP\(\backslash \)R. Similarly, we denote a CMDP without a mapping of context to reward by CMDP\(\backslash \)M.

Instead of manually tweaking the reward to produce the desired behavior, the idea is to observe and mimic an expert. The literature on IRL is quite vast and dates back to (Ng & Russell 2000; Abbeel & Ng 2004). In this setting, the reward function (while unknown to the apprentice) is a linear combination of a set of known features as we defined above. The expert demonstrates a set of trajectories that are used to estimate the feature expectations of its policy \(\pi _E\), denoted by \(\mu _E \). The goal is to find a policy \(\pi \), whose feature expectations are close to this estimate, and hence will have a similar return with respect to any weight vector w.

Formally, AL is posed as a two-player zero-sum game, where the objective is to find a policy \(\pi \) that does at least as well as the expert with respect to any reward function of the form \(r(s) = w\cdot \phi (s), w\in {\mathcal {W}}\). That is we solve

$$\begin{aligned} \max _{\pi \in \varPi }\min _{w\in {\mathcal {W}}} \left[ w \cdot \mu (\pi ) - w \cdot \mu _E \right] \end{aligned}$$
(1)

where \(\varPi \) denotes the set of mixed policies (Abbeel & Ng 2004), in which a deterministic policy is sampled according to a distribution at time 0, and executed from that point on. Thus, this policy class can be represented as a convex set of vectors – the distributions over the deterministic policies.

They define the problem of approximately solving Eq. (1) as AL, i.e., finding \(\pi \) such that

$$\begin{aligned} \forall w\in {\mathcal {W}}: w\cdot \mu (\pi ) \ge w\cdot \mu _E - \epsilon + f^\star . \end{aligned}$$
(2)

If we denote the value of Eq. (1) by \(f^\star \) then, due to the von-Neumann minimax theorem we also have that

$$\begin{aligned} f^\star = \min _{w\in {\mathcal {W}}} \max _{\pi \in \varPi } \left[ w \cdot \mu (\pi ) - w \cdot \mu _E \right] . \end{aligned}$$
(3)

We will later use this formulation to define the IRL objective, i.e., finding \(w \in {\mathcal {W}}\) such that

$$\begin{aligned} \forall \pi \in \varPi : w\cdot \mu _E \ge w\cdot \mu (\pi ) - \epsilon - f^\star ; \end{aligned}$$
(4)

Abbeel & Ng (2004) suggested two algorithms to solve Eq. (2) for the case that \({\mathcal {W}}\) is a ball in the Euclidean norm; one that is based on a maximum margin solver and a simpler projection algorithm. The latter starts with an arbitrary policy \(\pi _0\) and computes its feature expectation \(\mu _0\). At step t they define a reward function using weight vector \(w_t = \mu _E-{\bar{\mu }}_{t-1}\) and find the policy \(\pi _t\) that maximizes it. \(\bar{\mu }_t\) is a convex combination of feature expectations of previous (deterministic) policies \(\bar{\mu }_t = \sum ^t _{j=1}\alpha _j \mu (\pi _j).\) They show that in order to get that \(\left\Vert \bar{\mu }_T-\mu \right\Vert \le \epsilon \), it suffices to run the algorithm for \(T=O(\frac{k}{(1-\gamma )^2\epsilon ^2}\log (\frac{k}{(1-\gamma )\epsilon }))\) iterations.

Recently, Zahavy et al. (2020a) showed that the projection algorithm is in fact equivalent to a Frank-Wolfe method for finding the projection of the feature expectations of the expert on the feature expectations polytope – the convex hull of the feature expectations of all the deterministic policies in the MDP. The Frank-Wolfe analysis gives the projection method of Abbeel & Ng (2004) a slightly tighter bound of \(T=O(\frac{k}{(1-\gamma )^2\epsilon ^2}).\) In addition, a variation of the FW method that is based on taking “away steps” (Garber & Hazan 2016; Jaggi 2013) achieves a linear rate of convergence, i.e., it is logarithmic in \(\epsilon .\)

Another type of algorithms, based on online convex optimization, was proposed by Syed & Schapire (2008). In this approach, in each round the “reward player” plays an online convex optimization algorithm on losses \(l_t(w_t) = w_t\cdot (\mu _E - \mu (\pi _t))\); and the “policy player” plays the best response, i.e, the policy \(\pi _t\) that maximizes the return \(\mu (\pi _t)\cdot w_t\) at time t. The results in Syed & Schapire (2008) use a specific instance of MDA where the optimization set is the simplex and distances are measured w.r.t \(\left\Vert \cdot \right\Vert _1.\) This version of MDA is known as multiplicative weights or Hedge. The algorithm runs for T steps and returns a mixed policy \(\psi \) that draws with probability 1/T a policy \(\pi _t, t=1,\ldots ,T\). Thus,

$$\begin{aligned} f^\star&\le \frac{1}{T}\sum \nolimits _{t=1}^T\max _{\pi \in \varPi } \left[ w_t\cdot \mu (\pi )-w_t\cdot \mu _E\right] \nonumber \\&= \frac{1}{T}\sum \nolimits _{t=1}^T \left[ w_t\cdot \mu (\pi _t)-w_t\cdot \mu _E\right] \end{aligned}$$
(5)
$$\begin{aligned}&\le \min _{w\in {\mathcal {W}}}\frac{1}{T}\sum _{t=1}^T w\cdot \left[ \mu (\pi _t)-\mu _E\right] + {O}\left( \frac{\sqrt{\log (k)}}{(1-\gamma )\sqrt{T}}\right) \end{aligned}$$
(6)
$$\begin{aligned}&= \min _{w\in {\mathcal {W}}} w \cdot \left( \mu (\psi ) - \mu \right) + {O}\left( \frac{\sqrt{\log (k)}}{(1-\gamma )\sqrt{T}}\right) , \end{aligned}$$
(7)

where Eq. (5) follows from the fact that the policy player plays the best response, that is, \(\pi _t\) is the optimal policy w.r.t the reward \(w_t;\) Eq. (6) follows from the fact that the reward player plays a no-regret algorithm, e.g., online MDA. Thus, they get that \(\forall w\in {\mathcal {W}}: w\cdot \mu (\psi ) \ge w\cdot \mu + f^\star - {O}\left( \frac{1}{\sqrt{T}}\right) \).Footnote 1

2.3 Learned dynamics

Finally, we note that majority of AL papers consider the problem of learning the transition kernel and initial state distribution as an orthogonal ’supervised learning’ problem to the AL problem. That is, the algorithm starts by approximating the dynamics from samples and then follows by executing the AL algorithm on the approximated dynamics (Abbeel & Ng 2004; Syed & Schapire 2008). In this paper we adapt this principle. We also note that it is possible to learn a transition kernel and an initial state distribution that are parametrized by the context. Existing methods, such as in Modi et al. (2018), can be used to learn contextual transition kernels. Furthermore, in domains that allow access to the real environment, Abbeel & Ng (2005) provides theoretical bounds for the estimated dynamics of the frequently visited state-action pairs. Thus, we assume \(P_c\) is known when discussing suggested methods in Sect. 3, which enables the computation of feature expectations for any context and policy. In Sect. 4.5 we present an example of this principle, where we use a context-dependent model to estimate the dynamics.

3 Methods

In the previous section we have seen AL algorithms for finding a policy that satisfies Eq. (2). In a CMDP this policy will have to be a function of the context, but unfortunately, it is not clear how to analyze contextual policies. Instead, we follow the approach that was taken in the CMDP literature and aim to learn the linear mapping from contexts to rewards (Hallak et al. 2015; Modi et al. 2018; Modi & Tewari 2019). This requires us to design an IRL algorithm instead of an AL algorithm, i.e., to solve Eq. (4) rather than Eq. (2). Concretely, the goal in Contextual IRL is to approximate the mapping \(f^*(c)\) by observing an expert (for each context c, the expert provides a demonstration from \(\pi ^*_c\)).

This Section is organized as follows. We begin with Sect. 3.1, where we formulate COIRL as a convex optimization problem and derive subgradient descent algorithms for it based on the Mirror Descent Algorithm (MDA). Furthermore, we show that MDA can learn efficiently even when there is only a single expert demonstration per context. This novel approach is designed for COIRL but can be applied to standard IRL problems as well.

In Sect. 3.2 we present a cutting plane method for COIRL that is based on the ellipsoid algorithm. This algorithm requires, in addition to demonstrations, that the expert evaluate the agent’s policy and provide its demonstration only if the agent’s policy is sub-optimal.

In Sect. 3.3 we discuss how existing IRL algorithms can be adapted to the COIRL setting for domains with finite context spaces and how they compare to COIRL, which we later verify in the experiments section. Finally, in Sect. 3.4 we explore methods for efficient transfer to unseen contexts without additional planning.

3.1 Mirrored descent for COIRL

3.1.1 Problem formulation

In this section, we derive and analyze convex optimization algorithms for COIRL that minimize the following loss function,

$$\begin{aligned} \text {Loss}(W) = \left. \mathbb {E}_c \max _{\pi } \left[ f_W(c) \cdot \left( \mu ^\pi _c - \mu ^*_c\right) \right) \right] = \mathbb {E}_c \left[ f_W (c) \cdot \left( \mu ^{{{\hat{\pi }}}_c( W)}_c - \mu ^*_c\right) \right] \, . \end{aligned}$$
(8)

Remark 3.1

We analyze the descent methods for the linear mapping \(f(c)=c^{T}W\). It is possible to extend the analysis to general function classes (parameterized by W), where \(\frac{\partial f}{\partial W}\) is computable and f is convex. In this case, \(\frac{\partial f}{\partial W}\) aggregates to the descent direction instead of the context, c, and similar sample complexity bounds can be achieved.

The following lemma suggests that if W is a minimizer of Eq. (8), then the expert policy is optimal w.r.t. reward \({{\hat{R}}}_c\) for any context.

Lemma 3.1

\(\text {Loss}(W)\) satisfies the following properties: (1) For any W the loss is greater or equal to zero. (2) If \(\text {Loss}(W)=0\) then for any context, the expert policy is the optimal policy w.r.t. reward \({{\hat{R}}}_c(s)=c^TW\phi (s)\).

Proof

We need to show that \(\forall W\) , \(\text {Loss}(W)\ge 0,\) and \(\text {Loss}(W^*) = 0.\) Fix W. For any context c,  we have that \(\mu ^{{\hat{\pi }}_c(W)}_c\) is the optimal policy w.r.t. reward \(f_W (c),\) thus, \(f_W (c) \cdot \big (\mu ^{{{\hat{\pi }}}_c(W)}_c - \mu ^*_c\big )\ge 0.\) Therefore we get that \(\text {Loss}(W)\ge 0.\) For \(W^*,\) we have that \(\mu ^{{\hat{\pi }}_{c}(W)}_c = \mu ^*_c,\) thus \(\text {Loss}(W^*) = 0\).

For the second statement , note that \(\text {Loss}(W) = 0\) implies that \(\forall c, f_W (c) \cdot \big (\mu ^{{{\hat{\pi }}}_c(W)}_c - \mu ^*_c\big )= 0.\) This can happen in one of two cases. (1) \(\mu ^{{{\hat{\pi }}}_c(W)}_c = \mu ^*_c,\) in this case \(\pi ^*_{c},{\hat{\pi }}_{c}(W)\) have the same feature expectations. Therefore, they are equivalent in terms of value. (2) \(\mu ^{{\hat{\pi }}_{c}(W)}_c \ne \mu ^*_c,\) but \(f_W (c) \cdot \big (\mu ^{{\hat{\pi }}_{c}(W)}_c - \mu ^*_c\big )= 0.\) In this case, \(\pi ^*_{c},{\hat{\pi }}_{c}(W)\) have different feature expectations, but still achieve the same value w.r.t. reward \(f_W (c).\) Since \({\hat{\pi }}_{c}(W)\) is an optimal policy w.r.t. this reward, so is \(\pi ^*_{c}.\) \(\square \)

To evaluate the loss, the optimal policy \({{\hat{\pi }}}_c (W)\) and its features expectations \(\mu ^{{{\hat{\pi }}}_c (W)}_c\) must be computed for all contexts. Finding \({{\hat{\pi }}}_c (W)\), for a specific context, can be solved using standard RL methods, e.g., value or policy iteration. In addition, computing \(\mu ^{{{\hat{\pi }}}_c (W)}_c\) is equivalent to performing policy evaluation (solving a set of linear equations).

However, since we need to use an algorithm (e.g. policy iteration) to solve for the optimal policy, Eq. (8) is not differentiable w.r.t. W. We therefore consider two optimization schemes that do not involve differentiation: (i) subgradients and (ii) randomly perturbing the loss function (finite differences). Although the loss is non-differentiable, Lemma 3.2 below shows that in the special case that \(f_W(c)\) is a linear function, Eq. (8) is convex and Lipschitz continuous. Furthermore, it provides a method to compute its subgradients.

Lemma 3.2

Let \(f_W(c) = c^TW\) such that \(\text {Loss}(W),\) denoted by \( L_{\text {lin}}(W)\), is given by

$$\begin{aligned} L_{\text {lin}}(W) = \mathbb {E}_c \left[ c^T W \cdot \left( \mu ^{{\hat{\pi }}_{c}(W)}_c - \mu ^*_c\right) \right] . \end{aligned}$$

We have that:

  1. 1.

    \(L_{\text {lin}}(W)\) is a convex function.

  2. 2.

    \(g(W) = \mathbb {E}_c \left[ c \odot \big (\mu ^{{\hat{\pi }}_{c}(W)}_c - \mu ^*_c\big ) \right] \) is a subgradient of \( L_{\text {lin}}(W)\).

  3. 3.

    \(L_{\text {lin}}\) is a Lipschitz continuous function, with Lipschitz constant \(L = \frac{2}{1-\gamma }\) w.r.t. \(\left\Vert \cdot \right\Vert _\infty \) and \(L = \frac{2\sqrt{dk}}{1-\gamma }\) w.r.t. \(\left\Vert \cdot \right\Vert _2\).

In the supplementary material we provide the proof for the Lemma (Appendix A). The proof follows the definitions of convexity and subgradients, using the fact that for each W we compute the optimal policy for reward \(c^TW\). The Lipschitz continuity of \(L_\text {Lin}(W)\) is related to the simulation lemma (Kearns & Singh 2002), that is, a small change in the reward results in a small change in the optimal value.

Note that \(g(W)\in \mathbb {R}^{d\times k}\) is a matrix; we will sometimes refer to it as a matrix and sometimes as a flattened vector, depending on the context. Finally, g(W) is given in expectation over contexts, and in expectation over trajectories (feature expectations). We will later see how to replace g(W) with an unbiased estimate, which can be computed by aggregating state features from a single expert trajectory sample.

3.1.2 Algorithms

Lemma 3.2 identifies \(L_\text {Lin}(W)\) as a convex function and provides a method to compute its subgradients. A standard method for minimizing a convex function over a convex set is the subgradient projection algorithm (Bertsekas 1997). The algorithm is given by the following iterates:

$$\begin{aligned} W_{t+1} = \text {Proj}_{{\mathcal {W}}} \left\{ W_t - \alpha _t g(W_t)\right\} , \end{aligned}$$

where \(f(W_t)\) is a convex function, \(g(W_t)\) is a subgradient of \(f(W_t)\), and \(\alpha _t\) the learning rate. \({\mathcal {W}}\) is required to be a convex set; we will consider two particular cases, the \(\ell _2\) ball (Abbeel & Ng 2004) and the simplex (Syed & Schapire 2008).Footnote 2

Next, we consider a generalization of the subgradient projection algorithm that is called the mirror descent algorithm (Nemirovsky & Yudin 1983, MDA):

$$\begin{aligned} W_{t+1} = \arg \min _{W\in {\mathcal {W}}} \left\{ W\cdot \nabla _f(W_t) + \frac{1}{\alpha _t}D_\psi (W,W_t) \right\} , \end{aligned}$$
(9)

where \(D_\psi (W,W_t)\) is a Bregman distance,Footnote 3 associated with a strongly convex function \(\psi \). The following theorem characterizes the convergence rate of MDA.

Theorem 3.1

(Convergence rate of MDA) Let \(\psi \) be a \(\sigma \)-strongly convex function on \({\mathcal {W}}\) w.r.t. \(\left\Vert \cdot \right\Vert \), and let \(D^2 = \sup _{W_1,W_2\in {\mathcal {W}}}D_\psi (W _1,W_2)\). Let f be convex and L-Lipschitz continuous w.r.t. \(\left\Vert \cdot \right\Vert \). Then, MDA with \(\alpha _t = \frac{D}{L}\sqrt{\frac{2\sigma }{t}}\) satisfies:

$$\begin{aligned} f\left( \frac{1}{T}\sum _{t=1}^T W_t\right) - f(W^*) \le DL\sqrt{\frac{2}{\sigma T}} . \end{aligned}$$

We refer the reader to Beck & Teboulle (2003) and Bubeck (2015) for the proof. Specific instances of MDA require one to choose a norm and to define the function \(\psi .\) Once those are defined, one can compute \(\sigma , D\) and L which define the learning rate schedule. Below, we provide two MDA instances (see, for example Beck & Teboulle (2003) for derivation) and analyze them for COIRL.

Projected subgradient descent (PSGD): Let \({\mathcal {W}}\) be an \(\ell _2\) ball with radius 1. Fix \(||\cdot ||_2\), and \(\psi (W) = \frac{1}{2}||W||_2^2.\) \(\psi \) is strongly convex w.r.t. \(||\cdot ||_2\) with \(\sigma =1.\) The associated Bregman divergence is given by \(D_\psi (W_1, W_2) = 0.5 ||W_1 - W_2||_2^2.\) Thus, mirror descent is equivalent to PSGD. \(D^2 = \max _{W_1,W_2 \in {\mathcal {W}}}D_\psi (W_1, W_2) \le 1,\) and according to Lemma 3.2, \(L = \frac{2\sqrt{dk}}{1-\gamma }\). Thus, we have that the learning rate is \(\alpha _t = (1-\gamma )\sqrt{\frac{1}{2dkt}}\) and the update to W is given by

$$\begin{aligned} {\tilde{W}} = W_t - \alpha _t g_t, W_{t+1} = {\tilde{W}}/||{\tilde{W}}||_2, \end{aligned}$$

and according to Theorem 3.1 we have that after T iterations,

$$\begin{aligned} L_{\text {lin}}\left( \frac{1}{T}\sum \nolimits _{t=1}^T W_t\right) - L_{\text {lin}}(W^*) \le {\mathcal {O}}\left( \frac{\sqrt{dk}}{(1-\gamma )\sqrt{T}} \right) \, . \end{aligned}$$

Exponential Weights (EW): Let \({\mathcal {W}}\) be the standard \({dk-1}\) dimensional simplex. Let \({\psi (W) = \sum _i W(i)\log (W(i))}\). \(\psi \) is strongly convex w.r.t. \(||\cdot ||_1\) with \({\sigma =1}\). We get that the associated Bregman divergence is given by

$$\begin{aligned} D_\psi \left( W_1, W_2\right) = \sum _i W_1(i) \log \left( \frac{W_1(i)}{W_2(i)}\right) , \end{aligned}$$

also known as the Kullback-Leibler divergence. In addition,

$$\begin{aligned} D^2 = \max _{W_1,W_2 \in {\mathcal {W}}}D_\psi \left( W_1, W_2\right) \le \log (dk) \end{aligned}$$

and according to Lemma 3.2, \(L = \frac{2}{1-\gamma }\). Thus, we have that the learning rate is \(\alpha _t = (1-\gamma )\sqrt{\frac{\log (dk)}{2t}}.\) Furthermore, the projection onto the simplex w.r.t. to this distance amounts to a simple renormalization \(W \leftarrow W/||W||_1\). Thus, we get that MDA is equivalent to the exponential weights algorithm and the update to w is given by

$$\begin{aligned} \forall i \in [1..dk], {\tilde{W}}(i) = W_{t}(i) \exp \left( -\alpha _t g_t (i)\right) , W_{t+1} = {\tilde{W}}/||{\tilde{W}}||_1. \end{aligned}$$

Finally, according to Theorem 3.1 we have that after T iterations,

$$\begin{aligned} L_{\text {lin}}\left( \frac{1}{T}\sum \nolimits _{t=1}^T W_t\right) - L_{\text {lin}}(W^*) \le {\mathcal {O}}\left( \frac{\sqrt{\log (dk)}}{(1-\gamma )\sqrt{T}} \right) \, . \end{aligned}$$
figure a

Evolution strategies for COIRL: Next, we consider a derivative-free algorithm for computing subgradients, based on Evolution Strategies (Salimans et al. 2017, ES). For convex optimization problems, ES is a gradient-free descent method based on computing finite differences (Nesterov & Spokoiny 2017). The subgradient in ES is computed by sampling m random perturbations and computing the loss for them, in the following form

$$\begin{aligned} \text {For } j =&1, ..., m \text { do} \\&\text {Sample } u_j \sim {\mathcal {N}}\left( 0,\rho ^{2}\right) \in {\mathcal {R}}^{dk},\\&g^j = \text {Loss} \left( W_t + \frac{\nu u_j}{||u_j||} \right) \frac{ \nu u_j}{||u_j||} , \\ \text {End For}&, \end{aligned}$$

and the subgradient is given by

$$\begin{aligned} g_t = \frac{1}{m\rho } \sum _{j=1}^m g^j. \end{aligned}$$
(10)

Theorem 3.2 presents the sample complexity of PSGD with the subgradient in Eq. (10) for the case that the loss is convex, as in \(L_\text {Lin}\). While this method has looser upper-bound guarantees compared to MDA (Theorem 3.1), Nesterov & Spokoiny (2017) observed that in practice, it often outperforms subgradient-based methods. Thus, we test ES empirically and compare it with the subgradient method (Sect. 3.1). Additionally, Salimans et al. (2017) have shown the ability of ES to cope with high dimensional non-convex tasks (DNNs).

Theorem 3.2

(ES Convergence Rate (Nesterov & Spokoiny 2017)) Let \(L_\text {lin}(W)\) be a non-smooth convex function with Lipschitz constant L, such that \(||W_0 - W^*|| \le D\), step size of \(\alpha _t = \frac{D}{(dk+4)\sqrt{T+1}L}\) and \(\nu \le \frac{\epsilon }{2L\sqrt{dk}}\) then in \(T = \frac{4(dk+4)^2D^2 L^2}{\epsilon ^2}\) ES finds a solution which is bounded by \(\mathbb {E}_{U_{T-1}} [L_\text {lin}({{\hat{W}}}_T)] - L_\text {lin}(W^*) \le \epsilon \), where \({U_T = \{ u_0, \ldots , u_T \}}\) denotes the random variables of the algorithm up to time T and \({{\hat{W}}}_T = {\mathrm{arg}}\,{\mathrm{min}}_{t = 1, \ldots , T} L_\text {lin}(W_t)\).

Practical MDA: One of the “miracles” of MDA is its robustness to noise. If we replace \(g_t\) with an unbiased estimate \({\tilde{g}}_t,\) such that \(\mathbb {E} \tilde{g}_t = g_t\) and \(\mathbb {E} \left\Vert \tilde{g}_t\right\Vert \le L\), we obtain the same convergence results as in Theorem 3.1 (Robbins & Monro 1951) (see, for example, Bubeck 2015, Theorem 6.1). Such an unbiased estimate can be obtained in the following manner: (i) sample a context \(c_t\), (ii) compute \(\mu ^{\pi ^*_{c_t}(W_t)}_{c_t}\), (iii) observe a single expert demonstration \(\tau ^E_i = \{s_0^i,a_0,s_1^i,a_1,\ldots ,\},\) where \(a_i\) is chosen by the expert policy \(\pi ^*_{c_t}\) (iv) let \({{\hat{\mu }}}_i = \sum _{t\in [0,\ldots ,|\tau ^E_i|-1]}\gamma ^t \phi (s_t^i)\) be the accumulated discounted features across the trajectory such that \(\mathbb {E} {{\hat{\mu }}}_i = \mu ^*_{c_t}\).

However, for \({{\hat{\mu }}}_i\) to be an unbiased estimate of \(\mu ^*_{c_t}\), \(\tau ^E_i\) needs to be of infinite length. Thus one can either (1) execute the expert trajectory online, and terminate it at each time step with probability \(1-\gamma \) (Kakade & Langford 2002), or (2) execute a trajectory of length \(H=\frac{1}{1-\gamma }\log (1/\epsilon _H)\). The issue with the first approach is that since the trajectory length is unbounded, the estimate \({{\hat{\mu }}} _i \) cannot be shown to concentrate to \(\mu ^*_{c_t}\) via Hoeffding type inequalities. Nevertheless, it is possible to obtain a concentration inequality using the fact that the length of each trajectory is bounded in high probability (similar to Zahavy et al. (2020b)). The second approach can only guarantee that \(\left\Vert g_t - \mathbb {E}{\tilde{g}}_t\right\Vert \le \epsilon _H\) (Syed & Schapire 2008). Hence, using the robustness of MDA to adversarial noise (Zinkevich 2003), we get that MDA converges with an additional error of \(\epsilon _H\), i.e.,

$$\begin{aligned} L_{\text {lin}}\left( \frac{1}{T}\sum _{t=1}^T W_t\right) - L_{\text {lin}}(W^*) \le {\mathcal {O}}\left( \frac{1}{\sqrt{T}}\right) + \epsilon _H . \end{aligned}$$

While this sampling mechanism has the cost of a controlled bias, usually it is more practical, in particular, if the trajectories are given as a set of demonstrations (offline data).

3.2 Ellipsoid algorithms for COIRL

Fig. 3
figure 3

The ellipsoid algorithm proceeds in an iterative way, using linear constraints to gradually reduce the size of the ellipsoid until the center defines an \(\epsilon \)-optimal solution

In this section we present the ellipsoid method, introduced to the IRL setting by Amin et al. (2017). We extend this method to the contextual setting, and focus on finding a linear mapping \(W \in {\mathcal {W}} \) where \({\mathcal {W}} = \{W: ||W||_\infty \le 1\}\), and \(W^*\in {\mathcal {W}}\). The algorithm, illustrated in Fig. 3, maintains an ellipsoid-shaped feasibility set for \(W^*\). In each iteration, the algorithm receives a demonstration which is used to create a linear constraint, halving the feasibility set. The remaining half-ellipsoid, still containing \(W^*\), is then encapsulated by a new ellipsoid. With every iteration, this feasibility set is reduced until it converges to \(W^*\).

Formally, an ellipsoid is defined by its center – a vector u, and by an invertible matrix Q: \(\{x:(x-u)Q^{-1}(x-u) \le 1\}\). The feasibility set for \(W^*\) is initialized to be the minimal sphere containing \(\{W: ||W||_\infty \le 1\}\). At every step t, the current estimation \(W_t\) of \(W^*\) is defined as the center of the feasibility set, and the agent acts optimally w.r.t. the reward function \({{\hat{R}}}_c(s) = c^TW_t\phi (s)\). If the agent performs sub-optimally, the expert provides a demonstration in the form of its feature expectations for \(c_t\): \(\mu ^*_{c_t}\). These feature expectations are used to generate a linear constraint (hyperplane) on the ellipsoid that is crossing its center. Under this constraint, we construct a new feasibility set that is half of the previous ellipsoid, and still contains \(W^*\). For the algorithm to proceed, we compute a new ellipsoid that is the minimum volume enclosing ellipsoid (MVEE) around this “half-ellipsoid”. These updates are guaranteed to gradually reduce the volume of the ellipsoid, as shown in Lemma 3.3, until its center is a mapping which induces \(\epsilon \)-optimal policies for all contexts.

Lemma 3.3

(Boyd & Barratt (1991)) If \(B \subseteq \mathbb {R}^D\) is an ellipsoid with center w, and \(x\in \mathbb {R}^D \backslash \{0\}\), we define \(B^+ = \text {MVEE}(\{ \theta \in B: (\theta -w)^Tx \ge 0 \})\), then: \(\frac{Vol(B^+)}{Vol(B)} \le e^{-\frac{1}{2(D+1)}} .\)

Theorem 3.3 below shows that this algorithm achieves a polynomial upper bound on the number of sub-optimal time-steps. The proof, found in Appendix B, is adapted from (Amin et al. 2017) to the contextual setup.

figure b

Theorem 3.3

In the linear setting where \(R^*_c(s) = c^TW^*\phi (s)\), for an agent acting according to Algorithm 1, the number of rounds in which the agent is not \(\epsilon \)-optimal is \({\mathcal {O}}(d^2k^2\log (\frac{d k}{(1-\gamma )\epsilon }))\).

Remark 3.2

Note that the ellipsoid method presents a new learning framework, where demonstrations are only provided when the agent performs sub-optimally. Thus, the theoretical results in this section cannot be directly compared with those of the descent methods. We further discuss this in Appendix D.2.1.

Remark 3.3

The ellipsoid method does not require a distribution over contexts - an adversary may choose them. MDA can also be easily extended to the adversarial setting via known regret bounds on online MDA (Hazan 2016).

3.2.1 Practical ellipsoid algorithm

In real-world scenarios, it may be impossible for the expert to evaluate the value of the agent’s policy, i.e. check if \(V^*_{c_t} - V^{{{\hat{\pi }}}_t}_{c_t} > \epsilon \), and to provide its policy or feature expectations \(\mu ^*_{c_t}\). To address these issues, we follow Amin et al. (2017) and consider a relaxed approach, in which the expert evaluates each of the individual actions performed by the agent rather than its policy (Algorithm 3). When a sub-optimal action is chosen, the expert provides finite roll-outs instead of its policy or feature expectations. We define the expert criterion for providing a demonstration to be \(Q^*_{c_t}(s,a) + \epsilon < V^*_{c_t}(s)\) for each state-action pair (sa) in the agent’s trajectory.

figure c

Near-optimal experts: In addition, we relax the optimality requirement of the expert and instead assume that, for each context \(c_t\), the expert acts optimally w.r.t. \(W^*_t\) which is close to \(W^*\); the expert also evaluates the agent w.r.t. this mapping. This allows the agent to learn from different experts, and from non-stationary experts whose judgment and performance slightly vary over time. If a sub-optimal action w.r.t. \(W^*_t\) is played at state s, the expert provides a roll-out of H steps from s to the agent. As this roll-out is a sample of the optimal policy w.r.t. \(W^*_t\), we aggregate n examples to assure that with high probability, the linear constraint that we use in the ellipsoid algorithm does not exclude \(W^*\) from the feasibility set. Note that these batches may be constructed across different contexts, different experts, and different states from which the demonstrations start. Theorem 3.4, proven in Appendix B, upper bounds the number of sub-optimal actions that Algorithm 3 chooses.Footnote 4

Theorem 3.4

For an agent acting according to Algorithm 3, \(H=\lceil \frac{1}{1-\gamma }\log (\frac{8k}{(1-\gamma )\epsilon }) \rceil \) and \(n=\lceil \frac{512k^2}{(1-\gamma )^2\epsilon ^2}\log (4dk(dk+1)\log (\frac{16k\sqrt{dk}}{(1-\gamma )\epsilon })/\delta ) \rceil \), with probability of at least \(1-\delta \), if \(\forall t: {\underline{W}}^*_t \in B_\infty ({\underline{W}}^*,\frac{(1-\gamma )\epsilon }{8k}) \cap \varTheta _0\) the number of rounds in which a sub-optimal action is played is \({\mathcal {O}}\Big (\frac{d^2k^4}{(1-\gamma )^2\epsilon ^2}\log \big (\frac{dk}{(1-\gamma )\delta \epsilon }\log (\frac{dk}{(1-\gamma )\epsilon })\big )\Big )\,\).

The theoretical guarantees of the algorithms presented so far are summarized in Table 1. We can see that MDA, in particular EW, achieves the best scalability. In the unrealistic case where the expert can provide its feature expectations, the ellipsoid method has the lowest sample complexity. However, in the realistic scenario where only samples are provided, the sample complexity is identical across all methods. We also note that unlike MDA and ES, it isn’t possible to extend the ellipsoid method to work with DNNs. Overall, the theoretical guarantees favor the MDA methods when it comes to the realistic setting.

Table 1 Summary of theoretical guarantees

3.3 Existing approaches

We focus our comparisons to methods that can be used for zero-shot generalization across contexts or tasks. Hence, we omit discussion of “meta inverse reinforcement learning” methods which focus on few-shot generalization (Xu et al. 2018). Our focus is on two approaches: (1) standard IRL methods applied to a model which incorporates the context as part of the state, and (2) contextual policies through behavioral cloning (BC) (Pomerleau 1989).

3.3.1 Application of IRL to COIRL problems

We first examine the straight-forward approach of incorporating the contextual information into the state, i.e., defining \({\mathcal {S}}' = {\mathcal {C}} \times {\mathcal {S}}\), and applying standard IRL methods to one environment which captures all contexts. This construction limits the context space to a finite one, as opposed to COIRL which works trivially with an infinite number of contexts. At first glance, this method results in the same scalability and sample complexity as COIRL; however, when considering the inner loop in which an optimal policy is calculated, COIRL has the advantage of a smaller state space by a factor of \(|{\mathcal {C}}|\). This results in significantly better run-time when considering large context spaces. In Sect. 4.1, we present experiments that evaluate the run-time of this approach, compared to COIRL, for increasingly large context spaces. These results demonstrate that the run-time of IRL scales with \(|{\mathcal {C}}|\) while the run-time of COIRL is unaffected by \(|{\mathcal {C}}|\), making COIRL much more practical for environments with many or infinite contexts.

3.3.2 Contextual policies

Another possible approach is to use Behavioral Cloning (BC) to learn contextual policies, i.e., policies that are functions of both state and context \(\pi (c,s)\). In BC, the policy is learned using supervised learning methods, skipping the step of learning the reward function. While BC is an intuitive method, with successful applications in various domains (Bojarski et al. 2016; Ratliff et al. 2007), it has a fundamental flaw; BC violates the i.i.d. assumptions of supervised learning methods, as the learned policy affects the distribution of states it encounters. This results in a covariate shift in test-time leading to compounding errors (Ross & Bagnell 2010; Ross et al. 2011). Methods presented in Ross et al. (2011); Laskey et al. (2017) mitigate this issue but operate outside of the offline framework. This explains why BC compares unfavorably to IRL methods, especially with a limited number of available demonstrations (Ho & Ermon 2016; Ghasemipour et al. 2019). In Sect. 4.4.2, we provide experimental results that exhibit the same trend. These results demonstrate how matching actions on the train set poorly translates to value on the test set, until much of the expert policy is observed. While a single trajectory per context suffices for COIRL, BC requires more information to avoid encountering unfamiliar states. We also provide a hardness result for learning a contextual policy for a linear separator hypothesis class, further demonstrating the challenges of this approach.

3.4 Transfer across contexts in test-time

In this section, we examine the application of the learned mapping W when encountering a new, unseen context in test-time. Unlike during training, in test-time the available resources and latency requirements may render re-solving the MDP for every new context infeasible. We address this issue by leveraging optimal policies \(\{\pi ^*_{c_j}\}_{j=1}^{N}\) for contexts \(\{c_j\}_{j=1}^{N}\) which were previously calculated during training or test time. We separately handle context-independent dynamics and contextual dynamics by utilizing (1) generalized policy improvement (GPI) (Barreto et al. 2017), and (2) the simulation lemma (Kearns & Singh 2002), respectively.

For context-independent dynamics, the framework of Barreto et al. (2017) can be applied to efficiently transfer knowledge from previously observed contexts \(\{c_j\}_{j=1}^{N}\) to a new context c. As the policies \(\{\pi ^*_{c_j}\}_{j=1}^{N}\) were computed, so were their feature expectations, starting from any state. As the dynamics are context-independent, these feature expectations are also valid for c, enabling fast computation of the corresponding Q-functions, thanks to the linear decomposition of the reward. GPI generalizes policy improvement, allowing us to use these Q-functions to create a new policy that is as good as any of them and potentially strictly better than them all. The following theorem, a parallel of Theorem 2 in Barreto et al. (2017), defines the GPI calculation and provides the lower bound on its value. While these theorems and their proofs are written for \(W^*\), the results hold for any \(W \in {\mathcal {W}}\).

Theorem 3.5

(Barreto et al. (2017)) Let \(\phi _{max} = \max _s ||W^*\phi (s)||_1\), \(\{c_j\}_{j=1}^N \subseteq {\mathcal {C}}\), \(c \in {\mathcal {C}}\), and \({\pi (s) \in arg\,max_a \max _j Q^{\pi ^*_{c_j}}_{c}(s,a)}\). If the dynamics are context independent, then:

$$\begin{aligned} V^*_{c} - V^\pi _{c} \le 2 \frac{\phi _{max}}{1-\gamma }\min _j||c-c_j||_\infty . \end{aligned}$$

When the dynamics are a function of the context, the feature expectations calculated for \(\{c_j\}_{j=1}^N\) are not valid for c, thus GPI can not be used efficiently. However, due to the linearity and therefore continuity of the mapping, similar contexts induce similar environments. Thus, it is intuitive that if we know the optimal policy for a context, it should transfer well to nearby contexts without additional planning. This intuition is formalized in the simulation lemma, which is used to provide bounds on the performance of a transferred policy in the following theorem.

Theorem 3.6

Let \(c,c_j\in {\mathcal {C}}, \phi _{max} = \max _s ||W^*\phi (s)||_1\), \(V_{max} = \max _{c,s} |V^*_c(s)| \). Then:

$$\begin{aligned} V^*_{c} - V^{\pi ^*_{c_j}}_{c} \le 2 \frac{\phi _{max} + \gamma d V_{max}}{\gamma (1-\gamma )}||c-c_j||_\infty . \end{aligned}$$

Remark 3.4

The bound depends on \({\mathcal {W}}\). For example, for \({\mathcal {W}}=\varDelta _{dk-1}\), the bound is \(2 \frac{1 - \gamma + \gamma d}{\gamma (1-\gamma )^2}||c-c_j||_\infty \), and for \({\mathcal {W}}=B_\infty (0,1)\) the bound is \(\frac{2dk}{\gamma (1-\gamma )^2}||c-c_j||_\infty \).

Remark 3.5

If the dynamics are independent of the context, the term \(\gamma d V_{max}\) is omitted from the bound.

Using these methods, one can efficiently find a good policy for a new context c, either as a good starting point for policy/value iteration which will converge faster or as the final policy to be used in test-time. The last thing to consider is the construction of the set \(\{c_j\}_{j=1}^{N}\). As COIRL requires computing the optimal policies for W during training, the training contexts are a natural set to use. In addition, as suggested in Barreto et al. (2017), we may reduce this set or enhance it in a way that maintains a covering radius in \({\mathcal {C}}\) and guarantees a desired level of performance. If the above methods are used as initializations for calculating the optimal policy, the set can be updated in test-time as well.

4 Experiments

In the previous sections we described the theoretical COIRL problem, proposed methods to solve it and analyzed them. In this section our goal is to take COIRL from theory to practice. This section presents the process and the guidelines we follow to achieve this goal in a step-by-step manner, to bridge the gap between theoretical and real-life problems through a set of experiments.Footnote 5

We begin by focusing on the grid world and autonomous driving simulation environments. As these are relatively small domains, for which we can easily compute the optimal policy, they provide easily accessible insight into the behavior of each method and allow us to eliminate methods that are less likely to work well in practice. Then we use the sepsis treatment simulator in a series of experiments to test and adjust the methods towards real-life application. The simulator is constructed from real-world data in accordance with the theoretical assumptions of COIRL. Throughout the experiments we strip the assumptions from the simulator and show that the methods perform well in an offline setting. Furthermore, we show that a DNN estimator achieves high performance when the mapping from the context to the reward is non-linear.

Finally, we test the methods in sepsis treatment – without the simulator. Here, we use real clinicians’ trajectories for training and testing. For COIRL, we estimate a CMDP\(\backslash \) M model from the train data (states and dynamics) which is used for training purposes. We then show that COIRL achieves high action matching on unseen clinicians trajectories.

4.1 Grid world

Fig. 4
figure 4

Run-time comparison between COIRL and AL. AL run-time grows as the number of contexts grows while COIRL run-time stays fixed

The grid world domain is an n by m grid which makes \(|{\mathcal {S}}|=n\cdot m\) states. The actions are \({\mathcal {A}}=\{left,up,right,down\}\) and the dynamics are deterministic for each action, i.e., if the action taken is up, the next state will be the state above the current state in the grid (with cyclic transitions on the borders, i.e., taking the action right at state \((n-1, y)\) will transition to (0, y)). The features are one-hot vectors (\(\phi (s_i)=e_i \in \mathbb {R}^{n\cdot m}\)). The contexts correspond to “preferences” of certain states on the grid. The contexts are sampled from a uniform distribution over the \(n\cdot m\) dimensional simplex.

This domain is used to evaluate the application of IRL to COIRL problems. We compare the performance of PSGD (COIRL) and the projection algorithm (AL) of Abbeel & Ng (2004) as a function of the context space size. This framework is applied on a grid with dimensions of \(3\cdot 4\), overall 12 states. The PSGD method trains on a CMDP model and the projection algorithm trains on a large MDP model, with a state space that includes the contexts, as noted in Sect. 3.3.1. The new states are \(s' = (s,c)\), and the new features are \(\phi (s')=c\odot \phi (s)\). We measure the run-time of every iteration. The most time consuming part of both methods is the optimal policy computation time for a given reward. Both methods use the same implementation of value iteration in order to enable a comparison of the run-time.

The results shown in Fig. 4 show that the projection algorithm in the large MDP requires significantly more time to run as the number of contexts grows, while the run-time of PSGD is not affected by the number of contexts.

4.2 Conclusion

Applying IRL methods in a large MDP environment limits the number of contexts that can be used, and as seen in the results, its run time grows when the number of contexts increases. We conclude that applying IRL to COIRL problems is inefficient and exclude this method from the following experiments (Sect. 4.2 through Sect. 4.4).

4.3 Autonomous driving simulation

Fig. 5
figure 5

An illustration of the driving simulator

While the grid world focused on comparing COIRL with the standard IRL method, in this section we compare the various methods for performing COIRL in an autonomous driving simulator (Fig. 5). This domain involves a three-lane highway with two visible cars, cars A and B. The agent, controlling car A, can drive both on the highway and off-road. Car B drives in a fixed lane, at a slower speed than car A. Upon leaving the frame, car B is replaced by a new car, appearing in a random lane at the top of the screen. The features denote the speed of car A, whether or not it is on the road and whether it has collided with car B. The context implies different priorities for the agent; should it prefer speed or safety? Is going off-road a valid option? For example, an ambulance will prioritize speed and may drive off-road, as long as it goes fast and avoids collisions, while a bus will prioritize avoiding both collisions and off-road driving as safety is its primary concern. The mapping from the contexts to the true reward is constructed in a way that induces different behaviors for different contexts, making generalization a challenging task.

4.3.1 Ellipsoid setting

The ellipsoid method requires its own framework. Here, the agent’s policy is evaluated by an expert for every new context revealed. Only if its value is not \(\epsilon \)-close to the optimal policy value, an expert demonstration will be provided (feature expectations of an expert for the revealed context). While the ellipsoid method can only perform a single update for each demonstration, the descent methods can utilize all of the previously revealed demonstrations and perform update steps until convergence. We measure the accumulated amount of expert demonstrations given at each time-step and the value of the agent on a holdout test set, for each new given demonstration.

The amount of given demonstrations is important in the ellipsoid framework, as it is equal to the number of times that the agent is not \(\epsilon \)-close to the optimal policy value. In addition, it is a way to measure how much intervention is required by an external expert. We would expect a ‘good’ method to be \(\epsilon \)-optimal for most revealed contexts and therefore it should observe a small amount of demonstrations.

The results, presented in Fig. 6, show that all methods eventually reach the expert’s value; however, the descent methods are more sample efficient than the ellipsoid method and require fewer expert demonstrations. While according to the theoretical guarantees (Table 1, feature expectations setting) the ellipsoid method should have better sample complexity, in practice it is surpassed by the results of the descent methods. Note that in this experiment each demonstration may be used more than once by the descent methods, hence the theoretical results are not valid for them.

Fig. 6
figure 6

Comparison of the ellipsoid method with the ES and PSGD methods in the autonomous driving simulation. The graph on the left compares the number of demonstrations required by each method, while the graph on the right compares the performance at each time-step. We observe that while, as theoretically shown, all methods eventually find an \(\epsilon \)-optimal solution, the descent methods attain better sample efficiency (converge faster and require less expert interaction)

4.3.2 Online setting

Here, we compare the descent methods presented in Sect. 3 in an online setting. Each descent step is performed on a context-\(\mu \) pair, where the context is sampled uniformly from the simplex and \(\mu \) is the feature expectations of a policy that is optimal for this context. For each method, we measure the normalized value of the proposed policies with respect to the real reward, the loss (Eq. (8)), and the accuracy, which represents how often the expert and agent policies match. These criteria are evaluated on a holdout set of contexts, unseen during training. The x-axis corresponds to the number of contexts seen during training, i.e., the number of subgradient steps taken.

In this setting we use two setups, which differ by the observed feature expectations. First, in the feature expectations setup, we assume that the whole optimal policy can be observed, therefore, for training we use the feature expectations of the expert’s policy. The results are shown in Fig. 7. They show a strong correlation between ‘loss minimization’ and ‘value maximization’. EW converges faster than PSGD and the ES method consistently lies between EW and PSGD, displaying comparable sample complexity. These results match the theoretical guarantees (Table 1, feature expectations) as EW has tighter bounds when it comes to scalability compared to PSGD and ES.

Fig. 7
figure 7

Online learning curve in the autonomous driving simulation—learning from feature expectations. The expert demonstrations are provided in the form of the feature expectations of the expert’s policy. We compare the loss, value and accuracy, where the value and accuracy are relative to the expert’s behavior. As can be seen, all descent methods minimize the loss and achieve high value. Additionally, we observe that while they do attain relatively high accuracy, they find policies which are optimal yet differ from the expert in the actions taken

The second setup we use is the trajectories setup. Here we construct the feature expectations using a finite number of samples taken from the expert’s policy, each context correspond to a finite rollout of an expert (motivated by real life limitations). The results in Fig. 8 show that all three descent methods attain high value and accuracy in this setup. As in the feature expectations setting, the results validate the theoretical sample complexity, with the exception that ES performs slightly better than PSGD. Comparing the results of the different setups we observe similar performance for training with the whole expert’s policy or a sample of it, as expected (Sect. 3.1, practical MDA). Training with trajectories is closer to the available data in real-life applications, since only samples of policies are provided.

Fig. 8
figure 8

Online learning curve in the autonomous driving simulation—learning from trajectories. While in Fig. 7 the demonstrations were in the form of feature expectations, here we provide trajectories, a less informative approach. Although less informative, we observe that, similarly to Fig. 7, all methods perform well, attaining similar performance as when given the full information

4.3.3 Conclusion

The ellipsoid method is not as sample efficient as the descent methods. Furthermore, it demands constant expert monitoring, which in real-world problems might be unavailable. In many real-world tasks, such as the sepsis treatment domain, there is an abundance of offline data, yet evaluation in real-life may not be available. Thus, we do not include experiments of the ellipsoid method in the sepsis treatment domain.

The ES and EW methods also have their drawbacks: ES requires computation of the loss function at a considerably large number of points for every descent step. This requirement makes the ES method computationally expensive and prevents it from scaling to larger environments. The EW method assumes that the model parameters lay within the simplex, an assumption that limits the policy space in the linear case, and may not hold in the non-linear case, where the mapping between the context and the reward is modeled by a neural network. As such, we do not include these methods in the sepsis treatment domain.

4.4 Sepsis treatment simulator

This domain simulates a decision-making process for treating sepsis. Sepsis is a severe, life-threatening infection, where the treatment applied to a patient is crucial for saving its life. To create a sepsis treating simulator, we leverage the MIMIC-III data set (Johnson et al. 2016). This data set includes data from hospital electronic databases, social security, and archives from critical care information systems, that had been acquired during routine hospital care. We follow the data processing steps that were taken in Jeter et al. (2019) to extract the relevant data in a form of normalized measurements of sepsis patients during their hospital admission and the treatments that were given to each patient. The measurements include dynamic measures, e.g., heart rate, blood pressure, weight, body temperature, blood analysis standard measures (glucose, albumin, platelets count, minerals, etc.), as well as static measures such as age, gender, re-admission (of the patient), and more.

From the processed data we construct a dynamic treatment regime, modeled as a CMDP, in which a clinician acts to improve a sick patient’s medical condition. The context represents patient features that are constant during treatment, such as age and height. The state summarizes dynamic measurements of the patient, e.g., blood pressure and EEG readouts. The actions represent different combinations of fluids and vasopressors, drugs commonly provided to restore and maintain blood pressure in sepsis patients. The mapping from the context to the true reward is constructed from the data. Dynamic treatment regimes are particularly useful for managing chronic disorders and fit well into the broader paradigm of personalized medicine (Komorowski et al. 2018; Prasad et al. 2017). Furthermore, dynamic treatment regimes have contextual properties; what is defined as healthy blood pressure for a patient differs based on age and weight (Wesselink et al. 2018). In our setting, \(W^*\) captures this information – mapping from contextual (e.g., age) and dynamic information (e.g., blood pressure) to reward.

As noted in previous sections, we move toward real-life application and eliminate the inefficient methods. In this section we evaluate the PSGD and compare it with GPI (Sect. 3.4) and contextual BC (Sect. 3.3.2).

4.4.1 Online setting

In this setting we evaluate only the PSGD method. Similarly to the autonomous driving simulation we use two setups: (1) we train the methods with the expert’s feature expectations for each context, and (2) instead of using the expert’s feature expectations for each given context, we use an estimation, calculated from a given expert trajectory (Sect. 3.1, practical MDA). We present the results of both setups in the same figure, so a comparison between the setups can be done.

We observe in Fig. 9 that PSGD performs well in both setups, with slightly better performance with feature expectations, as expected. This supports the theory, as using samples should not affect the convergence results and truncation after 40 steps should incur only a small penalty. An important observation is that high accuracy is not necessary for high value, as our agents achieve near-perfect value with relatively low accuracy. This reinforces the use of IRL for imitation learning tasks, as it supports the claim that the reward function, not the policy, is the most concise and accurate representation of the task (Abbeel & Ng 2004).

Fig. 9
figure 9

Online setting in sepsis treatment. We compare the relative value and accuracy when the agent is provided the feature expectations or finite length trajectories. We observe that while as the feature expectations are more informative, the performance is slightly better. However, notice that the difference is negligible and amounts to less than 0.5% difference in the relative value

4.4.2 Offline setting

Here, we evaluate the COIRL, GPI and contextual BC methods. We test the ability of these methods to generalize with a limited amount of data. The motivation for this experiment comes from real-life applications, where the data available is often limited in size. The data, similarly to the online setting, is constructed from context-trajectory pairs. In this setting we minimize the loss function (Eq. (8)) by taking descent steps on mini-batches sampled from the data set, with repetition, which invalidates the theoretical results. We conduct two experiments that evaluate the performance as a function of the train-set size (the amount of context-trajectory pairs used for training). We consider two mappings from the context to the reward; a linear mapping, and a non-linear mapping. For the non-linear mapping we use a DNN estimator which constitutes another step towards real-world applicability.

Remark 4.1

Contextual BC is a method to learn a contextual policy, instead of a contextual reward. In its implementation we use a DNN that, given a context and state-features, computes a probability vector, \({\hat{\pi }}_c(s)\), representing the agent’s policy – i.e., the probability to take action \(a\in {\mathcal {A}}\) is the a’th element of the DNN output \({\hat{\pi }}_c(s, a)\). The state-features that are given as an input greatly affect BC performance, especially when we compare it to COIRL, which uses the real dynamics as well as features that represent each state. BC can make good use of the dynamics, as states with similar dynamics should be mapped to similar actions. To improve the performance for BC, we use the same state-features that COIRL uses (HR, blood pressure, etc...), in addition to a feature-vector that represents the dynamics. For each state, \(s\in {\mathcal {S}}\), the dynamics can be represented as a concatenation of the probability vectors, \(\big \{P(s,a)\big \}_{a\in {\mathcal {A}}}\), where \(P(s,a)[i]=P(s,a,s_i)\). The dimension of the dynamics for each state is \(|{\mathcal {S}}|\cdot |{\mathcal {A}}|\) which is relatively large in the sepsis treatment simulator, hence we reduce its dimensionality with PCA.

Fig. 10
figure 10

Offline setting in sepsis treatment. The x-axis denotes the number of contexts in the training set. Results on the train data are represented using circles and x’s, the results on a holdout test data-set represented as lines. Given a sufficient amount of contexts seen, GPI is comparable to re-solving the domain, hence there is a large overlap between the results of GPI and COIRL. Contextual BC requires much more data to generalize well

In Fig. 10 we compare the performance of COIRL, GPI and contextual BC in the linear setting, when provided with a fixed amount of data. The results show that in the sepsis treatment domain, the COIRL and GPI methods perform similarly and able to generalize well for a small amount of train data compared to contextual BC. As expected, in Fig. 10(b) BC attains better accuracy on the train data while in Fig. 10(a) COIRL and GPI methods attain better value on the train data. Another observation is that COIRL achieves similar performance on the training data and on the test data; it is able to generalize to unseen contexts, even when the amount of training data is small. On the other hand, BC achieves almost perfect accuracy and high value on the train data but performs poorly on the test data. This generalization gap goes away only when a large amount of data is available for training.

Fig. 11
figure 11

Offline setting in sepsis treatment: non-linear mapping. The x-axis denotes the number of contexts in the training set (logarithmic scale). Results on the train data are represented using circles and x’s, the results on a holdout test data-set represented as lines. Similar to the linear setup, GPI and COIRL generalize well for a small amount of train data where the performance on the train data and on the test data is similar. Contextual BC performance on the train set is almost perfect, where its performance on the test data requires a large amount of expert demonstrations

The non-linear setup results presented in Fig. 11. Here, the x-axis is in logarithmic scale. The performance of all methods is similar to the linear setup; COIRL and GPI methods perform similarly and generalize to unseen contexts even when given a small amount of train data. Contextual BC generalizes to unseen contexts only for a large amount of train data. As in the linear setup, the BC method attains better accuracy while the COIRL and GPI methods attain better value.

4.5 Sepsis treatment in real-life

In the previous subsections we focused on analyzing COIRL in simulated environments. We have taken a sequence of steps with the aim of making the simulations more and more realistic. In all of these simulations, the expert trajectories were always generated from the optimal policy (for a given context) w.r.t to the true context-reward mapping. Our results suggest that the reward estimated by COIRL induces a policy that attains a close-to-expert value in both linear and non-linear settings. Now we turn to examine our algorithms in a real world data set. Since the true mapping is no longer known, we can only measure the accuracy of our resulted policies. In previous sections we observed that while accuracy does not necessarily imply value (i.e., a policy can have optimal value but not be \(100\%\) accurate), these measures are often correlated. In addition, since the true dynamics of the MDP is now unknown, we estimate it from the data itself.

4.5.1 Data processing

We follow the steps done in Komorowski et al. (2018) to construct a time-series data of static and dynamic measurements. The data is divided to trajectories, where each trajectory represent a different patient. We consider only trajectories of length greater than 10 that represent 40 hours. The processed data is consisted of 14, 591 trajectories, divided to a 60-20-20 train-validation-test partition. Each trajectory corresponds to a static measurements vector and a time series of dynamic measurements vectors, with time steps of 4 hours. In the following experiments each model is trained on the training set, until an early stopping criteria is met on the validation set. We then report the accuracy (action matching with the clinicians actions) on the holdout test set.

4.5.2 Model fitting

As in Sect. 4.4, the contexts and the states constructed from static and dynamic measurements respectively. In our model, the contexts are in \(\mathbb {R}^7\) and include the gender, age, weight, GCS, elixhauser co-morbidity score, whether the patient was mechanically ventilated at \(s_0\) and whether the patient has been re-admitted to the hospital. The actions are defined to be the amount of vasopressors given to a patient at each time slot, and five discrete actions are constructed by dividing the possible values into five bins. The state space is constructed by clustering the observed patient dynamic measurements from the data with K-means (MacQueen et al. 1967). The clustering process is repeated for different numbers of states and different weights for each measurement (to control the importance of each measurement for the state space). Each model is evaluated by two terms: (1) number of different actions taken on the same state for the same patient: \(\mathbb {E}_\tau \big [\mathbb {E}_{s\in \tau }[ |\hat{{\mathcal {A}}}_s^{\tau }|]\big ]\), where \(\hat{{\mathcal {A}}}_s^{\tau }=\{a\in {\mathcal {A}} : (s,a) \in \tau \}\). (2) number of different states in each trajectory: \(\mathbb {E}_\tau \big [ |\hat{{\mathcal {S}}}^\tau | \big ]\), where \(\hat{{\mathcal {S}}}^\tau = \{s\in {\mathcal {S}} : s \in \tau \}\). In both terms, \(\tau \) is a trajectory drawn from the data. We require the first term to be as small as possible, to achieve a consistent experts policies in the CMDP model, the second term required to be large, to force the resulted model to distinguish between different states in the same patient’s trajectory. Obviously, the model has to be as small as possible, to enable generalization. The chosen model consists 5000 states.

While processing the data, we noticed that clinicians behavior with respect to some measurements is random. To address this matter we consulted with clinicians and defined a set of important dynamic measurements, among them we use the clustering process to choose the patients relevant dynamic measurements for the states; states were clustered for any possible single measurement and the five best dynamic measurements were chosen: mean blood pressure, diastolic blood pressure, shock index, cumulative balance of fluids and the fluids given to a patient. The features in this CMDP are action-dependent and set to be a concatenation of \(e_i \in \mathbb {R}^{|{\mathcal {S}}|}\) and \(e_j \in \mathbb {R}^{|{\mathcal {A}}|}\) where \(e_i\) is a vector of all zeros and a single 1 that represents each state and \(e_j\) represents the action, overall the there are 5, 005 features for each state-action pair.

As described in Sect. 2, learning the transition kernel is an orthogonal problem to the COIRL problem, and can be viewed as a part of the model fitting process. Our dynamics model is context-dependent; the contexts (patients) clustered into five clusters and the dynamics of each cluster are then estimated using the training data.

4.5.3 Methods

For COIRL we report results for the linear and the non-linear mappings. In both setups, we use a discount factor \(\gamma =0.7\) and a mini-batch of size 32. The stopping criteria is set to stop when five consecutive steps do not increase the validation accuracy. To speed-up the validation process we sample a subset of 300 patients from the validation partition at the beginning of each seed and use them to validate the model. In the linear setup the step size is \(\alpha _t = 0.25\cdot 0.95^t\). The non-linear setup use a DNN to learn the mapping \(f_W:{\mathcal {C}}\longrightarrow \mathbb {R}^{|S|+|A|}=\mathbb {R}^{5,005} \approx \mathbb {R}^{5K}\), it has four layers with a Leaky ReLU activation and batch-normalization between the first and second layers, and Leaky ReLu activation between the second and third layers, their sizes are 20K, 10K, 10K, and 5K, respectively. Here, the step size is \(\alpha _t = 0.2\cdot 0.95^t\)

Table 2 Results on real world data. We measure the accuracy of each method over a holdout test set. In the non-linear setting, COIRL achieves the best accuracy and outperforms BC

For BC we also use a DNN for function approximation, as we found it to work much better than a linear model. We also experimented with different sets of features as inputs. The features that we found to give the best performance were computed in a similar manner to the features that we used for BC in Remark 4.1, using the dynamics of the estimated CMDP, resulted with 5K features that represent each state. Concretely, the DNN received a concatenation of the context and the features that represent the current state (size of 5, 007) and outputs a stochastic policy (softmax over the outputs of the last layer). The network architecture is composed of three linear layers of sizes 625, 125, and 5, respectively. Each layer is followed by a Leaky ReLU activation, and a Softmax activation is used on the output. Similar to COIRL, the model is trained over the training set partition and the stopping criteria is set to stop after 5 epochs of non-increasing validation accuracy. The loss of the DNN is the binary cross-entropy loss between the DNN output and the observed action, \(e_i \in \mathbb {R}^{|{\mathcal {A}}|}\). The mini-batch size is 32 and the optimizer is SGD with step size \(\alpha _t = 0.1\cdot \frac{1}{1+10^{-7}t}\).

Each method trained and evaluated over five seeds, the results are presented in Table 2. We can see that COIRL with a non-linear mapping attains the best performance, while the linear mapping achieves poor accuracy. BC performs well overall, but not as good as COIRL. In Lee et al. (2019) the authors use similar data set and action space. Their methods, TRIL and DSFN, achieve \(80 \pm 2\%\) and \(79 \pm 5\%,\) respectively, which is lower than COIRL and with higher variance. These results suggest that the contextual hypothesis better represents the real world, i.e., that physicians indeed use personalized treatments based on context.

5 Discussion

Motivated by current trends in personalized medicine (Juskalian et al. 2020), we proposed the Contextual Inverse Reinforcement Learning framework. While most works in RL assume the agent is provided with a reward signal, we focused on a more realistic setting, where the reward is unknown to the agent and, instead, it observes and receives feedback from an expert. As opposed to the standard IRL setting, in the contextual case, each context defines a new MDP. This leads to a new form of generalization in RL, where the agent is trained and learns how to act optimally on a set of contexts, followed by an evaluation procedure on a set to which the agent was not exposed during training.

We show that solving the COIRL objective can be performed by minimizing a convex optimization task. As this objective is not differentiable, we proposed two schemes based on subgradient descent (MDA and ES) and an adaptation of cutting plane methods (ellipsoid). We analyzed the convergence properties of each algorithm and summarized the results in Table 1.

All of the proposed methods assume that the dynamics are known, but in many applications the dynamics and even the state space are unknown. Following the description in Sect. 2, any method that learns the dynamics efficiently can be used prior to COIRL. For example, in online frameworks, where the expert provides demonstrations in an online manner, the dynamics can be learned as proposed in Abbeel & Ng (2005). In this case, the dynamics estimation and COIRL should run iteratively, such that every change in the estimation of the dynamics introduces a new COIRL problem that should be solved. In offline frameworks the dynamics can be estimated prior to COIRL, similarly to Sect. 4.5.

In addition to the theoretical analysis, we performed extensive empirical evaluation between all proposed algorithms, including baseline approaches. Here, we see a mixed correlation between theoretical and practical results. Regarding the ellipsoid schemes, we observe that indeed as shown theoretically, they are sub-optimal compared to the other methods. However, comparing MDA to ES, we see that ES matches and sometimes outperforms MDA even though the theoretical upper-bounds are tighter for MDA. These results correlate with observations seen by Nemirovsky & Yudin (1983), where ES often provides better empirical results.

Aside from comparing between our proposed methods, we also compared to a common learning scheme—behavioral cloning. While IRL aims to find a reward function which explains the experts behavior, behavioral cloning (log-likelihood action matching) simply converts the RL task into a supervised learning problem. Previous works (Abbeel & Ng 2004) talk about the importance of IRL, compared to BC. In our experiments we see this clearly. While the reward/value is smooth (Lipschitz) w.r.t. the context, the policy is not. As a small change in the context may lead to a large switch in the policy (the optimal actions change in certain states), we observe that BC struggles. This can also be seen in the fact that COIRL often reaches imperfect action-matching (accuracy) yet attains the optimal return.

We demonstrated how existing policies can be transferred to new contexts, avoiding planning in test-time. This is important, as planning complexity is a function of the size of the MDP, thus this form of transfer may be crucial for real-world scenarios. Our experiments illustrate how combining offline COIRL with GPI eases the computational load on the agent while maintaining strong performance with few training examples.

Finally, COIRL achieved the highest accuracy in the challenging task of predicting the clinicians treatment in the real world sepsis treatment data set. This suggests that sepsis treatment can be modeled as a contextual MDP; we hope that these findings will motivate future work in using contextual MDPs to model real-world decision making.

To conclude, we proposed the COIRL framework and analyzed it under a linear mapping assumption. In real-world cases, where the linear assumption holds, COIRL can be used effectively. Future work may combine COIRL with schemes such as meta-learning (Finn et al. 2017) in order to cope with infinitely large MDPs and non-linear mappings.