7 from ROOT
import TH1F, TMVA
18 dsinfo = dl.GetDefaultDataSetInfo()
21 for i
in range(dsinfo.GetNVariables()):
22 if dsinfo.GetVariableInfo(i).GetLabel()==variableName:
23 vi = dsinfo.GetVariableInfo(i)
29 h = TH1F(className, str(vi.GetExpression()) +
" ("+className+
")", numBin, vi.GetMin(), vi.GetMax())
31 clsn = dsinfo.GetClassInfo(className).GetNumber()
32 ds = dsinfo.GetDataSet()
34 trfsDef = processTrfs.split(
';')
36 for trfDef
in trfsDef:
37 trfs.append(TMVA.TransformationHandler(dsinfo,
"DataLoader"))
38 TMVA.CreateVariableTransforms( trfDef, dsinfo, trfs[-1], dl.Log())
40 inputEvents = ds.GetEventCollection()
45 transformed = trf.CalcTransformations(inputEvents, 1)
47 tmp = trf.CalcTransformations(transformed, 1)
52 for event
in transformed:
53 if event.GetClass() != clsn:
55 h.Fill(event.GetValue(ivar))
58 for event
in inputEvents:
59 if event.GetClass() != clsn:
61 h.Fill(event.GetValue(ivar))
70 th2 = dl.GetCorrelationMatrix(className)
71 th2.SetMarkerSize(1.5)
74 th2.GetXaxis().SetLabelSize(labelSize)
75 th2.GetYaxis().SetLabelSize(labelSize)
77 th2.SetLabelOffset(0.011)
78 JPyInterface.JsDraw.Draw(th2,
'drawTH2')
88 if len(processTrfs)>0:
90 processTrfsSTR += str(o) +
";" 91 processTrfsSTR = processTrfsSTR[:-1]
94 c, l = JPyInterface.JsDraw.sbPlot(sig, bkg, {
"xaxis": sig.GetTitle(),
95 "yaxis":
"Number of events",
96 "plot":
"Input variable: "+sig.GetTitle()})
97 JPyInterface.JsDraw.Draw(c)
102 originalFunction, args = JPyInterface.functions.ProcessParameters(0, *args, **kwargs)
103 return originalFunction(*args)
105 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"SigCut",
"BkgCut"], *args, **kwargs)
106 except AttributeError:
108 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"Cut"], *args, **kwargs)
109 except AttributeError:
111 originalFunction, args = JPyInterface.functions.ProcessParameters(3, *args, **kwargs)
112 return originalFunction(*args)
def DrawInputVariable(dl, variableName, numBin=100, processTrfs=[])
Draw input variables This function uses the previously defined GetInputVariableHist function to creat...
def GetInputVariableHist(dl, className, variableName, numBin, processTrfs="")
Creates the input variable histogram and perform the transformations if necessary.
def DrawCorrelationMatrix(dl, className)
Draw correlation matrix This function uses the TMVA::DataLoader::GetCorrelationMatrix function added ...
def ChangeCallOriginalPrepareTrainingAndTestTree(args, kwargs)
Rewrite TMVA::DataLoader::PrepareTrainingAndTestTree.