intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch.random.fork_rng_state does not restore rng state

garrett361 opened this issue · comments

Describe the bug

torch.random.fork_rng does not appear to restore the previous rng state.

Testing code:

import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA", flush=True)
else:
    import intel_extension_for_pytorch as ipex  # noqa

    device = torch.device("xpu")
    print("Using XPU", flush=True)

if __name__ == "__main__":
    results = []
    for _ in range(10):
        with torch.random.fork_rng(devices=(device,)):
            results.append(torch.randn(1, device=device))

    # Every generated tensor should be the same.
    print(f"Results: {results}", flush=True)
    for r in results[1:]:
        torch.testing.assert_close(r, results[0])

I expect this code to fork the rng state within each fork_rng context manager and restore the state upon exiting. If this is done, then every item appended to results will be the same. This successfully occurs on cuda:

# On CUDA
Results: [tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0'), tensor([1.8193], device='cuda:0')]

# Assertion passes

However, on xpu the assertion fails and the output is

# On XPU
Results: [tensor([0.7696], device='xpu:0'), tensor([-1.9231], device='xpu:0'), tensor([0.6481], device='xpu:0'), tensor([-1.0920], device='xpu:0'), tensor([-1.7259], device='xpu:0'), tensor([0.1721], device='xpu:0'), tensor([0.7106], device='xpu:0'), tensor([-1.0049], device='xpu:0'), tensor([0.4109], device='xpu:0'), tensor([0.5113], device='xpu:0')]

# Assertion fails

I also tried passing a devices=(torch.device("xpu"), ) arg to fork_rng, but this yielded an error:

AssertionError: Torch not compiled with CUDA enabled

Versions

Results from collect_env.py:

/usr/local/lib/python3.10/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: ''If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
My guessed rank = 0
Collecting environment information...
PyTorch version: 2.1.0a0+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.1.10+xpu
IPEX commit: a12f9f650
Build type: Release

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: N/A
IGC version: N/A
CMake version: version 3.20.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.14.21-150500.55.31_13.0.62-cray_shasta_c-x86_64-with-glibc2.35
Is XPU available: True
DPCPP runtime version: N/A
MKL version: N/A
GPU models and configuration:
[0] _DeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=131072MB, max_compute_units=896, gpu_eu_count=896)
[1] _DeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=131072MB, max_compute_units=896, gpu_eu_count=896)
[2] _DeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=131072MB, max_compute_units=896, gpu_eu_count=896)
[3] _DeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=131072MB, max_compute_units=896, gpu_eu_count=896)
[4] _DeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=131072MB, max_compute_units=896, gpu_eu_count=896)
[5] _DeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=131072MB, max_compute_units=896, gpu_eu_count=896)
Intel OpenCL ICD version: 23.30.26918.50-736~22.04
Level Zero version: 1.3.26918.50-736~22.04

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             208
On-line CPU(s) list:                0-207
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8465C CPU @2.10GHz
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 52
Socket(s):                          2
Stepping:                           5
Frequency boost:                    enabled
CPU max MHz:                        2101.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4200.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk arch_lbr avx512_fp16 amx_tile flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          4.9 MiB (104 instances)
L1i cache:                          3.3 MiB (104 instances)
L2 cache:                           208 MiB (104 instances)
L3 cache:                           210 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-51,104-155
NUMA node1 CPU(s):                  52-103,156-207
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.10+xpu
[pip3] mypy==1.5.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.1.0a0+cxx11.abi
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.1.0a0+cxx11.abi
[pip3] torchvision==0.16.0a0+cxx11.abi
[conda] N/A

@garrett361 Thanks for reporting the issue. We will try reproducing it and come back with more details

@tye1 fyi
@garrett361 Apparently torch.random.fork_rng is using the device_type as cuda by default (although the device is set as xpu). I found the rng to work as expected with the context initialized as below

torch.random.fork_rng(devices=(device,), device_type='xpu')

Output:

Results: [tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0'), tensor([1.9484], device='xpu:0')]

Ah thank you! That makes total sense; I didn’t realize there was a device_type arg. I probably should have known. Appreciate it!