10 from xml.etree.ElementTree
import ElementTree
12 from IPython.core.display
import display, HTML, clear_output
13 from ipywidgets
import widgets
14 from threading
import Thread
16 from string
import Template
25 for methodMapElement
in fac.fMethodsMap:
26 if methodMapElement[0] != datasetName:
28 methods = methodMapElement[1]
30 m.GetName._threaded =
True 31 if m.GetName() == methodName:
35 print(
"Factory.GetMethodObject: no method object found")
44 roottree = tree.getroot()
46 network[
"variables"] = []
47 for v
in roottree.find(
"Variables"):
48 network[
"variables"].append(v.get(
'Title'))
49 layout = roottree.find(
"Weights").find(
"Layout")
52 net.append({
"Connection": layer.get(
"Connection"),
53 "Nodes": layer.get(
"Nodes"),
54 "ActivationFunction": layer.get(
"ActivationFunction"),
55 "OutputMode": layer.get(
"OutputMode")
57 network[
"layers"] = net
58 Synapses = roottree.find(
"Weights").find(
"Synapses")
60 "InputSize": Synapses.get(
"InputSize"),
61 "OutputSize": Synapses.get(
"OutputSize"),
62 "NumberSynapses": Synapses.get(
"NumberSynapses"),
65 for i
in Synapses.text.split(
" "):
66 tmp = i.replace(
"\n",
"")
68 synapses[
"synapses"].append(tmp)
69 network[
"synapses"] = synapses
70 return json.dumps(network)
77 roottree = tree.getroot()
79 network[
"variables"] = []
80 for v
in roottree.find(
"Variables"):
81 network[
"variables"].append(v.get(
'Title'))
82 layout = roottree.find(
"Weights").find(
"Layout")
84 net = {
"nlayers": layout.get(
"NLayers") }
86 neuron_num = int(layer.get(
"NNeurons"))
87 neurons = {
"nneurons": neuron_num }
90 label =
"neuron_"+str(i)
92 nsynapses = int(neuron.get(
'NSynapses'))
93 neurons[label] = {
"nsynapses": nsynapses}
96 text = str(neuron.text)
97 wall = text.replace(
"\n",
"").split(
" ")
101 weights.append(float(w))
102 neurons[label][
"weights"] = weights
103 net[
"layer_"+str(layer.get(
'Index'))] = neurons
104 network[
"layout"] = net
105 return json.dumps(network)
113 self.__xmltree.parse(fileName)
114 self.
__NTrees = int(self.__xmltree.find(
"Weights").get(
'NTrees'))
124 def __getBinaryTree(self, itree):
126 print(
"to big number, tree number must be less then %s"%self.
__NTrees )
128 return self.__xmltree.find(
"Weights").find(
"BinaryTree["+str(itree+1)+
"]")
135 def __readTree(self, binaryTree, tree={}, depth=0):
136 nodes = binaryTree.findall(
"Node")
139 if len(nodes)==1
and nodes[0].get(
"pos")==
"s":
141 "IVar": nodes[0].get(
"IVar"),
142 "Cut" : nodes[0].get(
"Cut"),
143 "purity": nodes[0].get(
"purity"),
147 tree[
"children"] = []
152 "IVar": node.get(
"IVar"),
153 "Cut" : node.get(
"Cut"),
154 "purity": node.get(
"purity"),
155 "pos": node.get(
"pos")
157 tree[
"children"].append({
161 self.
__readTree(node, tree[
"children"][-1], depth+1)
177 varstree = self.__xmltree.find(
"Variables").findall(
"Variable")
178 variables = [
None]*len(varstree)
180 variables[int(v.get(
'VarIndex'))] = v.get(
'Expression')
188 canvas = fac.GetROCCurve(datasetName)
189 JPyInterface.JsDraw.Draw(canvas)
199 mvaRes = method.Data().GetResults(method.GetMethodName(), TMVA.Types.kTesting, TMVA.Types.kMaxAnalysisType)
200 sig = mvaRes.GetHist(
"MVA_S")
201 bgd = mvaRes.GetHist(
"MVA_B")
202 c, l = JPyInterface.JsDraw.sbPlot(sig, bgd, {
"xaxis": methodName+
" response",
203 "yaxis":
"(1/N) dN^{ }/^{ }dx",
204 "plot":
"TMVA response for classifier: "+methodName})
205 JPyInterface.JsDraw.Draw(c)
215 mvaRes = method.Data().GetResults(method.GetMethodName(), TMVA.Types.kTesting, TMVA.Types.kMaxAnalysisType)
216 sig = mvaRes.GetHist(
"Prob_S")
217 bgd = mvaRes.GetHist(
"Prob_B")
218 c, l = JPyInterface.JsDraw.sbPlot(sig, bgd, {
"xaxis":
"Signal probability",
219 "yaxis":
"(1/N) dN^{ }/^{ }dx",
220 "plot":
"TMVA probability for classifier: "+methodName})
221 JPyInterface.JsDraw.Draw(c)
231 mvaRes = method.Data().GetResults(method.GetMethodName(), TMVA.Types.kTesting, TMVA.Types.kMaxAnalysisType)
232 sigE = mvaRes.GetHist(
"MVA_EFF_S")
233 bgdE = mvaRes.GetHist(
"MVA_EFF_B")
238 f = ROOT.TFormula(
"sigf",
"x/sqrt(x+y)")
240 pname =
"purS_" + methodName
241 epname =
"effpurS_" + methodName
242 ssigname =
"significance_" + methodName
244 nbins = sigE.GetNbinsX()
245 low = sigE.GetBinLowEdge(1)
246 high = sigE.GetBinLowEdge(nbins+1)
248 purS = ROOT.TH1F(pname, pname, nbins, low, high)
249 sSig = ROOT.TH1F(ssigname, ssigname, nbins, low, high)
250 effpurS = ROOT.TH1F(epname, epname, nbins, low, high)
253 sigE.SetTitle(
"Cut efficiencies for "+methodName+
" classifier")
255 TMVA.TMVAGlob.SetSignalAndBackgroundStyle( sigE, bgdE )
256 TMVA.TMVAGlob.SetSignalAndBackgroundStyle( purS, bgdE )
257 TMVA.TMVAGlob.SetSignalAndBackgroundStyle( effpurS, bgdE )
258 sigE.SetFillStyle( 0 )
259 bgdE.SetFillStyle( 0 )
260 sSig.SetFillStyle( 0 )
261 sigE.SetLineWidth( 3 )
262 bgdE.SetLineWidth( 3 )
263 sSig.SetLineWidth( 3 )
265 purS.SetFillStyle( 0 )
266 purS.SetLineWidth( 2 )
267 purS.SetLineStyle( 5 )
268 effpurS.SetFillStyle( 0 )
269 effpurS.SetLineWidth( 2 )
270 effpurS.SetLineStyle( 6 )
273 for i
in range(1,sigE.GetNbinsX()+1):
274 eS = sigE.GetBinContent( i )
276 B = bgdE.GetBinContent( i ) * fNBackground
278 purS.SetBinContent( i, 0)
280 purS.SetBinContent( i, S/(S+B) )
282 sSig.SetBinContent( i, f.Eval(S,B) )
283 effpurS.SetBinContent( i, eS*purS.GetBinContent( i ) )
285 maxSignificance = sSig.GetMaximum()
286 maxSignificanceErr = 0
287 sSig.Scale(1/maxSignificance)
289 c = ROOT.TCanvas(
"canvasCutEff",
"Cut efficiencies for "+methodName+
" classifier", JPyInterface.JsDraw.jsCanvasWidth,
290 JPyInterface.JsDraw.jsCanvasHeight)
296 TMVAStyle = ROOT.gROOT.GetStyle(
"Plain")
297 TMVAStyle.SetLineStyleString( 5,
"[32 22]" )
298 TMVAStyle.SetLineStyleString( 6,
"[12 22]" )
302 effpurS.SetTitle(
"Cut efficiencies and optimal cut value")
303 if methodName.find(
"Cuts")!=-1:
304 effpurS.GetXaxis().SetTitle(
"Signal Efficiency" )
306 effpurS.GetXaxis().SetTitle(
"Cut value applied on " + methodName +
" output" )
307 effpurS.GetYaxis().SetTitle(
"Efficiency (Purity)" )
308 TMVA.TMVAGlob.SetFrameStyle( effpurS )
311 c.SetRightMargin ( 2.0 )
313 effpurS.SetMaximum(1.1)
314 effpurS.Draw(
"histl")
316 purS.Draw(
"samehistl")
318 sigE.Draw(
"samehistl")
319 bgdE.Draw(
"samehistl")
321 signifColor = ROOT.TColor.GetColor(
"#00aa00" )
323 sSig.SetLineColor( signifColor )
324 sSig.Draw(
"samehistl")
326 effpurS.Draw(
"sameaxis" )
329 legend1 = ROOT.TLegend( c.GetLeftMargin(), 1 - c.GetTopMargin(),
330 c.GetLeftMargin() + 0.4, 1 - c.GetTopMargin() + 0.12 )
331 legend1.SetFillStyle( 1 )
332 legend1.AddEntry(sigE,
"Signal efficiency",
"L")
333 legend1.AddEntry(bgdE,
"Background efficiency",
"L")
335 legend1.SetBorderSize(1)
336 legend1.SetMargin( 0.3 )
339 legend2 = ROOT.TLegend( c.GetLeftMargin() + 0.4, 1 - c.GetTopMargin(),
340 1 - c.GetRightMargin(), 1 - c.GetTopMargin() + 0.12 )
341 legend2.SetFillStyle( 1 )
342 legend2.AddEntry(purS,
"Signal purity",
"L")
343 legend2.AddEntry(effpurS,
"Signal efficiency*purity",
"L")
344 legend2.AddEntry(sSig,
"S/#sqrt{ S+B }",
"L")
346 legend2.SetBorderSize(1)
347 legend2.SetMargin( 0.3 )
349 effline = ROOT.TLine( sSig.GetXaxis().GetXmin(), 1, sSig.GetXaxis().GetXmax(), 1 )
350 effline.SetLineWidth( 1 )
351 effline.SetLineColor( 1 )
358 tl.SetTextSize( 0.033 )
359 maxbin = sSig.GetMaximumBin()
360 line1 = tl.DrawLatex( 0.15, 0.23,
"For %1.0f signal and %1.0f background"%(fNSignal, fNBackground))
361 tl.DrawLatex( 0.15, 0.19,
"events the maximum S/#sqrt{S+B} is")
363 if maxSignificanceErr > 0:
364 line2 = tl.DrawLatex( 0.15, 0.15,
"%5.2f +- %4.2f when cutting at %5.2f"%(
367 sSig.GetXaxis().GetBinCenter(maxbin)) )
369 line2 = tl.DrawLatex( 0.15, 0.15,
"%4.2f when cutting at %5.2f"%(
371 sSig.GetXaxis().GetBinCenter(maxbin)) )
373 if methodName.find(
"Cuts")!=-1:
374 tl.DrawLatex( 0.13, 0.77,
"Method Cuts provides a bundle of cut selections, each tuned to a")
375 tl.DrawLatex(0.13, 0.74,
"different signal efficiency. Shown is the purity for each cut selection.")
377 wx = (sigE.GetXaxis().GetXmax()+abs(sigE.GetXaxis().GetXmin()))*0.135
378 rightAxis = ROOT.TGaxis( sigE.GetXaxis().GetXmax()+wx,
380 sigE.GetXaxis().GetXmax()+wx,
381 0.7, 0, 1.1*maxSignificance,510,
"+L")
382 rightAxis.SetLineColor ( signifColor )
383 rightAxis.SetLabelColor( signifColor )
384 rightAxis.SetTitleColor( signifColor )
386 rightAxis.SetTitleSize( sSig.GetXaxis().GetTitleSize() )
387 rightAxis.SetTitle(
"Significance" )
392 JPyInterface.JsDraw.Draw(c)
402 if (methodName==
"DNN"):
406 JPyInterface.JsDraw.Draw(net,
"drawNeuralNetwork",
True)
418 variables = tr.getVariables();
421 if treeSelector.value>tr.getNTrees():
422 treeSelector.value = tr.getNTrees()
425 "variables": variables,
426 "tree": tr.getTree(treeSelector.value)
428 json_str = json.dumps(toJs)
429 JPyInterface.JsDraw.Draw(json_str,
"drawDecisionTree",
True)
431 mx = str(tr.getNTrees()-1)
433 treeSelector = widgets.IntText(value=0, font_weight=
"bold")
434 drawTree = widgets.Button(description=
"Draw", font_weight=
"bold")
435 label = widgets.HTML(
"<div style='padding: 6px;font-weight:bold;color:#333;'>Decision Tree [0-"+mx+
"]:</div>")
437 drawTree.on_click(clicked)
438 container = widgets.HBox([label,treeSelector, drawTree])
447 <script type="text/javascript"> 448 require(["jquery"], function(jQ){ 449 jQ("input.stopTrainingButton").on("click", function(){ 450 IPython.notebook.kernel.interrupt(); 452 "background-color": "rgba(200, 0, 0, 0.8)", 454 "box-shadow": "0 3px 5px rgba(0, 0, 0, 0.3)", 459 <style type="text/css"> 460 input.stopTrainingButton { 461 background-color: #fff; 462 border: 1px solid #ccc; 471 input.stopTrainingButton:hover { 472 background-color: rgba(204, 204, 204, 0.4); 475 <input type="button" value="Stop" class="stopTrainingButton" /> 479 <script type="text/javascript" id="progressBarScriptInc"> 480 require(["jquery"], function(jQ){ 481 jQ("#jsmva_bar_$id").css("width", $progress + "%"); 482 jQ("#jsmva_label_$id").text($progress + '%'); 483 jQ("#progressBarScriptInc").parent().parent().remove(); 487 progress_bar = Template(
""" 489 #jsmva_progress_$id { 494 background-color: #f5f5f5; 496 box-shadow: inset 0 3px 6px rgba(0, 0, 0, 0.1); 502 background-color: #337ab7; 510 <div id="jsmva_progress_$id"> 511 <div id="jsmva_bar_$id"> 512 <div id="jsmva_label_$id">0%</div> 518 def exit_supported(mn):
520 es = [
"SVM",
"Cuts",
"Boost",
"BDT"]
522 if name.find(e) != -1:
526 wait_times = {
"MLP": 0.5,
"DNN": 1,
"BDT": 0.5}
528 for methodMapElement
in fac.fMethodsMap:
529 display(HTML(
"<center><h1>Dataset: "+str(methodMapElement[0])+
"</h1></center>"))
530 for m
in methodMapElement[1]:
531 m.GetName._threaded =
True 532 name = str(m.GetName())
533 display(HTML(
"<h2><b>Train method: "+name+
"</b></h2>"))
534 m.InitIPythonInteractive()
535 t = Thread(target=ROOT.TMVA.MethodBase.TrainMethod, args=[m])
537 if name
in wait_times:
538 display(HTML(button))
539 time.sleep(wait_times[name])
540 if m.GetMaxIter() != 0:
541 display(HTML(progress_bar.substitute({
"id": progress_bar_idx})))
542 display(HTML(inc.substitute({
"id": progress_bar_idx,
"progress": 100 * m.GetCurrentIter() / m.GetMaxIter()})))
543 JPyInterface.JsDraw.Draw(m.GetInteractiveTrainingError(),
"drawTrainingTestingErrors")
545 while not m.TrainingEnded():
546 JPyInterface.JsDraw.InsertData(m.GetInteractiveTrainingError())
547 if m.GetMaxIter() != 0:
548 display(HTML(inc.substitute({
549 "id": progress_bar_idx,
550 "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
553 except KeyboardInterrupt:
556 if exit_supported(name):
557 display(HTML(button))
559 if m.GetMaxIter()!=0:
560 display(HTML(progress_bar.substitute({
"id": progress_bar_idx})))
561 display(HTML(inc.substitute({
"id": progress_bar_idx,
"progress": 100*m.GetCurrentIter()/m.GetMaxIter()})))
563 display(HTML(
"<b>Training...</b>"))
564 if exit_supported(name):
566 while not m.TrainingEnded():
567 if m.GetMaxIter()!=0:
568 display(HTML(inc.substitute({
569 "id": progress_bar_idx,
570 "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
573 except KeyboardInterrupt:
576 while not m.TrainingEnded():
577 if m.GetMaxIter() != 0:
578 display(HTML(inc.substitute({
579 "id": progress_bar_idx,
580 "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
583 if m.GetMaxIter() != 0:
584 display(HTML(inc.substitute({
585 "id": progress_bar_idx,
586 "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
589 display(HTML(
"<b>End</b>"))
590 progress_bar_idx += 1
597 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"JobName",
"TargetFile"], *args, **kwargs)
598 except AttributeError:
600 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"JobName"], *args, **kwargs)
601 except AttributeError:
603 originalFunction, args = JPyInterface.functions.ProcessParameters(3, *args, **kwargs)
604 return originalFunction(*args)
608 compositeOpts =
False 610 if "Composite" in kwargs:
611 composite = kwargs[
"Composite"]
612 del kwargs[
"Composite"]
613 if "CompositeOptions" in kwargs:
614 compositeOpts = kwargs[
"CompositeOptions"]
615 del kwargs[
"CompositeOptions"]
616 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"DataLoader",
"Method",
"MethodTitle"], *args, **kwargs)
617 originalFunction, args = JPyInterface.functions.ProcessParameters(4, *args, **kwargs)
620 args.append(composite)
622 if compositeOpts!=
False:
623 o, compositeOptStr = JPyInterface.functions.ProcessParameters(-10, **compositeOpts)
625 args.append(compositeOptStr[0])
627 return originalFunction(*args)
632 originalFunction, args = JPyInterface.functions.ProcessParameters(0, *args, **kwargs)
633 return originalFunction(*args)
634 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"DataLoader",
"VIType",
"Method",
"MethodTitle"], *args, **kwargs)
635 originalFunction, args = JPyInterface.functions.ProcessParameters(5, *args, **kwargs)
636 hist = originalFunction(*args)
637 JPyInterface.JsDraw.Draw(hist)
643 originalFunction, args = JPyInterface.functions.ProcessParameters(0, *args, **kwargs)
644 return originalFunction(*args)
649 if "optParams" in kwargs:
650 optParams = kwargs[
"optParams"]
651 del kwargs[
"optParams"]
652 if "NumFolds" in kwargs:
653 NumFolds = kwargs[
"NumFolds"]
654 del kwargs[
"NumFolds"]
655 if "remakeDataSet" in kwargs:
656 remakeDataSet = kwargs[
"remakeDataSet"]
657 del kwargs[
"remakeDataSet"]
658 if "rocIntegrals" in kwargs:
659 rocIntegrals = kwargs[
"rocIntegrals"]
660 del kwargs[
"rocIntegrals"]
661 args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs([
"DataLoader",
"Method",
"MethodTitle"], *args, **kwargs)
662 originalFunction, args = JPyInterface.functions.ProcessParameters(4, *args, **kwargs)
664 args.append(optParams)
665 args.append(NumFolds)
666 args.append(remakeDataSet)
667 if rocIntegrals!=
None and rocIntegrals!=0:
668 args.append(rocIntegrals)
670 return originalFunction(*args)
673 __BookDNNHelper =
None 677 global __BookDNNHelper
678 def __bookDNN(optString):
679 self.BookMethod(loader, ROOT.TMVA.Types.kDNN, title, optString)
681 __BookDNNHelper = __bookDNN
682 JPyInterface.JsDraw.InsertCSS(
"NetworkDesigner.css")
683 JPyInterface.JsDraw.Draw(
"",
"NetworkDesigner",
True)
def __getBinaryTree(self, itree)
def __init__(self, fileName)
Standard Constructor.
def GetNetwork(xml_file)
Reads neural network weights from file and returns it in JSON format.
def ChangeCallOriginalEvaluateImportance(args, kwargs)
Rewrite the constructor of TMVA::Factory::EvaluateImportance.
def GetMethodObject(fac, datasetName, methodName)
Getting method object from factory.
def DrawProbabilityDistribution(fac, datasetName, methodName)
Draw output probability distribution.
def getTree(self, itree)
Public function which returns the specified tree object.
def DrawROCCurve(fac, datasetName)
Draw ROC curve.
def DrawDecisionTree(fac, datasetName, methodName)
Draw deep neural network.
def DrawCutEfficiencies(fac, datasetName, methodName)
Draw cut efficiencies.
def DrawNeuralNetwork(fac, datasetName, methodName)
Draw neural network.
def ChangeCallOriginalCrossValidate(args, kwargs)
Rewrite the constructor of TMVA::Factory::CrossValidate.
def BookDNN(self, loader, title="DNN")
Graphical interface for booking DNN.
def getNTrees(self)
Returns the number of trees.
def __readTree(self, binaryTree, tree={}, depth=0)
Reads the tree.
def GetDeepNetwork(xml_file)
Reads deep neural network weights from file and returns it in JSON format.
def DrawOutputDistribution(fac, datasetName, methodName)
Draw output distributions.
Helper class for reading decision tree from XML file.
def ChangeCallOriginal__init__(args, kwargs)
Rewrite the constructor of TMVA::Factory.
def ChangeTrainAllMethods(fac)
Rewrite function for TMVA::Factory::TrainAllMethods.
def getVariables(self)
Returns a list with input variable names.
def ChangeCallOriginalBookMethod(args, kwargs)
Rewrite TMVA::Factory::BookMethod.