diff --git a/scripts/ML/split.hs b/scripts/ML/split.hs new file mode 100755 index 0000000000000000000000000000000000000000..d645b84e9699e903a2b230c6448ed69dda8c6c9b --- /dev/null +++ b/scripts/ML/split.hs @@ -0,0 +1,33 @@ +#!/usr/bin/env -S runhaskell --ghc-arg="-Wall" --ghc-arg="-i lib" --ghc-arg="-fprint-potential-instances" +import Data.Text as Text (Text) +import Data.Text.IO as Text (getContents, writeFile) +import System.Environment (getArgs) +import System.Script (syntax, try) +import Text.Filter (Editable(..)) +import Text.Read (readEither) + +parseRatio :: String -> Either String Float +parseRatio input = asFloat >>= checkBounds + where + checkBounds float + | float >= 0 && float < 1 = pure float + | otherwise = Left "Ratio must represent a number between 0 and 1" + asFloat = + case reverse input of + '%':f -> (/100) <$> readEither (reverse f) + _ -> readEither input + +split :: Float -> [Text] -> ([Text], [Text]) +split ratio texts = splitAt cutLine texts + where + cutLine = round . (ratio *) . fromIntegral $ length texts + +main :: IO () +main = getArgs >>= run + where + run [trainRatio, trainPath, testPath] = do + ratio <- try (pure $ parseRatio trainRatio) + (train, test) <- split ratio . enter <$> Text.getContents + Text.writeFile trainPath $ leave train + Text.writeFile testPath $ leave test + run _ = syntax "TRAIN_RATIO TRAIN_PATH TEST_PATH"