Home [Algorithm] Quick Select: K 번째 원소 찾기
Post
Cancel

[Algorithm] Quick Select: K 번째 원소 찾기

1. Selection Algorithm

정렬되지 않은 리스트나 배열에서 k 번째 수를 찾는 문제selection 문제라고 한다.

selection 문제를 해결하는 방법은 크게 세 가지로 나뉜다.

  1. 오름차순 정렬 후 index로 접근 ($O(n\log n)$)
  2. heap 사용 ($O(n\log k)$)
  3. selection 알고리즘 (ex. quick select)

효율적인 정렬 알고리즘의 시간 복잡도가 $O(n\log n)$ 이므로, selection 문제는 $O(n\log n)$ 내에 해결할 수 있다.


2. Quick Select

quick sort와 유사하게 partitioning을 이용하지만, binary search 처럼 k 번째 요소가 속한 부분만을 확인한다는 점에서 다르다.

partitioning이란, 임의의 값인 pivot을 기준으로 pivot 보다 작은 수는 왼쪽에, 큰 수는 오른쪽에 위치하도록 요소를 배치하는 방법이다.

partitioning 결과로 나뉜 두 부분 중, k 번째 요소가 속한 부분에 대해서만 재귀적으로 quick select를 수행하면 된다.

pivot_index의 크기동작
pivot_index > kleft part에 대해 재귀 수행
pivot_index < kright part에 대해 재귀 수행
pivot_index == kk 번째 원소 발견! 🎉


3. Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def partition(left, right):
    # pivot보다 작은 원소들 이후의 첫 번째 index를 가리키게 되므로,
    # 마지막에 pivot과 swap 되면 pivot의 index를 나타냄
    p = left

    # 배열의 맨 오른쪽 값을 pivot으로 설정
    pivot = arr[right]

    for i in range(left, right):
        if arr[i] <= pivot:
            arr[i], arr[p] = arr[p], arr[i]
            p += 1

    arr[p], arr[right] = arr[right], arr[p]

    return p

def quick_select(left, right):
    pivot_index = partition(left, right)

    if pivot_index > k:     # -- left part
        return quick_select(left, pivot_index - 1)
    elif pivot_index < k:   # -- right part
        return quick_select(pivot_index + 1, right)
    else:                   # -- find kth element
        return arr[pivot_index]

quick_select(0, len(nums) - 1)


4. Time Complexity

quick sort와 마찬가지로 pivot 값에 따라 알고리즘의 성능이 결정된다.

  • best case: $O(n)$

    선택된 pivot 값이 k 번째 수인 경우

  • worst case: $O(n^2)$

    선택된 pivot 값이 배열 내 가장 작거나 큰 값인 경우

  • average: $O(n)$

    pivot 값을 기준으로 배열이 1/2로 나누어지는 경우, $n+\frac{n}{2}+\frac{n}{4}+\dots+\frac{n}{2^k}=2n$


LeetCode의 973. K Closest Points to Origin 문제를 풀어보자!

[방법 #1] Max Heap 구성하기

max heap의 길이를 최대 k로 유지한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import heapq

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        """
        heapq
        - TC: O(nlogk)
        - SC: O(k)
        """

        def get_distance(point):
            x, y = point
            return x * x + y * y
        
        q = []
        for point in points:
            heapq.heappush(q, (-get_distance(point), point))  # as max heap

            if len(q) > k:
                heapq.heappop(q)
        
        return [point for _, point in q]


[방법 #2] Quick Select 응용하기 (Randomized Quick Select)

  1. partition(left, right)에서 pivot을 선택할 때 random 하게 선택하는 randomized quick select를 사용하였다.

    • pivot을 맨 오른쪽(right) 원소로 설정한 코드와 비교했을 때, 실제로 제출 시 실행 시간이 1/10로 단축되었다.

      quick select도 quick sort와 마찬가지로, pivot을 최댓값 혹은 최솟값으로 설정하면 $O(n^2)$이기 때문이다.

    • randomized quick select를 구현하기 위해서는 (1) randint(left, right)random을 구하고, (2) 이를 right와 swap만 해주면 된다. 이렇게 되면 선택된 random pivot이 맨 오른쪽(right)에 위치하게 된다.

  2. 이 문제에서는 “값”이 아닌 “거리”를 기준으로 하고 있으므로, 단순히 pivot의 값을 비교하는 것이 아닌 i가 가리키는 원소의 distance를 비교해야 함에 주의한다.

    points[right]가 아닌 get_distance(points[right])를, points[i]가 아닌 get_distance(points[i])와 비교해야 한다.

  3. 적당한 pivot_index를 찾은 순간 실행을 종료하기 위해 더 직관적인 iterative 형식으로 작성했다.

    recursive 형식뿐만 아니라 iterative 형식으로도 작성 가능하다!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        """
        quick select
        - TC: O(n)
        - SC: O(1)
        """

        def get_distance(point):
            x, y = point
            return x * x + y * y

        def partition(left, right):
            p = left

            # pivot를 random하게 선택하고, right와 swap
            # 즉, pivot을 right 위치로 이동시킴
            random = randint(left, right)
            points[random], points[right] = points[right], points[random]

            # pivot의 distance 구하기
            pivot_distance = get_distance(points[right])

            for i in range(left, right):
                # 현재 보고 있는 point의 distance가 pivot distance 이하이면 p와 swap
                if get_distance(points[i]) <= pivot_distance:
                    points[i], points[p] = points[p], points[i]
                    p += 1
            
            # 마지막으로 p와 right swap
            # - p를 기준으로, points[p]의 distance보다 작은 distance인 point는 왼쪽에,
            # - points[p]의 distance보다 큰 distance인 point은 오른쪽에 위치
            points[p], points[right] = points[right], points[p]

            return p

        # iterative quick select
        left, right, pivot_index = 0, len(points) - 1, len(points)
        while pivot_index != k:
            pivot_index = partition(left, right)
            if pivot_index < k:
                left = pivot_index + 1
            else:
                right = pivot_index - 1
        
        return points[:k]


References

This post is licensed under CC BY 4.0 by the author.

[Python] 파이썬의 반올림은 사사오입? 오사오입? (+ 부동 소수점, Decimal)

[Python] 메타클래스(Metaclass)