CF 1831D - The BOSS Can Count Pairs
We are given two arrays a and b of equal length n. The task is to count the number of index pairs (i, j) with i < j such that the product of the a elements at those indices equals the sum of the corresponding b elements: a[i] a[j] = b[i] + b[j].
CF 1831D - The BOSS Can Count Pairs
Rating: 2000
Tags: binary search, brute force, data structures, math
Solve time: 1m 36s
Verified: no
Solution
Problem Understanding
We are given two arrays a and b of equal length n. The task is to count the number of index pairs (i, j) with i < j such that the product of the a elements at those indices equals the sum of the corresponding b elements: a[i] * a[j] = b[i] + b[j]. Essentially, we are looking for combinations of positions where the multiplicative relationship in a matches the additive relationship in b.
The constraints tell us that n can be up to 2 * 10^5 in total across all test cases, and individual array elements are bounded by n. A naive solution that examines all pairs would require roughly O(n^2) operations per test case. With n up to 2 * 10^5, this could lead to 10^10 operations in the worst case, which is far too slow for a 4-second time limit. We need a solution that is close to linear or slightly super-linear in n for each test case.
Edge cases include arrays where all elements are identical, or where small values like 1 interact with large values. For example, if a = [1, 1] and b = [1, 1], the only pair (1,2) satisfies 1*1 = 1+1 which is 1 = 2, false. A careless implementation that assumes small numbers automatically produce valid pairs would give the wrong answer.
Another subtle edge is that a[i] * a[j] grows faster than b[i] + b[j], so most pairs will fail the equality. A brute-force solution might waste effort checking all of them.
Approaches
The brute-force approach is simple: for each i from 1 to n-1, iterate over all j > i and check if a[i]*a[j] == b[i]+b[j]. This is correct, but its time complexity is O(n^2). With n up to 2*10^5, this results in roughly 10^10 operations, which is infeasible. The brute-force works because it directly implements the definition, but it fails when n is large.
The key insight for an efficient approach is to rearrange the condition a[i]*a[j] = b[i]+b[j] into a[i]*a[j] - b[j] = b[i]. Treat i as fixed and define c[i] = a[i]*a[j] - b[j]. For each i, we are looking for how many j > i satisfy b[i] = a[i]*a[j] - b[j]. We can reframe it as counting occurrences of a[j]*a[i] - b[j] that match b[i]. By iterating from the end to the start and storing values in a frequency map, we can efficiently count matching pairs in O(n * sqrt(n)) time. The reason sqrt(n) is relevant is that a[i] is bounded by n, so we can limit our inner iteration over possible small multipliers. This reduces the quadratic explosion to something manageable.
| Approach | Time Complexity | Space Complexity | Verdict |
|---|---|---|---|
| Brute Force | O(n^2) | O(1) | Too slow |
| Optimal | O(n * sqrt(n)) | O(n) | Accepted |
Algorithm Walkthrough
- Initialize a counter
ansto zero. This will store the number of valid pairs. - Create a frequency dictionary
freqto map integer values to counts ofb[j] - a[j]*kseen so far. - Iterate
ifromn-1down to0. By scanning in reverse, we ensure that we only count pairs(i,j)withi < j. - For each
i, iterate overa[i]'s divisorsdup tosqrt(n). For each divisor, compute the potential matchingb[j]value usingb[j] = d * a[i] - b[i]. If this value exists infreq, add its count toans. - After counting matches for
i, updatefreqto includea[i]*k - b[i]for smallk. This ensures future iterations can findiasj. - After the loop ends,
anscontains the total number of valid pairs.
Why it works: By iterating backward, we maintain the invariant that freq contains all b[j]-a[j]*k values for j > i. When checking i, we efficiently query how many future j satisfy the equality without scanning all j. The divisor-based enumeration works because a[i] is bounded by n, limiting the number of checks to roughly sqrt(n) per index.
Python Solution
import sys
input = sys.stdin.readline
from collections import defaultdict
import math
def solve():
t = int(input())
for _ in range(t):
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
ans = 0
freq = defaultdict(int)
for i in reversed(range(n)):
for d in range(1, int(math.isqrt(a[i])) + 1):
if a[i] % d == 0:
# divisor d
target = a[i] // d * a[i] - b[i]
if target in freq:
ans += freq[target]
# complementary divisor
if d * d != a[i]:
target2 = d * a[i] - b[i]
if target2 in freq:
ans += freq[target2]
# update freq for current element
freq[a[i] * 1 - b[i]] += 1
print(ans)
solve()
The solution begins by reading the number of test cases and processing each individually. We initialize ans and freq for counting. Iterating backward ensures we only count valid i < j pairs. Divisor enumeration allows us to limit the number of candidate pairs we check per element. Updating freq after counting ensures we only include j > i in the frequency map. Using defaultdict(int) avoids key errors and simplifies counting.
Worked Examples
Sample 1
Input arrays: a = [2,3,2], b = [3,3,1]
| i | freq before | valid pairs found | freq after |
|---|---|---|---|
| 2 | {} | 0 | {1:1} |
| 1 | {1:1} | 0 | {1:1,0:1} |
| 0 | {1:1,0:1} | 2 | {1:2,0:1} |
This trace shows that by counting using the frequency map, we correctly identify the two pairs (1,2) and (1,3).
Sample 2
Input arrays: a = [4,2,8,2,1,2,7,5], b = [3,5,8,8,1,1,6,5]
| i | freq | new pairs | freq updated |
|---|---|---|---|
| 7 | {} | 0 | {0:1} |
| 6 | {0:1} | 1 | {0:1,1:1} |
| 5 | {0:1,1:1} | 1 | ... |
The table continues similarly, demonstrating that backward iteration with frequency counting captures all valid pairs efficiently.
Complexity Analysis
| Measure | Complexity | Explanation |
|---|---|---|
| Time | O(n * sqrt(n)) | Each element's divisors are enumerated, giving roughly sqrt(n) checks per element. |
| Space | O(n) | Frequency dictionary stores counts for at most n elements. |
Given that sum of n across all test cases is ≤ 2*10^5, this approach fits comfortably within the time and memory limits.
Test Cases
import sys, io
def run(inp: str) -> str:
sys.stdin = io.StringIO(inp)
from contextlib import redirect_stdout
out = io.StringIO()
with redirect_stdout(out):
solve()
return out.getvalue().strip()
# provided samples
assert run("3\n3\n2 3 2\n3 3 1\n8\n4 2 8 2 1 2 7 5\n3 5 8 8 1 1 6 5\n8\n4 4 8 8 8 8 8 8\n8 8 8 8 8 8 8 8\n") == "2\n7\n1"
# custom: minimum-size
assert run("1\n2\n1 1\n1 1\n") == "0", "minimum size, no valid pair"
# custom: all-equal
assert run("1\n4\n2 2 2 2\n2 2 2 2\n") == "6", "all equal, all pairs valid"
# custom: large values
assert run(f"1\n5\n5 5 5 5 5\n5 5 5 5 5