it looks like it could work :)

i did a top_k of 16 and a very simple loop that fetched each index 
individually, it's miles faster, many avenues for tuning

expecting to see garbage at the end tho :/

one of the interesting things is that it is actually downloading and storing 
these weights in sparse files ... it's really cool to see that near something 
useful

it processes like 10 (maybe 20) layers and then gets ratelimited, which i 
haven't handled.
this might be fixed with proper range headers. [it looks really interesting to 
implement sparse indices when nettensor does not yet support, but it's a 
different task from this, it will take looking near the tensor internals to see 
if the sparsity is exposed somewhere;;;;;

    def F_linear(input, weight, bias=None):
        #import pdb; pdb.set_trace()
        cls = NetTensor
        assert type(weight) is cls and type(input) is torch.Tensor
        assert bias is None or type(bias) is cls
        name = weight.safeslice.name.rsplit('.',1)[0]

        top_k = 16

        #number_passes = math.ceil(weight.mem_usage_frac())
        number_passes = math.ceil(weight[...,:top_k].mem_usage_frac())

        input_mask_data = input.abs()
        while len(input_mask_data.shape) > 1:
            input_mask_data = input_mask_data.max(dim=-2).values
        top_k_indices = input_mask_data.sort(descending=True).indices[:top_k]
        top_k_indices = top_k_indices.sort().values
        input = input[...,top_k_indices]

        if number_passes == 1:
            product = torch.matmul(
                input,
                #weight.fetch(progress=name, validate_usage=False).T
                torch.stack([
                    weight[..., index].fetch(validate_usage=False)
                    for index in top_k_indices.tolist()
                ], dim=-1).T,
            )                                                                   
                                                                                
                                                        else:
            rows_at_once = math.ceil(weight.shape[0] / number_passes)
            product = torch.cat([
                torch.matmul(
                    input,
#                    weight[offset : 
offset+rows_at_once].fetch(progress=f'row{offset}-{offset+rows_at_once}/{weight.shape[0]}',
 validate_usage=False).T
                    torch.stack([
                        weight[offset : offset+rows_at_once, 
index].fetch(validate_usage=False)
                        for index in top_k_indices.tolist()                     
                                                                                
                                                                    ], 
dim=-1).T,
                )
                for offset in tqdm.tqdm(range(
                    0,
                    weight.shape[0],
                    rows_at_once
                ), desc=name, unit='blk', leave=False)
            ], dim=-1)
        if bias is None:
            return product

  File "/home/karl3/projects/httptransformer/netsafetensors.py", line 57, in 
read
    assert readsize == length
           ^^^^^^^^^^^^^^^^^^
AssertionError

(Pdb) p buf[:readsize].tobytes().decode()
'<!DOCTYPE html>\n<html class="" lang="en">\n<head>\n    <meta charset="utf-8" 
/>\n    <meta\n            name="viewport"\n            
content="width=device-width, initial-scale=1.0, user-scalable=no"\n    />\n    
<meta\n            name="description"\n            content="We\'re on a journey 
to advance and democratize artificial intelligence through open source and open 
science."\n    />\n    <meta property="fb:app_id" content="1321688464574422" 
/>\n    <meta name="twitter:card" content="summary_large_image" />\n    <meta 
name="twitter:site" content="@huggingface" />\n    <meta\n            
property="og:title"\n            content="Hugging Face - The AI community 
building the future."\n    />\n    <meta property="og:type" content="website" 
/>\n\n    <title>Hugging Face - The AI community building the future.</title>\n 
   <style>\n        body {\n            margin: 0;\n        }\n\n        main 
{\n            background-color: white;\n            min-height: 100vh;\n       
     padding: 7rem 1rem 8rem 1rem;\n            text-align: center;\n           
 font-family: Source Sans Pro, ui-sans-serif, system-ui, -apple-system,\n       
     BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, Arial, Noto Sans,\n  
          sans-serif, Apple Color Emoji, Segoe UI Emoji, Segoe UI Symbol,\n     
       Noto Color Emoji;\n        }\n\n        img {\n            width: 
6rem;\n            height: 6rem;\n            margin: 0 auto 1rem;\n        
}\n\n        h1 {\n            font-size: 3.75rem;\n            line-height: 
1;\n            color: rgba(31, 41, 55, 1);\n            font-weight: 700;\n    
        box-sizing: border-box;\n            margin: 0 auto;\n        }\n\n     
   p, a {\n            color: rgba(107, 114, 128, 1);\n            font-size: 
1.125rem;\n            line-height: 1.75rem;\n            max-width: 28rem;\n   
         box-sizing: border-box;\n            margin: 0 auto;\n        }\n\n    
    .dark main {\n            background-color: rgb(11, 15, 25);\n        }\n   
     .dark h1 {\n            color: rgb(209, 213, 219);\n        }\n        
.dark p, .dark a {\n            color: rgb(156, 163, 175);\n        }\n    
</style>\n    <script>\n        // On page load or when changing themes, best 
to add inline in `head` to avoid FOUC\n        const key = 
"_tb_global_settings";\n        let theme = 
window.matchMedia("(prefers-color-scheme: dark)").matches\n            ? 
"dark"\n            : "light";\n        try {\n            const storageTheme = 
JSON.parse(window.localStorage.getItem(key)).theme;\n            if 
(storageTheme) {\n                theme = storageTheme === "dark" ? "dark" : 
"light";\n            }\n        } catch (e) {}\n        if (theme === "dark") 
{\n            document.documentElement.classList.add("dark");\n        } else 
{\n            document.documentElement.classList.remove("dark");\n        }\n  
  </script>\n</head>\n\n<body>\n<main>\n    <img\n            
src="https://cdn-media.huggingface.co/assets/huggingface_logo.svg"\n            
alt=""\n    />\n    <div>\n        <h1>429</h1>\n        <p>We had to rate 
limit you. If you think it\'s an error, send us <a 
href="mailto:[email protected]";>an email</a></p>\n    
</div>\n</main>\n</body>\n</html>'

Reply via email to