TimothyQiu's Blog

keep it simple stupid

用 STL 寻找前 K 个最大的数

分类:技术

前几天找资料,顺手看到了这篇博文,讲了博主如何优化自己以前的算法的过程。于是立马想起了 GoingNative 2013 上的演讲 C++ Seasoning 以及前不久看到的这篇 From goto to std::transform。于是不禁想,「寻找前 K 大数」这样的任务,能不能直接用 C++ 标准库来完成?如果能的话,性能又如何呢?于是我重新登录了那个以前做了两道题就尘封多年的 OJ 帐号……

凭着最直观最 naïve 想法,第一个版本我用了 std::sort 来代替手工排序。不出所料,这样的做法跑了 562 ms,要排到 700 名开外的样子。

#include <stdio.h>
#include <algorithm>
#include <functional>
#include <vector>

int main()
{
    int numVillagers;
    int topCount;

    while (scanf("%d %d\n", &numVillagers, &topCount) == 2) {
        if (numVillagers == 0 && topCount == 0) {
            break;
        }

        std::vector<int> wealth(numVillagers);
        for (int i = 0; i < numVillagers; i++) {
            scanf("%d", &(wealth[i]));
        }
        std::sort(wealth.begin(), wealth.end(), std::greater<int>());

        if (topCount > numVillagers) {
            topCount = numVillagers;
        }

        for (int i = 0; i < topCount - 1; i++) {
            printf("%d ", wealth[i]);
        }
        printf("%d\n", wealth[topCount - 1]);
    }
}

于是优化:考虑到每次新建 std::vector 的开销,把 wealth 拎出来放到循环外,reserve() 题目中的最大值,每次在循环里重新 resize() 后,时间顿时缩短到了 296 ms,可以排到 180+ 的样子。

std::vector<int> wealth;
wealth.reserve(100000);

而在前面说过的演讲中,我还听到了一个之前从未留意的函数 std::nth_element。其作用是确保调用后第 N 个元素是范围内第 N 大的元素。调用后,[begin, N) 内的任意元素都小于 (N, end) 内的任意元素。

std::nth_element(wealth.begin(), wealth.begin() + topCount, wealth.end(), std::greater<int>());
std::sort(wealth.begin(), wealth.begin() + topCount, std::greater<int>());

尝试修改成以上的略显罗嗦的代码后,程序运行时间缩短到了 171 ms,可以排到 70+ 的样子。时间上和博文中给出的最小堆实现相同,但是内存占用要比它大很多。

既然上面是先用 std::nth_element 大致排序了一下,而后再用 std::sort 排序前半部分,那么,STL 里是否存在一次性只排 [begin, N] 范围内的数,而无视 (N, end) 内的顺序的函数呢?答案是存在,可以直接使用 std::partial_sort 解决。

#include <stdio.h>
#include <algorithm>
#include <functional>
#include <vector>

int main()
{
    int numVillagers;
    int topCount;

    std::vector<int> wealth;
    wealth.reserve(100000);

    while (scanf("%d %d\n", &numVillagers, &topCount) == 2) {
        if (numVillagers == 0 && topCount == 0) {
            break;
        }

        wealth.resize(numVillagers);
        for (int i = 0; i < numVillagers; i++) {
            scanf("%d", &(wealth[i]));
        }

        if (topCount > numVillagers) {
            topCount = numVillagers;
        }

        std::partial_sort(wealth.begin(), wealth.begin() + topCount, wealth.end(), std::greater<int>());

        for (int i = 0; i < topCount - 1; i++) {
            printf("%d ", wealth[i]);
        }
        printf("%d\n", wealth[topCount - 1]);
    }
}

此时的程序执行时间 156 ms,排名第 25 位。而截止到此时,我个人其实什么都没有干,实际任务都交给了 STL。

如同各种 MyStringMyVector 一样,一般情况下自己去实现这些通用算法在 STL 面前几乎没有任何优势。尤其是 C++11 带来 lambda 表达式以来,大大方便了 <algorithm> 中各种算法的使用,我觉得不去用 STL 的理由越发稀少了。

以上。

C++

已有 2 条评论 »

  1. 尤其是 C++11 带来 lambda 表达式以来,大大方便了 <algorithm> 中各种算法的使用,我觉得不去用 STL 的理由越发稀少了。
    右值引用直接带来了效率上的提升。

  2. wawa wawa

    感谢!

添加新评论 »