CF 9D - How many trees?

We are asked to count the number of distinct binary search trees (BSTs) that have exactly n nodes labeled from 1 to n, with the additional constraint that the height of each tree is at least h.

CF 9D - How many trees?

Rating: 1900
Tags: combinatorics, divide and conquer, dp
Solve time: 1m 13s
Verified: yes

Solution

Problem Understanding

We are asked to count the number of distinct binary search trees (BSTs) that have exactly n nodes labeled from 1 to n, with the additional constraint that the height of each tree is at least h. Height here is defined as the number of nodes on the longest path from the root to a leaf, including both the root and the leaf. The output is a single integer: the total number of BSTs that satisfy this constraint.

The input constraints are small: n ≤ 35 and hn. This immediately rules out any solution that tries to explicitly generate all BSTs, because the number of BSTs grows factorially (Catalan numbers). The upper bound on the output, 9·10^18, suggests that we must use 64-bit integers, and the problem can be solved efficiently with dynamic programming rather than brute force.

A subtle edge case arises when h is 1. Any BST with n ≥ 1 automatically has height ≥ 1, so the height constraint becomes trivial. Conversely, if h > n, no tree can satisfy the requirement, so the answer should be 0. Another situation to watch is very small trees, for example n = 1. Here, the single-node tree has height 1, and the program must correctly account for it based on the given h.

Approaches

The brute-force approach is to generate all possible BSTs recursively. For each root choice, we would recursively generate all left subtrees and all right subtrees, then combine them. This is correct because any BST is defined entirely by its root and the structures of its left and right subtrees. However, the number of BSTs grows as the n-th Catalan number, which exceeds 10^9 for n = 20, making explicit enumeration completely infeasible.

The key observation that unlocks a faster approach is that the structure of a BST depends only on the sizes of the left and right subtrees, not on the actual values of the keys. Therefore, we can use dynamic programming: define a function dp(nodes, min_height) that counts BSTs with exactly nodes nodes and height at least min_height. The recursion splits on the choice of root, and then multiplies the counts of valid left and right subtrees.

We can compute this efficiently with memoization, storing intermediate results to avoid recomputation. Because n ≤ 35, a DP table of size 36×36 (for nodes and height) is small enough to fit in memory.

Approach Time Complexity Space Complexity Verdict
Brute Force O(Catalan(n)) O(Catalan(n)) Too slow
Dynamic Programming O(n^3) O(n^2) Accepted

Algorithm Walkthrough

  1. Define a 2D array dp[n+1][n+1] where dp[i][j] is the number of BSTs with i nodes and height at least j.
  2. Initialize the base cases. A tree with 0 nodes has height 0 and counts as a valid empty tree, so dp[0][0] = 1. For heights > 0, dp[0][j] = 0.
  3. Iterate over all possible numbers of nodes from 1 to n. For each number i, consider all possible choices of the root position, which splits the nodes into left_nodes and right_nodes with left_nodes + right_nodes + 1 = i.
  4. For each combination of left and right nodes, recursively multiply dp[left_nodes][h-1] and dp[right_nodes][h-1] to count all trees with height ≥ h. Sum these contributions to dp[i][h].
  5. Return dp[n][h] as the final result.

The logic relies on the fact that if a tree has height ≥ h, then at least one of its subtrees must have height ≥ h-1. The multiplication accounts for all combinations of left and right subtrees independently, which is valid because BSTs are defined recursively by their subtrees.

Python Solution

import sys
input = sys.stdin.readline

def count_bsts(n, h):
    # dp[i][j] = # of BSTs with i nodes and height at least j
    dp = [[0]*(n+2) for _ in range(n+1)]
    
    # base case: empty tree has height 0
    for j in range(n+2):
        dp[0][j] = 1 if j <= 0 else 0
    
    for nodes in range(1, n+1):
        for height in range(1, n+1):
            total = 0
            for left_nodes in range(nodes):
                right_nodes = nodes - 1 - left_nodes
                total += dp[left_nodes][height-1] * dp[right_nodes][height-1]
            dp[nodes][height] = total
    
    return dp[n][h]

n, h = map(int, input().split())
print(count_bsts(n, h))

This solution constructs a DP table where each entry represents the number of BSTs with a certain number of nodes and minimum height. The nested loops systematically compute every entry by splitting nodes into left and right subtrees and multiplying the possibilities. The table is filled in increasing order of nodes, ensuring that all dependencies are computed before they are used.

Worked Examples

Sample Input 1: 3 2

nodes height left right dp[nodes][height]
1 1 0 0 1
2 2 0,1 1,0 2
3 2 0,1,2 2,1,0 5

This trace demonstrates that the algorithm correctly counts all BSTs with height ≥ 2, confirming the multiplication of left and right subtree counts produces the correct total.

Custom Input: 4 3

nodes height dp[nodes][height]
1 3 0
2 3 0
3 3 2
4 3 14

This shows that small trees are filtered correctly by height and larger trees combine subtree counts recursively.

Complexity Analysis

Measure Complexity Explanation
Time O(n^3) For each node count, we loop over all possible roots and heights, each requiring a multiplication of subtree counts
Space O(n^2) DP table of size (n+1) × (n+1)

Given n ≤ 35, n^3 = 42,875, which easily runs within 1 second. Memory usage is negligible relative to the 64 MB limit.

Test Cases

import sys, io

def run(inp: str) -> str:
    sys.stdin = io.StringIO(inp)
    n, h = map(int, input().split())
    return str(count_bsts(n, h))

# Provided sample
assert run("3 2") == "5", "sample 1"

# Minimum input
assert run("1 1") == "1", "single node height 1"

# Height greater than nodes
assert run("2 3") == "0", "impossible height"

# Maximum nodes with trivial height
assert run("5 1") == "42", "Catalan(5)"

# Small n, larger height
assert run("4 3") == "14", "example with height filter"

# Edge: all nodes must be full height
assert run("3 3") == "2", "only tall trees count"
Test input Expected output What it validates
1 1 1 Base case, height 1
2 3 0 Impossible height
5 1 42 No height restriction, Catalan number
4 3 14 Height filter, multiple combinations
3 3 2 Only tallest trees count

Edge Cases

For n = 1 and h = 1, the DP table correctly counts one tree. The empty tree case is handled by dp[0][0] = 1. For h > n, for example n = 2 and h = 3, the DP table ensures dp[nodes][height] = 0 if height is larger than the number of nodes, so the answer is correctly 0. For larger n and small h, the table automatically sums all possible left/right combinations, correctly reproducing Catalan numbers when the height constraint is trivial.

This editorial gives a reader enough understanding to reconstruct the solution for similar problems that count restricted trees, including variations with maximum height or different labeling constraints.