算法学习笔记(5):Dijkstra

在刷Hackerrank的时候碰到了这道题,其本身倒是十分普通的一道题,照本宣科的把Dijkstra实现一遍就可以做出来了。在讨论区和别人讨论发现的时间复杂度也是会被Accept的。

如果使用数组储存所有unvisited点的距离,在找当前最近点的时候,会需要遍历整个数组,而这是的时间复杂度。如果使用一个最小优先级队列,那么只需要将队首的元素取出,然后min-heapify一下,保持最小堆的性质,而这是的时间复杂度。能够显著减少时间消耗。

但是如果使用最小堆的话,降低堆中元素的优先级的操作就很麻烦了。首先如果想要找到堆中的某个元素就需要的时间复杂度,而修改了这个元素的优先级之后,想要维持最小堆的性质又要的时间复杂度。所以整体而言就变得和没有使用最小堆一样了。

为了解决这个问题有一个弥补的方法。就是在一个散列表中存储优先级队列中对应点的元素的handle。

如果想要降低某个点的优先级,那么可以通过查询散列表,获得最小堆中对应元素的handle,然后直接标志该点已经被删除(但是不实际从堆中删除),然后在储存堆的数组的最后新添加一个储存新优先级和该点的元素。标志删除的操作因为没有破坏最小堆的性质,只是 的时间复杂度,而最小堆中插入一个元素是 的时间复杂度。最后就实现了 的降低优先级的操作。

如果使用python实现的话,python的官方文档中简单介绍了使用heapq实现最小优先级队列的方法。

不过残念的是,add_task()方法的第一行就出现了个遍历dict,而这是的,所以推荐额外再使用一个set来维护entry_finder.keys(),这样查找和删除操作就都是 了。

另外一个有意思的事情是,heappush()原地的,所以pq会在原地修改。而pq中保存的是[distance, vertex]的list,而list是mutable的。所以在remove_vertex()把entry[-1]也就是vertex标志城REMOVED了之后,pq中的对应entry会自动修改。

最后挂上自己的代码。

# Enter your code here. Read input from STDIN. Print output to STDOUT
from heapq import *
 
def add_vertex(vertex, distance=0):
    if vertex in entry_set:
        remove_vertex(vertex)
    entry = [distance, vertex]
    entry_finder[vertex] = entry
    entry_set.add(vertex)
    heappush(pq, entry)
 
def remove_vertex(vertex):
    entry = entry_finder.pop(vertex)
    entry[-1] = REMOVED
    entry_set.remove(vertex)
 
def pop_vertex():
    while pq:
        priority, vertex = heappop(pq)
        if vertex is not REMOVED:
            del entry_finder[vertex]
            return priority, vertex
 
T = int(raw_input())
for T_ in range(T):
    pq = []
    entry_finder = {}
    REMOVED = "REMOVED"
 
    N, M = [int(i) for i in raw_input().split(" ")]
    connection = [[] for conn_ in range(N)]
    for M_ in range(M):
        x, y, r = [int(i)-1 for i in raw_input().split(" ")]
        connection[x].append([y, r+1])
        connection[y].append([x, r+1])
    S = int(raw_input())-1
     
    entry_set = set()
    unvisited = set()
    distance = {}
 
    for node in range(M):
        unvisited.add(node)
        if node!=S:
            distance[node] = float("Inf")
            add_vertex(vertex = node, distance = float("Inf"))
        else:
            distance[node] = 0
            add_vertex(vertex = node, distance = 0)
 
    while len(unvisited)!=0:
        current_distance, current_node = pop_vertex()
        if current_distance==float("Inf"):
            break
        for node, edge in connection[current_node]:
            if node in unvisited:
                new_distance = current_distance + edge
                if new_distance < distance[node]:
                    distance[node] = new_distance
                    add_vertex(node, new_distance)
                distance[node] = new_distance if new_distance < distance[node] else distance[node]
        unvisited.remove(current_node)
 
    for node in range(N):
        if node!=S:
            if distance[node]==float("Inf"):
                print -1,
            else:
                print distance[node],
    print