To protect your data, the CISO officer has suggested users to enable GitLab 2FA as soon as possible.

Commit c8367bdc authored by Zixian Cai's avatar Zixian Cai
Browse files

Support specifying pipelines to go through

parent b718f030
......@@ -24,22 +24,18 @@ logger = logging.getLogger(__name__)
@click.command()
@click.argument('file', type=click.Path(exists=True))
@click.argument('pipeline', default="log,boxplot,barplot")
@click.option('--skip_compile', default='none',
help="a list of tasks to skip compilation." +
"Can be 'all', 'none', or a string in the form of 'taskset1:task1,task2;taskset2'. " +
"If only a task set's name is defined, skip compilation of all its tasks. " +
"Default target in the form of 'taskset-task' is assumed to be found under output directory.")
@click.option('--comp_remote', is_flag=True, default=False)
def local(file, skip_compile, comp_remote):
def local(file, pipeline, skip_compile, comp_remote):
logger.info("Constructing a LocalRevision")
revision = LocalRevision(file)
logger.info("Running tasks specified in file")
revision.run_tasksets(skip_compile)
logger.info("Generating report, compare remote?: {}".format(comp_remote))
report = revision.generate_report()
report_pipelines = {
"mubench.models.pipeline.LogOutputPipeline": 42,
"mubench.models.pipeline.BarplotPipeline": 100,
"mubench.models.pipeline.BoxplotPipeline": 101
}
go_through_pipelines(report, report_pipelines)
go_through_pipelines(report, pipeline_names=pipeline.split(","))
......@@ -14,19 +14,12 @@
# limitations under the License.
import logging
from mubench.models.pipeline import pipelines
logger = logging.getLogger(__name__)
def go_through_pipelines(report, pipelines):
def import_name(name):
import importlib
mod, obj = name.rsplit('.', 1)
return getattr(importlib.import_module(mod), obj)
pipelines_cls = sorted(pipelines.items(), key=lambda x: x[1])
logger.info("Going through pipelines to process the report")
logger.info(pipelines_cls)
pipelines = [import_name(c[0])() for c in pipelines_cls]
for pipeline in pipelines:
def go_through_pipelines(report, pipeline_names):
ps = [pipelines[name]() for name in pipeline_names]
for pipeline in ps:
report = pipeline.process(report)
......@@ -143,3 +143,10 @@ class BoxplotPipeline(Pipeline):
box.set_facecolor("lightblue")
plt.show()
return report
pipelines = {
"boxplot": BoxplotPipeline,
"barplot": BarplotPipeline,
"log": LogOutputPipeline
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment