Simple Mahjong Game Simulation

Today, I played Mahjong. I got crushed. The other players were sharks, and I lost every game. The version I played was an "Americanized" one, following rules explained by Alpaca Alix.

If you’ve never played Mahjong, I recommend trying it. It’s a great mix of strategy and luck. That said, the Chinese characters and tile designs made it hard for me to build a mental model at first.

Game Mechanics

Modelling this game is simple. We know that there are four players (4) and 144 tiles, which include,

Initially, all tiles are shuffled and placed on the board facing down, so they are invisible to all players. Each player then randomly selects 13 tiles to form their hand. Players can see the tiles in their own hands (if a tile is in a player's hand, it is visible only to them). Dice are then rolled, and the order of players is determined, so we know who is first, second, and so on. In real life (IRL), the player with the highest dice roll goes first, followed by the closest player in a counter-clockwise circle, and so on.

On each player's turn, they:

The objective of the game is to achieve a winning hand. A winning hand is a hand in which:

Winning Hand Example 1

Winning Hand Example 2

Winning Hand Example 3

The game ends when a player achieves a winning hand or when the hidden tile pile is exhausted. I excluded the Pong or Chong mechanics, which allow players to choose a tile from the visible pile regardless of their turn, for the first version of the game.

Game State

Well, the question is: how do we model this game for simulation? Ideally, we need some mathematical framework so we know the properties and constraints of states, as well as the transformations associated with them. The game state consists of the following elements:

So, the total state $S$ of the game at any time can be described as a tuple of the following sets:

$$ S = \left( P_1, P_2, P_3, P_4, H, D \right) $$

The total set should be the union of the following sets:

$$P_1 \cup P_2 \cup P_3 \cup P_4 \cup H \cup D = T$$

And now we can model state changes by turn as a function:

$$ f(S, A) = S' $$

Where,

In which, action consist of drawing a tile, and discarding a tile

$$ \text{Player } t \text{ draws tile } x: H \to P_t, \quad H = H \setminus { x } $$

$$ \text{Player } t \text{ discards tile } y: P_t \to D, \quad P_t = P_t \setminus { y }, \quad D = D \cup { y } $$

Sweet! Now we need to write code to identify a winning hand so we can terminate the game. We can do this by backtracking. Basically,

Now, with all of this, we can write a basic game loop with:

Simple AI Agent

This is all good, but how do we determine which tile to discard? If we look at the decision tree of a player with respect to the game state, it looks like this:

Mahjong Decision Tree

Note that branches explode as we increase in depth. The decision space is massive, making it difficult to have a simple policy just by analyzing this tree. A better approach is to train an AI agent by playing this game over and over again, so we have a model that can provide us with the optimal tile to discard in relation to the player's observation space (only the tiles visible to the player). It has been a while since I've written RL code, but nowadays, quality libraries exist, so we can get a prototype up and running rather quickly. I ended up going with Pettingzoo and Tianshou, as the API seemed the simplest. I ended up creating a custom environment using the AEC (Agent Environment Cycle) API. Since there is already an environment wrapper for PettingZoo in Tianshou, I ended up writing a custom environment for Simplified Mahjong.

In summary, this environment would be:

FieldMahjong Environment
ActionsDiscrete
Parallel APINo
Manual ControlNo
Agents['player_0', 'player_1', 'player_2', 'player_3']
Agents Count4
Action ShapeDiscrete(1)
Action Values[0, 1, ..., 143]
Observation ShapeDiscrete(144)
Observation Values[1, 2, 3]

Where the index of the observation vector represents a unique tile, and the value at the index represents the player's knowledge of it (where it is).

Observation Space Index RangeTile Set
0 - 8Bamboo 1 to Bamboo 9
9 - 17Dot 1 to Dot 9
18 - 26Character 1 to Character 9
27 - 30Wind 1 to Wind 4
31 - 33Dragon 1 to Dragon 3
34 - 42Bamboo 1 to Bamboo 9
43 - 51Dot 1 to Dot 9
52 - 60Character 1 to Character 9
61 - 64Wind 1 to Wind 4
65 - 67Dragon 1 to Dragon 3
68 - 76Bamboo 1 to Bamboo 9
77 - 85Dot 1 to Dot 9
86 - 94Character 1 to Character 9
95 - 98Wind 1 to Wind 4
99 - 101Dragon 1 to Dragon 3
102 - 110Bamboo 1 to Bamboo 9
111 - 119Dot 1 to Dot 9
120 - 128Character 1 to Character 9
129 - 132Wind 1 to Wind 4
133 - 135Dragon 1 to Dragon 3
136 - 139Season 1 to Season 4
140 - 143Flower 1 to Flower 4
Observation ValueDescription
1Tile in current player’s hand
2Tile in discard pile
3Player does not know where the tile is

Since the game itself does not reveal the current score (e.g., we do not know how well each player is doing until the game terminates), it means we really have to pay attention to the reward function. I ended up writing a simple reward function, which takes a hand and returns a score based on how close it is to a winning hand. Now, on each turn, I mark the reward as reward = score(new_hand) - score(old_hand). This way, we can represent each turn with a reward.

I chose a model-based approach with DQN. DQN is likely not the best choice here, but it simplest to experiment with, and the model turned out much better than I expected (given the amount of time I spent on it).

Here is some pseudocode to get you started if you are interested in implementing this or any other variation of Mahjong.

from itertools import combinations
from collections import Counter

def suite_and_value(tile):
    """Splits a tile into its suite and value for efficient reuse."""
    suite, value = tile.split(' ')
    return suite, int(value) if value.isdigit() else 0

def are_same_tile(a, b):
    """Checks if two tiles are identical (same suite and same value)."""
    return suite_and_value(a) == suite_and_value(b)

def form_triplet(a, b, c):
    """Check if three tiles form a valid triplet (same suite and value)."""
    return are_same_tile(a, b) and are_same_tile(a, c)

def form_sequence(a, b, c):
    """Check if three tiles form a valid sequence (consecutive in the same suite)."""
    va, vb, vc = suite_and_value(a)[1], suite_and_value(b)[1], suite_and_value(c)[1]
    sorted_v = sorted([va, vb, vc])
    is_sequence = (
        sorted_v[0] + 1 == sorted_v[1] 
        and sorted_v[1] + 1 == sorted_v[2] 
        and suite_and_value(a)[0] == suite_and_value(b)[0] == suite_and_value(c)[0]
    )
    return is_sequence

def mk_score(result):
    """Calculates the score based on triplets, sequences, and pairs."""
    return 3 * (
        len(result['triplet_sets']) + 
        len(result['seq_sets'])
    ) + len(result['pair'])

def remove_tiles(remaining_tiles, tiles_to_remove):
    """Helper function to remove a set of tiles from the remaining list."""
    new_remaining = remaining_tiles.copy()
    for tile in tiles_to_remove:
        new_remaining.remove(tile)
    return new_remaining

def backtrack(hand, remaining_tiles, triplet_sets, seq_sets, pair_set):
    """Recursive function to try and form sets from remaining tiles."""
    if len(remaining_tiles) == 0:
        return {
            'triplet_sets': triplet_sets, 
            'seq_sets': seq_sets, 
            'pair': pair_set
        }
    
    best_result = {
        'triplet_sets': triplet_sets, 
        'seq_sets': seq_sets, 
        'pair': pair_set
    }

    for combo in combinations(remaining_tiles, 3):
        if form_triplet(*combo):
            new_remaining = remove_tiles(remaining_tiles, combo)
            result = backtrack(
                hand, 
                new_remaining, 
                triplet_sets + [tuple(combo)], 
                seq_sets, 
                pair_set
            )
            if result and mk_score(result) > mk_score(best_result):
                best_result = result

        elif form_sequence(*combo):
            new_remaining = remove_tiles(remaining_tiles, combo)
            result = backtrack(
                hand, 
                new_remaining, 
                triplet_sets, 
                seq_sets + [tuple(combo)], 
                pair_set
            )
            if result and mk_score(result) > mk_score(best_result):
                best_result = result
    
    return best_result

def compute_hand_analysis(hand):
    """Check if the provided hand is a winning hand in Mahjong."""
    tile_count = Counter(hand)
    best_result = {'triplet_sets': [], 'seq_sets': [], 'pair': []}

    for tile, count in tile_count.items():
        if count >= 2:
            raw_hand = hand.copy()
            raw_hand.remove(tile)
            raw_hand.remove(tile)
            
            remaining_tiles = raw_hand.copy()
            result = backtrack(
                raw_hand, 
                remaining_tiles, 
                [], 
                [], 
                [tile, tile]
            )

            if result and mk_score(result) > mk_score(best_result):
                best_result = result

            raw_hand.append(tile)
            raw_hand.append(tile)

    score = mk_score(best_result)
    return score >= 14, score, best_result

ex_winning_hand = [
    'bamboo 4', 
    'bamboo 5', 
    'bamboo 6',
    # --
    'bamboo 8', 
    'bamboo 8', 
    'bamboo 8',
    # --
    'char 3', 
    'char 4', 
    'char 5',
    # --
    'dot 1', 
    'dot 1', 
    'dot 1',
    # --
    'dragon red', 
    'dragon red'
]

winning, score, result = compute_hand_analysis(ex_winning_hand)
print(f"Winning: {winning}, Score: {score}, Result: {result}")

Here's the pseudocode for the Mahjong environment using PettingZoo. This should give you a straightforward foundation for implementation, depending on the specific use case and variation of the game you'd like to play.

class VerySimpleMahjongEnvironment(AECEnv):
    """The metadata holds environment constants."""
    NUM_TILES = 144
    NUM_PLAYERS = 4
    NUM_TILES_PER_PLAYER = 13
    HIDDEN_TILE_TOKEN = NUM_TILES + 1
    ALL_TILES = set([i for i in range(NUM_TILES)])

    # for observations
    TILE_IN_PLAYER_PILE = 1
    TILE_IN_PLAYER_PILE_BUT_VISIBLE_TO_OTHERS = 2
    TILE_UNKNOWN = 3
    TILE_IN_BOARD_VISIBLE_TO_OTHERS = 4
    
    metadata = {
        "name": "simple_american_mahjong_v0",
    }

    def __init__(self, render_mode = 'ansi'):
        """The init method takes in environment arguments."""
        self.hidden_tiles = set()
        self.discarded_tiles = set()
        self.player_tiles = [set() for _ in range(self.NUM_PLAYERS)]
        self.possible_agents = [str(i) for i in range(self.NUM_PLAYERS)]
        self.agent_selection = 0
        self.render_mode = 'ansi'

    def reset(self, seed=None, options=None):
        """Reset set the environment to a starting point."""
        pass

    def observe(self, agent):
        """
        Observe should return the observation of the specified agent. This function
        should return a sane observation (though not necessarily the most up to date possible)
        at any time after reset() is called.
        """
        return self.observations[agent]

    def update_observation(self):
        for player_i in self.agents:
            self.observations[player_i] = {
                "observation": "" # ....,
                "action_mask": "" # ....
            }

    def get_action_mask(self, player):
        """Returns the action mask for the given player (only tiles in hand can be discarded)."""
        action_mask = np.zeros(self.NUM_TILES, dtype=np.int8)
        action_mask[list(self.player_tiles[player])] = 1
        return action_mask
    
    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent):
        return Box(low=0, high=4, shape=(self.NUM_TILES,), dtype=np.int32)

    @functools.lru_cache(maxsize=None)
    def action_space(self, agent):
        return Discrete(self.NUM_TILES)

    def step(self, action):
        """Takes in an action for the current agent (specified by agent_selection)."""
        curr_player = self.agent_selection
        if action is None:
            return

        # Step 1: Check for winning hand for the current player
        # Player wins, game ends for all players
        # ....

        # Step 2: If there are no more hidden tiles, 
        # terminate game for all players
        # ....
        
        # Step 3: Validate the action (only allow discarding tiles that are in hand)
        # ....
        
        # Step 4: Discard the tile (remove it from player's hand, add to the board)
        # and drawn one more tile for next round, and update observations and masks
        # ....
        
        # Step 5: Apply reward for an action
        # ....
        
        # Step 6: Move to the next agent
        # ....
    
        self.num_moves += 1
        return self.observations, self.rewards, self.terminations, self.truncations, {}

    def render(self):
        pass