import uproot import torch from tqdm import tqdm from utils import GatorConfig n_events = 5 n_features = 3 print(f"n events: {n_events}") print(f"n features: {n_features}") features1 = [] features2 = [] for f in range(n_features): # feature = torch.tensor(list(range(n_features)), dtype=torch.float) feature = torch.rand(n_events)*100 features1.append(feature.clone()) feature -= feature.min() feature /= feature.max() features2.append(feature.clone()) features1 = torch.transpose(torch.stack(features1), 0, 1) print(f"features1 before:\n{features1}") print(f"features1 min:\n{features1.min(1, keepdim=True)[0]}") features1 -= features1.min(1, keepdim=True)[0] print(f"features1 - min:\n{features1}") print(f"features1 max:\n{features1.max(1, keepdim=True)[0]}") features1 /= features1.max(1, keepdim=True)[0] print(f"features1 / max:\n{features1}") features2 = torch.transpose(torch.stack(features2), 0, 1) print(f"features2:\n{features2}") # features = ["t5_pt", "t5_eta", "t5_phi"] # tree = uproot.open(f"/blue/p.chang/jguiang/data/lst/GATOR/CMSSW_12_2_0_pre2/LSTGnnNtuple_noT5Chi2_noPixelsInTCs.root:tree") # for batch in tree.iterate(step_size=1, filter_name="/(t3|t5)_*/", entry_start=0, entry_stop=2): # batch = batch[0,:] # only one event per batch # attr = torch.tensor([batch[f].to_list() for f in features], dtype=torch.float) # print(attr) # attr -= attr.min(1, keepdim=True)[0] # attr /= attr.max(1, keepdim=True)[0] # print(attr) # break # config = GatorConfig.from_json("configs/T5_withChi2_DNN.json") # graph_loader = torch.load(f"{config.basedir}/{config.name}_val.pt") # if config.model.name == "DNN": # from datasets import EdgeDataset, EdgeDataBatch # from torch.utils.data import DataLoader # edge_loader = DataLoader(EdgeDataset(graph_loader), collate_fn=lambda batch: EdgeDataBatch(batch)) # # edges1 = [] # n_edges = 0 # for data in tqdm(graph_loader): # # print(data.edge_index.transpose(0, 1)) # # break # n_edges += data.num_edges # # edges2 = [] # _n_edges = 0 # for data in tqdm(edge_loader): # if len(data.edge_index) > 1: # print("hullo") # print(data.edge_index) # _n_edges += 1 # print(n_edges, _n_edges)