• Backpropagate
  • Posts
  • Sanity Check #2 - Beyond Autoregressive Generation

Sanity Check #2 - Beyond Autoregressive Generation

Training Large Language Models to Reason in a Continuous Latent Space

Problem Statement

In the standard Chain-of-Thought (CoT) we prompt an LLM to generate a series of intermediate, step-by-step reasoning steps in natural language before it gives the final answer. In mathematical terms, if we have a question which is an input sequence of tokens Q, the model doesn't just output the answer A. Instead, it generates a "chain" of thought tokens in between:

Q → [Step 1] → [Step 2] → ⋯ → [Step N] → A

Each "[Step]" is a sentence or a phrase in natural language, like "First, I need to calculate the cost per item...". Well, this is a bit problematic.

First, it is inefficient as many words in a reasoning chain are just for making the text fluent and coherent. Meaning, they don't contribute to the actual logic. For example, in the sentence "Okay, so the next step is to multiply 3 by 60," the words "Okay, so the next step is to" are mostly to fill the blanks. The second, and maybe the most important, it is hard to backtrack. Since CoT generates text one word after another (an autoregressive process), it's difficult for the model to change and explore a different path (equivalent what is for a human “to change its mind”) once it has started down one. It prematurely commits to a single, deterministic path.

The paper, positions COCONUT within two main areas of research.

  1. Chain-of-Thought (CoT) Reasoning as a broad category that includes any method where a model generates intermediate language steps before an answer. Researchers have explored making these chains more concise or using them to build tree-search algorithms to improve planning. Theoretical work has shown that CoT is effective because it essentially increases the "depth" of the model's computation by looping outputs back as inputs, which inspired the COCONUT design.

  2. Latent Reasoning in LLMs is another branch of the previous work involved around adding special learnable tokens to the input, giving the model extra computational capacity to “think”, before answering. Another approaches trained models on full language-based reasoning chains and then gradually removed the words from the beginning of the chain, forcing the model to "internalise" those reasoning steps into its latent representations. COCONUT adopts a similar multi-stage training strategy.

Method

Overall, the idea behind COCONUT is simple and can be easily characterised. Instead of creating a sequence of word embeddings at prediction time t, Et​=[e(x1​),…,e(xi​)] where e is the embedding function for each token, the method takes the last hidden state ht-1 from the previous step of the model regression and uses it directly as the input for the current step t. Thus, the input sequence of embeddings proposed by the paper, if the latent mode runs from position i to j, the input for a step t inside this window is constructed as Et​=[e(x1​), … , e(xi​), hi​, hi+1 ​,…, ht−1​].

Note that the initial embeddings up to the predefined token <bot> , are not word embeddings, but rather models’ internal hidden states. This is referred as “purely internal reasoning process”.

Note that the method has two predefined tokens, <bot> to start internal reasoning, and the token <eot>, responsible for “stop thinking”. Deciding when to stop thinking is another challenge, as <bot> token is inserted immediately following the prompt tokens. The authors propose training a binary classifier on latent embeddings to enable the model to autonomously decide when to terminate the reasoning or to always pad the latent thoughts to a constant length. The results later show that both approaches work well.

Note that it is possible to always use hidden states as input for the next autoregressive iteration, as the output dimension of the last hidden state in the transformer is the size of the word embedding.

Experimental Tasks for Evaluation

The authors tested COCONUT on two main categories of reasoning tasks to see how it performs in different scenarios. For Math Reasoning, they used the GSM8k dataset, which is a dataset of school-level math word problems that require between 2-7 reasoning steps to be solved. And for logical reasoning, they use datasets that require the model to apply a set of given rules to determine if a conclusion is true, which involves planning and choosing between different reasoning paths. For this sanity check, we will be using the GSM8k dataset.

The authors compared the method against several other methods to see its strengths and weaknesses. For example, CoT (Chain-of-Thought): This is the standard method where the model is trained with complete reasoning chains and then generates its own reasoning steps before giving an answer during inference. No-CoT: The simplest approach. The model is trained to directly predict the answer from the question, with no intermediate reasoning steps at all. iCoT (Implicit Chain-of-Thought): A more advanced baseline where the model is trained with a special schedule that gradually removes the reasoning steps from the training, forcing the model to "internalize" the thinking process. At inference time, it gives the answer directly.

The proposed method also has a couple of explored variations. COCONUT without curriculum: This is the specific training procedure the authors designed for the method. It's a step-by-step process that gradually teaches the model to reason in the latent space instead of with natural language. It works with stages. Stage 0: The model is trained on standard CoT data, but with empty <bot><eot> tokens inserted after the question. The model learns to predict the full language-based reasoning chain. Stage 1: The first reasoning step in the language chain is removed. In its place, the model is trained to generate one or more "continuous thoughts" (the paper uses a hyperparameter c for this). The model then has to predict the rest of the language chain. Subsequent Stages (Stage k): This process continues. At each new stage k, the first k language-based reasoning steps are replaced by k x c continuous thoughts. The model's task is always to predict the remaining part of the original language chain. Final Stage: Eventually, the entire language-based reasoning chain is replaced by continuous thoughts, and the model learns to generate the final answer directly after its internal latent reasoning process.

Results

On the GSM8k math reasoning task, COCONUT's performance steadily improved as the number of continuous thoughts per reasoning step increased from 0 to 2. It outperformed other methods trained with similar strategies, including iCoT.

The experiments show that the multi-stage training curriculum is essential. The "COCONUT without curriculum" variant, which was trained without this guided process, performed not better than the simple No-CoT baseline. This indicates that models need to be taught how to use latent space for reasoning in a structured way.

The most significant results relate to the distribution of potential paths to take during the reasoning. The paper reported that a single continuous thought can encode a distribution over multiple possible next steps. This allows the model to perform a more advanced reasoning process similar to a breadth-first search (BFS), where it explores multiple paths in parallel before committing to one. Interesting indeed 🙂 

Limitations and Future Directions

The training process requires multiple sequential forward passes, which is a challenge for parallelism and efficiency.

The authors observed that adding too many continuous thoughts at once (e.g., increasing c to 3) could lead to training instability. Thus, the training procedure needs to be further studied.

Developing strategies to learn effective latent reasoning without supervision from existing CoT data is a key challenge for future work. But an interesting one.

Finally, it seems a promising future direction to pretrain LLMs with continuous thoughts from the start 🙂 

Sincerely,

Keep learning,

MO