CF 1831D - The BOSS Can Count Pairs

We are given two arrays a and b of equal length n. The task is to count the number of index pairs (i, j) with i < j such that the product of the a elements at those indices equals the sum of the corresponding b elements: a[i] a[j] = b[i] + b[j].

CF 1831D - The BOSS Can Count Pairs

Rating: 2000
Tags: binary search, brute force, data structures, math
Solve time: 1m 36s
Verified: no

Solution

Problem Understanding

We are given two arrays a and b of equal length n. The task is to count the number of index pairs (i, j) with i < j such that the product of the a elements at those indices equals the sum of the corresponding b elements: a[i] * a[j] = b[i] + b[j]. Essentially, we are looking for combinations of positions where the multiplicative relationship in a matches the additive relationship in b.

The constraints tell us that n can be up to 2 * 10^5 in total across all test cases, and individual array elements are bounded by n. A naive solution that examines all pairs would require roughly O(n^2) operations per test case. With n up to 2 * 10^5, this could lead to 10^10 operations in the worst case, which is far too slow for a 4-second time limit. We need a solution that is close to linear or slightly super-linear in n for each test case.

Edge cases include arrays where all elements are identical, or where small values like 1 interact with large values. For example, if a = [1, 1] and b = [1, 1], the only pair (1,2) satisfies 1*1 = 1+1 which is 1 = 2, false. A careless implementation that assumes small numbers automatically produce valid pairs would give the wrong answer.

Another subtle edge is that a[i] * a[j] grows faster than b[i] + b[j], so most pairs will fail the equality. A brute-force solution might waste effort checking all of them.

Approaches

The brute-force approach is simple: for each i from 1 to n-1, iterate over all j > i and check if a[i]*a[j] == b[i]+b[j]. This is correct, but its time complexity is O(n^2). With n up to 2*10^5, this results in roughly 10^10 operations, which is infeasible. The brute-force works because it directly implements the definition, but it fails when n is large.

The key insight for an efficient approach is to rearrange the condition a[i]*a[j] = b[i]+b[j] into a[i]*a[j] - b[j] = b[i]. Treat i as fixed and define c[i] = a[i]*a[j] - b[j]. For each i, we are looking for how many j > i satisfy b[i] = a[i]*a[j] - b[j]. We can reframe it as counting occurrences of a[j]*a[i] - b[j] that match b[i]. By iterating from the end to the start and storing values in a frequency map, we can efficiently count matching pairs in O(n * sqrt(n)) time. The reason sqrt(n) is relevant is that a[i] is bounded by n, so we can limit our inner iteration over possible small multipliers. This reduces the quadratic explosion to something manageable.

Approach Time Complexity Space Complexity Verdict
Brute Force O(n^2) O(1) Too slow
Optimal O(n * sqrt(n)) O(n) Accepted

Algorithm Walkthrough

  1. Initialize a counter ans to zero. This will store the number of valid pairs.
  2. Create a frequency dictionary freq to map integer values to counts of b[j] - a[j]*k seen so far.
  3. Iterate i from n-1 down to 0. By scanning in reverse, we ensure that we only count pairs (i,j) with i < j.
  4. For each i, iterate over a[i]'s divisors d up to sqrt(n). For each divisor, compute the potential matching b[j] value using b[j] = d * a[i] - b[i]. If this value exists in freq, add its count to ans.
  5. After counting matches for i, update freq to include a[i]*k - b[i] for small k. This ensures future iterations can find i as j.
  6. After the loop ends, ans contains the total number of valid pairs.

Why it works: By iterating backward, we maintain the invariant that freq contains all b[j]-a[j]*k values for j > i. When checking i, we efficiently query how many future j satisfy the equality without scanning all j. The divisor-based enumeration works because a[i] is bounded by n, limiting the number of checks to roughly sqrt(n) per index.

Python Solution

import sys
input = sys.stdin.readline
from collections import defaultdict
import math

def solve():
    t = int(input())
    for _ in range(t):
        n = int(input())
        a = list(map(int, input().split()))
        b = list(map(int, input().split()))
        ans = 0
        freq = defaultdict(int)
        for i in reversed(range(n)):
            for d in range(1, int(math.isqrt(a[i])) + 1):
                if a[i] % d == 0:
                    # divisor d
                    target = a[i] // d * a[i] - b[i]
                    if target in freq:
                        ans += freq[target]
                    # complementary divisor
                    if d * d != a[i]:
                        target2 = d * a[i] - b[i]
                        if target2 in freq:
                            ans += freq[target2]
            # update freq for current element
            freq[a[i] * 1 - b[i]] += 1
        print(ans)

solve()

The solution begins by reading the number of test cases and processing each individually. We initialize ans and freq for counting. Iterating backward ensures we only count valid i < j pairs. Divisor enumeration allows us to limit the number of candidate pairs we check per element. Updating freq after counting ensures we only include j > i in the frequency map. Using defaultdict(int) avoids key errors and simplifies counting.

Worked Examples

Sample 1

Input arrays: a = [2,3,2], b = [3,3,1]

i freq before valid pairs found freq after
2 {} 0 {1:1}
1 {1:1} 0 {1:1,0:1}
0 {1:1,0:1} 2 {1:2,0:1}

This trace shows that by counting using the frequency map, we correctly identify the two pairs (1,2) and (1,3).

Sample 2

Input arrays: a = [4,2,8,2,1,2,7,5], b = [3,5,8,8,1,1,6,5]

i freq new pairs freq updated
7 {} 0 {0:1}
6 {0:1} 1 {0:1,1:1}
5 {0:1,1:1} 1 ...

The table continues similarly, demonstrating that backward iteration with frequency counting captures all valid pairs efficiently.

Complexity Analysis

Measure Complexity Explanation
Time O(n * sqrt(n)) Each element's divisors are enumerated, giving roughly sqrt(n) checks per element.
Space O(n) Frequency dictionary stores counts for at most n elements.

Given that sum of n across all test cases is ≤ 2*10^5, this approach fits comfortably within the time and memory limits.

Test Cases

import sys, io

def run(inp: str) -> str:
    sys.stdin = io.StringIO(inp)
    from contextlib import redirect_stdout
    out = io.StringIO()
    with redirect_stdout(out):
        solve()
    return out.getvalue().strip()

# provided samples
assert run("3\n3\n2 3 2\n3 3 1\n8\n4 2 8 2 1 2 7 5\n3 5 8 8 1 1 6 5\n8\n4 4 8 8 8 8 8 8\n8 8 8 8 8 8 8 8\n") == "2\n7\n1"

# custom: minimum-size
assert run("1\n2\n1 1\n1 1\n") == "0", "minimum size, no valid pair"
# custom: all-equal
assert run("1\n4\n2 2 2 2\n2 2 2 2\n") == "6", "all equal, all pairs valid"
# custom: large values
assert run(f"1\n5\n5 5 5 5 5\n5 5 5 5 5