Problem Description

Design a data structure that follows the constraints of a Least Recently Used (LRU) cache.

Implement the LRUCache class:

  • LRUCache(int capacity) Initialize the LRU cache with positive size capacity.
  • int get(int key) Return the value of the key if the key exists, otherwise return -1.
  • void put(int key, int value) Update the value of the key if the key exists. Otherwise, add the key-value pair to the cache. If the number of keys exceeds the capacity from this operation, evict the least recently used key. The functions get and put must each run in O(1) average time complexity.

Example 1:

  • Input: ["LRUCache", "put", "put", "get", "put", "get", "put", "get", "get", "get"] [[2], [1, 1], [2, 2], [1], [3, 3], [2], [4, 4], [1], [3], [4]]
  • Output: [null, null, null, 1, null, -1, null, -1, 3, 4]
  • Explanation:
LRUCache lRUCache = new LRUCache(2);
lRUCache.put(1, 1); // cache is {1=1}
lRUCache.put(2, 2); // cache is {1=1, 2=2}
lRUCache.get(1);    // return 1
lRUCache.put(3, 3); // LRU key was 2, evicts key 2, cache is {1=1, 3=3}
lRUCache.get(2);    // returns -1 (not found)
lRUCache.put(4, 4); // LRU key was 1, evicts key 1, cache is {4=4, 3=3}
lRUCache.get(1);    // return -1 (not found)
lRUCache.get(3);    // return 3
lRUCache.get(4);    // return 4

My Idea

The idea here is to use a doubly linked list to store the (key,value) pairs and their order in combination with a hash map to ensure O(1) lookup. The attribute head represents the element that was accessed furthest back in time and tail the element that was just recently accessed.

My solution

class ListNode:
    def __init__(self, key=0, val=0, next=None, prev=None):
        self.key = key
        self.val = val
        self.next = next
        self.prev = prev
# put() and get() with Time Complexity: O(1)
class LRUCache:

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.records = dict()
        self.head = None
        self.tail = None

    def get(self, key: int) -> int:
        if key not in self.records:
            return -1
        else:
            if key == self.tail.key:
                return self.tail.val
            if key == self.head.key:
                self.head = self.head.next
                if self.head:
                    self.head.prev = None
            else:
                self.records[key].prev.next = self.records[key].next
                if self.records[key].next:
                    self.records[key].next.prev = self.records[key].prev
            self.tail.next = self.records[key]
            self.records[key].prev = self.tail
            self.records[key].next = None
            self.tail = self.records[key]
            return self.tail.val

    def add(self, key: int, value: int) -> None:
        new_node = ListNode(key, value, None, self.tail)
        self.records[key] = new_node
        if self.tail:
            self.tail.next = new_node
        self.tail = new_node
        if not self.head:
            self.head = new_node

    def delete(self, key) -> None:
        if key not in self.records:
            return

        node = self.records[key]

        if node == self.head:
            self.head = self.head.next
            if self.head:
                self.head.prev = None
            else:
                self.tail = None  # If empty, reset tail
        elif node == self.tail:
            self.tail = self.tail.prev
            if self.tail:
                self.tail.next = None
        else:
            node.prev.next = node.next
            node.next.prev = node.prev

        del self.records[key]

    def put(self, key: int, value: int) -> None:
        if key in self.records:
            self.delete(key)
        elif len(self.records) >= self.capacity:
            self.delete(self.head.key)

        self.add(key, value)