0
点赞
收藏
分享

微信扫一扫

一个简单的例子测试numpy和Jax的性能对比 (续)

老榆 2024-01-16 阅读 9





numpy代码:

import numpy as np
import time

x = np.random.random([10000, 10000]).astype(np.float32)
try:
    st = time.time()
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: {e}")


一个简单的例子测试numpy和Jax的性能对比 (续)_性能对比



Jax代码:

import jax.numpy as np
from jax import random
import time

x = random.uniform(random.PRNGKey(0), [10000, 10000])
st = time.time()
try:
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: {e}")


一个简单的例子测试numpy和Jax的性能对比 (续)_性能对比_02



可以说,在这个例子里面,Jax和numpy的性能基本持平。



举报

相关推荐

0 条评论