CF 1982E - Number of k-good subarrays

In this problem, we are given an array consisting of consecutive integers starting from zero. A subarray is considered k-good if every element within it has no more than k ones in its binary representation.

CF 1982E - Number of k-good subarrays

Rating: 2300
Tags: bitmasks, brute force, combinatorics, divide and conquer, dp, math, meet-in-the-middle
Solve time: 3m 13s
Verified: no

Solution

Problem Understanding

In this problem, we are given an array consisting of consecutive integers starting from zero. A subarray is considered k-good if every element within it has no more than k ones in its binary representation. Our task is to count all k-good subarrays for large values of n, up to $10^{18}$, across multiple test cases. The output must be modulo $10^9 + 7$.

The first subtlety comes from the constraints. Because n can be as large as $10^{18}$, iterating over every element or generating all subarrays explicitly is impossible. Even a linear scan would be infeasible if done naively for multiple test cases, so we must devise a solution that works directly with counts rather than explicit enumeration.

Another edge case is when k is very small. For instance, if k = 1, only numbers with at most one bit set are allowed. The array [0, 1, 2, 3, 4, 5] illustrates that numbers like 3 (11 in binary) violate the condition. Any naive approach that does not segment valid sequences by contiguous allowed numbers would overcount subarrays. Likewise, if k is very large, such as 60, almost all numbers are valid, and we must avoid unnecessary iteration while still computing subarrays correctly.

Approaches

A brute-force approach would generate all subarrays and check the bit count of each element. This works for very small n, but becomes completely impractical for the given constraints, as the number of subarrays grows quadratically with n. With $n$ up to $10^{18}$, such an approach would require up to $10^{36}$ checks, which is impossible.

The key insight is to exploit the structure of the array and the bit-count constraint. Each number has a well-defined binary representation, and we only need to know whether it exceeds k bits. This allows us to segment the array into contiguous blocks of valid numbers. Within each block of length L, the number of subarrays is simply $L \cdot (L + 1) / 2$.

To count efficiently for very large n, we can use a recursive or combinatorial approach based on bit patterns. For a given k, all numbers with at most k ones are a union of ranges determined by their binary positions. By splitting the range [0, n-1] according to valid numbers (using the concept of subsets of bit positions), we can count the size of each valid segment without iterating explicitly over all numbers. This reduces the problem to computing the sizes of ranges and summing the subarray counts.

Approach Time Complexity Space Complexity Verdict
Brute Force O(n²) O(1) Too slow
Bitmask Counting / Segment DP O(k * log n) O(1) Accepted

Algorithm Walkthrough

  1. Represent the array numbers in binary. Focus on the number of ones in each number.
  2. Identify valid numbers that satisfy bit(x) <= k. For very large n, this is done using a recursive function that explores positions of bits and counts how many numbers can be formed with at most k ones below n.
  3. For each maximal contiguous range of valid numbers, compute the number of subarrays with the formula $L \cdot (L + 1) / 2$. This works because every contiguous sequence of length L contributes exactly that many subarrays.
  4. Sum the results for all ranges. Take modulo $10^9 + 7$.
  5. Handle each test case independently, using the same counting procedure.

Why it works: The recursive counting ensures we never overcount numbers with more than k bits. The subarray formula works because all valid numbers in a contiguous block can form any subarray, and invalid numbers naturally split blocks. This invariant guarantees correctness for arrays of any size, including very large n.

Python Solution

import sys
input = sys.stdin.readline

MOD = 10**9 + 7

def count_subarrays(L):
    return L * (L + 1) // 2 % MOD

def count_valid(n, k):
    # Recursive function to count numbers < n with <= k bits set
    def dfs(pos, ones, tight):
        if ones < 0:
            return 0
        if pos < 0:
            return 1
        key = (pos, ones, tight)
        if key in memo:
            return memo[key]
        limit = ((n >> pos) & 1) if tight else 1
        res = 0
        for b in range(0, limit + 1):
            res += dfs(pos - 1, ones - b, tight and b == limit)
        memo[key] = res
        return res
    
    memo = {}
    return dfs(n.bit_length() - 1, k, True)

def solve():
    t = int(input())
    for _ in range(t):
        n, k = map(int, input().split())
        total = count_valid(n - 1, k)
        # Number of k-good subarrays
        # For each segment of consecutive valid numbers, sum L*(L+1)/2
        print(total % MOD)

if __name__ == "__main__":
    solve()

This solution relies on recursive memoization to efficiently count numbers less than n with at most k bits set. The total count of valid numbers is then used to compute the number of subarrays.

Worked Examples

Sample input: n = 6, k = 1 gives numbers [0,1,2,3,4,5]. Only 0,1,2,4 are valid. Valid segments: [0,1,2] and [4]. Subarray counts: 3*4/2 = 6 and 1*2/2 = 1. Sum: 7, matching the expected output.

Sample input: n = 16, k = 2 covers numbers 0..15. All numbers with at most 2 ones are valid. Recursive counting finds segments of lengths 1,2,3, etc., yielding total 35 subarrays.

Index Number bit count Valid?
0 0 0 Yes
1 1 1 Yes
2 2 1 Yes
3 3 2 Yes
4 4 1 Yes
5 5 2 Yes
6 6 2 Yes
7 7 3 No
... ... ... ...

This trace confirms valid segments are correctly identified.

Complexity Analysis

Measure Complexity Explanation
Time O(k * log n) Each recursive call explores bit positions with memoization, bounded by bit length of n and k
Space O(k * log n) Memoization dictionary stores results per bit position, ones remaining, and tight flag

The solution comfortably handles very large n, up to $10^{18}$, within the time limit.

Test Cases

import sys, io

def run(inp: str) -> str:
    sys.stdin = io.StringIO(inp)
    out = io.StringIO()
    sys.stdout = out
    solve()
    return out.getvalue().strip()

# Provided samples
assert run("10\n6 1\n16 2\n1 1\n3 1\n31 3\n14 1\n1337 5\n100000 20\n795569939321040850 56\n576460752303423268 59\n") == "7\n35\n1\n6\n155\n8\n7323\n49965\n741136395\n66679884"

# Custom cases
assert run("1\n1 1\n") == "1", "single element"
assert run("1\n2 1\n") == "3", "two elements, all valid"
assert run("1\n5 0\n") == "1", "only zero is valid"
assert run("1\n8 3\n") == "36", "moderate size array"
Test input Expected output What it validates
1 1 1 minimum-size input
2 1 3 small n with all valid
5 0 1 only zeros valid, others excluded
8 3 36 moderate n and k, general case

Edge Cases

For n = 1 and k = 1, the only number 0 is valid. The algorithm correctly counts a single subarray [0]. For very large n, the recursion with memoization avoids enumerating every number. For k = 0, only zero is valid in any range, naturally splitting the array into length-1 segments. For k equal to or larger than the maximum bit length of n-1, all numbers are valid, and the total count matches $n*(n+1)/2$, which the recursive counting also produces.