0
点赞
收藏
分享

微信扫一扫

一个通用的utils脚本

杏花疏影1 2022-03-22 阅读 77
  • utils.py
    • json保存&加载
    • 大文件序列化数据保存&加载
    • 创建目录
    • log初始化
    • 随机种子初始化
    • 耗时统计

import pandas as pd
import json
import numpy as np
import joblib
import os
from contextlib import contextmanager
import time
import logging
import random

def load_json(f):
    """加载json数据"""
    with open(f,'r') as fr:
        return json.load(fr)

def save_json(d,f):
    """保存json数据"""
    with open(f,'w') as fw:
        json.dump(d,fw,ensure_ascii=False,indent=4)

def load_joblib(f):
    """加载序列化数据"""
    return joblib.load(f)

def save_joblib(d, f):
    """保存序列化数据"""
    joblib.dump(d, f)

def mkdir(d):
    """安全创建目录"""
    if not os.path.exists(d):
        os.mkdir(d)

def creat_logger(log_file='log.txt'):
    """初始化log"""
    log_name = f'{log_file}'
    msg_format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    logging.basicConfig(level=logging.INFO, filename=log_name, format=msg_format, filemode='w')
    handler = logging.StreamHandler()
    formatter = logging.Formatter(msg_format)
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logging.getLogger().addHandler(handler)

def seed_everything(seed=42):
    '''随机种子初始化
    pip install tensorflow-gpu
    pip install pandas==1.0.3
    pip install xgboost
    pip install catboost
    pip install numpy==1.17.2
    pip install lightgbm==2.3.1
    '''
    assert pd.__version__ == '1.0.3', "pd.__version__ != '1.0.3'"
    assert lgb.__version__ == '2.3.1', "lgb.__version__ != '2.3.1'"
    assert np.__version__ == '1.17.2', "np.__version__ != '1.17.2'"
    random.seed(seed)
    # os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

@contextmanager
def timer(name):
    """时间统计
    """
    t0 = time.time()
    yield
    logging.info(f'[{name}] done in {time.time() - t0} s')



if __name__ == "__main__":

    creat_logger("utils_test.log")
    with timer("测试时间脚本"):
        time.sleep(1)




举报

相关推荐

0 条评论