Source code for snat_sim.utils.caching
"""The ``caching`` module defines numpy compatible function wrappers
for implementing function level memoization.
Usage Example
-------------
The builtin Python memoization routines (e.g., ``lru_cache``) are not
compatible with ``numpy`` arrays because array objects are not hashable.
The ``Cache`` decorator provides an alternative memoization solution
that supports numpy arguments. Arguments that are numpy arguments must be
specified by name when constructing the decorator:
.. doctest:: python
>>> import numpy as np
>>> from snat_sim.utils.caching import Cache
>>> def add(x: np.array, y: np.array) -> np.array:
... print('The function has been called!')
... return x + y
>>> add = Cache(add, 1000, 'x', 'y')
>>> x_arr = np.arange(1, 5)
>>> y_arr = np.arange(5, 9)
>>> print(add(x_arr, y_arr))
The function has been called!
[ 6 8 10 12]
>>> print(add(x_arr, y_arr))
[ 6 8 10 12]
Module API
----------
"""
from __future__ import annotations
import inspect
import sys
from collections import OrderedDict
from typing import Any, Callable, Hashable, Tuple, Type
import numpy as np
[docs]class MemoryCache(OrderedDict):
"""Ordered dictionary with an imposed limit on overall memory usage"""
[docs] def __init__(self, max_size: int = None):
"""Ordered dictionary with an imposed size limit im memory.
When memory usage exceeds the predefined amount, remove the oldest
entry from the cache.
Args:
max_size: Maximum memory size in bytes
"""
super(MemoryCache, self).__init__()
self.max_size = max_size
if (max_size is not None) and not (isinstance(max_size, int) and not (max_size <= 0)):
raise ValueError('Maximum cache size must be a positive integer')
def __setitem__(self, key: Hashable, value: Any):
"""Update an entry in the hash table."""
OrderedDict.__setitem__(self, key, value)
self._check_size_limit()
def _check_size_limit(self) -> None:
"""Pop items from memory until instance size is <= the size limit."""
if self.max_size is not None:
while self and sys.getsizeof(self) > self.max_size:
self.popitem(last=False)
[docs]class Cache(MemoryCache):
"""Memoization function wrapper"""
[docs] def __init__(self, function: callable, cache_size: int, *numpy_args: str) -> None:
"""Memoization decorator supporting ``numpy`` arrays.
Args:
*numpy_args: Function arguments to treat as numpy arrays
cache_size: Maximum memory to allocate to cached data in bytes
Returns:
A callable function decorator
"""
self.function = function
self.cache_size = cache_size
self.numpy_args = numpy_args
super(Cache, self).__init__(max_size=cache_size)
def __call__(self, *args: Any, **kwargs: Any) -> Callable:
"""Cache return values of the given function
Args:
function: The function to cache returns of
Returns:
The wrapped function
"""
kwargs_for_key = inspect.getcallargs(self.function, *args, **kwargs)
for arg_to_cast in self.numpy_args:
kwargs_for_key[arg_to_cast] = np.array(kwargs_for_key[arg_to_cast]).tobytes()
key = tuple(kwargs_for_key.items())
try:
return self[key]
except KeyError:
new_val = self.function(*args, **kwargs)
self[key] = new_val
return new_val
def __reduce__(self) -> Tuple[Type[Cache], Tuple]:
# Ensures instances can be pickled
return self.__class__, (self.function, self.cache_size, *self.numpy_args)