0
点赞
收藏
分享

微信扫一扫

TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测


“You Jump,I Jump”语出经典爱情电影《泰坦尼克号》经典台词,女主角Rose在船首即将跳入海里,站在旁边的男主Jack为挽救女主,便说出经典台词“You Jump,I Jump”。当一个陌生男人肯为一个陌生女人没理由地去死的时候,毫无缘由的,女主对男主产生了爱的情愫。
当然这跟我这篇教程关系不大,这里我们将会通过AI预测Jack和Rose的存活率,国庆没断更,属实不易,需要数据集可以私聊本人or加学习群。谢谢大家支持!

一、数据集

1.读取数据集

import pandas as pd

df = pd.read_excel('titanic3.xls')
df.describe()



pclass

survived

age

sibsp

parch

fare

body

count

1309.000000

1309.000000

1046.000000

1309.000000

1309.000000

1308.000000

121.000000

mean

2.294882

0.381971

29.881135

0.498854

0.385027

33.295479

160.809917

std

0.837836

0.486055

14.413500

1.041658

0.865560

51.758668

97.696922

min

1.000000

0.000000

0.166700

0.000000

0.000000

0.000000

1.000000

25%

2.000000

0.000000

21.000000

0.000000

0.000000

7.895800

72.000000

50%

3.000000

0.000000

28.000000

0.000000

0.000000

14.454200

155.000000

75%

3.000000

1.000000

39.000000

1.000000

0.000000

31.275000

256.000000

max

3.000000

1.000000

80.000000

8.000000

9.000000

512.329200

328.000000

df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1309 entries, 0 to 1308
Data columns (total 14 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 pclass 1309 non-null int64
1 survived 1309 non-null int64
2 name 1309 non-null object
3 sex 1309 non-null object
4 age 1046 non-null float64
5 sibsp 1309 non-null int64
6 parch 1309 non-null int64
7 ticket 1309 non-null object
8 fare 1308 non-null float64
9 cabin 295 non-null object
10 embarked 1307 non-null object
11 boat 486 non-null object
12 body 121 non-null float64
13 home.dest 745 non-null object
dtypes: float64(3), int64(4), object(7)
memory usage: 143.3+ KB

df.head()



pclass

survived

name

sex

age

sibsp

parch

ticket

fare

cabin

embarked

boat

body

home.dest

0

1

1

Allen, Miss. Elisabeth Walton

female

29.0000

0

0

24160

211.3375

B5

S

2

NaN

St Louis, MO

1

1

1

Allison, Master. Hudson Trevor

male

0.9167

1

2

113781

151.5500

C22 C26

S

11

NaN

Montreal, PQ / Chesterville, ON

2

1

0

Allison, Miss. Helen Loraine

female

2.0000

1

2

113781

151.5500

C22 C26

S

NaN

NaN

Montreal, PQ / Chesterville, ON

3

1

0

Allison, Mr. Hudson Joshua Creighton

male

30.0000

1

2

113781

151.5500

C22 C26

S

NaN

135.0

Montreal, PQ / Chesterville, ON

4

1

0

Allison, Mrs. Hudson J C (Bessie Waldo Daniels)

female

25.0000

1

2

113781

151.5500

C22 C26

S

NaN

NaN

Montreal, PQ / Chesterville, ON

2.处理数据集

  • 提取字段
  • 处理缺失值
  • 转换编码
  • 删除name列

# 筛选需要提取的字段
selected_cols = ['survived','name','pclass','sex','age','sibsp','parch','fare','embarked']
df_selected = df[selected_cols]
df = df[selected_cols] # 默认按列取值
df.head()



survived

name

pclass

sex

age

sibsp

parch

fare

embarked

0

1

Allen, Miss. Elisabeth Walton

1

female

29.0000

0

0

211.3375

S

1

1

Allison, Master. Hudson Trevor

1

male

0.9167

1

2

151.5500

S

2

0

Allison, Miss. Helen Loraine

1

female

2.0000

1

2

151.5500

S

3

0

Allison, Mr. Hudson Joshua Creighton

1

male

30.0000

1

2

151.5500

S

4

0

Allison, Mrs. Hudson J C (Bessie Waldo Daniels)

1

female

25.0000

1

2

151.5500

S

# 找出有null值的字段
df.isnull().any()

survived    False
name False
pclass False
sex False
age True
sibsp False
parch False
fare True
embarked True
dtype: bool

# 统计各个列有多少个空值
df.isnull().sum()

survived      0
name 0
pclass 0
sex 0
age 263
sibsp 0
parch 0
fare 1
embarked 2
dtype: int64

# 确定缺失值的位置
df[df.isnull().values==True]



survived

name

pclass

sex

age

sibsp

parch

fare

embarked

15

0

Baumann, Mr. John D

1

male

NaN

0

0

25.9250

S

37

1

Bradley, Mr. George ("George Arthur Brayton")

1

male

NaN

0

0

26.5500

S

40

0

Brewe, Dr. Arthur Jackson

1

male

NaN

0

0

39.6000

C

46

0

Cairns, Mr. Alexander

1

male

NaN

0

0

31.0000

S

59

1

Cassebeer, Mrs. Henry Arthur Jr (Eleanor Genev...

1

female

NaN

0

0

27.7208

C

...

...

...

...

...

...

...

...

...

...

1293

0

Williams, Mr. Howard Hugh "Harry"

3

male

NaN

0

0

8.0500

S

1297

0

Wiseman, Mr. Phillippe

3

male

NaN

0

0

7.2500

S

1302

0

Yousif, Mr. Wazli

3

male

NaN

0

0

7.2250

C

1303

0

Yousseff, Mr. Gerious

3

male

NaN

0

0

14.4583

C

1305

0

Zabour, Miss. Thamine

3

female

NaN

1

0

14.4542

C

266 rows × 9 columns

# 将age空的字段改为平均值
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
df['age'].isnull().any() # 但凡有空值就返回True

False

# 将fare空的字段改为平均值
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)

# 为确实embarked记录填充值
df['embarked'] = df['embarked'].fillna('S')

df.isnull().any()

survived    False
name False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool

# 转换编码
# 性别sex由字符串转换为数字编码
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
# 港口embarked由字母表示转换为数字编码
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)

