From be0fa0cd7e00484a02cf584795c2de79d8f5a6d9 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 16 Jun 2020 17:52:56 -0700 Subject: [PATCH] Added support for global threadpool --- CHANGELOG.md | 4 ++++ lib/onnxruntime/ffi.rb | 8 ++++---- lib/onnxruntime/inference_session.rb | 6 +++++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dff935a..1caa63a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.4.0 (unreleased) + +- Added support for global threadpool + ## 0.3.2 (2020-06-16) - Fixed error with FFI 1.13.0+ diff --git a/lib/onnxruntime/ffi.rb b/lib/onnxruntime/ffi.rb index 470cc6a..f1b664b 100644 --- a/lib/onnxruntime/ffi.rb +++ b/lib/onnxruntime/ffi.rb @@ -137,10 +137,10 @@ class Api < ::FFI::Struct :ModelMetadataLookupCustomMetadataMap, callback(%i[pointer pointer pointer pointer], :pointer), :ModelMetadataGetVersion, callback(%i[pointer pointer], :pointer), :ReleaseModelMetadata, callback(%i[pointer], :void), - :CreateEnvWithGlobalThreadPools, callback(%i[], :pointer), - :DisablePerSessionThreads, callback(%i[], :pointer), - :CreateThreadingOptions, callback(%i[], :pointer), - :ReleaseThreadingOptions, callback(%i[], :pointer), + :CreateEnvWithGlobalThreadPools, callback(%i[int string pointer pointer], :pointer), + :DisablePerSessionThreads, callback(%i[pointer], :pointer), + :CreateThreadingOptions, callback(%i[pointer], :pointer), + :ReleaseThreadingOptions, callback(%i[pointer], :pointer), :ModelMetadataGetCustomMetadataMapKeys, callback(%i[pointer pointer pointer pointer], :pointer), :AddFreeDimensionOverrideByName, callback(%i[], :pointer) end diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index 12f5277..5b13a3d 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -27,6 +27,7 @@ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: tr check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath + check_status api[:DisablePerSessionThreads].call(session_options.read_pointer) # session @session = ::FFI::MemoryPointer.new(:pointer) @@ -414,8 +415,11 @@ def env # use mutex for thread-safety Utils.mutex.synchronize do @@env ||= begin + threading_options = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CreateThreadingOptions].call(threading_options) + env = ::FFI::MemoryPointer.new(:pointer) - check_status api[:CreateEnv].call(3, "Default", env) + check_status api[:CreateEnvWithGlobalThreadPools].call(3, "Default", threading_options.read_pointer, env) at_exit { release :Env, env } # disable telemetry # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md