29 lines
997 B
Python
29 lines
997 B
Python
import numpy as np
|
|
|
|
|
|
class DataSplitter:
|
|
|
|
def __init__(self, x, y):
|
|
(self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)
|
|
(self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)
|
|
|
|
def getTrain(self):
|
|
return (self.x_train, self.y_train)
|
|
|
|
def getValid(self):
|
|
return (self.x_valid, self.y_valid)
|
|
|
|
def getTest(self):
|
|
return (self.x_test, self.y_test)
|
|
|
|
@staticmethod
|
|
def _splitData(x, y, train_size=0.9, shuffle=True):
|
|
size = len(x)
|
|
indices = np.arange(size)
|
|
if shuffle:
|
|
np.random.shuffle(indices)
|
|
train_samples = int(size * train_size)
|
|
x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]
|
|
x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]
|
|
return (x_train, y_train), (x_test, y_test)
|