# 删除name字段
df = df.drop(['name'],axis=1) # 0行1列
df.head()



survived

pclass

sex

age

sibsp

parch

fare

embarked

0

1

1

0

29.0000

0

0

211.3375

2

1

1

1

1

0.9167

1

2

151.5500

2

2

0

1

0

2.0000

1

2

151.5500

2

3

0

1

1

30.0000

1

2

151.5500

2

4

0

1

0

25.0000

1

2

151.5500

2

3.划分特征值和标签值

# 分离特征值和标签值
data = df.values

# 后七列是特征值
features = data[:,1:] # ndarray默认取行,dataframe默认取列
# 第零列是标签值
labels = data[:,0]

labels.shape

(1309,)

4.定义数据预处理函数

def prepare_data(df):
# 删除name列
df = df.drop(['name'],axis=1)

# 将age空的字段改为平均值
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)

# 将fare空的字段改为平均值
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)

# 为确实embarked记录填充值
df['embarked'] = df['embarked'].fillna('S')

# 性别sex由字符串转换为数字编码
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)

# 港口embarked由字母表示转换为数字编码
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
print(df.isnull().any())
# 分离特征值和标签值
data = df.values

# 后七列是特征值
features = data[:,1:] # ndarray默认取行,dataframe默认取列
# 第零列是标签值
labels = data[:,0]

return features,labels

5.划分训练集和测试集

shuffle_df = df_selected.sample(frac=1) # 打乱数据顺序,为后面训练做准备,frac为百分比,df保持不变

x_data,y_data = prepare_data(shuffle_df)
x_data.shape,y_data.shape

survived    False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool





((1309, 7), (1309,))

shuffle_df.head()



survived

name

pclass

sex

age

sibsp

parch

fare

embarked

58

0

Case, Mr. Howard Brown

1

male

49.0

0

0

26.0000

S

666

0

Barbara, Mrs. (Catherine David)

3

female

45.0

0

1

14.4542

C

781

0

Drazenoic, Mr. Jozef

3

male

33.0

0

0

7.8958

C

480

0

Laroche, Mr. Joseph Philippe Lemercier

2

male

25.0

1

2

41.5792

C

459

0

Jacobsohn, Mr. Sidney Samuel

2

male

42.0

1

0

27.0000

S

test_split = 0.2
train_num = int((1 - test_split) * x_data.shape[0])
# 训练集
x_train = x_data[:train_num]
y_trian = y_data[:train_num]
# 测试集
x_test = x_data[train_num:]
y_test = y_data[train_num:]

6.归一化

from sklearn import preprocessing

minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_train = minmax_scale.fit_transform(x_train) # 特征值标准化
x_test = minmax_scale.fit_transform(x_test)

二、模型

