0
点赞
收藏
分享

微信扫一扫

AI大模型日报#0409:Llama 3下周发布、特斯联20亿融资、Karpathy新项目

天涯学馆 15小时前 阅读 1

tensorflow使用篇

1. 使用teachable machine训练模型

地址: 传送门, 需要梯子翻一下

训练后, 导出的时候可以选择三种类型
在这里插入图片描述

导出模型文件 converted_keras.zip (py版)
解压后得到
在这里插入图片描述

2. py项目中使用模型

根据你当时使用teachable machine的时间, 选择py项目中TensorFlow的版本

如果版本不匹配会报错如下

解决的方法就是升级TensorFlow版本

目录结构如下
在这里插入图片描述

# -*- coding: utf-8 -*-
import flask as fk
from flask import jsonify, request
import tensorflow as tf
from PIL import Image
import numpy as np

app = fk.Flask(__name__)

# 加载标签映射
class_label_map = {}
with open('labels.txt', 'r', encoding='utf-8') as f:
    for line in f.readlines():
        index, label = line.strip().split()
        class_label_map[int(index)] = label

print(class_label_map)

# 加载模型
global model
model = tf.keras.models.load_model('keras_model.h5')
print('模型加载成功')

# 图片预处理方法
def preprocess_image(image_path):
    img = Image.open(image_path)
    # 调整大小、归一化等操作,具体取决于模型要求
    img_resized = img.resize((224, 224))
    img_array = np.array(img_resized) / 255.0  # 将像素值归一化到[0, 1]区间
    img_array = np.expand_dims(img_array, axis=0)  # 添加批量维度(batch size = 1)
    return img_array

# 预测方法
def load_model():

    # 准备输入数据
    input_data = preprocess_image("danka.jpg")
    # input_data = preprocess_image("duolianka.jpg")
    # 预测
    predictions = model.predict(input_data)
    # 获取预测结果
    predicted_class_index = np.argmax(predictions[0])
    # 获取预测的类名
    predicted_class_name = class_label_map[predicted_class_index]
    print(f"Predicted class: {predicted_class_name}")
    return predicted_class_name


# 测试预测
@app.route('/api/hello', methods=['GET'])
def get_data():
    return load_model()


# 假设我们要提供一个获取用户信息的API
@app.route('/api/user/<int:user_id>', methods=['GET'])
def get_user_info(user_id):
    # 这里模拟从数据库或其他服务获取用户信息
    user_data = {'id': user_id, 'name': 'John Doe', 'email': 'john.doe@example.com'}

    # 假设用户不存在,返回404
    # 返回JSON格式的用户信息
    return jsonify(user_data)


# 定义一个接收POST请求的路由,假设该接口用于创建新用户
@app.route('/api/users', methods=['POST'])
def create_user():
    # 从请求体中获取JSON格式的数据
    data = request.get_json()

    # 检查必要的字段是否存在
    if not all(key in data for key in ('username', 'email', 'password')):
        return jsonify({"error": "Missing required fields"}), 400

    # 这里仅做示例,实际开发中应将数据保存至数据库等
    new_user = {
        'username': data['username'],
        'email': data['email'],
        'password': data['password']
    }

    # 模拟用户创建成功
    resultMap = {"message": "User created successfully", "user": new_user}

    # 返回201状态码表示已创建资源
    return jsonify(resultMap), 201


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)
# -*- coding: utf-8 -*-
import flask as fk
from flask import jsonify, request
import tensorflow as tf
from PIL import Image
import numpy as np

app = fk.Flask(__name__)



# 加载标签映射
global class_label_map
class_label_map = {}
with open('labels.txt', 'r', encoding='utf-8') as f:
    for line in f.readlines():
        index, label = line.strip().split()
        class_label_map[int(index)] = label

print(class_label_map)

# 加载模型
global model
model = tf.keras.models.load_model('keras_model.h5')
print('模型加载成功')

# 本地图片预处理方法
def preprocess_image(image_path):
    img = Image.open(image_path)
    # 调整大小、归一化等操作,具体取决于模型要求
    img_resized = img.resize((224, 224))
    img_array = np.array(img_resized) / 255.0  # 将像素值归一化到[0, 1]区间
    img_array = np.expand_dims(img_array, axis=0)  # 添加批量维度(batch size = 1)
    return img_array


# 预测方法
def load_model(input_data):
    # 预测
    predictions = model.predict(input_data)
    # 获取预测结果
    predicted_class_index = np.argmax(predictions[0])
    # 获取预测的类名
    predicted_class_name = class_label_map[predicted_class_index]
    print(f"Predicted class: {predicted_class_name}")
    return predicted_class_name


# 测试预测
@app.route('/api/hello', methods=['GET'])
def get_data():
    # 准备输入数据
    input_data = preprocess_image("danka.jpg")
    # input_data = preprocess_image("duolianka.jpg")
    return load_model(input_data)

# 定义一个接收POST请求的路由,假设该接口用于图片预测
@app.route('/api/forecast', methods=['POST'])
def forecast():
    # 从请求体中获取图片数据
    if 'image' not in request.files:
        return jsonify({"error": "No image provided"}), 400
    image = request.files['image']
    input_data = preprocess_image(image)
    result = load_model(input_data)
    resultMap = {"message": result, "code": 200}
    return jsonify(resultMap), 200


# 假设我们要提供一个获取用户信息的API
@app.route('/api/user/<int:user_id>', methods=['GET'])
def get_user_info(user_id):
    # 这里模拟从数据库或其他服务获取用户信息
    user_data = {'id': user_id, 'name': 'John Doe', 'email': 'john.doe@example.com'}

    # 假设用户不存在,返回404
    # 返回JSON格式的用户信息
    return jsonify(user_data)


# 定义一个接收POST请求的路由,假设该接口用于创建新用户
@app.route('/api/users', methods=['POST'])
def create_user():
    # 从请求体中获取JSON格式的数据
    data = request.get_json()

    # 检查必要的字段是否存在
    if not all(key in data for key in ('username', 'email', 'password')):
        return jsonify({"error": "Missing required fields"}), 400

    # 这里仅做示例,实际开发中应将数据保存至数据库等
    new_user = {
        'username': data['username'],
        'email': data['email'],
        'password': data['password']
    }

    # 模拟用户创建成功
    resultMap = {"message": "User created successfully", "user": new_user}

    # 返回201状态码表示已创建资源
    return jsonify(resultMap), 201


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)

举报

相关推荐

0 条评论