Files
HowBadIsMyBatch/src/captcha/DataSplitter.py
frankknoll 5379644a89 refactoring
2023-03-15 17:28:11 +01:00

29 lines
997 B
Python

import numpy as np
class DataSplitter:
def __init__(self, x, y):
(self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)
(self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)
def getTrain(self):
return (self.x_train, self.y_train)
def getValid(self):
return (self.x_valid, self.y_valid)
def getTest(self):
return (self.x_test, self.y_test)
@staticmethod
def _splitData(x, y, train_size=0.9, shuffle=True):
size = len(x)
indices = np.arange(size)
if shuffle:
np.random.shuffle(indices)
train_samples = int(size * train_size)
x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]
x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]
return (x_train, y_train), (x_test, y_test)