import tensorflow as tf
tf.__version__

'2.6.0'

1.建立序列模型

model = tf.keras.models.Sequential()

2.添加隐藏层

model.add(tf.keras.layers.Dense(units=64,
use_bias=True,
activation='relu',
input_dim=7, # 也可以用input_shape=(7,)
bias_initializer='zeros',
kernel_initializer='normal'))

model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合

model.add(tf.keras.layers.Dense(units=32,
activation='sigmoid',
input_shape=(64,), # 也可以用input_dim=64
bias_initializer='zeros',
kernel_initializer='uniform'))

model.add(tf.keras.layers.Dropout(rate=0.2)) # 丢弃层,rate代表丢弃前一层的神经元的比例,防止过拟合

3.添加输出层

model.add(tf.keras.layers.Dense(units=1,
activation='sigmoid',
input_dim=32, # 也可以用input_shape=(7,)
bias_initializer='zeros',
kernel_initializer='uniform'))

model.summary()

Model: "sequential_23"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_68 (Dense) (None, 64) 512
_________________________________________________________________
dropout_6 (Dropout) (None, 64) 0
_________________________________________________________________
dense_69 (Dense) (None, 32) 2080
_________________________________________________________________
dropout_7 (Dropout) (None, 32) 0
_________________________________________________________________
dense_70 (Dense) (None, 1) 33
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________

三、训练

1.训练

# 定义训练模式
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),loss='binary_crossentropy',metrics=['accuracy'])

# 设置训练参数
train_epochs = 100
batch_size = 40

train_history = model.fit(x=x_train,#训练特征值
y=y_trian,#训练集的标签
validation_split=0.2,#验证集的比例
epochs=train_epochs,#训练的次数
batch_size=batch_size,#批量的大小
verbose=2) #训练过程的日志信息显示,一个epoch输出一行记录

