refactoring
This commit is contained in:
28
src/captcha/DataSplitter.py
Normal file
28
src/captcha/DataSplitter.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataSplitter:
|
||||
|
||||
def __init__(self, x, y):
|
||||
(self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)
|
||||
(self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)
|
||||
|
||||
def getTrain(self):
|
||||
return (self.x_train, self.y_train)
|
||||
|
||||
def getValid(self):
|
||||
return (self.x_valid, self.y_valid)
|
||||
|
||||
def getTest(self):
|
||||
return (self.x_test, self.y_test)
|
||||
|
||||
@staticmethod
|
||||
def _splitData(x, y, train_size=0.9, shuffle=True):
|
||||
size = len(x)
|
||||
indices = np.arange(size)
|
||||
if shuffle:
|
||||
np.random.shuffle(indices)
|
||||
train_samples = int(size * train_size)
|
||||
x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]
|
||||
x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
Reference in New Issue
Block a user