Fix mismatched case and unhandled exception in open_drm_fd_for_cuda_device()

This commit is contained in:
Cameron Gutman
2024-03-04 18:43:16 -06:00
parent cacadc4df4
commit 9f94eebd32

View File

@@ -247,13 +247,18 @@ namespace cuda {
// There's no way to directly go from CUDA to a DRM device, so we'll // There's no way to directly go from CUDA to a DRM device, so we'll
// use sysfs to look up the DRM device name from the PCI ID. // use sysfs to look up the DRM device name from the PCI ID.
char pci_bus_id[13]; std::array<char, 13> pci_bus_id;
CU_CHECK(cdf->cuDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), device), "Couldn't get CUDA device PCI bus ID"); CU_CHECK(cdf->cuDeviceGetPCIBusId(pci_bus_id.data(), pci_bus_id.size(), device), "Couldn't get CUDA device PCI bus ID");
BOOST_LOG(debug) << "Found CUDA device with PCI bus ID: "sv << pci_bus_id; BOOST_LOG(debug) << "Found CUDA device with PCI bus ID: "sv << pci_bus_id.data();
// Linux uses lowercase hexadecimal while CUDA uses uppercase
std::transform(pci_bus_id.begin(), pci_bus_id.end(), pci_bus_id.begin(),
[](char c) { return std::tolower(c); });
// Look for the name of the primary node in sysfs // Look for the name of the primary node in sysfs
try {
char sysfs_path[PATH_MAX]; char sysfs_path[PATH_MAX];
std::snprintf(sysfs_path, sizeof(sysfs_path), "/sys/bus/pci/devices/%s/drm", pci_bus_id); std::snprintf(sysfs_path, sizeof(sysfs_path), "/sys/bus/pci/devices/%s/drm", pci_bus_id.data());
fs::path sysfs_dir { sysfs_path }; fs::path sysfs_dir { sysfs_path };
for (auto &entry : fs::directory_iterator { sysfs_dir }) { for (auto &entry : fs::directory_iterator { sysfs_dir }) {
auto file = entry.path().filename(); auto file = entry.path().filename();
@@ -268,8 +273,12 @@ namespace cuda {
auto device_path = dri_path / file; auto device_path = dri_path / file;
return open(device_path.c_str(), O_RDWR); return open(device_path.c_str(), O_RDWR);
} }
}
catch (const std::filesystem::filesystem_error &err) {
BOOST_LOG(error) << "Failed to read sysfs: "sv << err.what();
}
BOOST_LOG(error) << "Unable to find DRM device with PCI bus ID: "sv << pci_bus_id; BOOST_LOG(error) << "Unable to find DRM device with PCI bus ID: "sv << pci_bus_id.data();
return -1; return -1;
} }