Epoch 1/100
21/21 - 1s - loss: 0.6780 - accuracy: 0.5854 - val_loss: 0.6464 - val_accuracy: 0.6429
Epoch 2/100
21/21 - 0s - loss: 0.6623 - accuracy: 0.6057 - val_loss: 0.6293 - val_accuracy: 0.6429
Epoch 3/100
21/21 - 0s - loss: 0.6306 - accuracy: 0.6069 - val_loss: 0.5861 - val_accuracy: 0.6667
Epoch 4/100
21/21 - 0s - loss: 0.5771 - accuracy: 0.7336 - val_loss: 0.5199 - val_accuracy: 0.7905
Epoch 5/100
21/21 - 0s - loss: 0.5364 - accuracy: 0.7646 - val_loss: 0.4939 - val_accuracy: 0.7952
Epoch 6/100
21/21 - 0s - loss: 0.5200 - accuracy: 0.7670 - val_loss: 0.4847 - val_accuracy: 0.8143
Epoch 7/100
21/21 - 0s - loss: 0.5118 - accuracy: 0.7718 - val_loss: 0.4771 - val_accuracy: 0.8143
Epoch 8/100
21/21 - 0s - loss: 0.5060 - accuracy: 0.7766 - val_loss: 0.4738 - val_accuracy: 0.8095
Epoch 9/100
21/21 - 0s - loss: 0.4934 - accuracy: 0.7861 - val_loss: 0.4670 - val_accuracy: 0.7952
Epoch 10/100
21/21 - 0s - loss: 0.4966 - accuracy: 0.7814 - val_loss: 0.4637 - val_accuracy: 0.8000
Epoch 11/100
21/21 - 0s - loss: 0.4928 - accuracy: 0.7766 - val_loss: 0.4635 - val_accuracy: 0.7905
Epoch 12/100
21/21 - 0s - loss: 0.4995 - accuracy: 0.7670 - val_loss: 0.4691 - val_accuracy: 0.7905
Epoch 13/100
21/21 - 0s - loss: 0.4886 - accuracy: 0.7957 - val_loss: 0.4620 - val_accuracy: 0.8095
Epoch 14/100
21/21 - 0s - loss: 0.4790 - accuracy: 0.7838 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 15/100
21/21 - 0s - loss: 0.4877 - accuracy: 0.7766 - val_loss: 0.4576 - val_accuracy: 0.8095
Epoch 16/100
21/21 - 0s - loss: 0.4839 - accuracy: 0.7897 - val_loss: 0.4560 - val_accuracy: 0.8095
Epoch 17/100
21/21 - 0s - loss: 0.4813 - accuracy: 0.7814 - val_loss: 0.4614 - val_accuracy: 0.8095
Epoch 18/100
21/21 - 0s - loss: 0.4812 - accuracy: 0.7742 - val_loss: 0.4553 - val_accuracy: 0.8095
Epoch 19/100
21/21 - 0s - loss: 0.4762 - accuracy: 0.7885 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 20/100
21/21 - 0s - loss: 0.4784 - accuracy: 0.7802 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 21/100
21/21 - 0s - loss: 0.4794 - accuracy: 0.7885 - val_loss: 0.4626 - val_accuracy: 0.8000
Epoch 22/100
21/21 - 0s - loss: 0.4824 - accuracy: 0.7838 - val_loss: 0.4567 - val_accuracy: 0.7857
Epoch 23/100
21/21 - 0s - loss: 0.4786 - accuracy: 0.7849 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 24/100
21/21 - 0s - loss: 0.4801 - accuracy: 0.7742 - val_loss: 0.4735 - val_accuracy: 0.7905
Epoch 25/100
21/21 - 0s - loss: 0.4752 - accuracy: 0.7849 - val_loss: 0.4571 - val_accuracy: 0.7905
Epoch 26/100
21/21 - 0s - loss: 0.4688 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8000
Epoch 27/100
21/21 - 0s - loss: 0.4624 - accuracy: 0.7873 - val_loss: 0.4577 - val_accuracy: 0.8048
Epoch 28/100
21/21 - 0s - loss: 0.4656 - accuracy: 0.7993 - val_loss: 0.4602 - val_accuracy: 0.8000
Epoch 29/100
21/21 - 0s - loss: 0.4649 - accuracy: 0.7969 - val_loss: 0.4546 - val_accuracy: 0.8000
Epoch 30/100
21/21 - 0s - loss: 0.4645 - accuracy: 0.7849 - val_loss: 0.4638 - val_accuracy: 0.8000
Epoch 31/100
21/21 - 0s - loss: 0.4635 - accuracy: 0.7921 - val_loss: 0.4603 - val_accuracy: 0.7952
Epoch 32/100
21/21 - 0s - loss: 0.4646 - accuracy: 0.7909 - val_loss: 0.4567 - val_accuracy: 0.7952
Epoch 33/100
21/21 - 0s - loss: 0.4664 - accuracy: 0.7909 - val_loss: 0.4583 - val_accuracy: 0.7952
Epoch 34/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7921 - val_loss: 0.4575 - val_accuracy: 0.8000
Epoch 35/100
21/21 - 0s - loss: 0.4660 - accuracy: 0.7838 - val_loss: 0.4582 - val_accuracy: 0.7952
Epoch 36/100
21/21 - 0s - loss: 0.4577 - accuracy: 0.8005 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 37/100
21/21 - 0s - loss: 0.4648 - accuracy: 0.7909 - val_loss: 0.4585 - val_accuracy: 0.7952
Epoch 38/100
21/21 - 0s - loss: 0.4613 - accuracy: 0.7921 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 39/100
21/21 - 0s - loss: 0.4643 - accuracy: 0.7921 - val_loss: 0.4687 - val_accuracy: 0.8000
Epoch 40/100
21/21 - 0s - loss: 0.4696 - accuracy: 0.7814 - val_loss: 0.4601 - val_accuracy: 0.8048
Epoch 41/100
21/21 - 0s - loss: 0.4589 - accuracy: 0.7933 - val_loss: 0.4562 - val_accuracy: 0.7952
Epoch 42/100
21/21 - 0s - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4594 - val_accuracy: 0.8000
Epoch 43/100
21/21 - 0s - loss: 0.4601 - accuracy: 0.7981 - val_loss: 0.4563 - val_accuracy: 0.7905
Epoch 44/100
21/21 - 0s - loss: 0.4639 - accuracy: 0.7897 - val_loss: 0.4594 - val_accuracy: 0.8048
Epoch 45/100
21/21 - 0s - loss: 0.4569 - accuracy: 0.7957 - val_loss: 0.4587 - val_accuracy: 0.8000
Epoch 46/100
21/21 - 0s - loss: 0.4619 - accuracy: 0.7957 - val_loss: 0.4556 - val_accuracy: 0.8048
Epoch 47/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7861 - val_loss: 0.4563 - val_accuracy: 0.8000
Epoch 48/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7969 - val_loss: 0.4538 - val_accuracy: 0.8000
Epoch 49/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7873 - val_loss: 0.4572 - val_accuracy: 0.8048
Epoch 50/100
21/21 - 0s - loss: 0.4603 - accuracy: 0.7909 - val_loss: 0.4584 - val_accuracy: 0.8000
Epoch 51/100
21/21 - 0s - loss: 0.4575 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8095
Epoch 52/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.8029 - val_loss: 0.4584 - val_accuracy: 0.8048
Epoch 53/100
21/21 - 0s - loss: 0.4594 - accuracy: 0.7909 - val_loss: 0.4558 - val_accuracy: 0.8000
Epoch 54/100
21/21 - 0s - loss: 0.4588 - accuracy: 0.8065 - val_loss: 0.4523 - val_accuracy: 0.8000
Epoch 55/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.8029 - val_loss: 0.4593 - val_accuracy: 0.8048
Epoch 56/100
21/21 - 0s - loss: 0.4578 - accuracy: 0.8100 - val_loss: 0.4614 - val_accuracy: 0.8048
Epoch 57/100
21/21 - 0s - loss: 0.4549 - accuracy: 0.8041 - val_loss: 0.4580 - val_accuracy: 0.8095
Epoch 58/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8095
Epoch 59/100
21/21 - 0s - loss: 0.4567 - accuracy: 0.7981 - val_loss: 0.4532 - val_accuracy: 0.8095
Epoch 60/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.7993 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 61/100
21/21 - 0s - loss: 0.4543 - accuracy: 0.7969 - val_loss: 0.4555 - val_accuracy: 0.8000
Epoch 62/100
21/21 - 0s - loss: 0.4472 - accuracy: 0.8053 - val_loss: 0.4543 - val_accuracy: 0.8048
Epoch 63/100
21/21 - 0s - loss: 0.4458 - accuracy: 0.8100 - val_loss: 0.4534 - val_accuracy: 0.8095
Epoch 64/100
21/21 - 0s - loss: 0.4497 - accuracy: 0.8005 - val_loss: 0.4593 - val_accuracy: 0.8000
Epoch 65/100
21/21 - 0s - loss: 0.4511 - accuracy: 0.8053 - val_loss: 0.4522 - val_accuracy: 0.8095
Epoch 66/100
21/21 - 0s - loss: 0.4506 - accuracy: 0.8005 - val_loss: 0.4592 - val_accuracy: 0.7952
Epoch 67/100
21/21 - 0s - loss: 0.4533 - accuracy: 0.8005 - val_loss: 0.4545 - val_accuracy: 0.8000
Epoch 68/100
21/21 - 0s - loss: 0.4481 - accuracy: 0.7909 - val_loss: 0.4545 - val_accuracy: 0.7952
Epoch 69/100
21/21 - 0s - loss: 0.4555 - accuracy: 0.7981 - val_loss: 0.4551 - val_accuracy: 0.8000
Epoch 70/100
21/21 - 0s - loss: 0.4440 - accuracy: 0.8029 - val_loss: 0.4552 - val_accuracy: 0.7952
Epoch 71/100
21/21 - 0s - loss: 0.4584 - accuracy: 0.8029 - val_loss: 0.4530 - val_accuracy: 0.7952
Epoch 72/100
21/21 - 0s - loss: 0.4480 - accuracy: 0.7933 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 73/100
21/21 - 0s - loss: 0.4554 - accuracy: 0.7981 - val_loss: 0.4536 - val_accuracy: 0.7952
Epoch 74/100
21/21 - 0s - loss: 0.4438 - accuracy: 0.8029 - val_loss: 0.4532 - val_accuracy: 0.7952
Epoch 75/100
21/21 - 0s - loss: 0.4483 - accuracy: 0.8053 - val_loss: 0.4515 - val_accuracy: 0.8095
Epoch 76/100
21/21 - 0s - loss: 0.4408 - accuracy: 0.8041 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 77/100
21/21 - 0s - loss: 0.4470 - accuracy: 0.8017 - val_loss: 0.4531 - val_accuracy: 0.8000
Epoch 78/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.8053 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 79/100
21/21 - 0s - loss: 0.4456 - accuracy: 0.8053 - val_loss: 0.4526 - val_accuracy: 0.8048
Epoch 80/100
21/21 - 0s - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4573 - val_accuracy: 0.7952
Epoch 81/100
21/21 - 0s - loss: 0.4496 - accuracy: 0.7981 - val_loss: 0.4573 - val_accuracy: 0.8095
Epoch 82/100
21/21 - 0s - loss: 0.4515 - accuracy: 0.8053 - val_loss: 0.4502 - val_accuracy: 0.8095
Epoch 83/100
21/21 - 0s - loss: 0.4503 - accuracy: 0.8100 - val_loss: 0.4546 - val_accuracy: 0.7952
Epoch 84/100
21/21 - 0s - loss: 0.4386 - accuracy: 0.8065 - val_loss: 0.4540 - val_accuracy: 0.8048
Epoch 85/100
21/21 - 0s - loss: 0.4371 - accuracy: 0.8088 - val_loss: 0.4552 - val_accuracy: 0.8095
Epoch 86/100
21/21 - 0s - loss: 0.4420 - accuracy: 0.8053 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 87/100
21/21 - 0s - loss: 0.4437 - accuracy: 0.8112 - val_loss: 0.4550 - val_accuracy: 0.7952
Epoch 88/100
21/21 - 0s - loss: 0.4432 - accuracy: 0.7969 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 89/100
21/21 - 0s - loss: 0.4396 - accuracy: 0.8065 - val_loss: 0.4552 - val_accuracy: 0.8000
Epoch 90/100
21/21 - 0s - loss: 0.4477 - accuracy: 0.8088 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 91/100
21/21 - 0s - loss: 0.4412 - accuracy: 0.8017 - val_loss: 0.4507 - val_accuracy: 0.8048
Epoch 92/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8048
Epoch 93/100
21/21 - 0s - loss: 0.4433 - accuracy: 0.8017 - val_loss: 0.4519 - val_accuracy: 0.8048
Epoch 94/100
21/21 - 0s - loss: 0.4415 - accuracy: 0.7957 - val_loss: 0.4524 - val_accuracy: 0.8095
Epoch 95/100
21/21 - 0s - loss: 0.4399 - accuracy: 0.8065 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 96/100
21/21 - 0s - loss: 0.4387 - accuracy: 0.8065 - val_loss: 0.4546 - val_accuracy: 0.8095
Epoch 97/100
21/21 - 0s - loss: 0.4463 - accuracy: 0.7945 - val_loss: 0.4542 - val_accuracy: 0.8048
Epoch 98/100
21/21 - 0s - loss: 0.4447 - accuracy: 0.7993 - val_loss: 0.4542 - val_accuracy: 0.8143
Epoch 99/100
21/21 - 0s - loss: 0.4368 - accuracy: 0.8041 - val_loss: 0.4551 - val_accuracy: 0.8048
Epoch 100/100
21/21 - 0s - loss: 0.4395 - accuracy: 0.8053 - val_loss: 0.4501 - val_accuracy: 0.8095

