#!/usr/bin/env python
# encoding: utf-8
# Stdlib:
from itertools import combinations
from collections import Counter, deque
import math
import logging
LOGGER = logging.getLogger(__name__)
# Internal:
from munin.song import Song
from munin.helper import sliding_window, centering_window, RunningMean
from munin.history import ListenHistory, RuleIndex
import munin.plot
[docs]class Database:
'Class managing Database concerns.'
def __init__(self, session):
"""Usually you access this as ``.database`` attribute of
:class:`munin.session.Session`.
You can do the following tasks with it:
* Trigger updates (:func:`rebuild`)
* Get a plot of the graph for debuggin purpose.
* Iterative over the database (``for song in database``).
* Get a song by it's uid. (``database[song.uid]``)
.. note::
The division of :class:`munin.session.Session` and :class:`Database`
is purely cosmetical. Both classes cannot exist on its own.
"""
self._session = session
self._song_list = []
self._reset_history()
def _reset_history(self):
self._revoked_uids = set()
self._listen_history = ListenHistory(
maxlen=self._session.config['history_max_pkg'],
time_threshold_sec=self._session.config['history_timeout'],
max_group_size=self._session.config['history_pkg_size']
)
self._rule_index = RuleIndex(
maxlen=self._session.config['history_max_rules']
)
self._playcounts = Counter()
def __iter__(self):
return filter(None, self._song_list)
def __len__(self):
return len(self._song_list) - len(self._revoked_uids)
[docs] def __getitem__(self, idx):
"""Lookup a certain song by it's uid.
:param uid: A uid previously given by
:returns: a :class:`munin.song.Song`, which is a read-only mapping of normalized attributes.
"""
try:
return self._song_list[idx]
except IndexError:
raise IndexError('song uid #{} is invalid'.format(idx))
def _current_uid(self):
if self._revoked_uids:
return self._revoked_uids.pop()
return len(self._song_list)
[docs] def plot(self, width=1000, height=1000, **kwargs):
"""Plot the current graph for debugging purpose.
Will try to open an installed image viewer - does not return an image.
:param database: The database (and the assoicate graph with it) to plot.
:param width: Width of the plotted image in pixel.
:param height: Width of the plotted image in pixel.
"""
munin.plot.plot(self, width, height, **kwargs)
def playcount(self, song):
return self._playcounts.get(song, 0)
def playcounts(self, n=0):
if n < 1:
return self._playcounts
else:
return self._playcounts.most_common(n)
def feed_history(self, song):
if self._listen_history.feed(song):
rules = self._listen_history.find_rules()
self._rule_index.insert_rules(rules)
self._playcounts[song] += 1
def find_matching_attributes(self, subset, max_numeric_offset=None):
if max_numeric_offset is None:
return self.find_matching_attributes_generic(subset)
else:
return self.find_matching_attributes_numeric(subset, max_numeric_offset)
def find_matching_attributes_numeric(self, subset, max_offset):
try:
numerics = {}
for key, value in subset.items():
provider = self._session.provider_for_key(key)
numerics[key] = provider.process(value)
for song in self:
for key in (numerics.keys() & song.keys()):
value = song.get(key)
if value is None:
break
compar = numerics[key][0]
if not (compar - max_offset <= value[0] <= compar + max_offset):
break
else:
yield song
except KeyError:
raise KeyError('key "{k}" is not in mask'.format(k=key))
def find_matching_attributes_generic(self, subset):
try:
value_set = set()
for key, value in subset.items():
provider = self._session.provider_for_key(key)
value_set.add(provider.process(value))
for song in self:
if all((song[key] in value_set for key in subset.keys())):
yield song
except KeyError:
raise KeyError('key "{k}" is not in mask'.format(k=key))
def _rebuild_step_base(self, mean_counter, window_size, step_size):
"""Do the Base Iterations.
This involves three iterations:
* :func:`munin.helper.sliding_window`
Window over the List (overlapping with * window_size/step_size).
* :func:`munin.helper.centering_window` with `parallel=True`.
* :func:`munin.helper.centering_window` with `parallel=True`.
:param mean_counter: A RunningMean counter to sample the initial mean/sd
:param window_size: The max. size of the window in which combinations are taken.
:param step_size: The movement of the window per iteration.
"""
if window_size is None:
window_size = self._session.config['rebuild_window_size']
if step_size is None:
step_size = self._session.config['rebuild_step_size']
# Base Iteration:
slider = sliding_window(self, window_size, step_size)
center = centering_window(self, window_size // 2)
anticn = centering_window(self, window_size // 2, parallel=False)
# Prebind the functions for performance reasons.
compute = Song.distance_compute
add = Song.distance_add
# Select the iterator:
for idx, iterator in enumerate((slider, center, anticn)):
LOGGER.debug('|-- Applying iteration #{}: {}'.format(idx + 1, iterator))
# Iterate over the list:
for window in iterator:
# Calculate the combination set:
for song_a, song_b in combinations(window, 2):
distance = compute(song_a, song_b)
add(song_a, song_b, distance)
# Sample the newly calculated distance.
mean_counter.add(distance.distance)
def _rebuild_step_refine(self, mean_counter, num_passes=None, mean_scale=None):
"""Do the refinement step.
.. seealso:: :func:`rebuild`
:param mean_counter: RunningMean Counter
:param num_passes: How many times the song list shall be iterated.
"""
if num_passes is None:
num_passes = self._session.config['rebuild_refine_passes']
if mean_scale is None:
mean_scale = self._session.config['rebuild_mean_scale']
# Prebind the functions for performance reasons:
add = Song.distance_add
compute = Song.distance_compute
# Do the whole thing `num_passes` times...
for n_iteration in range(num_passes):
threshold = (mean_counter.mean * mean_scale - mean_counter.sd) / mean_scale
newly_found = 0
# Go through the song_list...
for idx, song in enumerate(self):
# ..and remember each calculated distance
# we got from compare the song with its indirect neighbors.
result_set = deque()
# Iterate over the indirect neighbors (those having a certain
# distance lower than threshold):
for ind_ngb in set(song.distance_indirect_iter(threshold)):
distance = compute(song, ind_ngb)
result_set.append((ind_ngb, distance))
mean_counter.add(distance.distance)
# Add the distances (we should not do this during # iteration)
# Also count which of these actually
for ind_ngb, dist in result_set:
newly_found += add(song, ind_ngb, dist)
# Stop iteration when not enough new distances were gathered
# (at least one new addition per song)
# This usually only triggers for high num_passes
if newly_found < len(self) // 2:
break
LOGGER.debug('Did {}x (of max. {}) refinement steps.'.format(n_iteration, num_passes))
[docs] def rebuild_stupid(self):
"""(Re)build the graph by calculating the combination of all songs.
This is a *very* expensive operation which takes quadratic time and
only should be ever used for a small amount of songs where accuracy
matters even more thant time.
"""
for song_a, song_b in combinations(self._song_list, 2):
distance = Song.distance_compute(song_a, song_b)
Song.distance_add(song_a, song_b, distance)
[docs] def rebuild(self, window_size=None, step_size=None, refine_passes=None, stupid_threshold=None):
"""Rebuild all distances and the associated graph.
This will be triggered for you automatically after a transaction.
:param int window_size: The size of the sliding window in the base iteration.
:param int step_size: The amount to move the window per iteration.
:param int refine_passes: How often step #2 should be repeated.
:param int stupid_threshold: If less songs than this just brute forcely calculate all combations of songs.
"""
if stupid_threshold is None:
stupid_threshold = self._session.config['rebuild_stupid_threshold']
if len(self) < stupid_threshold:
LOGGER.debug('+ Step #1 + 2: Brute Force calculation due to few songs')
self.rebuild_stupid()
else:
# Average and Standard Deviation Counter:
mean_counter = RunningMean()
LOGGER.debug('+ Step #1: Calculating base distance (sliding window)')
self._rebuild_step_base(
mean_counter,
window_size=window_size,
step_size=step_size
)
LOGGER.debug('|-- Mean Distane: {:f} (sd: {:f})'.format(
mean_counter.mean, mean_counter.sd
))
LOGGER.debug('+ Step #2: Applying refinement:')
self._rebuild_step_refine(
mean_counter,
num_passes=refine_passes
)
LOGGER.debug('|-- Mean Distane: {:f} (sd: {:f})'.format(
mean_counter.mean, mean_counter.sd
))
self._reset_history()
def add(self, value_dict):
for key, value in value_dict.items():
try:
provider = self._session.provider_for_key(key)
if value is None:
value_dict[key] = None
else:
value_dict[key] = provider.process(value)
except KeyError:
raise KeyError('key "{k}" is not in mask'.format(k=key))
new_song = Song(
self._session, value_dict,
max_neighbors=self._session.config['max_neighbors'],
max_distance=self._session.config['max_distance']
)
new_song.uid = self._current_uid()
if new_song.uid >= len(self._song_list):
self._song_list.append(new_song)
else:
self._song_list[new_song.uid] = new_song
return new_song.uid
def fix_graph(self):
for song in self:
song.distance_finalize()
# This is just some sort of assert and has no functionality:
last = None
for other, dist in song.distance_iter():
if last is not None and last > dist:
LOGGER.critical('!! warning: unsorted elements: !({} < {})'.format(dist, last))
last = dist
def modify(self, song, sub_value_dict, star_threshold=0.75, iterstep_threshold=50):
value_dict = song.to_dict()
for key, value in sub_value_dict.items():
try:
provider = self._session.provider_for_key(key)
if value is None:
sub_value_dict[key] = None
else:
sub_value_dict[key] = provider.process(value)
except KeyError:
raise KeyError('key "{k}" is not in mask'.format(k=key))
value_dict.update(sub_value_dict)
new_song = Song(
self._session, value_dict,
max_neighbors=self._session.config['max_neighbors'],
max_distance=self._session.config['max_distance']
)
new_song.uid = self.remove(song.uid)
self._song_list[song.uid] = new_song
# Clear all know distances:
new_song.distance_reset()
return self._insert_song_to_graph(
new_song, star_threshold, iterstep_threshold
)
def insert(self, value_dict, star_threshold=0.75, iterstep_threshold=50):
new_song = self._song_list[self.add(value_dict)]
return self._insert_song_to_graph(
new_song, star_threshold, iterstep_threshold
)
def _insert_song_to_graph(self, new_song, star_threshold=0.75, iterstep_threshold=50):
next_len = len(self._song_list)
if len(self) < iterstep_threshold:
iterstep = 1
else:
iterstep = round(max(1, math.log(max(next_len, 1))))
# Step 1: Find samples with similar songs (similar to the base step)
distances = deque()
for song in self._song_list[::iterstep]:
if song is not None:
distance = Song.distance_compute(song, new_song)
distances.append((song, distance))
new_song.distance_add(song, distance)
# Step 2: Short refinement step
for song, distance in distances:
if distance.distance > star_threshold:
for neighbor in song.neighbors():
distance = new_song.distance_compute(neighbor)
new_song.distance_add(neighbor, distance)
return new_song.uid
def remove(self, uid):
if len(self._song_list) <= uid:
raise ValueError('Invalid UID #{}'.format(uid))
song = self._song_list[uid]
self._song_list[uid] = None
self._revoked_uids.add(uid)
# Patch the hole:
song.disconnect()
return uid
###########################################################################
# Test Stuff #
###########################################################################
if __name__ == '__main__':
import unittest
import sys
from munin.session import Session
from munin.provider import Provider
class DatabaseTests(unittest.TestCase):
def setUp(self):
self._session = Session('session_test', {
'genre': (None, None, 0.2),
'artist': (None, None, 0.3)
})
def test_modify(self):
from munin.distance.rating import RatingDistance
session = Session('session_test_modify', {
'rating': (None, RatingDistance(), 1),
})
with session.transaction():
for i in range(0, 6):
session.add({'rating': i})
self.assertAlmostEqual(session[5].distance_get(session[0]).distance, 0.5)
self.assertAlmostEqual(session[5]['rating'], (5, ))
with session.fix_graph():
session.modify(5, {'rating': 0})
self.assertAlmostEqual(session[5].distance_get(session[0]).distance, 0.0)
self.assertAlmostEqual(session[5]['rating'], (0, ))
def test_basics(self):
with self._session.transaction():
N = 20
for i in range(N):
self._session.database.add({
'genre': i / N,
'artist': i / N
})
def test_no_match(self):
with self.assertRaisesRegex(KeyError, '.*mask.*'):
self._session.database.add({
'not_in_session': 42
})
def test_insert_remove_song(self):
songs = []
with self._session.transaction():
for idx, v in enumerate(['l', 'r', 't', 'd']):
songs.append(self._session.add({'genre': [0], 'artist': [0]}))
# self._session.database.plot(250, 250)
with self._session.fix_graph():
self._session.insert({'genre': [0], 'artist': [0]})
# self._session.database.plot(250, 250)
for song in self._session.database:
for other in self._session.database:
if self is not other:
self.assertAlmostEqual(song.distance_get(other).distance, 0.0)
self._session.remove(4)
# self._session.database.plot(250, 250)
with self._session.fix_graph():
self._session.insert({'genre': [0], 'artist': [0]})
# self._session.database.plot(250, 250)
def test_find_matching_attributes_numeric(self):
from munin.provider import GenreTreeProvider
from munin.distance import GenreTreeDistance
from munin.helper import pairup
session = Session('session_find_test', {
'x': pairup(None, None, 1),
'y': pairup(None, None, 1)
})
a = session[session.add({
'x': 21,
'y': 42,
})]
b = session[session.add({
'x': 0,
'y': 100,
})]
session[session.add({
'x': 51,
'y': 50,
})]
self.assertEqual(list(session.database.find_matching_attributes_numeric(
{'x': 10},
20
)),
[a, b]
)
self.assertEqual(list(session.database.find_matching_attributes_numeric(
{'y': 100},
0
)),
[b]
)
self.assertEqual(list(session.database.find_matching_attributes_numeric(
{'x': 10, 'y': 40},
20
)),
[a]
)
self.assertEqual(list(session.database.find_matching_attributes_numeric(
{'x': 10, 'y': 10},
0,
)),
[]
)
def test_find_matching_attributes_generic(self):
from munin.provider import GenreTreeProvider
from munin.distance import GenreTreeDistance
from munin.helper import pairup
session = Session('session_find_test', {
'genre': pairup(GenreTreeProvider(), GenreTreeDistance(), 5),
'artist': pairup(None, None, 1)
})
session.add({
'artist': 'Berta',
'genre': 'death metal'
})
session.add({
'artist': 'Hans',
'genre': 'metal'
})
session.add({
'artist': 'Berta',
'genre': 'pop'
})
found = list(session.find_matching_attributes({'genre': 'metal'}))
self.assertEqual(len(found), 1)
self.assertEqual(found[0], session[1])
found = list(session.find_matching_attributes(
{'genre': 'metal', 'artist': 'Berta'}
))
self.assertEqual(len(found), 0)
found = list(session.find_matching_attributes(
{'genre': 'metal', 'artist': 'Hans'}
))
self.assertEqual(len(found), 1)
self.assertEqual(found[0], session[1])
found = list(session.find_matching_attributes(
{'genre': 'pop', 'artist': 'Berta'}
))
self.assertEqual(len(found), 1)
self.assertEqual(found[0], session[2])
found = list(session.find_matching_attributes({'artist': 'Berta'}))
self.assertEqual(len(found), 2)
self.assertEqual(found[0], session[0])
self.assertEqual(found[1], session[2])
def main():
from munin.testing import DummyDistanceFunction
LOGGER.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# add formatter to ch
ch.setFormatter(
logging.Formatter('%(name)s - %(levelname)s - %(message)s')
)
# add ch to logger
LOGGER.addHandler(ch)
session = Session('session_test', {
'genre': (None, DummyDistanceFunction(), 0.2),
'artist': (None, DummyDistanceFunction(), 0.3)
})
import math
with session.transaction():
N = 100
for i in range(int(N / 2) + 1):
session.add({
'genre': 1.0 - i / N,
'artist': 1.0 - i / N
})
# Pseudo-Random, but deterministic:
if '--euler' in sys.argv:
euler = lambda x: math.fmod(math.e ** x, 1.0)
session.database.add({
'genre': euler((i + 1) % 30),
'artist': euler((N - i + 1) % 30)
})
LOGGER.debug('+ Step #3: Layouting and Plotting')
if '--plot' in sys.argv:
session.database.plot(1000, 500)
if '--cli' in sys.argv:
main()
else:
unittest.main()