0
点赞
收藏
分享

微信扫一扫

机器学习手撕代码(0)数据

火热如冰 2022-02-15 阅读 54

机器学习手撕代码(0)数据

文件树就是下面这个样子,不必须,import数据集文件没问题就行。
在这里插入图片描述

数据来源:kaggle葡萄酒预测
datasets文件夹下面放一个dataset.py文件,后面所有的模型都用这一个数据集。

dataset.py

import pandas as pd
import numpy as np


class DataSet:
    def __init__(self,path,mode='cla',rad_seed = 2021):
        data = pd.read_csv(path).dropna(axis=0, how='any')
        data1 = data[:2000]
        data2 = data[-2000:]
        data = pd.concat([data1, data2]).reset_index().drop(['index'], axis=1)
        data = data.replace('white', 0).replace('red', 1)
        if mode == 'cla':
            self.target_head = 'type'
        elif mode == 'reg':
            self.target_head = 'residual sugar'
        self.data_head = data.columns.to_list()
        self.data_head.remove(self.target_head)
        self.target = data[self.target_head].to_numpy()
        self.data = data[self.data_head].to_numpy()
        if rad_seed is not False:
            np.random.seed(rad_seed)
        permutation = list(np.random.permutation(len(self.data)))
        self.data = self.data[permutation]
        self.target = self.target[permutation]

    def get_data(self):
        return self.data,self.target,self.target_head,self.data_head

后面文章的模型原理就不详述了,看我的不如看书,分享一下自己手撕的代码,做了最大简化,个人感觉简洁一些。

举报

相关推荐

0 条评论