RL
Table of Contents
- 1. Markov Decision Process (MDP)
- 2. Bellman Equation
- 3. Dynamic Programming Iterations in RL
- 4. Monte Carlo Methods in RL
- 5. Temporal Difference
- 6. Policy Gradient
- 7. Actor-Critic
- 8. PPO
- 9. LLM Decouple Actor and Critic
- 10. GRPO
- 11. MoE
- 12. Multi-token prodiction
- 13. Multi-token training
1. Markov Decision Process (MDP)
- State: the set of states \( S \)
- at a time point, for all parts, how the world looks like.
- State space: at a time point, for all parts, how the world could look like
- Action: the set of actions \( \mathcal{A}(s) \) is associated with state \( s \in S \)
- at a time point, for each parts, what I do.
- Action space : at a time point, for each parts, what I could do
- Reward: the set of rewards \( \mathcal{R}(s, a) \)
- Trajectory
- episode, return: sum of discount return \(\gamma\) in a episode(trajectory)
- State transition probability:
- state transition: at a time point, if I did one operation, what the world is changed to.
- at a time point, if I did one operation, How the world can be changed.
- At state \( s \), taking action \( a \), the probability to transit to state \( s' \) is \( p(s' \mid s, a) \)
- Reward probability: At state \( s \), taking action \( a \), the probability to get reward \( r \) is \( p(r \mid s, a) \)
- Policy
- deterministic & stochastic
- At state \( s \), the probability to choose action \( a \) is \( \pi(a \mid s) \)
Markov property Memoryless property: \[ p(s_{t+1} \mid a_{t+1}, s_t, \ldots, a_1, s_0) = p(s_{t+1} \mid a_{t+1}, s_t) \]
\[ p(r_{t+1} \mid a_{t+1}, s_t, \ldots, a_1, s_0) = p(r_{t+1} \mid a_{t+1}, s_t) \]
2. Bellman Equation
- Return : summary of all discount return in one complete trajectory
- State Value: expectation of Return for all possiable trajectories \[ V_{\pi}(s) = \sum_{a} \pi(a|s) \sum_{s', r} p(s', r | s, a) \left[ r + \gamma V_{\pi}(s') \right] \] Relationship to action value: \[ V_{\pi}(s) = \sum_{a \in \mathcal{A}} \pi(a|s) Q_{\pi}(s, a) \]
- Action Value: expectation of Return for all possiable trajectories after taking a specified action \[ Q_{\pi}(s, a) = \sum_{s', r} p(s', r | s, a) \left[ r + \gamma \sum_{a'} \pi(a'|s') Q_{\pi}(s', a') \right] \] Relationship to state value \[ Q_{\pi}(s, a) = \sum_{s', r} p(s', r | s, a) \left[ r + \gamma V_{\pi}(s') \right] \]
- Bellman Optimality Equation is bellman equation with the best policy
3. Dynamic Programming Iterations in RL
3.1. Value Iteration
Value Iteration combines policy improvement and a single step of policy evaluation into one operation.
- Initialization: Start with an arbitrary initial value for all states (e.g., \(v_0 = 0\)).
- Iteration:
- Implicit Policy Update: For all states, apply all possible action (Q-table), look ahead to find the best action. \[ \pi_{k+1} = \arg\max_{\pi} \left( r_{\pi} + \gamma P_{\pi} v_k \right) \]
- Value Update: Use that best action to update the state value immediately. \[ v_{k+1} = r_{\pi_{k+1}} + \gamma P_{\pi_{k+1}} v_k \]
3.2. Policy Iteration
Policy Iteration separates the process into two distinct phases. Crucially, the Policy Evaluation phase is itself an iterative process that looks very similar to Value Iteration, but with a fixed policy.
3.2.1. Phase 1: Policy Evaluation (The Inner Loop)
This phase starts with a policy \(\pi\) and calculates its value \(v_{\pi}\).
- Input: The current fixed policy \(\pi_k\).
- Process: We iterate to find the value by performing "Value Updates" repeatedly.
- Initialize a value estimate \(v\) (e.g., \(v=0\)).
- Apply the Bellman Expectation Operator repeatedly (infinite steps): \[ v_{i+1} = r_{\pi_k} + \gamma P_{\pi_k} v_i \]
- Stop when \(v\) converges (i.e., \(i \to \infty\)).
- Output: The converged state-value function \(v_{\pi_k}\).
3.2.2. Phase 2: Policy Improvement (The Outer Loop)
Once we know exactly how good the current policy \(\pi_k\) is (from Phase 1), we make it greedy.
- Process: \[ \pi_{k+1} = \arg\max_{\pi} \left( r_{\pi} + \gamma P_{\pi} v_{\pi_k} \right) \]
3.3. Truncated Policy Iteration
The mathematical connection is defined by the depth of the evaluation step (the number of iterations \(j\) in Phase 1):
- Value Iteration is a special case of Truncated Policy Iteration where the evaluation depth is \(j=1\).
- Policy Iteration is the limit of Truncated Policy Iteration where the evaluation depth \(j \to \infty\).
4. Monte Carlo Methods in RL
Model-free, without to know the model, we use the expecation of example from MC to fine-tune Action value: \(q_{\pi_{k}(s,a)} = E(G_{t}| S_{t}=s, A_{t}=a )\). Basic the process works the same as Policy iteration(Policy evaluation and policy improvement), only the evaluation is with the exampling from MC.
4.1. 1. MC Basic (The Theoretical Baseline)
This is the simplest form, adapted directly from Policy Iteration logic but using sample returns instead of models.
- Theory:
- Operates in strict distinct steps: Policy Evaluation (wait for many episodes) \(\rightarrow\) Policy Improvement (update policy).
- Uses the Initial-Visit Strategy: Only the start of an episode is used to update the value of the starting state-action pair \((s_0, a_0)\). Intermediate steps in the episode are ignored for updates.
- Limitation:
- Low Sample Efficiency: Wastes data by ignoring subsequent state visits in an episode.
- Impractical: Requires collecting "sufficiently many episodes" for every state-action pair before making a single policy update.
4.2. 2. MC Exploring Starts (The Efficient Simulator)
An extension designed to fix sample efficiency but introduces a strong dependency on environment control.
- Theory:
- Data Efficiency: Uses the Every-Visit Strategy (or sub-episode decomposition). One long episode is broken down into multiple "sub-episodes" to update values for every state visited, not just the start.
- Episode-by-Episode Update: Updates the policy immediately after a single episode (Generalized Policy Iteration), rather than waiting for a batch.
- Limitation:
- The "Exploring Starts" Assumption: It requires the environment to be able to start an episode at any random state-action pair \((s, a)\).
- Real-world Friction: This is often impossible in physical reality (e.g., you cannot initialize a robot in a specific "falling over" state instantly).
4.3. 3. MC Epsilon-Greedy (The Practical Solution)
This method removes the unrealistic "Exploring Starts" assumption, making MC methods viable for real-world learning where you cannot control the starting state.
4.3.1. Core Concept: Soft Policies
Instead of forcing the environment to start randomly, we force the agent to behave randomly occasionally.
- Goal: Ensure all state-action pairs are visited "sufficiently many times" without external resets.
- Mechanism: Uses a Soft Policy ( \(\epsilon\) -greedy), meaning there is always a non-zero probability of taking any action in any state.
4.3.2. Exploration vs. Exploitation
The algorithm balances these two conflicting goals via the parameter \(\epsilon\) (epsilon).
- Exploration ( The \(\epsilon\) component)
- Purpose: To discover new strategies and ensure the agent does not get stuck in a suboptimal loop. By taking random actions, the agent eventually wanders into every possible state, fulfilling the coverage requirement that "Exploring Starts" used to handle.
- Mechanism: With probability \(\epsilon\), the agent ignores its knowledge and chooses randomly.
- Exploitation (The \(1-\epsilon\) component)
- Purpose: To maximize rewards based on current knowledge. The agent selects the "Greedy" action (the one with the highest estimated value \(q(s,a)\)).
- Mechanism: With probability \(1 - \epsilon\), the agent chooses the best known action.
- Continuous Learning: Even if the agent always starts at the same spot (Start Line), the stochastic nature of the policy (\(\epsilon\)) ensures that over many episodes, it will eventually drift into unvisited states.
5. Temporal Difference
Core Concept TD Learning is the solution(policy) to the Bellman Expectation Equation, formulated as a root-finding problem and solved using the Robbins-Monro stochastic approximation algorithm.
5.1. The Mathematical Driver
- The Goal: Bellman Expectation Equation We seek \(V_{\pi}(s)\) such that: \[V_{\pi}(s) = \mathbb{E}_{\pi} [ R_{t+1} + \gamma V_{\pi}(S_{t+1}) \mid S_t = s ]\]
- The Problem: Root Finding Define the objective function (Bellman Error) \(J(V)\), where we want the root \(J(V)=0\): \[J(V)(s) = \mathbb{E} [ R_{t+1} + \gamma V(S_{t+1}) ] - V(s) = 0\]
- The Solver: Robbins-Monro Algorithm
Iteratively find root \(\theta^*\) of \(f(\theta)=0\) using noisy observations \(\tilde{f}(\theta)\): \[\theta_{t+1} = \theta_t + \alpha_t \cdot \tilde{f}(\theta_t)\]
- \(\theta\) > \(\theta^*\) minus \(\tilde{f}(\theta_t)\)
- \(\theta\) < \(\theta^*\) plus \(\tilde{f}(\theta_t)\)
5.2. The Derivation
- The Noisy Observation (TD Error)
Since we cannot compute the expectation \(\mathbb{E}\), we take a single sample:
- Observation: \(\tilde{J}(V_t) = (r_{t+1} + \gamma V_t(s_{t+1})) - V_t(s_t)\)
- This is the TD Error (\(\delta_t\)).
- The Sample Target
The term \(r_{t+1} + \gamma V_t(s_{t+1})\) acts as the "Label":
- Reality: \(r_{t+1}\) (Ground truth, low variance).
- Guess: \(\gamma V_t(s_{t+1})\) (Bootstrap estimate).
- Function: It acts as the target \(y\) for \(V(s_t)\), providing a better estimate than \(V(s_t)\) alone because it includes real data.
- The Solution (TD Update Rule) Substituting \(\tilde{f}(\theta)\) into the Robbins-Monro update: \[V(s_t) \leftarrow V(s_t) + \alpha \delta_t\] \[V(s_t) \leftarrow V(s_t) + \alpha \underbrace{[ (r_{t+1} + \gamma V(s_{t+1})) - V(s_t) ]}_{\text{TD Error}}\] \[V(s_t) \leftarrow V(s_t) - \alpha [ V(s_t) - \underbrace{(r_{t+1} + \gamma V(s_{t+1}))}_{\text{Target}} ]\]
5.3. Unified View of TD-series
5.3.1. 1. The General Update Rule
When we use Action value instead of policy for TD, all discussed algorithms can be expressed as a stochastic approximation update solving a Bellman equation. The unified update rule (Gradient Descent form) is:
\[q_{t+1}(s_t, a_t) = q_t(s_t, a_t) - \alpha_t(s_t, a_t) [ q_t(s_t, a_t) - \bar{q}_t ]\]
- \(q_t(s_t, a_t)\): Current Estimate
- \(\alpha_t\): Learning rate (Step size)
- \(\bar{q}_t\): The Target (The noisy sample derived from environment interaction)
- The term \([q_t - \bar{q}_t]\) represents the error to minimize.
5.3.2. 2. Algorithm Specifics
- Sarsa:
- Target Expression (\(\bar{q}_t\)): \[\bar{q}_t = r_{t+1} + \gamma q_t(s_{t+1}, a_{t+1})\]
- Equation Aimed to Solve: Bellman Expectation Equation (BE) for \(q_{\pi}\): \[q_{\pi}(s, a) = \mathbb{E} [R_{t+1} + \gamma q_{\pi}(S_{t+1}, A_{t+1}) \mid S_t = s, A_t = a]\]
- Note: This is an on-policy method using the action \(a_{t+1}\) actually taken by the current policy.
- Q-learning
- Target Expression (\(\bar{q}_t\)): \[\bar{q}_t = r_{t+1} + \gamma \max_{a} q_t(s_{t+1}, a)\]
- Equation Aimed to Solve: Bellman Optimality Equation (BOE) for \(q_*\): \[q_{*}(s, a) = \mathbb{E} [R_{t+1} + \max_{a} q_{*}(S_{t+1}, a) \mid S_t = s, A_t = a]\]
- Note: This is an off-policy method because it updates towards the best possible action (\(\max\)), regardless of the policy actually followed.
- Expected Sarsa
- Target Expression (\(\bar{q}_t\)): \[\bar{q}_t = r_{t+1} + \gamma \sum_{a} \pi_t(a|s_{t+1})q_t(s_{t+1}, a)\]
- Equation Aimed to Solve: Bellman Expectation Equation (BE) for \(q_{\pi}\): \[q_{\pi}(s, a) = \mathbb{E} [R_{t+1} + \gamma \mathbb{E}_{A_{t+1}}[q_{\pi}(S_{t+1}, A_{t+1})] \mid S_t = s, A_t = a]\]
- Note: It reduces variance by taking the expectation over all possible next actions rather than just sampling one.
- N-step Sarsa
- Target Expression (\(\bar{q}_t\)): \[\bar{q}_t = r_{t+1} + \gamma r_{t+2} + \dots + \gamma^n q_t(s_{t+n}, a_{t+n})\]
- Equation Aimed to Solve: Bellman Expectation Equation (BE) for \(q_{\pi}\): \[q_{\pi}(s, a) = \mathbb{E} [R_{t+1} + \gamma R_{t+2} + \dots + \gamma^n q_{\pi}(S_{t+n}, A_{t+n}) \mid S_t = s, A_t = a]\]
- Note: It balances bias and variance by looking \(n\) steps ahead before bootstrapping.
- Monte Carlo (MC)
- Target Expression (\(\bar{q}_t\)): \[\bar{q}_t = r_{t+1} + \gamma r_{t+2} + \dots \text{ (Full Return } G_t)\]
- Equation Aimed to Solve: Bellman Expectation Equation (BE) for \(q_{\pi}\): \[q_{\pi}(s, a) = \mathbb{E} [R_{t+1} + \gamma R_{t+2} + \dots \mid S_t = s, A_t = a]\]
- Note: Can be viewed as the unified expression where \(\alpha_t(s_t, a_t) = 1\), making \(q_{t+1} = \bar{q}_t\) (direct assignment of the return).
5.4. Deep Q-learning
- Deep Q-learning replaces the tabular \(q(s,a)\) with a parameterized neural network \(\hat{q}(s, a, w)\).
- We want the neural network to satisfy the Bellman Optimality Equation: \[q_*(s, a) = \mathbb{E} [R_{t+1} + \gamma \max_{a'} q_*(S_{t+1}, a') \mid S_t=s, A_t=a]\]
- It aims to minimize the loss function \(J(w)\): \[J(w) = \mathbb{E} \left[ \left( \underbrace{R + \gamma \max_{a' \in \mathcal{A}(S')} \hat{q}(S', a', w)}_{\text{Target (Bellman Optimality)}} - \underbrace{\hat{q}(S, A, w)}_{\text{Prediction}} \right)^2 \right]\]
where The term inside the squared brackets is the TD Error (specifically for Q-Learning). \[\delta = (R + \gamma \max \hat{q}(S', a', w)) - \hat{q}(S, A, w)\]. In order to minimize the distance (error) between the Prediction and the Target, the network \(\hat{q}\) converges towards the optimal value function \(q_*\).
- Algorithm
- Sample: Uniformly draw a mini-batch of samples from \(\mathcal{B}\).
- Calculate Targets:
For each sample \((s, a, r, s')\) in the mini-batch, calculate the target value \(y_T\):
\[y_T = r + \gamma \max_{a \in \mathcal{A}(s')} \hat{q}(s', a, w_T)\]
- Where \(w_T\) is the parameter of the target network.
- Update Main Network:
Update the main network parameter \(w\) to minimize the loss:
\[Loss = (y_T - \hat{q}(s, a, w))^2\]
- This update uses the mini-batch data \(\{(s, a, y_T)\}\).
- Update Target Network: Set \(w_T = w\) every \(C\) iterations.
6. Policy Gradient
6.1. Metrics
- \(\bar{v}_\pi\) (Discounted Average Value) \[\sum_{s \in S} d(s)\, v_\pi(s)\] \[\mathbb{E}_{S \sim d}[v_\pi(S)]\] \[\mathbb{E}\left[\sum_{t=0}^{\infty} \gamma^{t} R_{t+1}\right]\]
- \(\bar{r}_\pi\) (Average Reward Objective) \[\sum_{s \in S} d_\pi(s)\, r_\pi(s)\] \[\mathbb{E}_{S \sim d_\pi}[r_\pi(S)]\] \[\lim_{n \to \infty} \frac{1}{n} \mathbb{E}\left[\sum_{t=0}^{n-1} R_{t+1}\right]\]
6.2. General objective function for metrics
The gradient of the objective function \(J(\theta)\) is given by the Policy Gradient Theorem: \[\nabla_\theta J(\theta) = \sum_{s \in S} \eta(s) \sum_{a \in A} \nabla_\theta \pi(a \mid s, \theta)\, q_\pi(s, a)\]
where:
- \(\eta(s)\) is the state distribution (discounted or stationary depending on metric)
- \(\nabla_\theta \pi(a \mid s, \theta)\) denotes the gradient of the policy \(\pi\) with respect to the parameters \(\theta\),
- \(q_\pi(s, a)\) is the action-value function.
- Note: The theorem proves equality, but in practice, ignoring the gradient of the state distribution leads to an approximation often called the "proportional" gradient.
Moreover, as an expectation:\[ \nabla_\theta J(\theta) = \mathbb{E}_{S \sim \eta,\; A \sim \pi(S,\theta)} \left[ \nabla_\theta \ln \pi(A \mid S, \theta)\, q_\pi(S, A) \right].\]
6.3. Policy parameters update
\[ \theta_{t+1} = \theta_t + \alpha \nabla_\theta J(\theta_t)\]
Using the expectation form of the policy gradient, this becomes:
\[\theta_{t+1} = \theta_t + \alpha \mathbb{E} [ \nabla_\theta \ln \pi(A \mid S, \theta_{t})\, q_{\pi}(S, A) ]\]
We do not have \(q_{\pi}(S, A)\), so use \(\hat{q}(S_t, A_t)\) from sampling. \[\theta_{t+1} = \theta_t + \alpha \nabla_\theta \ln \pi(A_t \mid S_t, \theta_t)\, \hat{q}(S_t, A_t)\]
- Sampling from MC: REINFORCE
- Sampling from TD: Actor-Critic
6.4. Theory
\[\theta_{t+1} = \theta_t + \alpha \nabla_{\theta} \ln \pi(a_t \mid s_t, \theta_t) q_t(s_t, a_t)\] Because of the log-derivative trick: \[\nabla_{\theta} \ln \pi(a_t \mid s_t, \theta_t) = \frac{\nabla_{\theta} \pi(a_t \mid s_t, \theta_t)}{\pi(a_t \mid s_t, \theta_t)}\] We have: \[\theta_{t+1} = \theta_t + \alpha \frac{\nabla_{\theta} \pi(a_t \mid s_t, \theta_t)}{\pi(a_t \mid s_t, \theta_t)} q_t(s_t, a_t)\] So, defining \(\beta_t = \frac{q_t(s_t, a_t)}{\pi(a_t \mid s_t, \theta_t)}\): \[\theta_{t+1} = \theta_t + \alpha \beta_t \nabla_{\theta} \pi(a_t \mid s_t, \theta_t)\]
- If \(\beta_t \ge 0 \implies\) Move in direction of gradient:
- -\(\implies \pi(a_t | s_t, \theta_{t+1}) \ge \pi(a_t | s_t, \theta_t)\)
- Enhancement
- Reinforce good actions
- If \(\beta_t < 0 \implies\) Move opposite to gradient :
- \(\implies \pi(a_t | s_t, \theta_{t+1}) < \pi(a_t | s_t, \theta_t)\)
- Decrease
- Suppress bad actions
6.5. REINFORCE Algorithm
Monte Carlo Policy Gradient
- Initialize: \(\theta\), \(\gamma \in (0,1)\), \(\alpha > 0\)
Loop for each episode:
- Generate episode \(\{s_0, a_0, r_1, \dots, s_{T-1}, a_{T-1}, r_T\}\) following policy \(\pi(\cdot|\cdot, \theta_k)\)
- For \(t = 0\) to \(T-1\):
- \(G_t \leftarrow \sum_{k=t+1}^{T} \gamma^{k-t-1} r_k\) (Calculate return)
- \(\theta_{t} \leftarrow \theta_{t} + \alpha \gamma^t G_t \nabla_{\theta} \ln \pi(a_t | s_t, \theta_{k})\)
- \(\theta_{k} \leftarrow \theta_{t}\)
7. Actor-Critic
7.1. Q-Actor-Critic (QAC)
7.1.1. Background
Derived from the Policy Gradient theorem, QAC replaces the Monte Carlo (MC) return with a Temporal Difference (TD) value function approximation. The policy parameters are updated using the gradient of the log-probability scaled by the action-value function: \[\theta_{t+1} = \theta_t + \alpha \nabla_{\theta} \ln \pi(a_t \mid s_t, \theta_t) q_t(s_t, a_t)\]
7.1.2. Initialization
- Policy Function: \(\pi(a|s, \theta_0)\) with initial parameters \(\theta_0\).
- Value Function: \(q(s, a, w_0)\) with initial parameters \(w_0\).
- Learning Rates: \(\alpha_w, \alpha_{\theta} > 0\).
- Goal: Maximize the expected return \(J(\theta)\).
7.1.3. Loop (for each time step \(t\) in episode):
- Generate Action: Sample \(a_t \sim \pi(a|s_t, \theta_t)\).
- Observe Reward: Get \(r_{t+1}\) and next state \(s_{t+1}\).
- Select Next Action: Sample \(a_{t+1} \sim \pi(a|s_{t+1}, \theta_t)\).
- Actor (Policy Update): Update the policy parameters in the direction of higher rewards: \[\theta_{t+1} = \theta_t + \alpha_{\theta} \nabla_{\theta} \ln \pi(a_t | s_t, \theta_t) q(s_t, a_t, w_t)\]
- Critic (Value Update): Update the action-value parameters using the semi-gradient TD error: \[w_{t+1} = w_t + \alpha_w \left[ r_{t+1} + \gamma q(s_{t+1}, a_{t+1}, w_t) - q(s_t, a_t, w_t) \right] \nabla_w q(s_t, a_t, w_t)\]
7.2. Advantage Actor-Critic
7.2.1. Add baseline
To improve the stability of the Policy Gradient, we introduce a baseline to the update function. Subtracting a state-value function \(v(s)\) from the return reduces variance without introducing bias. The standard gradient update is modified by subtracting the baseline \(v_t(s_t)\): \[ \theta_{t+1} = \theta_t + \alpha \nabla_{\theta} \ln \pi(a_t \mid s_t, \theta_t) [q_t(s_t, a_t) - v_t(s_t)] \]
7.2.2. The Advantage Function (\(\delta_t\))
We define the Advantage Function as the difference between the action-value and the state-value. However, since we often do not know \(q_t\) explicitly, we approximate it using the TD Error \(\delta_t\):
\[ \delta_t(s_t, a_t) = q_t(s_t, a_t) - v_t(s_t) \approx r_{t+1} + \gamma v_t(s_{t+1}) - v_t(s_t) \]
Substituting this back into our update rule gives us a more robust update: \[ \theta_{t+1} = \theta_t + \alpha \nabla_{\theta} \ln \pi(a_t \mid s_t, \theta_t) \delta_t(s_t, a_t) \]
Why use the Advantage Function? The advantage \(\delta_t\) effectively measures how much better an action \(a_t\) was compared to the "average" value of the state. This helps balance exploration and exploitation more effectively than raw returns, as the agent explicitly learns which actions outperform the expected baseline.
7.2.3. Algorithm
- Policy Function (Actor): \(\pi(a|s, \theta_0)\) with initial parameters \(\theta_0\).
- Value Function (Critic): \(v(s, w_0)\) with initial parameters \(w_0\).
- Learning Rates: \(\alpha_w, \alpha_{\theta} > 0\).
- Goal: Learn an optimal policy to maximize expected return \(J(\theta)\).
At time step \(t\) in each episode:
- Generate Action & Observe: Generate \(a_t\) following \(\pi(a|s_t, \theta_t)\), then observe reward \(r_{t+1}\) and next state \(s_{t+1}\).
- Calculate Advantage (TD Error): Compute the TD error using the Critic's current value estimates: \[ \delta_t = r_{t+1} + \gamma v(s_{t+1}, w_t) - v(s_t, w_t) \]
- Actor Update (Policy): Update the policy parameters to encourage actions with high advantage: \[ \theta_{t+1} = \theta_t + \alpha_{\theta} \delta_t \nabla_{\theta} \ln \pi(a_t | s_t, \theta_t) \]
- Critic Update (Value): Update the value function parameters to minimize the TD error: \[ w_{t+1} = w_t + \alpha_w \delta_t \nabla_w v(s_t, w_t) \]
7.3. Off-Policy Actor-Critic
We use a behavior policy \(\beta\) to generate experience samples. To estimate the gradient of the target policy \(\pi\), we must use Importance Sampling.
7.3.1. The Objective Gradient
\[ \nabla_{\theta} J(\theta) = \mathbb{E}_{S \sim \rho, A \sim \beta} \left[ \frac{\pi(A|S, \theta)}{\beta(A|S)} \nabla_{\theta} \ln \pi(A|S, \theta) q_{\pi}(S, A) \right] \]
7.3.2. The Update Rule
The update is similar to A2C but scaled by the importance sampling ratio \(\rho_t\):
\[ \theta_{t+1} = \theta_t + \alpha \underbrace{ \frac{\pi(a_t|s_t, \theta_t)}{\beta(a_t|s_t)} }_{\text{Importance Weight } \rho_t} \delta_t(s_t, a_t) \nabla_{\theta} \ln \pi(a_t \mid s_t, \theta_t) \]
Because of the log-derivative trick: \[ \theta_{t+1} = \theta_t + \alpha \frac{\delta_t(s_t, a_t)}{\beta(a_{t}|s_{t})} \nabla_{\theta} \pi(a_t \mid s_t, \theta_t) \]
Algorithm is similar only with difference of factor \(\beta(a_{t}|s_{t})\)
8. PPO
use Importance Sampling as \(r_t(\theta)\) , old policy as \(\beta\) behavior policy \[r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\]
if \(\beta(a|s)\) is very small (the behavior policy rarely took the action) but \(\pi(a|s)\) is large (the new policy really wants to take it), the ratio \(r_t(\theta)\) becomes a huge number.
- This causes massive gradient updates.
- This destroys the training stability (the "step" is too big).
limit the potential explode by selecting the minimum between: \[L^{CLIP} = \mathbb{E} \left[ \min( r_t \delta_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \delta_t ) \right]\]
9. LLM Decouple Actor and Critic
Unlike traditional RL (e.g., Atari) where state representation is efficiently shared, LLM-based RL often requires decoupling the Actor and Critic due to Negative Transfer and Optimization Divergence.
9.1. 1. Negative Transfer (Objective Interference)
The Actor and Critic optimize fundamentally conflicting objectives, leading to Catastrophic Forgetting in shared architectures.
- Gradient Wash-Out: The Critic's gradients can "wash out" the delicate weights required for language generation.
- Manifold Collapse: When the model maximizes reward aggressively, it loses coherence, causing catastrophic forgetting of the pre-trained language manifold.
9.2. 2. Optimization Divergence (Timescale Mismatch)
Shared weights prevent the necessary decoupling of learning dynamics between the two heads.
- The Critic (Fast Learner): Requires rapidly tracking non-stationary value targets (\(V(s)\) changes constantly as \(\pi\) evolves).
- The Actor (Slow Learner): Requires strict constraints (Trust Region/Clipping) to ensure monotonic improvement.
- The Conflict: In a shared body, you cannot tune optimizers independently. A learning rate high enough for the Critic destabilizes the Actor; a rate low enough for the Actor starves the Critic.
9.3. 3. Phasic Policy Gradient
Separating policy and value function training into distinct phases PPG.
10. GRPO
Group Relative Policy Optimization (GRPO) resolves these conflicts by eliminating the Critic network entirely.
By removing the Critic part, GRPO removes the source of the interference. There are no value-function gradients to clash with the language modeling objectives. Instead of a learned Value Function \(V(s)\) (which requires a separate network/optimizer), GRPO estimates the baseline using the mean reward of a group of outputs: \[ A_i = \frac{r_i - \text{mean}(r_{1..G})}{\text{std}(r_{1..G})} \]
11. MoE
With shared Experts for common sense training. in order to training all Experts P,
- Switch transformer, minimal the loss to force the f and P to be uniformly distributed : \(loss = \alpha \cdot N \cdot \sum_{i=1}^{N} f_{i} \cdot P_{i}\)
- Loss free: using self-adjusted bias before softmax to control the P.
- if some experts has too much token, decrease the bias,
- if some experts has too less token, incurease the bias.
- DeepSeek use bias parameter before active function for dynamical adjustment of token loading for each expert
12. Multi-token prodiction
predict multiple token, some are from small model. If LLM accepte them, it does not need to generate them again.
13. Multi-token training
13.1. The Core Concept: MTP as an Implicit Critic
We reject GRPO (which requires generating groups for a baseline). Instead, we use Time as our baseline.
We use a single Transformer. The MTP Modules (which normally predict future tokens) are slightly modified to also predict the Value (Expected Future Reward) of those tokens.
- No Separate Critic Network: The MTP heads are the Critic.
- No Group Sampling: We do not compare against a group average. We compare against our own prediction from the next step (Bootstrapping).
- One Network: Parameters \(\theta\) are shared.
13.2. Architecture: The "Value-Aware" MTP Head
Standard DeepSeek MTP predicts tokens \(t_{n+1}, t_{n+2} \dots\). We modify the output projection of the MTP modules to output two things:
- Token Logits: \(P(t_{n+k})\) (What happens next?)
- Scalar Value: \(V_{n+k}\) (How good is it?)
Equation: \[ [ \text{Logits}_{k}, v_{k} ] = \text{MTP Head}_k(h_n) \] Where \(v_k\) represents the expected return starting from step \(n+k\).
13.3. The Algorithm: Recursive Value Propagation
13.3.1. A. Forward Pass (Generation & Storage)
At time step \(t\), the network produces:
- Action: Sample \(a_t\) from the main policy head.
- Lookahead Values (Internal Value): The MTP heads produce value estimates for future steps:
- MTP Head 1 gives \(v^{(1)}_t\) (Estimate of \(V(s_{t+1})\))
- MTP Head 2 gives \(v^{(2)}_t\) (Estimate of \(V(s_{t+2})\))
We save these internal value estimates \(v^{(k)}_t\) into a buffer.
13.3.2. B. Interaction
Execute action \(a_t\). Observe Reward \(r_{t+1}\) and next state \(s_{t+1}\).
13.3.3. C. The "Next Step" Operation (Bootstrapping)
This is the key requirement you mentioned. We use the value calculated at \(t+1\) to update the network at \(t\).
We define the TD Target (Temporal Difference): \[ Y_t = r_{t+1} + \gamma v^{(1)}_{t+1} \] Note: \(v^{(1)}_{t+1}\) is the "Value of the next state" predicted by the MTP head at the next step.
13.3.4. D. The Update (Loss Function)
We update the single network \(\theta\) with three components:
- Policy Loss (Actor): Maximize likelihood of \(a_t\) if the Advantage is positive. \[ \delta_t = Y_t - v^{(1)}_t \quad (\text{Advantage} = \text{Target} - \text{Prediction}) \] \[ L_{policy} = - \delta_t \ln \pi(a_t|s_t) \]
- MTP Value Consistency Loss (The "Saved Value" Update): Force the MTP head at step \(t\) to accurately predict the value at \(t+1\). \[ L_{value} = (v^{(1)}_t - \text{stop\_grad}(Y_t))^2 \]
- MTP Token Loss (Auxiliary): Keep the standard MTP token prediction to ensure the representations remain grounded in language/reasoning. \[ L_{token} = \text{CrossEntropy}(\text{MTP\_Heads}) \]
\[ L_{total} = L_{policy} + \alpha L_{value} + \beta L_{token} \]
—
13.4. Why this meets the criteria
- Single Network: The Policy and Value are fused. The "Value" is just a tiny scalar output on the existing MTP heads.
- No GRPO: We don't need multiple samples to find a baseline. We use the Bellman Consistency (\(V_t \approx r + V_{t+1}\)) as the training signal.
- MTP Integration: The MTP heads are essential. They provide the "Lookahead" capability that stabilizes the single-network value estimation (reducing the noise of a single step).
- Internal Value Saved: The training relies on carrying the scalar \(v\) from step \(t+1\) backward to step \(t\).
13.5. Visualization of Data Flow
Step T:
Input -> [Backbone] -> h_t
|-> Main Head -> Action a_t (Sampled)
|-> MTP Head -> Predicts V_t (Saved)
--- Environment Step (r_t) --->
Step T+1:
Input -> [Backbone] -> h_{t+1}
|-> MTP Head -> Predicts V_{t+1} (Used as Target)
Update T:
Target = r_t + gamma * V_{t+1}
Error = Target - V_t
Backprop Error through h_t