JsMVA
DataLoader.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 ## @package JsMVA.DataLoader
3 # DataLoader module with the functions to be inserted to TMVA::DataLoader class and helper functions
4 # @authors Attila Bagoly <battila93@gmail.com>
5 
6 
7 from ROOT import TH1F, TMVA
8 import JPyInterface
9 
10 
11 ## Creates the input variable histogram and perform the transformations if necessary
12 # @param dl DataLoader object
13 # @param className string Signal/Background
14 # @param variableName string containing the variable name
15 # @param numBin for creating the histogram
16 # @param processTrfs string containing the list of transformations to be used on input variable; eg. "I;N;D;P;U;G,D"
17 def GetInputVariableHist(dl, className, variableName, numBin, processTrfs=""):
18  dsinfo = dl.GetDefaultDataSetInfo()
19  vi = 0
20  ivar = 0
21  for i in range(dsinfo.GetNVariables()):
22  if dsinfo.GetVariableInfo(i).GetLabel()==variableName:
23  vi = dsinfo.GetVariableInfo(i)
24  ivar = i
25  break
26  if vi==0:
27  return 0
28 
29  h = TH1F(className, str(vi.GetExpression()) + " ("+className+")", numBin, vi.GetMin(), vi.GetMax())
30 
31  clsn = dsinfo.GetClassInfo(className).GetNumber()
32  ds = dsinfo.GetDataSet()
33 
34  trfsDef = processTrfs.split(';')
35  trfs = []
36  for trfDef in trfsDef:
37  trfs.append(TMVA.TransformationHandler(dsinfo, "DataLoader"))
38  TMVA.CreateVariableTransforms( trfDef, dsinfo, trfs[-1], dl.Log())
39 
40  inputEvents = ds.GetEventCollection()
41  transformed = 0
42  tmp = 0
43  for trf in trfs:
44  if transformed==0:
45  transformed = trf.CalcTransformations(inputEvents, 1)
46  else:
47  tmp = trf.CalcTransformations(transformed, 1)
48  del transformed
49  transformed = tmp
50 
51  if transformed!=0:
52  for event in transformed:
53  if event.GetClass() != clsn:
54  continue
55  h.Fill(event.GetValue(ivar))
56  del transformed
57  else:
58  for event in inputEvents:
59  if event.GetClass() != clsn:
60  continue
61  h.Fill(event.GetValue(ivar))
62  return (h)
63 
64 
65 ## Draw correlation matrix
66 # This function uses the TMVA::DataLoader::GetCorrelationMatrix function added newly to root
67 # @param dl the object pointer
68 # @param className Signal/Background
69 def DrawCorrelationMatrix(dl, className):
70  th2 = dl.GetCorrelationMatrix(className)
71  th2.SetMarkerSize(1.5)
72  th2.SetMarkerColor(0)
73  labelSize = 0.040
74  th2.GetXaxis().SetLabelSize(labelSize)
75  th2.GetYaxis().SetLabelSize(labelSize)
76  th2.LabelsOption("d")
77  th2.SetLabelOffset(0.011)
78  JPyInterface.JsDraw.Draw(th2, 'drawTH2')
79 
80 ## Draw input variables
81 # This function uses the previously defined GetInputVariableHist function to create the histograms
82 # @param dl The object pointer
83 # @param variableName string containing the variable name
84 # @param numBin for creating the histogram
85 # @param processTrfs list of transformations to be used on input variable; eg. ["I", "N", "D", "P", "U", "G"]"
86 def DrawInputVariable(dl, variableName, numBin=100, processTrfs=[]):
87  processTrfsSTR = ""
88  if len(processTrfs)>0:
89  for o in processTrfs:
90  processTrfsSTR += str(o) + ";"
91  processTrfsSTR = processTrfsSTR[:-1]
92  sig = GetInputVariableHist(dl, "Signal", variableName, numBin, processTrfsSTR)
93  bkg = GetInputVariableHist(dl, "Background", variableName, numBin, processTrfsSTR)
94  c, l = JPyInterface.JsDraw.sbPlot(sig, bkg, {"xaxis": sig.GetTitle(),
95  "yaxis": "Number of events",
96  "plot": "Input variable: "+sig.GetTitle()})
97  JPyInterface.JsDraw.Draw(c)
98 
99 ## Rewrite TMVA::DataLoader::PrepareTrainingAndTestTree
101  if len(kwargs)==0:
102  originalFunction, args = JPyInterface.functions.ProcessParameters(0, *args, **kwargs)
103  return originalFunction(*args)
104  try:
105  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["SigCut", "BkgCut"], *args, **kwargs)
106  except AttributeError:
107  try:
108  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["Cut"], *args, **kwargs)
109  except AttributeError:
110  raise AttributeError
111  originalFunction, args = JPyInterface.functions.ProcessParameters(3, *args, **kwargs)
112  return originalFunction(*args)
def DrawInputVariable(dl, variableName, numBin=100, processTrfs=[])
Draw input variables This function uses the previously defined GetInputVariableHist function to creat...
Definition: DataLoader.py:86
def GetInputVariableHist(dl, className, variableName, numBin, processTrfs="")
Creates the input variable histogram and perform the transformations if necessary.
Definition: DataLoader.py:17
def DrawCorrelationMatrix(dl, className)
Draw correlation matrix This function uses the TMVA::DataLoader::GetCorrelationMatrix function added ...
Definition: DataLoader.py:69
def ChangeCallOriginalPrepareTrainingAndTestTree(args, kwargs)
Rewrite TMVA::DataLoader::PrepareTrainingAndTestTree.
Definition: DataLoader.py:100