segunda-feira, 6 de setembro de 2010

Combinação(n, r) em Python

Gerar todas combinações possíveis de uma coleção de ítens é muito útil para jogos, por exemplo, para criar uma inteligência artificial para o computador. Python oferece funções prontas pra isso, como pode ser visto em http://docs.python.org/library/itertools.html. De qualquer modo, em alguma situação pode ser útil que as combinações sejam geradas iterativamente, uma de cada vez, como é sugerido pelo algoritmo de Rosen. Segue uma implementação que fiz desse algoritmo:


import math class CombinationGenerator(object): def __init__(self, n, r): if n < 1: raise Exception('Invalid argument: n (%d) must be greater than 1.' % (n,)) elif r < 1: raise Exception('Invalid argument: r (%d) must be greater than 1.' % (r,)) elif r > n: raise Exception('Invalid argument: r (%d) must be lower or equal than n (%d)' % (r, n)) self.n = n self.r = r # total = n! / r! * (n - r)! self._total = math.factorial(n) / (math.factorial(r) * math.factorial(n-r)) self._current = [] def first_combination(self): return list(range(self.r)) def last_combination(self): return list(range(self.n - self.r, self.n)) def __get_index_of_last_valid_element_to_increment(self): for index in reversed(range(self.r)): possibleValue = self._current[index] + 1 possibleValueIsValid = (possibleValue <= (self.n - (self.r - index))) if possibleValueIsValid: return index return None def total(self): return self._total def reset(self): self._current = [] def current(self): return self._current def has_next(self): return self._current != self.last_combination() # Algorithm from Rosen def next(self): ''' Return a list of indexes ''' if len(self._current) == 0: self._current = self.first_combination() else: theIndex = self.__get_index_of_last_valid_element_to_increment() if theIndex == None: self._current = self.first_combination() else: self._current[theIndex] += 1 ref = self._current[theIndex] for index in range(theIndex + 1, self.r): ref += 1 self._current[index] = ref return self._current

Testes de Unidade:
import unittest class CombinationGeneratorTests(unittest.TestCase): def test_1Choose1(self): generator = utilities.CombinationGenerator(1, 1) self.assertEquals(1, generator.total()) self.assertEquals([0], generator.next()) self.assertEquals([0], generator.first_combination()) self.assertEquals([0], generator.last_combination()) def test_NChooseN(self): generator = utilities.CombinationGenerator(3, 3) self.assertEquals(1, generator.total()) self.assertEquals([0, 1, 2], generator.next()) self.assertEquals([0, 1, 2], generator.first_combination()) self.assertEquals([0, 1, 2], generator.last_combination()) def test_2Choose1(self): generator = utilities.CombinationGenerator(2, 1) self.assertEquals(2, generator.total()) self.assertEquals([0], generator.next()) self.assertEquals([1], generator.next()) self.assertEquals([0], generator.first_combination()) self.assertEquals([1], generator.last_combination()) def test_3Choose1(self): generator = utilities.CombinationGenerator(3, 1) self.assertEquals(3, generator.total()) self.assertEquals([0], generator.next()) self.assertEquals([1], generator.next()) self.assertEquals([2], generator.next()) def test_3Choose2(self): generator = utilities.CombinationGenerator(3, 2) self.assertEquals(3, generator.total()) self.assertEquals([0, 1], generator.next()) self.assertEquals([0, 2], generator.next()) self.assertEquals([1, 2], generator.next()) def test_4Choose3(self): generator = utilities.CombinationGenerator(4, 3) self.assertEquals(4, generator.total()) self.assertEquals([0, 1, 2], generator.next()) self.assertEquals([0, 1, 3], generator.next()) self.assertEquals([0, 2, 3], generator.next()) self.assertEquals([1, 2, 3], generator.next()) def test_4Choose2(self): generator = utilities.CombinationGenerator(4, 2) self.assertEquals(6, generator.total()) self.assertEquals([0, 1], generator.next()) self.assertEquals([0, 2], generator.next()) self.assertEquals([0, 3], generator.next()) self.assertEquals([1, 2], generator.next()) self.assertEquals([1, 3], generator.next()) self.assertEquals([2, 3], generator.next()) def test_5Choose3(self): generator = utilities.CombinationGenerator(5, 3) self.assertEquals(10, generator.total()) self.assertEquals([0, 1, 2], generator.next()) self.assertEquals([0, 1, 3], generator.next()) self.assertEquals([0, 1, 4], generator.next()) self.assertEquals([0, 2, 3], generator.next()) self.assertEquals([0, 2, 4], generator.next()) self.assertEquals([0, 3, 4], generator.next()) self.assertEquals([1, 2, 3], generator.next()) self.assertEquals([1, 2, 4], generator.next()) self.assertEquals([1, 3, 4], generator.next()) self.assertEquals([2, 3, 4], generator.next()) self.assertEquals([0, 1, 2], generator.first_combination()) self.assertEquals([2, 3, 4], generator.last_combination()) def test_Sanity15Choose4(self): generator = utilities.CombinationGenerator(15, 4) for x in range(30): combination = generator.next() self.assertEquals(4, len(combination)) for y in combination: self.assertTrue(y >= 0 and y < 15) def test_Sanity50Choose7(self): generator = utilities.CombinationGenerator(50, 7) for x in range(100): combination = generator.next() self.assertEquals(7, len(combination)) for y in combination: self.assertTrue(y >= 0 and y < 50) def test_invalid_input(self): self.assertRaises(Exception, utilities.CombinationGenerator, 0, 0) self.assertRaises(Exception, utilities.CombinationGenerator, -2, -1) self.assertRaises(Exception, utilities.CombinationGenerator, 1, 2) self.assertRaises(Exception, utilities.CombinationGenerator, 2, 0) def test_performance(self): n = 16 # Java Rosen algorithm 22 in 1 second for r in range(1, n+1): generator = utilities.CombinationGenerator(n, r) while generator.has_next(): indexes = generator.next() def test_AfterLastCombinationRestart(self): generator = utilities.CombinationGenerator(3, 2) self.assertEquals([0, 1], generator.next()) self.assertEquals([0, 2], generator.next()) self.assertEquals([1, 2], generator.next()) self.assertEquals([0, 1], generator.next()) def test_has_next(self): generator = utilities.CombinationGenerator(3, 2) self.assertTrue(generator.has_next()) generator.next() self.assertTrue(generator.has_next()) generator.next() self.assertTrue(generator.has_next()) generator.next() self.assertFalse(generator.has_next()) generator.next() self.assertTrue(generator.has_next()) def test_reset(self): generator = utilities.CombinationGenerator(3, 2) self.assertEquals([0, 1], generator.next()) generator.reset() self.assertEquals([0, 1], generator.next()) def test_current(self): generator = utilities.CombinationGenerator(3, 2) self.assertEquals([0, 1], generator.next()) self.assertEquals([0, 1], generator.current()) self.assertEquals([0, 2], generator.next()) self.assertEquals([0, 2], generator.current())

Nenhum comentário:

Postar um comentário