To run Llama 3.1-8B-instruct model on a local CPU with 4 GB ram without quantization. By Loading and Running a LLaMA Model on CPU with Disk-based Layer Loading.

Posted by Lord_Momus@reddit | programming | View on Reddit | 14 comments

I am trying to run 3.1 8B llama instruct model https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct on a 4GB ram laptop. The idea I'm using is to load and run one layer at a time.
I have a class.
It initializes key components of the LLaMA architecture:
LlamaTokenEmbed: Handles token embeddings.
LlamaLayer: Represents a transformer block.
LlamaFinalLayerNorm: Normalizes the output before final predictions.
LlamaFinalLayerHead: Generates final token probabilities.

Running Inference (run method)
It processes the tokens through the embedding layer.
Then, it iterates over 32 transformer layers (LlamaLayer) by Loading the corresponding layer weights from disk. Runs the layer on the input tensor x.
After all layers are processed, the final normalization and output head compute the final model output.
Here's the code

    
class LlamaCpuDiskRun():
    def __init__(self,config):
        self.config = config
        self.freqs_complex = precompute_theta_pos_frequencies(self.config.dim // self.config.n_heads, self.config.max_position_embeddings * 2, device = self.config.device)
        self.llamatoken = LlamaTokenEmbed(self.config)
        self.llamalayer = LlamaLayer(self.config,self.freqs_complex)
        self.llamafinalnorm = LlamaFinalLayerNorm(self.config)
        self.llamafinallmhead = LlamaFinalLayerHead(self.config)
        prev_time = time.time()
        self.llamatoken.load_state_dict(load_file(config.model_dir + "/separated_weights/embed_tokens.safetensors"), strict=True)
        print(time.time() - prev_time)
        self.llamafinalnorm.load_state_dict(load_file(config.model_dir + "/separated_weights/norm.safetensors"), strict=True)
        self.llamafinallmhead.load_state_dict(load_file(config.model_dir + "/separated_weights/lm_head.safetensors"), strict=True)

    def run(self,tokens : torch.Tensor, curr_pos: int):
        total_time = time.time()
        x = self.llamatoken(tokens)
        layer_time_avg = 0
        layer_load_t_avg = 0
        for i in range(0,32):
            print(f"layer{i}")
            prev_time = time.time()
            self.llamalayer.load_state_dict(load_file(self.config.model_dir + f"/separated_weights/layers{i}.safetensors"), strict=True)
            t = time.time() - prev_time
            layer_load_t_avg += t
            print(t)
            prev_time = time.time()
            x = self.llamalayer(x,curr_pos)
            t = time.time() - prev_time
            layer_time_avg += t
            print(t)
        print("final layers")
        prev_time = time.time()
        x = self.llamafinallmhead(self.llamafinalnorm(x))
        print(time.time() - prev_time)
        print(x.shape)
        print("total time")
        print(time.time() - total_time)
        print(f"average layer compute and load time:{layer_time_avg/32},{layer_load_t_avg/32}" )

    
class LlamaCpuDiskRun():
    def __init__(self,config):
        self.config = config
        self.freqs_complex = precompute_theta_pos_frequencies(self.config.dim // self.config.n_heads, self.config.max_position_embeddings * 2, device = self.config.device)
        self.llamatoken = LlamaTokenEmbed(self.config)
        self.llamalayer = LlamaLayer(self.config,self.freqs_complex)
        self.llamafinalnorm = LlamaFinalLayerNorm(self.config)
        self.llamafinallmhead = LlamaFinalLayerHead(self.config)
        prev_time = time.time()
        self.llamatoken.load_state_dict(load_file(config.model_dir + "/separated_weights/embed_tokens.safetensors"), strict=True)
        print(time.time() - prev_time)
        self.llamafinalnorm.load_state_dict(load_file(config.model_dir + "/separated_weights/norm.safetensors"), strict=True)
        self.llamafinallmhead.load_state_dict(load_file(config.model_dir + "/separated_weights/lm_head.safetensors"), strict=True)


    def run(self,tokens : torch.Tensor, curr_pos: int):
        total_time = time.time()
        x = self.llamatoken(tokens)
        layer_time_avg = 0
        layer_load_t_avg = 0
        for i in range(0,32):
            print(f"layer{i}")
            prev_time = time.time()
            self.llamalayer.load_state_dict(load_file(self.config.model_dir + f"/separated_weights/layers{i}.safetensors"), strict=True)
            t = time.time() - prev_time
            layer_load_t_avg += t
            print(t)
            prev_time = time.time()
            x = self.llamalayer(x,curr_pos)
            t = time.time() - prev_time
            layer_time_avg += t
            print(t)
        print("final layers")
        prev_time = time.time()
        x = self.llamafinallmhead(self.llamafinalnorm(x))
        print(time.time() - prev_time)
        print(x.shape)
        print("total time")
        print(time.time() - total_time)
        print(f"average layer compute and load time:{layer_time_avg/32},{layer_load_t_avg/32}" )

Output:
total time
27.943154096603394
average layer compute and load time:0.03721388429403305,0.8325831741094589

The weights loading part takes most of the time 0.832*32 = 26.624 seconds, compute takes 0.037 * 32 = 1.18 seconds.

The compute is 22 times faster than loading the weights part.

I am looking for ideas to minimize the weights loading time. Any idea on how I can improve this?