PyTorch torch.set_num_threads 函数
torch.set_num_threads 是 PyTorch 中用于设置内部线程数的函数。这些线程用于并行执行张量操作(如矩阵乘法、卷积等)。
函数定义
torch.set_num_threads(num_threads)
参数说明
num_threads: 要设置的线程数
使用示例
实例
import torch
# 设置线程数
torch.set_num_threads(4)
print("线程数已设置为:", torch.get_num_threads())
# 创建大矩阵进行测试
a = torch.randn(1000, 1000)
b = torch.randn(1000, 1000)
# 矩阵乘法(会使用设置的线程数)
c = torch.matmul(a, b)
print("矩阵乘法完成,结果形状:", c.shape)
# 恢复默认值
torch.set_num_threads(1)
# 设置线程数
torch.set_num_threads(4)
print("线程数已设置为:", torch.get_num_threads())
# 创建大矩阵进行测试
a = torch.randn(1000, 1000)
b = torch.randn(1000, 1000)
# 矩阵乘法(会使用设置的线程数)
c = torch.matmul(a, b)
print("矩阵乘法完成,结果形状:", c.shape)
# 恢复默认值
torch.set_num_threads(1)
输出结果为:
线程数已设置为: 4 矩阵乘法完成,结果形状: torch.Size([1000, 1000])

Pytorch torch 参考手册