JsMVA
Factory.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 ## @package JsMVA.Factory
3 # Factory module with the functions to be inserted to TMVA::Factory class and helper functions and classes
4 # @authors Attila Bagoly <battila93@gmail.com>
5 
6 
7 import ROOT
8 from ROOT import TMVA
9 import JPyInterface
10 from xml.etree.ElementTree import ElementTree
11 import json
12 from IPython.core.display import display, HTML, clear_output
13 from ipywidgets import widgets
14 from threading import Thread
15 import time
16 from string import Template
17 
18 
19 ## Getting method object from factory
20 # @param fac the TMVA::Factory object
21 # @param datasetName selecting the dataset
22 # @param methodName which method we want to get
23 def GetMethodObject(fac, datasetName, methodName):
24  method = []
25  for methodMapElement in fac.fMethodsMap:
26  if methodMapElement[0] != datasetName:
27  continue
28  methods = methodMapElement[1]
29  for m in methods:
30  m.GetName._threaded = True
31  if m.GetName() == methodName:
32  method.append( m )
33  break
34  if len(method) != 1:
35  print("Factory.GetMethodObject: no method object found")
36  return None
37  return (method[0])
38 
39 ## Reads deep neural network weights from file and returns it in JSON format
40 # @param xml_file path to DNN weight file
41 def GetDeepNetwork(xml_file):
42  tree = ElementTree()
43  tree.parse(xml_file)
44  roottree = tree.getroot()
45  network = {}
46  network["variables"] = []
47  for v in roottree.find("Variables"):
48  network["variables"].append(v.get('Title'))
49  layout = roottree.find("Weights").find("Layout")
50  net = []
51  for layer in layout:
52  net.append({"Connection": layer.get("Connection"),
53  "Nodes": layer.get("Nodes"),
54  "ActivationFunction": layer.get("ActivationFunction"),
55  "OutputMode": layer.get("OutputMode")
56  })
57  network["layers"] = net
58  Synapses = roottree.find("Weights").find("Synapses")
59  synapses = {
60  "InputSize": Synapses.get("InputSize"),
61  "OutputSize": Synapses.get("OutputSize"),
62  "NumberSynapses": Synapses.get("NumberSynapses"),
63  "synapses": []
64  }
65  for i in Synapses.text.split(" "):
66  tmp = i.replace("\n", "")
67  if len(tmp)>1:
68  synapses["synapses"].append(tmp)
69  network["synapses"] = synapses
70  return json.dumps(network)
71 
72 ## Reads neural network weights from file and returns it in JSON format
73 # @param xml_file path to weight file
74 def GetNetwork(xml_file):
75  tree = ElementTree()
76  tree.parse(xml_file)
77  roottree = tree.getroot()
78  network = {}
79  network["variables"] = []
80  for v in roottree.find("Variables"):
81  network["variables"].append(v.get('Title'))
82  layout = roottree.find("Weights").find("Layout")
83 
84  net = { "nlayers": layout.get("NLayers") }
85  for layer in layout:
86  neuron_num = int(layer.get("NNeurons"))
87  neurons = { "nneurons": neuron_num }
88  i = 0
89  for neuron in layer:
90  label = "neuron_"+str(i)
91  i += 1
92  nsynapses = int(neuron.get('NSynapses'))
93  neurons[label] = {"nsynapses": nsynapses}
94  if nsynapses == 0:
95  break
96  text = str(neuron.text)
97  wall = text.replace("\n","").split(" ")
98  weights = []
99  for w in wall:
100  if w!="":
101  weights.append(float(w))
102  neurons[label]["weights"] = weights
103  net["layer_"+str(layer.get('Index'))] = neurons
104  network["layout"] = net
105  return json.dumps(network)
106 
107 ## Helper class for reading decision tree from XML file
109 
110  ## Standard Constructor
111  def __init__(self, fileName):
112  self.__xmltree = ElementTree()
113  self.__xmltree.parse(fileName)
114  self.__NTrees = int(self.__xmltree.find("Weights").get('NTrees'))
115 
116  ## Returns the number of trees
117  # @param self object pointer
118  def getNTrees(self):
119  return (self.__NTrees)
120 
121  # Returns DOM object to selected tree
122  # @param self object pointer
123  # @param itree the index of tree
124  def __getBinaryTree(self, itree):
125  if self.__NTrees<=itree:
126  print( "to big number, tree number must be less then %s"%self.__NTrees )
127  return 0
128  return self.__xmltree.find("Weights").find("BinaryTree["+str(itree+1)+"]")
129 
130  ## Reads the tree
131  # @param self the object pointer
132  # @param binaryTree the tree DOM object to be read
133  # @param tree empty object, this will be filled
134  # @param depth current depth
135  def __readTree(self, binaryTree, tree={}, depth=0):
136  nodes = binaryTree.findall("Node")
137  if len(nodes)==0:
138  return
139  if len(nodes)==1 and nodes[0].get("pos")=="s":
140  info = {
141  "IVar": nodes[0].get("IVar"),
142  "Cut" : nodes[0].get("Cut"),
143  "purity": nodes[0].get("purity"),
144  "pos": 0
145  }
146  tree["info"] = info
147  tree["children"] = []
148  self.__readTree(nodes[0], tree, 1)
149  return
150  for node in nodes:
151  info = {
152  "IVar": node.get("IVar"),
153  "Cut" : node.get("Cut"),
154  "purity": node.get("purity"),
155  "pos": node.get("pos")
156  }
157  tree["children"].append({
158  "info": info,
159  "children": []
160  })
161  self.__readTree(node, tree["children"][-1], depth+1)
162 
163  ## Public function which returns the specified tree object
164  # @param self the object pointer
165  # @param itree selected tree index
166  def getTree(self, itree):
167  binaryTree = self.__getBinaryTree(itree)
168  if binaryTree==0:
169  return {}
170  tree = {}
171  self.__readTree(binaryTree, tree)
172  return tree
173 
174  ## Returns a list with input variable names
175  # @param self the object pointer
176  def getVariables(self):
177  varstree = self.__xmltree.find("Variables").findall("Variable")
178  variables = [None]*len(varstree)
179  for v in varstree:
180  variables[int(v.get('VarIndex'))] = v.get('Expression')
181  return variables
182 
183 
184 ## Draw ROC curve
185 # @param fac the object pointer
186 # @param datasetName the dataset name
187 def DrawROCCurve(fac, datasetName):
188  canvas = fac.GetROCCurve(datasetName)
189  JPyInterface.JsDraw.Draw(canvas)
190 
191 ## Draw output distributions
192 # @param fac the object pointer
193 # @param datasetName the dataset name
194 # @param methodName we want to see the output distribution of this method
195 def DrawOutputDistribution(fac, datasetName, methodName):
196  method = GetMethodObject(fac, datasetName, methodName)
197  if method==None:
198  return None
199  mvaRes = method.Data().GetResults(method.GetMethodName(), TMVA.Types.kTesting, TMVA.Types.kMaxAnalysisType)
200  sig = mvaRes.GetHist("MVA_S")
201  bgd = mvaRes.GetHist("MVA_B")
202  c, l = JPyInterface.JsDraw.sbPlot(sig, bgd, {"xaxis": methodName+" response",
203  "yaxis": "(1/N) dN^{ }/^{ }dx",
204  "plot": "TMVA response for classifier: "+methodName})
205  JPyInterface.JsDraw.Draw(c)
206 
207 ## Draw output probability distribution
208 # @param fac the object pointer
209 # @param datasetName the dataset name
210 # @param methodName we want to see the output probability distribution of this method
211 def DrawProbabilityDistribution(fac, datasetName, methodName):
212  method = GetMethodObject(fac, datasetName, methodName)
213  if method==0:
214  return
215  mvaRes = method.Data().GetResults(method.GetMethodName(), TMVA.Types.kTesting, TMVA.Types.kMaxAnalysisType)
216  sig = mvaRes.GetHist("Prob_S")
217  bgd = mvaRes.GetHist("Prob_B") #Rar_S
218  c, l = JPyInterface.JsDraw.sbPlot(sig, bgd, {"xaxis": "Signal probability",
219  "yaxis": "(1/N) dN^{ }/^{ }dx",
220  "plot": "TMVA probability for classifier: "+methodName})
221  JPyInterface.JsDraw.Draw(c)
222 
223 ## Draw cut efficiencies
224 # @param fac the object pointer
225 # @param datasetName the dataset name
226 # @param methodName we want to see the cut efficiencies of this method
227 def DrawCutEfficiencies(fac, datasetName, methodName):
228  method = GetMethodObject(fac, datasetName, methodName)
229  if method==0:
230  return
231  mvaRes = method.Data().GetResults(method.GetMethodName(), TMVA.Types.kTesting, TMVA.Types.kMaxAnalysisType)
232  sigE = mvaRes.GetHist("MVA_EFF_S")
233  bgdE = mvaRes.GetHist("MVA_EFF_B")
234 
235  fNSignal = 1000
236  fNBackground = 1000
237 
238  f = ROOT.TFormula("sigf", "x/sqrt(x+y)")
239 
240  pname = "purS_" + methodName
241  epname = "effpurS_" + methodName
242  ssigname = "significance_" + methodName
243 
244  nbins = sigE.GetNbinsX()
245  low = sigE.GetBinLowEdge(1)
246  high = sigE.GetBinLowEdge(nbins+1)
247 
248  purS = ROOT.TH1F(pname, pname, nbins, low, high)
249  sSig = ROOT.TH1F(ssigname, ssigname, nbins, low, high)
250  effpurS = ROOT.TH1F(epname, epname, nbins, low, high)
251 
252  #chop off useless stuff
253  sigE.SetTitle( "Cut efficiencies for "+methodName+" classifier")
254 
255  TMVA.TMVAGlob.SetSignalAndBackgroundStyle( sigE, bgdE )
256  TMVA.TMVAGlob.SetSignalAndBackgroundStyle( purS, bgdE )
257  TMVA.TMVAGlob.SetSignalAndBackgroundStyle( effpurS, bgdE )
258  sigE.SetFillStyle( 0 )
259  bgdE.SetFillStyle( 0 )
260  sSig.SetFillStyle( 0 )
261  sigE.SetLineWidth( 3 )
262  bgdE.SetLineWidth( 3 )
263  sSig.SetLineWidth( 3 )
264 
265  purS.SetFillStyle( 0 )
266  purS.SetLineWidth( 2 )
267  purS.SetLineStyle( 5 )
268  effpurS.SetFillStyle( 0 )
269  effpurS.SetLineWidth( 2 )
270  effpurS.SetLineStyle( 6 )
271  sig = 0
272  maxSigErr = 0
273  for i in range(1,sigE.GetNbinsX()+1):
274  eS = sigE.GetBinContent( i )
275  S = eS * fNSignal
276  B = bgdE.GetBinContent( i ) * fNBackground
277  if (S+B)==0:
278  purS.SetBinContent( i, 0)
279  else:
280  purS.SetBinContent( i, S/(S+B) )
281 
282  sSig.SetBinContent( i, f.Eval(S,B) )
283  effpurS.SetBinContent( i, eS*purS.GetBinContent( i ) )
284 
285  maxSignificance = sSig.GetMaximum()
286  maxSignificanceErr = 0
287  sSig.Scale(1/maxSignificance)
288 
289  c = ROOT.TCanvas( "canvasCutEff","Cut efficiencies for "+methodName+" classifier", JPyInterface.JsDraw.jsCanvasWidth,
290  JPyInterface.JsDraw.jsCanvasHeight)
291 
292  c.SetGrid(1)
293  c.SetTickx(0)
294  c.SetTicky(0)
295 
296  TMVAStyle = ROOT.gROOT.GetStyle("Plain")
297  TMVAStyle.SetLineStyleString( 5, "[32 22]" )
298  TMVAStyle.SetLineStyleString( 6, "[12 22]" )
299 
300  c.SetTopMargin(.2)
301 
302  effpurS.SetTitle("Cut efficiencies and optimal cut value")
303  if methodName.find("Cuts")!=-1:
304  effpurS.GetXaxis().SetTitle( "Signal Efficiency" )
305  else:
306  effpurS.GetXaxis().SetTitle( "Cut value applied on " + methodName + " output" )
307  effpurS.GetYaxis().SetTitle( "Efficiency (Purity)" )
308  TMVA.TMVAGlob.SetFrameStyle( effpurS )
309 
310  c.SetTicks(0,0)
311  c.SetRightMargin ( 2.0 )
312 
313  effpurS.SetMaximum(1.1)
314  effpurS.Draw("histl")
315 
316  purS.Draw("samehistl")
317 
318  sigE.Draw("samehistl")
319  bgdE.Draw("samehistl")
320 
321  signifColor = ROOT.TColor.GetColor( "#00aa00" )
322 
323  sSig.SetLineColor( signifColor )
324  sSig.Draw("samehistl")
325 
326  effpurS.Draw( "sameaxis" )
327 
328 
329  legend1 = ROOT.TLegend( c.GetLeftMargin(), 1 - c.GetTopMargin(),
330  c.GetLeftMargin() + 0.4, 1 - c.GetTopMargin() + 0.12 )
331  legend1.SetFillStyle( 1 )
332  legend1.AddEntry(sigE,"Signal efficiency","L")
333  legend1.AddEntry(bgdE,"Background efficiency","L")
334  legend1.Draw("same")
335  legend1.SetBorderSize(1)
336  legend1.SetMargin( 0.3 )
337 
338 
339  legend2 = ROOT.TLegend( c.GetLeftMargin() + 0.4, 1 - c.GetTopMargin(),
340  1 - c.GetRightMargin(), 1 - c.GetTopMargin() + 0.12 )
341  legend2.SetFillStyle( 1 )
342  legend2.AddEntry(purS,"Signal purity","L")
343  legend2.AddEntry(effpurS,"Signal efficiency*purity","L")
344  legend2.AddEntry(sSig, "S/#sqrt{ S+B }","L")
345  legend2.Draw("same")
346  legend2.SetBorderSize(1)
347  legend2.SetMargin( 0.3 )
348 
349  effline = ROOT.TLine( sSig.GetXaxis().GetXmin(), 1, sSig.GetXaxis().GetXmax(), 1 )
350  effline.SetLineWidth( 1 )
351  effline.SetLineColor( 1 )
352  effline.Draw()
353 
354  c.Update()
355 
356  tl = ROOT.TLatex()
357  tl.SetNDC()
358  tl.SetTextSize( 0.033 )
359  maxbin = sSig.GetMaximumBin()
360  line1 = tl.DrawLatex( 0.15, 0.23, "For %1.0f signal and %1.0f background"%(fNSignal, fNBackground))
361  tl.DrawLatex( 0.15, 0.19, "events the maximum S/#sqrt{S+B} is")
362 
363  if maxSignificanceErr > 0:
364  line2 = tl.DrawLatex( 0.15, 0.15, "%5.2f +- %4.2f when cutting at %5.2f"%(
365  maxSignificance,
366  maxSignificanceErr,
367  sSig.GetXaxis().GetBinCenter(maxbin)) )
368  else:
369  line2 = tl.DrawLatex( 0.15, 0.15, "%4.2f when cutting at %5.2f"%(
370  maxSignificance,
371  sSig.GetXaxis().GetBinCenter(maxbin)) )
372 
373  if methodName.find("Cuts")!=-1:
374  tl.DrawLatex( 0.13, 0.77, "Method Cuts provides a bundle of cut selections, each tuned to a")
375  tl.DrawLatex(0.13, 0.74, "different signal efficiency. Shown is the purity for each cut selection.")
376 
377  wx = (sigE.GetXaxis().GetXmax()+abs(sigE.GetXaxis().GetXmin()))*0.135
378  rightAxis = ROOT.TGaxis( sigE.GetXaxis().GetXmax()+wx,
379  c.GetUymin()-0.3,
380  sigE.GetXaxis().GetXmax()+wx,
381  0.7, 0, 1.1*maxSignificance,510, "+L")
382  rightAxis.SetLineColor ( signifColor )
383  rightAxis.SetLabelColor( signifColor )
384  rightAxis.SetTitleColor( signifColor )
385 
386  rightAxis.SetTitleSize( sSig.GetXaxis().GetTitleSize() )
387  rightAxis.SetTitle( "Significance" )
388  rightAxis.Draw()
389 
390  c.Update()
391 
392  JPyInterface.JsDraw.Draw(c)
393 
394 ## Draw neural network
395 # @param fac the object pointer
396 # @param datasetName the dataset name
397 # @param methodName we want to see the network created by this method
398 def DrawNeuralNetwork(fac, datasetName, methodName):
399  m = GetMethodObject(fac, datasetName, methodName)
400  if m==None:
401  return None
402  if (methodName=="DNN"):
403  net = GetDeepNetwork(str(m.GetWeightFileName()))
404  else:
405  net = GetNetwork(str(m.GetWeightFileName()))
406  JPyInterface.JsDraw.Draw(net, "drawNeuralNetwork", True)
407 
408 ## Draw deep neural network
409 # @param fac the object pointer
410 # @param datasetName the dataset name
411 # @param methodName we want to see the deep network created by this method
412 def DrawDecisionTree(fac, datasetName, methodName):
413  m = GetMethodObject(fac, datasetName, methodName)
414  if m==None:
415  return None
416  tr = TreeReader(str(m.GetWeightFileName()))
417 
418  variables = tr.getVariables();
419 
420  def clicked(b):
421  if treeSelector.value>tr.getNTrees():
422  treeSelector.value = tr.getNTrees()
423  clear_output()
424  toJs = {
425  "variables": variables,
426  "tree": tr.getTree(treeSelector.value)
427  }
428  json_str = json.dumps(toJs)
429  JPyInterface.JsDraw.Draw(json_str, "drawDecisionTree", True)
430 
431  mx = str(tr.getNTrees()-1)
432 
433  treeSelector = widgets.IntText(value=0, font_weight="bold")
434  drawTree = widgets.Button(description="Draw", font_weight="bold")
435  label = widgets.HTML("<div style='padding: 6px;font-weight:bold;color:#333;'>Decision Tree [0-"+mx+"]:</div>")
436 
437  drawTree.on_click(clicked)
438  container = widgets.HBox([label,treeSelector, drawTree])
439  display(container)
440 
441 ## Rewrite function for TMVA::Factory::TrainAllMethods. This function provides interactive training.
442 # @param fac the factory object pointer
444  clear_output()
445  #stop button
446  button = """
447  <script type="text/javascript">
448  require(["jquery"], function(jQ){
449  jQ("input.stopTrainingButton").on("click", function(){
450  IPython.notebook.kernel.interrupt();
451  jQ(this).css({
452  "background-color": "rgba(200, 0, 0, 0.8)",
453  "color": "#fff",
454  "box-shadow": "0 3px 5px rgba(0, 0, 0, 0.3)",
455  });
456  });
457  });
458  </script>
459  <style type="text/css">
460  input.stopTrainingButton {
461  background-color: #fff;
462  border: 1px solid #ccc;
463  width: 100%;
464  font-size: 16px;
465  font-weight: bold;
466  padding: 6px 12px;
467  cursor: pointer;
468  border-radius: 6px;
469  color: #333;
470  }
471  input.stopTrainingButton:hover {
472  background-color: rgba(204, 204, 204, 0.4);
473  }
474  </style>
475  <input type="button" value="Stop" class="stopTrainingButton" />
476  """
477  #progress bar
478  inc = Template("""
479  <script type="text/javascript" id="progressBarScriptInc">
480  require(["jquery"], function(jQ){
481  jQ("#jsmva_bar_$id").css("width", $progress + "%");
482  jQ("#jsmva_label_$id").text($progress + '%');
483  jQ("#progressBarScriptInc").parent().parent().remove();
484  });
485  </script>
486  """)
487  progress_bar = Template("""
488  <style>
489  #jsmva_progress_$id {
490  position: relative;
491  float: left;
492  height: 30px;
493  width: 100%;
494  background-color: #f5f5f5;
495  border-radius: 3px;
496  box-shadow: inset 0 3px 6px rgba(0, 0, 0, 0.1);
497  }
498  #jsmva_bar_$id {
499  position: absolute;
500  width: 1%;
501  height: 100%;
502  background-color: #337ab7;
503  }
504  #jsmva_label_$id {
505  text-align: center;
506  line-height: 30px;
507  color: white;
508  }
509  </style>
510  <div id="jsmva_progress_$id">
511  <div id="jsmva_bar_$id">
512  <div id="jsmva_label_$id">0%</div>
513  </div>
514  </div>
515  """)
516  progress_bar_idx = 0
517 
518  def exit_supported(mn):
519  name = str(mn)
520  es = ["SVM", "Cuts", "Boost", "BDT"]
521  for e in es:
522  if name.find(e) != -1:
523  return True
524  return False
525 
526  wait_times = {"MLP": 0.5, "DNN": 1, "BDT": 0.5}
527 
528  for methodMapElement in fac.fMethodsMap:
529  display(HTML("<center><h1>Dataset: "+str(methodMapElement[0])+"</h1></center>"))
530  for m in methodMapElement[1]:
531  m.GetName._threaded = True
532  name = str(m.GetName())
533  display(HTML("<h2><b>Train method: "+name+"</b></h2>"))
534  m.InitIPythonInteractive()
535  t = Thread(target=ROOT.TMVA.MethodBase.TrainMethod, args=[m])
536  t.start()
537  if name in wait_times:
538  display(HTML(button))
539  time.sleep(wait_times[name])
540  if m.GetMaxIter() != 0:
541  display(HTML(progress_bar.substitute({"id": progress_bar_idx})))
542  display(HTML(inc.substitute({"id": progress_bar_idx, "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()})))
543  JPyInterface.JsDraw.Draw(m.GetInteractiveTrainingError(), "drawTrainingTestingErrors")
544  try:
545  while not m.TrainingEnded():
546  JPyInterface.JsDraw.InsertData(m.GetInteractiveTrainingError())
547  if m.GetMaxIter() != 0:
548  display(HTML(inc.substitute({
549  "id": progress_bar_idx,
550  "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
551  })))
552  time.sleep(0.5)
553  except KeyboardInterrupt:
554  m.ExitFromTraining()
555  else:
556  if exit_supported(name):
557  display(HTML(button))
558  time.sleep(0.5)
559  if m.GetMaxIter()!=0:
560  display(HTML(progress_bar.substitute({"id": progress_bar_idx})))
561  display(HTML(inc.substitute({"id": progress_bar_idx, "progress": 100*m.GetCurrentIter()/m.GetMaxIter()})))
562  else:
563  display(HTML("<b>Training...</b>"))
564  if exit_supported(name):
565  try:
566  while not m.TrainingEnded():
567  if m.GetMaxIter()!=0:
568  display(HTML(inc.substitute({
569  "id": progress_bar_idx,
570  "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
571  })))
572  time.sleep(0.5)
573  except KeyboardInterrupt:
574  m.ExitFromTraining()
575  else:
576  while not m.TrainingEnded():
577  if m.GetMaxIter() != 0:
578  display(HTML(inc.substitute({
579  "id": progress_bar_idx,
580  "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
581  })))
582  time.sleep(0.5)
583  if m.GetMaxIter() != 0:
584  display(HTML(inc.substitute({
585  "id": progress_bar_idx,
586  "progress": 100 * m.GetCurrentIter() / m.GetMaxIter()
587  })))
588  else:
589  display(HTML("<b>End</b>"))
590  progress_bar_idx += 1
591  t.join()
592  return
593 
594 ## Rewrite the constructor of TMVA::Factory
595 def ChangeCallOriginal__init__(*args, **kwargs):
596  try:
597  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["JobName", "TargetFile"], *args, **kwargs)
598  except AttributeError:
599  try:
600  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["JobName"], *args, **kwargs)
601  except AttributeError:
602  raise AttributeError
603  originalFunction, args = JPyInterface.functions.ProcessParameters(3, *args, **kwargs)
604  return originalFunction(*args)
605 
606 ## Rewrite TMVA::Factory::BookMethod
607 def ChangeCallOriginalBookMethod(*args, **kwargs):
608  compositeOpts = False
609  composite = False
610  if "Composite" in kwargs:
611  composite = kwargs["Composite"]
612  del kwargs["Composite"]
613  if "CompositeOptions" in kwargs:
614  compositeOpts = kwargs["CompositeOptions"]
615  del kwargs["CompositeOptions"]
616  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["DataLoader", "Method", "MethodTitle"], *args, **kwargs)
617  originalFunction, args = JPyInterface.functions.ProcessParameters(4, *args, **kwargs)
618  if composite!=False:
619  args = list(args)
620  args.append(composite)
621  args = tuple(args)
622  if compositeOpts!=False:
623  o, compositeOptStr = JPyInterface.functions.ProcessParameters(-10, **compositeOpts)
624  args = list(args)
625  args.append(compositeOptStr[0])
626  args = tuple(args)
627  return originalFunction(*args)
628 
629 ## Rewrite the constructor of TMVA::Factory::EvaluateImportance
631  if len(kwargs) == 0:
632  originalFunction, args = JPyInterface.functions.ProcessParameters(0, *args, **kwargs)
633  return originalFunction(*args)
634  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["DataLoader", "VIType", "Method", "MethodTitle"], *args, **kwargs)
635  originalFunction, args = JPyInterface.functions.ProcessParameters(5, *args, **kwargs)
636  hist = originalFunction(*args)
637  JPyInterface.JsDraw.Draw(hist)
638  return hist
639 
640 ## Rewrite the constructor of TMVA::Factory::CrossValidate
641 def ChangeCallOriginalCrossValidate(*args, **kwargs):
642  if len(kwargs) == 0:
643  originalFunction, args = JPyInterface.functions.ProcessParameters(0, *args, **kwargs)
644  return originalFunction(*args)
645  optParams = False
646  NumFolds = 5
647  remakeDataSet = True
648  rocIntegrals = None
649  if "optParams" in kwargs:
650  optParams = kwargs["optParams"]
651  del kwargs["optParams"]
652  if "NumFolds" in kwargs:
653  NumFolds = kwargs["NumFolds"]
654  del kwargs["NumFolds"]
655  if "remakeDataSet" in kwargs:
656  remakeDataSet = kwargs["remakeDataSet"]
657  del kwargs["remakeDataSet"]
658  if "rocIntegrals" in kwargs:
659  rocIntegrals = kwargs["rocIntegrals"]
660  del kwargs["rocIntegrals"]
661  args, kwargs = JPyInterface.functions.ConvertSpecKwargsToArgs(["DataLoader", "Method", "MethodTitle"], *args, **kwargs)
662  originalFunction, args = JPyInterface.functions.ProcessParameters(4, *args, **kwargs)
663  args = list(args)
664  args.append(optParams)
665  args.append(NumFolds)
666  args.append(remakeDataSet)
667  if rocIntegrals!=None and rocIntegrals!=0:
668  args.append(rocIntegrals)
669  args = tuple(args)
670  return originalFunction(*args)
671 
672 ## Background booking method for BookDNN
673 __BookDNNHelper = None
674 
675 ## Graphical interface for booking DNN
676 def BookDNN(self, loader, title="DNN"):
677  global __BookDNNHelper
678  def __bookDNN(optString):
679  self.BookMethod(loader, ROOT.TMVA.Types.kDNN, title, optString)
680  return
681  __BookDNNHelper = __bookDNN
682  JPyInterface.JsDraw.InsertCSS("NetworkDesigner.css")
683  JPyInterface.JsDraw.Draw("", "NetworkDesigner", True)
def __getBinaryTree(self, itree)
Definition: Factory.py:124
def __init__(self, fileName)
Standard Constructor.
Definition: Factory.py:111
def GetNetwork(xml_file)
Reads neural network weights from file and returns it in JSON format.
Definition: Factory.py:74
def ChangeCallOriginalEvaluateImportance(args, kwargs)
Rewrite the constructor of TMVA::Factory::EvaluateImportance.
Definition: Factory.py:630
def GetMethodObject(fac, datasetName, methodName)
Getting method object from factory.
Definition: Factory.py:23
def DrawProbabilityDistribution(fac, datasetName, methodName)
Draw output probability distribution.
Definition: Factory.py:211
def getTree(self, itree)
Public function which returns the specified tree object.
Definition: Factory.py:166
def DrawROCCurve(fac, datasetName)
Draw ROC curve.
Definition: Factory.py:187
def DrawDecisionTree(fac, datasetName, methodName)
Draw deep neural network.
Definition: Factory.py:412
def DrawCutEfficiencies(fac, datasetName, methodName)
Draw cut efficiencies.
Definition: Factory.py:227
def DrawNeuralNetwork(fac, datasetName, methodName)
Draw neural network.
Definition: Factory.py:398
def ChangeCallOriginalCrossValidate(args, kwargs)
Rewrite the constructor of TMVA::Factory::CrossValidate.
Definition: Factory.py:641
def BookDNN(self, loader, title="DNN")
Graphical interface for booking DNN.
Definition: Factory.py:676
def getNTrees(self)
Returns the number of trees.
Definition: Factory.py:118
def __readTree(self, binaryTree, tree={}, depth=0)
Reads the tree.
Definition: Factory.py:135
def GetDeepNetwork(xml_file)
Reads deep neural network weights from file and returns it in JSON format.
Definition: Factory.py:41
def DrawOutputDistribution(fac, datasetName, methodName)
Draw output distributions.
Definition: Factory.py:195
Helper class for reading decision tree from XML file.
Definition: Factory.py:108
def ChangeCallOriginal__init__(args, kwargs)
Rewrite the constructor of TMVA::Factory.
Definition: Factory.py:595
def ChangeTrainAllMethods(fac)
Rewrite function for TMVA::Factory::TrainAllMethods.
Definition: Factory.py:443
def getVariables(self)
Returns a list with input variable names.
Definition: Factory.py:176
def ChangeCallOriginalBookMethod(args, kwargs)
Rewrite TMVA::Factory::BookMethod.
Definition: Factory.py:607