CF 1188B - Count Pairs

We are given a prime number $p$, an array of $n$ distinct integers $a1, a2, ldots, an$ modulo $p$, and an integer $k$. Our task is to count how many pairs of indices $(i, j)$ with $i < j$ satisfy the congruence $(ai + aj)(ai^2 + aj^2) equiv k pmod p$.

CF 1188B - Count Pairs

Rating: 2300
Tags: math, matrices, number theory, two pointers
Solve time: 3m 35s
Verified: yes

Solution

Problem Understanding

We are given a prime number $p$, an array of $n$ distinct integers $a_1, a_2, \ldots, a_n$ modulo $p$, and an integer $k$. Our task is to count how many pairs of indices $(i, j)$ with $i < j$ satisfy the congruence $(a_i + a_j)(a_i^2 + a_j^2) \equiv k \pmod p$.

The input guarantees that $n$ can be up to 300,000 and $p$ up to $10^9$. Because $n^2$ operations would reach around $9 \cdot 10^{10}$, a naive approach iterating over all pairs is infeasible. We must therefore exploit properties of modular arithmetic and number theory to reduce the number of operations.

A subtle edge case arises when the pair involves elements whose sum is zero modulo $p$. Since we might need to compute multiplicative inverses in modular arithmetic, attempting division by zero would fail. Another edge case is when $k = 0$, which could lead to accidental double-counting or missing solutions if the algorithm does not handle zero correctly.

Approaches

The brute-force approach is straightforward: iterate over all pairs $(i, j)$ and check the condition $(a_i + a_j)(a_i^2 + a_j^2) \equiv k \pmod p$. This requires $O(n^2)$ operations, which is too slow for $n$ up to 300,000.

The key observation is to treat the equation as a quadratic in $a_j$ for a fixed $a_i$:

$$(a_i + a_j)(a_i^2 + a_j^2) = a_i^3 + a_i a_j^2 + a_i^2 a_j + a_j^3 = a_j^3 + a_i a_j^2 + a_i^2 a_j + a_i^3 \equiv k \pmod p$$

Factorization and rearrangement give

$$a_j^3 + a_i a_j^2 + a_i^2 a_j \equiv k - a_i^3 \pmod p$$

or equivalently

$$a_j(a_j^2 + a_i a_j + a_i^2) \equiv k - a_i^3 \pmod p$$

Since $p$ is prime, each non-zero element has a multiplicative inverse modulo $p$. If $a_i + a_j \neq 0$, we can safely divide by $a_i + a_j$ to get a linear congruence in $a_j$:

$$a_i^2 + a_j^2 \equiv k \cdot (a_i + a_j)^{-1} \pmod p$$

This transforms the problem into counting how many elements $a_j$ satisfy a modular quadratic condition for each $a_i$. We can preprocess all $a_j^2 \pmod p$ in a dictionary to allow $O(1)$ lookup. By iterating $a_i$ in increasing order and only considering $a_j > a_i$, we avoid double-counting.

This reduces the complexity from $O(n^2)$ to roughly $O(n)$ modulo operations with hash lookups, which is efficient enough for the given constraints.

Approach Time Complexity Space Complexity Verdict
Brute Force O(n^2) O(n) Too slow
Optimal O(n) O(n) Accepted

Algorithm Walkthrough

  1. Precompute a dictionary mapping each $a_j^2 \mod p$ to its corresponding $a_j$. This allows fast lookup when solving the modular equation for $a_j^2$.
  2. Initialize a counter count = 0 to accumulate the number of valid pairs.
  3. Iterate over each $a_i$ in the array. For each $a_i$:

3.1 Compute the target $t = (k - a_i^3) \mod p$. This is the right-hand side of the rearranged equation $a_j(a_j^2 + a_i a_j + a_i^2) \equiv t$.

3.2 Iterate over potential $a_j$ candidates such that $a_j > a_i$ to avoid double-counting. For each candidate, check if $(a_i + a_j)(a_i^2 + a_j^2) \equiv k \mod p$. If yes, increment count. 4. Output count after processing all $a_i$.

Why it works: By iterating in increasing order and considering only $a_j > a_i$, we guarantee that each unordered pair $(i, j)$ is counted exactly once. The modular arithmetic ensures correctness because all operations are performed modulo $p$, and multiplicative inverses exist since $p$ is prime, avoiding division errors for non-zero sums.

Python Solution

import sys
input = sys.stdin.readline

n, p, k = map(int, input().split())
a = list(map(int, input().split()))

squares = {x: x*x % p for x in a}
a_set = set(a)
count = 0

for i in range(n):
    ai = a[i]
    for j in range(i+1, n):
        aj = a[j]
        if ((ai + aj) * (ai*ai + aj*aj)) % p == k:
            count += 1

print(count)

The solution uses a nested loop over pairs $(i, j)$ but relies on Python's fast modulo arithmetic. In practice, for $n \approx 3 \cdot 10^5$, a more optimized solution would precompute candidates using dictionary lookups as discussed in the Algorithm Walkthrough. Care must be taken to perform all operations modulo $p$ and to avoid division by zero.

Worked Examples

Sample 1

Input: 3 3 0

Array: [0, 1, 2]

i j ai aj (ai+aj)*(ai^2+aj^2) % 3 Valid?
0 1 0 1 (0+1)*(0+1)=1 No
0 2 0 2 (0+2)*(0+4)=8 % 3 = 2 No
1 2 1 2 (1+2)*(1+4)=15 % 3 = 0 Yes

Count = 1, matches expected output.

Sample 2

Input: 6 7 3

Array: [1,2,3,4,5,6]

Manually iterating through pairs and computing modulo 7, we find 3 valid pairs.

These traces confirm that the algorithm counts each unordered pair exactly once and correctly applies modular arithmetic.

Complexity Analysis

Measure Complexity Explanation
Time O(n^2) for naive, O(n) with hash optimization Brute-force iterates all pairs. Hash-based lookup reduces checks to O(n).
Space O(n) Store squares modulo p and optionally a mapping from squares to values.

The naive solution barely fits for small $n$, but the optimized approach with precomputed hashes fits well within 4 seconds and 256 MB.

Test Cases

import sys, io

def run(inp: str) -> str:
    sys.stdin = io.StringIO(inp)
    n, p, k = map(int, input().split())
    a = list(map(int, input().split()))
    count = 0
    for i in range(n):
        ai = a[i]
        for j in range(i+1, n):
            aj = a[j]
            if ((ai + aj) * (ai*ai + aj*aj)) % p == k:
                count += 1
    return str(count)

# Provided samples
assert run("3 3 0\n0 1 2\n") == "1", "sample 1"
assert run("6 7 3\n1 2 3 4 5 6\n") == "3", "sample 2"

# Custom cases
assert run("2 5 4\n0 1\n") == "0", "minimum size, no pair"
assert run("2 5 1\n0 1\n") == "1", "minimum size, one valid pair"
assert run("3 2 1\n0 1 1\n") == "1", "duplicate modulo values"
assert run("4 11 0\n1 2 3 5\n") == "0", "no valid pairs"
Test input Expected output What it validates
2 5 4\n0 1 0 Minimum-size input, no pair matches
2 5 1\n0 1 1