2.训练过程可视化

# 训练过程可视化
import matplotlib.pyplot as plt

def show_train_history(trian_history,train_metric,validation_metric):
plt.plot(trian_history[train_metric])
plt.plot(trian_history[validation_metric])
plt.title('Train History')
plt.ylabel(train_metric)
plt.xlabel('epoch')
plt.legend(['train','validation'],loc='upper left')
plt.show()

show_train_history(train_history.history,'loss','val_loss')

TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测_python

show_train_history(train_history.history,'accuracy','val_accuracy')

TensorFlow 从入门到精通(8)—— 泰坦尼克号旅客生存预测_tensorflow_02

3.评估模型

loss,acc = model.evaluate(x_test,y_test)

9/9 [==============================] - 0s 2ms/step - loss: 0.3703 - accuracy: 0.8435

loss,acc

(0.3702643811702728, 0.8435114622116089)

四.预测

#@title
Jack_info = [0,'Jack',3,'male',23,1,0,5.0000,'S']
Rose_info = [1,'Rose',1,'female',20,1,0,100.0000,'S']

x_pre = pd.DataFrame([Jack_info,Rose_info],columns=selected_cols)
x_pre



survived

name

pclass

sex

age

sibsp

parch

fare

embarked

0

0

Jack

3

male

23

1

0

5.0

S

1

1

Rose

1

female

20

1

0

100.0

S

x_pre_features,y = prepare_data(x_pre)
from sklearn import preprocessing

minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_pre_features = minmax_scale.fit_transform(x_pre_features) # 特征值标准化
y_pre = model.predict(x_pre_features)

survived    False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool

x_pre.insert(len(x_pre.columns),'surv_probabilty',y_pre)

x_pre



survived

name

pclass

sex

age

sibsp

parch

fare

embarked

surv_probabilty

0

0

Jack

3

male

23

1

0

5.0

S

0.058498

1

1

Rose

1

female

20

1

0

100.0

S

0.975978


举报

相关推荐

0 条评论