[go: up one dir, main page]

Skip to content

Commit

Permalink
"gpu inference is supported"
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqy4933 committed Jul 20, 2023
2 parents b8f7cc3 + fb67873 commit fdd6fad
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 110 deletions.
7 changes: 7 additions & 0 deletions application/alpaca/alpaca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct app_params {
bool use_color = true; // use color to distinguish generations and inputs
bool use_mmap = false; // use mmap to load model
std::string dtype = "float32"; // configure the compute dtype
std::string device = "CPU"; // configure the compute device type
std::string mtype = "llama"; // the model type name, llama
};

Expand Down Expand Up @@ -76,6 +77,9 @@ void app_print_usage(int argc, char** argv, const app_params& params) {
fprintf(stderr,
" -d type configure the compute type, default float32, can "
"be float32 and flot16 now.\n");
fprintf(stderr,
" -g type configure the compute device type, default CPU, "
"can be CPU and GPU now.\n");
fprintf(stderr,
" --model_type type the model type name, default llama, can only be "
"llama now.\n");
Expand All @@ -95,6 +99,8 @@ bool app_params_parse(int argc, char** argv, app_params& params) {
params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--dtype") {
params.dtype = argv[++i];
} else if (arg == "-g") {
params.device = argv[++i];
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
Expand Down Expand Up @@ -149,6 +155,7 @@ int main(int argc, char** argv) {
int64_t t_load_us = 0;
inferllm::ModelConfig config;
config.compt_type = params.dtype;
config.device_type = params.device;
config.nr_thread = params.n_threads;
config.enable_mmap = params.use_mmap;
config.nr_ctx = params.n_ctx;
Expand Down
5 changes: 5 additions & 0 deletions application/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct app_params {

bool use_mmap = false; // use mmap to load model
std::string dtype = "float32"; // configure the compute dtype
std::string device = "CPU"; // configure the compute device type
std::string mtype = "llama"; // the model type name, llama
};

Expand All @@ -56,6 +57,7 @@ void app_print_usage(int argc, char** argv, const app_params& params) {
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " --mmap enable mmap when read weights, default = false\n");
fprintf(stderr, " -d type configure the compute type, default float32, can be float32 and flot16 now.\n");
fprintf(stderr, " -g type configure the compute device type, default CPU, can be CPU and GPU now.\n");
fprintf(stderr, " --model_type type the model type name, default llama, can only be llama now.\n");
fprintf(stderr, "\n");
// clang-format on
Expand All @@ -74,6 +76,8 @@ bool app_params_parse(int argc, char** argv, app_params& params) {
params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--dtype") {
params.dtype = argv[++i];
} else if (arg == "-g") {
params.device = argv[++i];
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
Expand Down Expand Up @@ -128,6 +132,7 @@ int main(int argc, char** argv) {
int64_t t_load_us = 0;
inferllm::ModelConfig config;
config.compt_type = params.dtype;
config.device_type = params.device;
config.nr_thread = params.n_threads;
config.enable_mmap = params.use_mmap;
config.nr_ctx = params.n_ctx;
Expand Down
5 changes: 5 additions & 0 deletions application/chatglm/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct app_params {
bool use_color = true; // use color to distinguish generations and inputs
bool use_mmap = false; // use mmap to load model
std::string dtype = "float32"; // configure the compute dtype
std::string device = "cpu"; // configure the compute device type
int32_t version = 1; // the model version
};

Expand All @@ -58,6 +59,7 @@ void app_print_usage(int argc, char** argv, const app_params& params) {
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " --mmap enable mmap when read weights, default = false\n");
fprintf(stderr, " -d type configure the compute type, default float32, can be float32 and flot16 now.\n");
fprintf(stderr, " -g type configure the compute device type, default CPU, can be CPU and GPU now.\n");
fprintf(stderr, "\n");
// clang-format on
}
Expand All @@ -75,6 +77,8 @@ bool app_params_parse(int argc, char** argv, app_params& params) {
params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "-d" || arg == "--dtype") {
params.dtype = argv[++i];
} else if (arg == "-g") {
params.device = argv[++i];
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
Expand Down Expand Up @@ -131,6 +135,7 @@ int main(int argc, char** argv) {
int64_t t_load_us = 0;
inferllm::ModelConfig config;
config.compt_type = params.dtype;
config.device_type = params.device;
config.nr_thread = params.n_threads;
config.enable_mmap = params.use_mmap;
config.nr_ctx = params.n_ctx;
Expand Down
4 changes: 3 additions & 1 deletion include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ namespace inferllm {

struct ModelConfig {
//! dtype include 'float32','float16','int8','int4'
std::string compt_type;
std::string compt_type = "float32";
//! device_type include 'cpu','gpu'
std::string device_type = "cpu";
uint32_t nr_thread;
uint32_t nr_ctx;
int32_t device_id;
Expand Down
22 changes: 15 additions & 7 deletions src/core/model_imp.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,24 @@ class ModelImp {
ModelImp(const ModelConfig& config, const std::string& name)
: m_name(name), m_config(config) {
uint32_t nr_thread = config.nr_thread;
// if compile with GPU, use GPU, else use CPUDevice
#if ENABLE_GPU
m_device = make_unique<GPUDevice>(0);
#elif INFER_X86
m_device = make_unique<CPUDevice>(KernelType::X86, nr_thread);
std::string device_type = config.device_type;
if (device_type == "CPU" || device_type == "cpu") {
#if INFER_X86
m_device = make_unique<CPUDevice>(KernelType::X86, nr_thread);
#elif INFER_ARM
m_device = make_unique<CPUDevice>(KernelType::Arm, nr_thread);
m_device = make_unique<CPUDevice>(KernelType::Arm, nr_thread);
#else
m_device = make_unique<CPUDevice>(KernelType::Naive, nr_thread);
m_device = make_unique<CPUDevice>(KernelType::Naive, nr_thread);
#endif
} else if (
device_type == "GPU" || device_type == "CUDA" || device_type == "gpu") {
// if compile with GPU, use GPU, else use CPUDevice
#if ENABLE_GPU
m_device = make_unique<GPUDevice>(0);
#else
INFER_ASSERT(0, "GPU is disabled when build, please build with GPU.");
#endif
}

UserConfig user_config;
user_config.compt_type = dtype_from_str(config.compt_type);
Expand Down
Loading

0 comments on commit fdd6fad

Please sign in to comment.