[go: up one dir, main page]

Skip to content

Commit

Permalink
optimize with asm
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqy4933 committed Aug 4, 2023
1 parent 239d4d7 commit 16a0265
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 174 deletions.
5 changes: 5 additions & 0 deletions application/chatglm/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ int main(int argc, char** argv) {
std::string user_input, output;

int iter = 0;
int token_id = 0;
//! main loop
while (model->get_remain_token() > 0) {
if (!user_input.empty()) {
Expand All @@ -226,8 +227,12 @@ int main(int argc, char** argv) {
auto o = model->decode_iter(token);
fix_word(o);
output += o;
token_id++;
printf("%s", output.c_str());
fflush(stdout);
if (token_id % 10 == 0) {
running_summary = model->decode_summary();
}

// token 2 is the end of the instruction
if (token == etoken) {
Expand Down
12 changes: 6 additions & 6 deletions src/core/model_imp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ void ModelImp::prefill(const std::string& promote) {
m_last_queue.push_back(token);
m_last_queue.pop_front();
}
auto start = m_timer.get_time();
//auto start = m_timer.get_time();
m_graph->execute(tokens, m_logist, m_past, true);
auto end = m_timer.get_time();
m_time_cost += end - start;
//auto end = m_timer.get_time();
//m_time_cost += end - start;
m_past = tokens.size();
}

Expand All @@ -41,10 +41,10 @@ std::string ModelImp::decode(const std::string& user_input, int& token) {
m_last_queue.push_back(token);
m_last_queue.pop_front();
}
auto start = m_timer.get_time();
//auto start = m_timer.get_time();
m_graph->execute(tokens, m_logist, m_past, false);
auto end = m_timer.get_time();
m_time_cost += end - start;
//auto end = m_timer.get_time();
//m_time_cost += end - start;
sample_and_update();
m_past += tokens.size();
token = m_pre_token;
Expand Down
2 changes: 1 addition & 1 deletion src/kern/kernel_define.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ struct BlockQ40 {
static_assert(sizeof(BlockQ40) == 20, "BlockQ40 size error");

struct BlockQ40X8 {
float scale[8]; // delta
uint8_t qs[QK40 / 2 * 8]; // nibbles / quants
float scale[8]; // delta
};
static_assert(sizeof(BlockQ40X8) == 160, "BlockQ40X8 size error");

Expand Down
3 changes: 2 additions & 1 deletion src/kern/optimized/arm/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ TaskSet llm_matmul_compute_int4_float_packed(
int8_t* src = q_src + m * weight_q80_stride;
float* dst_ptr = dst + m * N + n * 8;
const float* bias_ptr = bias ? bias + n * 8 : nullptr;
vec_vec_dot_q40_with_q80_packed(K, q_weight, src, dst_ptr, bias_ptr);
//vec_vec_dot_q40_with_q80_packed(K, q_weight, src, dst_ptr, bias_ptr);
vec_vec_dot_q40_with_q80_packed_asm(K, q_weight, src, dst_ptr, bias_ptr);
}
}
};
Expand Down
Loading

0 comments on commit 16a0265

Please sign in to comment.