From 151aa9cd48b224313b6fbf884eb9880cc6a14507 Mon Sep 17 00:00:00 2001 From: frankknoll Date: Tue, 14 Feb 2023 16:51:22 +0100 Subject: [PATCH] adding MultiIndexExploderTest --- src/MultiIndexExploder.py | 12 +++++++++ src/MultiIndexExploderTest.py | 47 +++++++++++++++++++++++++++++++++++ src/Utils.py | 4 +++ 3 files changed, 63 insertions(+) create mode 100644 src/MultiIndexExploder.py create mode 100644 src/MultiIndexExploderTest.py diff --git a/src/MultiIndexExploder.py b/src/MultiIndexExploder.py new file mode 100644 index 00000000000..911e4eaa553 --- /dev/null +++ b/src/MultiIndexExploder.py @@ -0,0 +1,12 @@ +import numpy as np +import Utils + + +class MultiIndexExploder: + + @staticmethod + def explodeMultiIndexOfTable(table): + batchcodeColumns = table.index.names + explodedTable = table.loc[np.repeat(table.index, len(batchcodeColumns))].reset_index() + explodedTable['VAX_LOT_EXPLODED'] = Utils.flatten(table.index.values) + return explodedTable.set_index(['VAX_LOT_EXPLODED'] + batchcodeColumns) diff --git a/src/MultiIndexExploderTest.py b/src/MultiIndexExploderTest.py new file mode 100644 index 00000000000..a45f2c76444 --- /dev/null +++ b/src/MultiIndexExploderTest.py @@ -0,0 +1,47 @@ +import unittest +from pandas.testing import assert_frame_equal +from MultiIndexExploder import MultiIndexExploder +from TestHelper import TestHelper +import pandas as pd + +class MultiIndexExploderTest(unittest.TestCase): + + def test_explodeMultiIndexOfTable(self): + # Given + table = TestHelper.createDataFrame( + columns = ['DATA'], + data = [ ['A, B data'], + ['C, A data'], + ['C, B data']], + index = pd.MultiIndex.from_tuples( + names = ['VAX_LOT1', 'VAX_LOT2'], + tuples = [['A', 'B'], + ['C', 'A'], + ['C', 'B']])) + + # When + explodedTable = MultiIndexExploder.explodeMultiIndexOfTable(table) + + # Then + assert_frame_equal( + explodedTable, + TestHelper.createDataFrame( + columns = ['DATA'], + data = [ ['A, B data'], + ['A, B data'], + + ['C, A data'], + ['C, A data'], + + ['C, B data'], + ['C, B data']], + index = pd.MultiIndex.from_tuples( + names = ['VAX_LOT_EXPLODED', 'VAX_LOT1', 'VAX_LOT2'], + tuples = [['A', 'A', 'B'], + ['B', 'A', 'B'], + + ['C', 'C', 'A'], + ['A', 'C', 'A'], + + ['C', 'C', 'B'], + ['B', 'C', 'B']]))) diff --git a/src/Utils.py b/src/Utils.py index 51ac94b8ab7..00ae51a7ab5 100644 --- a/src/Utils.py +++ b/src/Utils.py @@ -4,3 +4,7 @@ def fillLsts(lsts, desiredLen, fillValue): def fillLst(lst, desiredLen, fillValue): return lst + [fillValue] * (max(desiredLen - len(lst), 0)) + + +def flatten(tuples): + return [item for tuple in tuples for item in tuple]