Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		Samuel Stevens
		
	commited on
		
		
					Commit 
							
							·
						
						ef07580
	
1
								Parent(s):
							
							293f004
								
add common name
Browse files- app.py +9 -2
- txt_emb_species.json +2 -2
    	
        app.py
    CHANGED
    
    | @@ -115,6 +115,13 @@ def zero_shot_classification(img, cls_str: str) -> dict[str, float]: | |
| 115 | 
             
                return {cls: prob for cls, prob in zip(classes, probs)}
         | 
| 116 |  | 
| 117 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 118 | 
             
            @torch.no_grad()
         | 
| 119 | 
             
            def open_domain_classification(img, rank: int) -> dict[str, float]:
         | 
| 120 | 
             
                """
         | 
| @@ -133,13 +140,13 @@ def open_domain_classification(img, rank: int) -> dict[str, float]: | |
| 133 | 
             
                if rank + 1 == len(ranks):
         | 
| 134 | 
             
                    topk = probs.topk(k)
         | 
| 135 | 
             
                    return {
         | 
| 136 | 
            -
                         | 
| 137 | 
             
                    }
         | 
| 138 |  | 
| 139 | 
             
                # Sum up by the rank
         | 
| 140 | 
             
                output = collections.defaultdict(float)
         | 
| 141 | 
             
                for i in torch.nonzero(probs > min_prob).squeeze():
         | 
| 142 | 
            -
                    output[" ".join(txt_names[i][: rank + 1])] += probs[i]
         | 
| 143 |  | 
| 144 | 
             
                topk_names = heapq.nlargest(k, output, key=output.get)
         | 
| 145 |  | 
|  | |
| 115 | 
             
                return {cls: prob for cls, prob in zip(classes, probs)}
         | 
| 116 |  | 
| 117 |  | 
| 118 | 
            +
            def format_name(taxon, common):
         | 
| 119 | 
            +
                taxon = " ".join(taxon)
         | 
| 120 | 
            +
                if not common:
         | 
| 121 | 
            +
                    return taxon
         | 
| 122 | 
            +
                return f"{taxon} ({common})"
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
             
            @torch.no_grad()
         | 
| 126 | 
             
            def open_domain_classification(img, rank: int) -> dict[str, float]:
         | 
| 127 | 
             
                """
         | 
|  | |
| 140 | 
             
                if rank + 1 == len(ranks):
         | 
| 141 | 
             
                    topk = probs.topk(k)
         | 
| 142 | 
             
                    return {
         | 
| 143 | 
            +
                        format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
         | 
| 144 | 
             
                    }
         | 
| 145 |  | 
| 146 | 
             
                # Sum up by the rank
         | 
| 147 | 
             
                output = collections.defaultdict(float)
         | 
| 148 | 
             
                for i in torch.nonzero(probs > min_prob).squeeze():
         | 
| 149 | 
            +
                    output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
         | 
| 150 |  | 
| 151 | 
             
                topk_names = heapq.nlargest(k, output, key=output.get)
         | 
| 152 |  | 
    	
        txt_emb_species.json
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256: | 
| 3 | 
            -
            size  | 
|  | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:844e6fabc06cac072214d566b78f40825b154efa9479eb11285030ca038b2ece
         | 
| 3 | 
            +
            size 65731052
         | 
 
			
