"""This module contains all NVIDIA GPU related metrics functions"""
import logging
from py3nvml.py3nvml import *
_log = logging.getLogger(__name__)
[docs]def device_query(func, *args):
"""Convenience wrapper to query different metrics for NVIDIA GPUs
Args:
func (str): Name of the API function
Returns:
list: Metric value
"""
try:
# Initialise NVML
nvmlInit()
# Number of devices
num_gpus = nvmlDeviceGetCount()
# Query metric for each GPU
metric_values = []
for i in range(num_gpus):
handle = nvmlDeviceGetHandleByIndex(i)
metric_values.append(globals()[func](handle, *args))
# Teardown NVML
nvmlShutdown()
except NVMLError:
return []
return metric_values