這篇只是最基礎簡單的使用教學

這篇基於 : https://dama.tw/threads/javabtccsv.705/ 取得的數據進行模型訓練

1. 使用 Smile​

Smile 是一個用 Java 編寫的機器學習庫,提供了豐富的算法和數據處理功能。Smile 的接口簡潔,性能良好,適合處理複雜的數據分析任務

github : https://github.com/haifengl/smile


Java:
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.io.FileWriter;
import java.io.IOException;
import java.time.LocalDate;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.*;

import org.apache.commons.csv.CSVFormat;
import org.json.JSONArray;
import smile.data.Tuple;
import smile.data.type.StructType;
import tools.FileoutputUtil;

import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.io.Read;
import smile.regression.RandomForest;
public class BitcoinPricePrediction {
private static HttpClient client = HttpClient.newHttpClient();

    public static void main(String[] args) throws Exception {
        CSVFormat format = CSVFormat.DEFAULT.withHeader();

        DataFrame data = Read.csv("E:/btc/data.csv",format);


        // 显示数据集结构,确认数据已正确加载
        System.out.println(data.schema());
        System.out.println(data.summary());

        // 選擇特徵和標籤
        Formula formula = Formula.lhs("Low");

        // 分割數據為訓練集和測試集
        DataFrame[] splits = splitData(data, 0.8);
        DataFrame trainData = splits[0];
        DataFrame testData = splits[1];

        // 使用隨機森林進行訓練
        RandomForest model = RandomForest.fit(formula, trainData);

        // 在測試集上進行預測
        double[] predictions = model.predict(testData);

        // 获取时间列数据
        long[] times = testData.longVector("OpenTime").array();

        // 输出指定时间点的预测结果
        /*for (int i = 0; i < predictions.length; i++) {
            System.out.println("Time: " + FileoutputUtil.getTime2(times[i]) + ", Predicted Low: " + predictions[i]);
        }*/

        // 预测最后一条数据的下一时间点
        int lastRowIndex = data.nrows() - 1;
        DataFrame futureData = data.slice(lastRowIndex, lastRowIndex + 1); // 获取最后一行数据进行预测

        double[] predictionsA = model.predict(futureData);

        long[] timesA = futureData.longVector("OpenTime").array(); // 获取时间数据以打印结果

        // 输出预测结果
        if (predictionsA.length > 0) {
            System.out.println("Time: " + FileoutputUtil.getTime2(timesA[0]) + ", Predicted Low: " + predictions[0]);
        }
    }

    private static DataFrame[] splitData(DataFrame data, double trainSize) {
        int n = data.nrows();
        List<Integer> indices = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            indices.add(i);
        }
        Collections.shuffle(indices); // 随机打乱索引

        int trainCount = (int) (n * trainSize);

        DataFrame trainData = createDataFrameFromIndices(data, indices.subList(0, trainCount));
        DataFrame testData = createDataFrameFromIndices(data, indices.subList(trainCount, n));

        return new DataFrame[]{trainData, testData};
    }

    private static DataFrame createDataFrameFromIndices(DataFrame data, List<Integer> indices) {
        StructType schema = data.schema();
        List<Tuple> rows = new ArrayList<>();
        for (int index : indices) {
            rows.add(data.get(index));
        }
        return DataFrame.of(rows, schema);
    }
}


main(String[] args)​

這是Java程序的主入口函數。這個函數執行以下幾個主要任務:

  • 使用 Read.csv 從指定路徑加載CSV文件,該文件包含加密貨幣的交易數據。
  • 利用 Formula.lhs("Low") 確定模型將預測的目標變量(在此例中為 Low 這一列)。
  • 使用 splitData 函數將數據分割為訓練集和測試集。
  • 使用隨機森林算法訓練模型並在測試集上進行預測。
  • 從預測結果中提取並輸出指定時間點的預測結果。

splitData(DataFrame data, double trainSize)​

這個函數將整個數據集隨機分割成訓練集和測試集:

  • 接受整個 DataFrame 和訓練集所佔比例(trainSize)作為輸入。
  • 隨機打亂數據的索引,然後根據指定的比例分配索引到訓練集和測試集。
  • 根據這些索引,從原始 DataFrame 中提取對應行數據到新的 DataFrame 中。

createDataFrameFromIndices(DataFrame data, List<Integer> indices)​

這個函數根據給定的索引列表從原始數據集中提取行數據來創建新的 DataFrame:

  • 接受一個 DataFrame 和一個索引列表作為輸入。
  • 遍歷索引列表,對於每個索引,從原始數據集中提取相應的行數據。
  • 將這些行數據組合成一個新的 DataFrame,並返回。
這些函數共同完成了數據的加載、處理、模型訓練和預測等一系列操作,是機器學習應用中典型的數據管道流程。