Code BreakDown - MCTS in LLM

2 minute read

Published:

Re-implementation of rStar MCTS

RStar

10.20

RStar is a fairly simple method but have achieved great performance on small model. The basic idea is to introduce varies types of queries (actions) during the process of MCTS, to decompose the question, response to sub-questions or rephrase the sub-questions. This mimics cognitive reasoning and adds exploration to tree searching.

Its codebase is interesting to read. The key elements lie in a single python file: run_src/MCTS_for_reasoning.py defining three classes:

  • Generator

  • Reasoning_MCTS_Node

  • search_for_answers

In short, what the codebase do is to define a complex language node data structure Reasoning_MCTS_Node and set up a MCTS searcher to navigate the pathway by creating children nodes inside Reasoning_MCTS_Node which inherit all necessary information.

I found it interesting as I was participating in another open-source project: OpenR (big shoutout to the team!), where we also support a vanilla version of MCTS reasoning. The main difference in terms of implementation is that: RStar pass LLM call between nested nodes whereas OpenR follows conventiaonl RL framework and use LLM call in a centralized env entity. I feel like the latter way is more user-friendly and would like to transfer RStar to our developing codebase.

The language node class Reasoning_MCTS_Node contains basic attributes such as parent, depth, node type, as well as key information for building children, such as generator function and the question status. It will first inherit almost everything from its parent and then define detailed rules for action generation, simple as that.

In its _create_children function you will see the essence of the project. Those are five ways of action generation written as def do_action_generate_xxx, each of them will query the generator to create specific prompts and generate response to form children nodes. All children nodes will be added to the current language node and wait for selection and value update.

10.25

Ok, now I have almost finished the re-implementation of rStar in OpenR. At least I make it executable :). The motivation behind is that I hope to scale it up by borrowing the usage of Ray package in OpenR. I have also noticed a tiny bug in the original rStar repo (might be wrong). The is_valid_solution_node seems to consider the DIRECT ANSWER, SUBQUESTION and OST node type but during answer extracting it throws OST away. My current task is to run experiments and demonstrate that this week’s work is not wasted ! Fingers crossed!