Question about ds to universal
saxh opened this issue · comments
saxh commented
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.0*")))
#print(paths)
shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)
slices.append(slice)
return slices
hello,what does f"{prefix_path}.0*" mean? what if the path startswith f"{prefix_path}.1*"?